Package pyspark :: Module sql
[frames] | no frames]

Source Code for Module pyspark.sql

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  from pyspark.rdd import RDD 
 19   
 20  from py4j.protocol import Py4JError 
 21   
 22  __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] 
23 24 25 -class SQLContext:
26 """Main entry point for SparkSQL functionality. 27 28 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 29 tables, execute SQL over tables, cache tables, and read parquet files. 30 """ 31
32 - def __init__(self, sparkContext, sqlContext = None):
33 """Create a new SQLContext. 34 35 @param sparkContext: The SparkContext to wrap. 36 37 >>> srdd = sqlCtx.inferSchema(rdd) 38 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 39 Traceback (most recent call last): 40 ... 41 ValueError:... 42 43 >>> bad_rdd = sc.parallelize([1,2,3]) 44 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 45 Traceback (most recent call last): 46 ... 47 ValueError:... 48 49 >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, 50 ... "boolean" : True}]) 51 >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, 52 ... x.boolean)) 53 >>> srdd.collect()[0] 54 (1, u'string', 1.0, 1, True) 55 """ 56 self._sc = sparkContext 57 self._jsc = self._sc._jsc 58 self._jvm = self._sc._jvm 59 self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap 60 61 if sqlContext: 62 self._scala_SQLContext = sqlContext
63 64 @property
65 - def _ssql_ctx(self):
66 """Accessor for the JVM SparkSQL context. 67 68 Subclasses can override this property to provide their own 69 JVM Contexts. 70 """ 71 if not hasattr(self, '_scala_SQLContext'): 72 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 73 return self._scala_SQLContext
74
75 - def inferSchema(self, rdd):
76 """Infer and apply a schema to an RDD of L{dict}s. 77 78 We peek at the first row of the RDD to determine the fields names 79 and types, and then use that to extract all the dictionaries. 80 81 >>> srdd = sqlCtx.inferSchema(rdd) 82 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, 83 ... {"field1" : 3, "field2": "row3"}] 84 True 85 """ 86 if (rdd.__class__ is SchemaRDD): 87 raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) 88 elif not isinstance(rdd.first(), dict): 89 raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % 90 (SchemaRDD.__name__, rdd.first())) 91 92 jrdd = self._pythonToJavaMap(rdd._jrdd) 93 srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) 94 return SchemaRDD(srdd, self)
95
96 - def registerRDDAsTable(self, rdd, tableName):
97 """Registers the given RDD as a temporary table in the catalog. 98 99 Temporary tables exist only during the lifetime of this instance of 100 SQLContext. 101 102 >>> srdd = sqlCtx.inferSchema(rdd) 103 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 104 """ 105 if (rdd.__class__ is SchemaRDD): 106 jschema_rdd = rdd._jschema_rdd 107 self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) 108 else: 109 raise ValueError("Can only register SchemaRDD as table")
110
111 - def parquetFile(self, path):
112 """Loads a Parquet file, returning the result as a L{SchemaRDD}. 113 114 >>> import tempfile, shutil 115 >>> parquetFile = tempfile.mkdtemp() 116 >>> shutil.rmtree(parquetFile) 117 >>> srdd = sqlCtx.inferSchema(rdd) 118 >>> srdd.saveAsParquetFile(parquetFile) 119 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 120 >>> srdd.collect() == srdd2.collect() 121 True 122 """ 123 jschema_rdd = self._ssql_ctx.parquetFile(path) 124 return SchemaRDD(jschema_rdd, self)
125
126 - def sql(self, sqlQuery):
127 """Return a L{SchemaRDD} representing the result of the given query. 128 129 >>> srdd = sqlCtx.inferSchema(rdd) 130 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 131 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 132 >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, 133 ... {"f1" : 3, "f2": "row3"}] 134 True 135 """ 136 return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
137
138 - def table(self, tableName):
139 """Returns the specified table as a L{SchemaRDD}. 140 141 >>> srdd = sqlCtx.inferSchema(rdd) 142 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 143 >>> srdd2 = sqlCtx.table("table1") 144 >>> srdd.collect() == srdd2.collect() 145 True 146 """ 147 return SchemaRDD(self._ssql_ctx.table(tableName), self)
148
149 - def cacheTable(self, tableName):
150 """Caches the specified table in-memory.""" 151 self._ssql_ctx.cacheTable(tableName)
152
153 - def uncacheTable(self, tableName):
154 """Removes the specified table from the in-memory cache.""" 155 self._ssql_ctx.uncacheTable(tableName)
156
157 158 -class HiveContext(SQLContext):
159 """A variant of Spark SQL that integrates with data stored in Hive. 160 161 Configuration for Hive is read from hive-site.xml on the classpath. 162 It supports running both SQL and HiveQL commands. 163 """ 164 165 @property
166 - def _ssql_ctx(self):
167 try: 168 if not hasattr(self, '_scala_HiveContext'): 169 self._scala_HiveContext = self._get_hive_ctx() 170 return self._scala_HiveContext 171 except Py4JError as e: 172 raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ 173 "sbt/sbt assembly" , e)
174
175 - def _get_hive_ctx(self):
176 return self._jvm.HiveContext(self._jsc.sc())
177
178 - def hiveql(self, hqlQuery):
179 """ 180 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 181 """ 182 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
183
184 - def hql(self, hqlQuery):
185 """ 186 Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 187 """ 188 return self.hiveql(hqlQuery)
189
190 191 -class LocalHiveContext(HiveContext):
192 """Starts up an instance of hive where metadata is stored locally. 193 194 An in-process metadata data is created with data stored in ./metadata. 195 Warehouse data is stored in in ./warehouse. 196 197 >>> import os 198 >>> hiveCtx = LocalHiveContext(sc) 199 >>> try: 200 ... supress = hiveCtx.hql("DROP TABLE src") 201 ... except Exception: 202 ... pass 203 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') 204 >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 205 >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) 206 >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) 207 >>> num = results.count() 208 >>> reduce_sum = results.reduce(lambda x, y: x + y) 209 >>> num 210 500 211 >>> reduce_sum 212 130091 213 """ 214
215 - def _get_hive_ctx(self):
216 return self._jvm.LocalHiveContext(self._jsc.sc())
217
218 219 -class TestHiveContext(HiveContext):
220
221 - def _get_hive_ctx(self):
222 return self._jvm.TestHiveContext(self._jsc.sc())
223
224 225 # TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples 226 # are custom classes that must be generated per Schema. 227 -class Row(dict):
228 """A row in L{SchemaRDD}. 229 230 An extended L{dict} that takes a L{dict} in its constructor, and 231 exposes those items as fields. 232 233 >>> r = Row({"hello" : "world", "foo" : "bar"}) 234 >>> r.hello 235 'world' 236 >>> r.foo 237 'bar' 238 """ 239
240 - def __init__(self, d):
241 d.update(self.__dict__) 242 self.__dict__ = d 243 dict.__init__(self, d)
244
245 246 -class SchemaRDD(RDD):
247 """An RDD of L{Row} objects that has an associated schema. 248 249 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 250 utilize the relational query api exposed by SparkSQL. 251 252 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 253 L{SchemaRDD} is not operated on directly, as it's underlying 254 implementation is a RDD composed of Java objects. Instead it is 255 converted to a PythonRDD in the JVM, on which Python operations can 256 be done. 257 """ 258
259 - def __init__(self, jschema_rdd, sql_ctx):
260 self.sql_ctx = sql_ctx 261 self._sc = sql_ctx._sc 262 self._jschema_rdd = jschema_rdd 263 264 self.is_cached = False 265 self.is_checkpointed = False 266 self.ctx = self.sql_ctx._sc 267 self._jrdd_deserializer = self.ctx.serializer
268 269 @property
270 - def _jrdd(self):
271 """Lazy evaluation of PythonRDD object. 272 273 Only done when a user calls methods defined by the 274 L{pyspark.rdd.RDD} super class (map, filter, etc.). 275 """ 276 if not hasattr(self, '_lazy_jrdd'): 277 self._lazy_jrdd = self._toPython()._jrdd 278 return self._lazy_jrdd
279 280 @property
281 - def _id(self):
282 return self._jrdd.id()
283
284 - def saveAsParquetFile(self, path):
285 """Save the contents as a Parquet file, preserving the schema. 286 287 Files that are written out using this method can be read back in as 288 a SchemaRDD using the L{SQLContext.parquetFile} method. 289 290 >>> import tempfile, shutil 291 >>> parquetFile = tempfile.mkdtemp() 292 >>> shutil.rmtree(parquetFile) 293 >>> srdd = sqlCtx.inferSchema(rdd) 294 >>> srdd.saveAsParquetFile(parquetFile) 295 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 296 >>> srdd2.collect() == srdd.collect() 297 True 298 """ 299 self._jschema_rdd.saveAsParquetFile(path)
300
301 - def registerAsTable(self, name):
302 """Registers this RDD as a temporary table using the given name. 303 304 The lifetime of this temporary table is tied to the L{SQLContext} 305 that was used to create this SchemaRDD. 306 307 >>> srdd = sqlCtx.inferSchema(rdd) 308 >>> srdd.registerAsTable("test") 309 >>> srdd2 = sqlCtx.sql("select * from test") 310 >>> srdd.collect() == srdd2.collect() 311 True 312 """ 313 self._jschema_rdd.registerAsTable(name)
314
315 - def insertInto(self, tableName, overwrite = False):
316 """Inserts the contents of this SchemaRDD into the specified table. 317 318 Optionally overwriting any existing data. 319 """ 320 self._jschema_rdd.insertInto(tableName, overwrite)
321
322 - def saveAsTable(self, tableName):
323 """Creates a new table with the contents of this SchemaRDD.""" 324 self._jschema_rdd.saveAsTable(tableName)
325
326 - def count(self):
327 """Return the number of elements in this RDD. 328 329 Unlike the base RDD implementation of count, this implementation 330 leverages the query optimizer to compute the count on the SchemaRDD, 331 which supports features such as filter pushdown. 332 333 >>> srdd = sqlCtx.inferSchema(rdd) 334 >>> srdd.count() 335 3L 336 >>> srdd.count() == srdd.map(lambda x: x).count() 337 True 338 """ 339 return self._jschema_rdd.count()
340
341 - def _toPython(self):
342 # We have to import the Row class explicitly, so that the reference Pickler has is 343 # pyspark.sql.Row instead of __main__.Row 344 from pyspark.sql import Row 345 jrdd = self._jschema_rdd.javaToPython() 346 # TODO: This is inefficient, we should construct the Python Row object 347 # in Java land in the javaToPython function. May require a custom 348 # pickle serializer in Pyrolite 349 return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))
350 351 # We override the default cache/persist/checkpoint behavior as we want to cache the underlying 352 # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
353 - def cache(self):
354 self.is_cached = True 355 self._jschema_rdd.cache() 356 return self
357
358 - def persist(self, storageLevel):
359 self.is_cached = True 360 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 361 self._jschema_rdd.persist(javaStorageLevel) 362 return self
363
364 - def unpersist(self):
365 self.is_cached = False 366 self._jschema_rdd.unpersist() 367 return self
368
369 - def checkpoint(self):
370 self.is_checkpointed = True 371 self._jschema_rdd.checkpoint()
372
373 - def isCheckpointed(self):
374 return self._jschema_rdd.isCheckpointed()
375
376 - def getCheckpointFile(self):
377 checkpointFile = self._jschema_rdd.getCheckpointFile() 378 if checkpointFile.isDefined(): 379 return checkpointFile.get() 380 else: 381 return None
382
383 - def coalesce(self, numPartitions, shuffle=False):
384 rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) 385 return SchemaRDD(rdd, self.sql_ctx)
386
387 - def distinct(self):
388 rdd = self._jschema_rdd.distinct() 389 return SchemaRDD(rdd, self.sql_ctx)
390
391 - def intersection(self, other):
392 if (other.__class__ is SchemaRDD): 393 rdd = self._jschema_rdd.intersection(other._jschema_rdd) 394 return SchemaRDD(rdd, self.sql_ctx) 395 else: 396 raise ValueError("Can only intersect with another SchemaRDD")
397
398 - def repartition(self, numPartitions):
399 rdd = self._jschema_rdd.repartition(numPartitions) 400 return SchemaRDD(rdd, self.sql_ctx)
401
402 - def subtract(self, other, numPartitions=None):
403 if (other.__class__ is SchemaRDD): 404 if numPartitions is None: 405 rdd = self._jschema_rdd.subtract(other._jschema_rdd) 406 else: 407 rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) 408 return SchemaRDD(rdd, self.sql_ctx) 409 else: 410 raise ValueError("Can only subtract another SchemaRDD")
411
412 -def _test():
413 import doctest 414 from pyspark.context import SparkContext 415 globs = globals().copy() 416 # The small batch size here ensures that we see multiple batches, 417 # even in these small test examples: 418 sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 419 globs['sc'] = sc 420 globs['sqlCtx'] = SQLContext(sc) 421 globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, 422 {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) 423 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) 424 globs['sc'].stop() 425 if failure_count: 426 exit(-1)
427 428 429 if __name__ == "__main__": 430 _test() 431