Package pyspark :: Module sql
[frames] | no frames]

Source Code for Module pyspark.sql

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  from pyspark.rdd import RDD, PipelinedRDD 
 19  from pyspark.serializers import BatchedSerializer, PickleSerializer 
 20   
 21  from py4j.protocol import Py4JError 
 22   
 23  __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] 
24 25 26 -class SQLContext:
27 """Main entry point for SparkSQL functionality. 28 29 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 30 tables, execute SQL over tables, cache tables, and read parquet files. 31 """ 32
33 - def __init__(self, sparkContext, sqlContext = None):
34 """Create a new SQLContext. 35 36 @param sparkContext: The SparkContext to wrap. 37 38 >>> srdd = sqlCtx.inferSchema(rdd) 39 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 40 Traceback (most recent call last): 41 ... 42 ValueError:... 43 44 >>> bad_rdd = sc.parallelize([1,2,3]) 45 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 46 Traceback (most recent call last): 47 ... 48 ValueError:... 49 50 >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, 51 ... "boolean" : True}]) 52 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, 53 ... x.boolean)) 54 >>> srdd.collect()[0] 55 (1, u'string', 1.0, 1, True) 56 """ 57 self._sc = sparkContext 58 self._jsc = self._sc._jsc 59 self._jvm = self._sc._jvm 60 self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap 61 62 if sqlContext: 63 self._scala_SQLContext = sqlContext
64 65 @property
66 - def _ssql_ctx(self):
67 """Accessor for the JVM SparkSQL context. 68 69 Subclasses can override this property to provide their own 70 JVM Contexts. 71 """ 72 if not hasattr(self, '_scala_SQLContext'): 73 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 74 return self._scala_SQLContext
75
76 - def inferSchema(self, rdd):
77 """Infer and apply a schema to an RDD of L{dict}s. 78 79 We peek at the first row of the RDD to determine the fields names 80 and types, and then use that to extract all the dictionaries. Nested 81 collections are supported, which include array, dict, list, set, and 82 tuple. 83 84 >>> srdd = sqlCtx.inferSchema(rdd) 85 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, 86 ... {"field1" : 3, "field2": "row3"}] 87 True 88 89 >>> from array import array 90 >>> srdd = sqlCtx.inferSchema(nestedRdd1) 91 >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, 92 ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] 93 True 94 95 >>> srdd = sqlCtx.inferSchema(nestedRdd2) 96 >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, 97 ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] 98 True 99 """ 100 if (rdd.__class__ is SchemaRDD): 101 raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) 102 elif not isinstance(rdd.first(), dict): 103 raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % 104 (SchemaRDD.__name__, rdd.first())) 105 106 jrdd = self._pythonToJavaMap(rdd._jrdd) 107 srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) 108 return SchemaRDD(srdd, self)
109
110 - def registerRDDAsTable(self, rdd, tableName):
111 """Registers the given RDD as a temporary table in the catalog. 112 113 Temporary tables exist only during the lifetime of this instance of 114 SQLContext. 115 116 >>> srdd = sqlCtx.inferSchema(rdd) 117 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 118 """ 119 if (rdd.__class__ is SchemaRDD): 120 jschema_rdd = rdd._jschema_rdd 121 self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) 122 else: 123 raise ValueError("Can only register SchemaRDD as table")
124
125 - def parquetFile(self, path):
126 """Loads a Parquet file, returning the result as a L{SchemaRDD}. 127 128 >>> import tempfile, shutil 129 >>> parquetFile = tempfile.mkdtemp() 130 >>> shutil.rmtree(parquetFile) 131 >>> srdd = sqlCtx.inferSchema(rdd) 132 >>> srdd.saveAsParquetFile(parquetFile) 133 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 134 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 135 True 136 """ 137 jschema_rdd = self._ssql_ctx.parquetFile(path) 138 return SchemaRDD(jschema_rdd, self)
139 140
141 - def jsonFile(self, path):
142 """Loads a text file storing one JSON object per line, 143 returning the result as a L{SchemaRDD}. 144 It goes through the entire dataset once to determine the schema. 145 146 >>> import tempfile, shutil 147 >>> jsonFile = tempfile.mkdtemp() 148 >>> shutil.rmtree(jsonFile) 149 >>> ofn = open(jsonFile, 'w') 150 >>> for json in jsonStrings: 151 ... print>>ofn, json 152 >>> ofn.close() 153 >>> srdd = sqlCtx.jsonFile(jsonFile) 154 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 155 >>> srdd2 = sqlCtx.sql( 156 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") 157 >>> srdd2.collect() == [ 158 ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, 159 ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, 160 ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] 161 True 162 """ 163 jschema_rdd = self._ssql_ctx.jsonFile(path) 164 return SchemaRDD(jschema_rdd, self)
165
166 - def jsonRDD(self, rdd):
167 """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. 168 It goes through the entire dataset once to determine the schema. 169 170 >>> srdd = sqlCtx.jsonRDD(json) 171 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 172 >>> srdd2 = sqlCtx.sql( 173 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") 174 >>> srdd2.collect() == [ 175 ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, 176 ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, 177 ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] 178 True 179 """ 180 def func(split, iterator): 181 for x in iterator: 182 if not isinstance(x, basestring): 183 x = unicode(x) 184 yield x.encode("utf-8")
185 keyed = PipelinedRDD(rdd, func) 186 keyed._bypass_serializer = True 187 jrdd = keyed._jrdd.map(self._jvm.BytesToString()) 188 jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) 189 return SchemaRDD(jschema_rdd, self)
190
191 - def sql(self, sqlQuery):
192 """Return a L{SchemaRDD} representing the result of the given query. 193 194 >>> srdd = sqlCtx.inferSchema(rdd) 195 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 196 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 197 >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, 198 ... {"f1" : 3, "f2": "row3"}] 199 True 200 """ 201 return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
202
203 - def table(self, tableName):
204 """Returns the specified table as a L{SchemaRDD}. 205 206 >>> srdd = sqlCtx.inferSchema(rdd) 207 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 208 >>> srdd2 = sqlCtx.table("table1") 209 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 210 True 211 """ 212 return SchemaRDD(self._ssql_ctx.table(tableName), self)
213
214 - def cacheTable(self, tableName):
215 """Caches the specified table in-memory.""" 216 self._ssql_ctx.cacheTable(tableName)
217
218 - def uncacheTable(self, tableName):
219 """Removes the specified table from the in-memory cache.""" 220 self._ssql_ctx.uncacheTable(tableName)
221
222 223 -class HiveContext(SQLContext):
224 """A variant of Spark SQL that integrates with data stored in Hive. 225 226 Configuration for Hive is read from hive-site.xml on the classpath. 227 It supports running both SQL and HiveQL commands. 228 """ 229 230 @property
231 - def _ssql_ctx(self):
232 try: 233 if not hasattr(self, '_scala_HiveContext'): 234 self._scala_HiveContext = self._get_hive_ctx() 235 return self._scala_HiveContext 236 except Py4JError as e: 237 raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ 238 "sbt/sbt assembly" , e)
239
240 - def _get_hive_ctx(self):
241 return self._jvm.HiveContext(self._jsc.sc())
242
243 - def hiveql(self, hqlQuery):
244 """ 245 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 246 """ 247 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
248
249 - def hql(self, hqlQuery):
250 """ 251 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 252 """ 253 return self.hiveql(hqlQuery)
254
255 256 -class LocalHiveContext(HiveContext):
257 """Starts up an instance of hive where metadata is stored locally. 258 259 An in-process metadata data is created with data stored in ./metadata. 260 Warehouse data is stored in in ./warehouse. 261 262 >>> import os 263 >>> hiveCtx = LocalHiveContext(sc) 264 >>> try: 265 ... supress = hiveCtx.hql("DROP TABLE src") 266 ... except Exception: 267 ... pass 268 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') 269 >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 270 >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) 271 >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) 272 >>> num = results.count() 273 >>> reduce_sum = results.reduce(lambda x, y: x + y) 274 >>> num 275 500 276 >>> reduce_sum 277 130091 278 """ 279
280 - def _get_hive_ctx(self):
281 return self._jvm.LocalHiveContext(self._jsc.sc())
282
283 284 -class TestHiveContext(HiveContext):
285
286 - def _get_hive_ctx(self):
287 return self._jvm.TestHiveContext(self._jsc.sc())
288
289 290 # TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples 291 # are custom classes that must be generated per Schema. 292 -class Row(dict):
293 """A row in L{SchemaRDD}. 294 295 An extended L{dict} that takes a L{dict} in its constructor, and 296 exposes those items as fields. 297 298 >>> r = Row({"hello" : "world", "foo" : "bar"}) 299 >>> r.hello 300 'world' 301 >>> r.foo 302 'bar' 303 """ 304
305 - def __init__(self, d):
306 d.update(self.__dict__) 307 self.__dict__ = d 308 dict.__init__(self, d)
309
310 311 -class SchemaRDD(RDD):
312 """An RDD of L{Row} objects that has an associated schema. 313 314 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 315 utilize the relational query api exposed by SparkSQL. 316 317 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 318 L{SchemaRDD} is not operated on directly, as it's underlying 319 implementation is an RDD composed of Java objects. Instead it is 320 converted to a PythonRDD in the JVM, on which Python operations can 321 be done. 322 """ 323
324 - def __init__(self, jschema_rdd, sql_ctx):
325 self.sql_ctx = sql_ctx 326 self._sc = sql_ctx._sc 327 self._jschema_rdd = jschema_rdd 328 329 self.is_cached = False 330 self.is_checkpointed = False 331 self.ctx = self.sql_ctx._sc 332 self._jrdd_deserializer = self.ctx.serializer
333 334 @property
335 - def _jrdd(self):
336 """Lazy evaluation of PythonRDD object. 337 338 Only done when a user calls methods defined by the 339 L{pyspark.rdd.RDD} super class (map, filter, etc.). 340 """ 341 if not hasattr(self, '_lazy_jrdd'): 342 self._lazy_jrdd = self._toPython()._jrdd 343 return self._lazy_jrdd
344 345 @property
346 - def _id(self):
347 return self._jrdd.id()
348
349 - def saveAsParquetFile(self, path):
350 """Save the contents as a Parquet file, preserving the schema. 351 352 Files that are written out using this method can be read back in as 353 a SchemaRDD using the L{SQLContext.parquetFile} method. 354 355 >>> import tempfile, shutil 356 >>> parquetFile = tempfile.mkdtemp() 357 >>> shutil.rmtree(parquetFile) 358 >>> srdd = sqlCtx.inferSchema(rdd) 359 >>> srdd.saveAsParquetFile(parquetFile) 360 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 361 >>> sorted(srdd2.collect()) == sorted(srdd.collect()) 362 True 363 """ 364 self._jschema_rdd.saveAsParquetFile(path)
365
366 - def registerAsTable(self, name):
367 """Registers this RDD as a temporary table using the given name. 368 369 The lifetime of this temporary table is tied to the L{SQLContext} 370 that was used to create this SchemaRDD. 371 372 >>> srdd = sqlCtx.inferSchema(rdd) 373 >>> srdd.registerAsTable("test") 374 >>> srdd2 = sqlCtx.sql("select * from test") 375 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 376 True 377 """ 378 self._jschema_rdd.registerAsTable(name)
379
380 - def insertInto(self, tableName, overwrite = False):
381 """Inserts the contents of this SchemaRDD into the specified table. 382 383 Optionally overwriting any existing data. 384 """ 385 self._jschema_rdd.insertInto(tableName, overwrite)
386
387 - def saveAsTable(self, tableName):
388 """Creates a new table with the contents of this SchemaRDD.""" 389 self._jschema_rdd.saveAsTable(tableName)
390
391 - def schemaString(self):
392 """Returns the output schema in the tree format.""" 393 return self._jschema_rdd.schemaString()
394
395 - def printSchema(self):
396 """Prints out the schema in the tree format.""" 397 print self.schemaString()
398
399 - def count(self):
400 """Return the number of elements in this RDD. 401 402 Unlike the base RDD implementation of count, this implementation 403 leverages the query optimizer to compute the count on the SchemaRDD, 404 which supports features such as filter pushdown. 405 406 >>> srdd = sqlCtx.inferSchema(rdd) 407 >>> srdd.count() 408 3L 409 >>> srdd.count() == srdd.map(lambda x: x).count() 410 True 411 """ 412 return self._jschema_rdd.count()
413
414 - def _toPython(self):
415 # We have to import the Row class explicitly, so that the reference Pickler has is 416 # pyspark.sql.Row instead of __main__.Row 417 from pyspark.sql import Row 418 jrdd = self._jschema_rdd.javaToPython() 419 # TODO: This is inefficient, we should construct the Python Row object 420 # in Java land in the javaToPython function. May require a custom 421 # pickle serializer in Pyrolite 422 return RDD(jrdd, self._sc, BatchedSerializer( 423 PickleSerializer())).map(lambda d: Row(d))
424 425 # We override the default cache/persist/checkpoint behavior as we want to cache the underlying 426 # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
427 - def cache(self):
428 self.is_cached = True 429 self._jschema_rdd.cache() 430 return self
431
432 - def persist(self, storageLevel):
433 self.is_cached = True 434 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 435 self._jschema_rdd.persist(javaStorageLevel) 436 return self
437
438 - def unpersist(self):
439 self.is_cached = False 440 self._jschema_rdd.unpersist() 441 return self
442
443 - def checkpoint(self):
444 self.is_checkpointed = True 445 self._jschema_rdd.checkpoint()
446
447 - def isCheckpointed(self):
448 return self._jschema_rdd.isCheckpointed()
449
450 - def getCheckpointFile(self):
451 checkpointFile = self._jschema_rdd.getCheckpointFile() 452 if checkpointFile.isDefined(): 453 return checkpointFile.get() 454 else: 455 return None
456
457 - def coalesce(self, numPartitions, shuffle=False):
458 rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) 459 return SchemaRDD(rdd, self.sql_ctx)
460
461 - def distinct(self):
462 rdd = self._jschema_rdd.distinct() 463 return SchemaRDD(rdd, self.sql_ctx)
464
465 - def intersection(self, other):
466 if (other.__class__ is SchemaRDD): 467 rdd = self._jschema_rdd.intersection(other._jschema_rdd) 468 return SchemaRDD(rdd, self.sql_ctx) 469 else: 470 raise ValueError("Can only intersect with another SchemaRDD")
471
472 - def repartition(self, numPartitions):
473 rdd = self._jschema_rdd.repartition(numPartitions) 474 return SchemaRDD(rdd, self.sql_ctx)
475
476 - def subtract(self, other, numPartitions=None):
477 if (other.__class__ is SchemaRDD): 478 if numPartitions is None: 479 rdd = self._jschema_rdd.subtract(other._jschema_rdd) 480 else: 481 rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) 482 return SchemaRDD(rdd, self.sql_ctx) 483 else: 484 raise ValueError("Can only subtract another SchemaRDD")
485
486 -def _test():
487 import doctest 488 from array import array 489 from pyspark.context import SparkContext 490 globs = globals().copy() 491 # The small batch size here ensures that we see multiple batches, 492 # even in these small test examples: 493 sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 494 globs['sc'] = sc 495 globs['sqlCtx'] = SQLContext(sc) 496 globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, 497 {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) 498 jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', 499 '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', 500 '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'] 501 globs['jsonStrings'] = jsonStrings 502 globs['json'] = sc.parallelize(jsonStrings) 503 globs['nestedRdd1'] = sc.parallelize([ 504 {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, 505 {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) 506 globs['nestedRdd2'] = sc.parallelize([ 507 {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, 508 {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) 509 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) 510 globs['sc'].stop() 511 if failure_count: 512 exit(-1)
513 514 515 if __name__ == "__main__": 516 _test() 517