Package pyspark :: Module context
[frames] | no frames]

Source Code for Module pyspark.context

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  import os 
 19  import shutil 
 20  import sys 
 21  from threading import Lock 
 22  from tempfile import NamedTemporaryFile 
 23  from collections import namedtuple 
 24   
 25  from pyspark import accumulators 
 26  from pyspark.accumulators import Accumulator 
 27  from pyspark.broadcast import Broadcast 
 28  from pyspark.conf import SparkConf 
 29  from pyspark.files import SparkFiles 
 30  from pyspark.java_gateway import launch_gateway 
 31  from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ 
 32          PairDeserializer 
 33  from pyspark.storagelevel import StorageLevel 
 34  from pyspark import rdd 
 35  from pyspark.rdd import RDD 
 36   
 37  from py4j.java_collections import ListConverter 
38 39 40 -class SparkContext(object):
41 """ 42 Main entry point for Spark functionality. A SparkContext represents the 43 connection to a Spark cluster, and can be used to create L{RDD}s and 44 broadcast variables on that cluster. 45 """ 46 47 _gateway = None 48 _jvm = None 49 _writeToFile = None 50 _next_accum_id = 0 51 _active_spark_context = None 52 _lock = Lock() 53 _python_includes = None # zip and egg files that need to be added to PYTHONPATH 54 55
56 - def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, 57 environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, 58 gateway=None):
59 """ 60 Create a new SparkContext. At least the master and app name should be set, 61 either through the named parameters here or through C{conf}. 62 63 @param master: Cluster URL to connect to 64 (e.g. mesos://host:port, spark://host:port, local[4]). 65 @param appName: A name for your job, to display on the cluster web UI. 66 @param sparkHome: Location where Spark is installed on cluster nodes. 67 @param pyFiles: Collection of .zip or .py files to send to the cluster 68 and add to PYTHONPATH. These can be paths on the local file 69 system or HDFS, HTTP, HTTPS, or FTP URLs. 70 @param environment: A dictionary of environment variables to set on 71 worker nodes. 72 @param batchSize: The number of Python objects represented as a single 73 Java object. Set 1 to disable batching or -1 to use an 74 unlimited batch size. 75 @param serializer: The serializer for RDDs. 76 @param conf: A L{SparkConf} object setting Spark properties. 77 @param gateway: Use an existing gateway and JVM, otherwise a new JVM 78 will be instatiated. 79 80 81 >>> from pyspark.context import SparkContext 82 >>> sc = SparkContext('local', 'test') 83 84 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL 85 Traceback (most recent call last): 86 ... 87 ValueError:... 88 """ 89 if rdd._extract_concise_traceback() is not None: 90 self._callsite = rdd._extract_concise_traceback() 91 else: 92 tempNamedTuple = namedtuple("Callsite", "function file linenum") 93 self._callsite = tempNamedTuple(function=None, file=None, linenum=None) 94 SparkContext._ensure_initialized(self, gateway=gateway) 95 96 self.environment = environment or {} 97 self._conf = conf or SparkConf(_jvm=self._jvm) 98 self._batchSize = batchSize # -1 represents an unlimited batch size 99 self._unbatched_serializer = serializer 100 if batchSize == 1: 101 self.serializer = self._unbatched_serializer 102 else: 103 self.serializer = BatchedSerializer(self._unbatched_serializer, 104 batchSize) 105 106 # Set any parameters passed directly to us on the conf 107 if master: 108 self._conf.setMaster(master) 109 if appName: 110 self._conf.setAppName(appName) 111 if sparkHome: 112 self._conf.setSparkHome(sparkHome) 113 if environment: 114 for key, value in environment.iteritems(): 115 self._conf.setExecutorEnv(key, value) 116 117 # Check that we have at least the required parameters 118 if not self._conf.contains("spark.master"): 119 raise Exception("A master URL must be set in your configuration") 120 if not self._conf.contains("spark.app.name"): 121 raise Exception("An application name must be set in your configuration") 122 123 # Read back our properties from the conf in case we loaded some of them from 124 # the classpath or an external config file 125 self.master = self._conf.get("spark.master") 126 self.appName = self._conf.get("spark.app.name") 127 self.sparkHome = self._conf.get("spark.home", None) 128 for (k, v) in self._conf.getAll(): 129 if k.startswith("spark.executorEnv."): 130 varName = k[len("spark.executorEnv."):] 131 self.environment[varName] = v 132 133 # Create the Java SparkContext through Py4J 134 self._jsc = self._initialize_context(self._conf._jconf) 135 136 # Create a single Accumulator in Java that we'll send all our updates through; 137 # they will be passed back to us through a TCP server 138 self._accumulatorServer = accumulators._start_update_server() 139 (host, port) = self._accumulatorServer.server_address 140 self._javaAccumulator = self._jsc.accumulator( 141 self._jvm.java.util.ArrayList(), 142 self._jvm.PythonAccumulatorParam(host, port)) 143 144 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 145 146 # Broadcast's __reduce__ method stores Broadcast instances here. 147 # This allows other code to determine which Broadcast instances have 148 # been pickled, so it can determine which Java broadcast objects to 149 # send. 150 self._pickled_broadcast_vars = set() 151 152 SparkFiles._sc = self 153 root_dir = SparkFiles.getRootDirectory() 154 sys.path.append(root_dir) 155 156 # Deploy any code dependencies specified in the constructor 157 self._python_includes = list() 158 for path in (pyFiles or []): 159 self.addPyFile(path) 160 161 # Deploy code dependencies set by spark-submit; these will already have been added 162 # with SparkContext.addFile, so we just need to add them to the PYTHONPATH 163 for path in self._conf.get("spark.submit.pyFiles", "").split(","): 164 if path != "": 165 (dirname, filename) = os.path.split(path) 166 self._python_includes.append(filename) 167 sys.path.append(path) 168 if not dirname in sys.path: 169 sys.path.append(dirname) 170 171 # Create a temporary directory inside spark.local.dir: 172 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) 173 self._temp_dir = \ 174 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
175
176 - def _initialize_context(self, jconf):
177 """ 178 Initialize SparkContext in function to allow subclass specific initialization 179 """ 180 return self._jvm.JavaSparkContext(jconf)
181 182 @classmethod
183 - def _ensure_initialized(cls, instance=None, gateway=None):
184 """ 185 Checks whether a SparkContext is initialized or not. 186 Throws error if a SparkContext is already running. 187 """ 188 with SparkContext._lock: 189 if not SparkContext._gateway: 190 SparkContext._gateway = gateway or launch_gateway() 191 SparkContext._jvm = SparkContext._gateway.jvm 192 SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile 193 194 if instance: 195 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: 196 currentMaster = SparkContext._active_spark_context.master 197 currentAppName = SparkContext._active_spark_context.appName 198 callsite = SparkContext._active_spark_context._callsite 199 200 # Raise error if there is already a running Spark context 201 raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \ 202 " created by %s at %s:%s " \ 203 % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) 204 else: 205 SparkContext._active_spark_context = instance
206 207 @classmethod
208 - def setSystemProperty(cls, key, value):
209 """ 210 Set a Java system property, such as spark.executor.memory. This must 211 must be invoked before instantiating SparkContext. 212 """ 213 SparkContext._ensure_initialized() 214 SparkContext._jvm.java.lang.System.setProperty(key, value)
215 216 @property
217 - def defaultParallelism(self):
218 """ 219 Default level of parallelism to use when not given by user (e.g. for 220 reduce tasks) 221 """ 222 return self._jsc.sc().defaultParallelism()
223 224 @property
225 - def defaultMinPartitions(self):
226 """ 227 Default min number of partitions for Hadoop RDDs when not given by user 228 """ 229 return self._jsc.sc().defaultMinPartitions()
230
231 - def __del__(self):
232 self.stop()
233
234 - def stop(self):
235 """ 236 Shut down the SparkContext. 237 """ 238 if self._jsc: 239 self._jsc.stop() 240 self._jsc = None 241 if self._accumulatorServer: 242 self._accumulatorServer.shutdown() 243 self._accumulatorServer = None 244 with SparkContext._lock: 245 SparkContext._active_spark_context = None
246
247 - def parallelize(self, c, numSlices=None):
248 """ 249 Distribute a local Python collection to form an RDD. 250 251 >>> sc.parallelize(range(5), 5).glom().collect() 252 [[0], [1], [2], [3], [4]] 253 """ 254 numSlices = numSlices or self.defaultParallelism 255 # Calling the Java parallelize() method with an ArrayList is too slow, 256 # because it sends O(n) Py4J commands. As an alternative, serialized 257 # objects are written to a file and loaded through textFile(). 258 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 259 # Make sure we distribute data evenly if it's smaller than self.batchSize 260 if "__len__" not in dir(c): 261 c = list(c) # Make it a list so we can compute its length 262 batchSize = min(len(c) // numSlices, self._batchSize) 263 if batchSize > 1: 264 serializer = BatchedSerializer(self._unbatched_serializer, 265 batchSize) 266 else: 267 serializer = self._unbatched_serializer 268 serializer.dump_stream(c, tempFile) 269 tempFile.close() 270 readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile 271 jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) 272 return RDD(jrdd, self, serializer)
273
274 - def textFile(self, name, minPartitions=None):
275 """ 276 Read a text file from HDFS, a local file system (available on all 277 nodes), or any Hadoop-supported file system URI, and return it as an 278 RDD of Strings. 279 280 >>> path = os.path.join(tempdir, "sample-text.txt") 281 >>> with open(path, "w") as testFile: 282 ... testFile.write("Hello world!") 283 >>> textFile = sc.textFile(path) 284 >>> textFile.collect() 285 [u'Hello world!'] 286 """ 287 minPartitions = minPartitions or min(self.defaultParallelism, 2) 288 return RDD(self._jsc.textFile(name, minPartitions), self, 289 UTF8Deserializer())
290
291 - def wholeTextFiles(self, path, minPartitions=None):
292 """ 293 Read a directory of text files from HDFS, a local file system 294 (available on all nodes), or any Hadoop-supported file system 295 URI. Each file is read as a single record and returned in a 296 key-value pair, where the key is the path of each file, the 297 value is the content of each file. 298 299 For example, if you have the following files:: 300 301 hdfs://a-hdfs-path/part-00000 302 hdfs://a-hdfs-path/part-00001 303 ... 304 hdfs://a-hdfs-path/part-nnnnn 305 306 Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")}, 307 then C{rdd} contains:: 308 309 (a-hdfs-path/part-00000, its content) 310 (a-hdfs-path/part-00001, its content) 311 ... 312 (a-hdfs-path/part-nnnnn, its content) 313 314 NOTE: Small files are preferred, as each file will be loaded 315 fully in memory. 316 317 >>> dirPath = os.path.join(tempdir, "files") 318 >>> os.mkdir(dirPath) 319 >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: 320 ... file1.write("1") 321 >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: 322 ... file2.write("2") 323 >>> textFiles = sc.wholeTextFiles(dirPath) 324 >>> sorted(textFiles.collect()) 325 [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] 326 """ 327 minPartitions = minPartitions or self.defaultMinPartitions 328 return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, 329 PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
330
331 - def _checkpointFile(self, name, input_deserializer):
332 jrdd = self._jsc.checkpointFile(name) 333 return RDD(jrdd, self, input_deserializer)
334
335 - def union(self, rdds):
336 """ 337 Build the union of a list of RDDs. 338 339 This supports unions() of RDDs with different serialized formats, 340 although this forces them to be reserialized using the default 341 serializer: 342 343 >>> path = os.path.join(tempdir, "union-text.txt") 344 >>> with open(path, "w") as testFile: 345 ... testFile.write("Hello") 346 >>> textFile = sc.textFile(path) 347 >>> textFile.collect() 348 [u'Hello'] 349 >>> parallelized = sc.parallelize(["World!"]) 350 >>> sorted(sc.union([textFile, parallelized]).collect()) 351 [u'Hello', 'World!'] 352 """ 353 first_jrdd_deserializer = rdds[0]._jrdd_deserializer 354 if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): 355 rdds = [x._reserialize() for x in rdds] 356 first = rdds[0]._jrdd 357 rest = [x._jrdd for x in rdds[1:]] 358 rest = ListConverter().convert(rest, self._gateway._gateway_client) 359 return RDD(self._jsc.union(first, rest), self, 360 rdds[0]._jrdd_deserializer)
361
362 - def broadcast(self, value):
363 """ 364 Broadcast a read-only variable to the cluster, returning a 365 L{Broadcast<pyspark.broadcast.Broadcast>} 366 object for reading it in distributed functions. The variable will be 367 sent to each cluster only once. 368 """ 369 pickleSer = PickleSerializer() 370 pickled = pickleSer.dumps(value) 371 jbroadcast = self._jsc.broadcast(bytearray(pickled)) 372 return Broadcast(jbroadcast.id(), value, jbroadcast, 373 self._pickled_broadcast_vars)
374
375 - def accumulator(self, value, accum_param=None):
376 """ 377 Create an L{Accumulator} with the given initial value, using a given 378 L{AccumulatorParam} helper object to define how to add values of the 379 data type if provided. Default AccumulatorParams are used for integers 380 and floating-point numbers if you do not provide one. For other types, 381 a custom AccumulatorParam can be used. 382 """ 383 if accum_param is None: 384 if isinstance(value, int): 385 accum_param = accumulators.INT_ACCUMULATOR_PARAM 386 elif isinstance(value, float): 387 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 388 elif isinstance(value, complex): 389 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 390 else: 391 raise Exception("No default accumulator param for type %s" % type(value)) 392 SparkContext._next_accum_id += 1 393 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
394
395 - def addFile(self, path):
396 """ 397 Add a file to be downloaded with this Spark job on every node. 398 The C{path} passed can be either a local file, a file in HDFS 399 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 400 FTP URI. 401 402 To access the file in Spark jobs, use 403 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 404 download location. 405 406 >>> from pyspark import SparkFiles 407 >>> path = os.path.join(tempdir, "test.txt") 408 >>> with open(path, "w") as testFile: 409 ... testFile.write("100") 410 >>> sc.addFile(path) 411 >>> def func(iterator): 412 ... with open(SparkFiles.get("test.txt")) as testFile: 413 ... fileVal = int(testFile.readline()) 414 ... return [x * 100 for x in iterator] 415 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 416 [100, 200, 300, 400] 417 """ 418 self._jsc.sc().addFile(path)
419
420 - def clearFiles(self):
421 """ 422 Clear the job's list of files added by L{addFile} or L{addPyFile} so 423 that they do not get downloaded to any new nodes. 424 """ 425 # TODO: remove added .py or .zip files from the PYTHONPATH? 426 self._jsc.sc().clearFiles()
427
428 - def addPyFile(self, path):
429 """ 430 Add a .py or .zip dependency for all tasks to be executed on this 431 SparkContext in the future. The C{path} passed can be either a local 432 file, a file in HDFS (or other Hadoop-supported filesystems), or an 433 HTTP, HTTPS or FTP URI. 434 """ 435 self.addFile(path) 436 (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix 437 438 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): 439 self._python_includes.append(filename) 440 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
441
442 - def setCheckpointDir(self, dirName):
443 """ 444 Set the directory under which RDDs are going to be checkpointed. The 445 directory must be a HDFS path if running on a cluster. 446 """ 447 self._jsc.sc().setCheckpointDir(dirName)
448
449 - def _getJavaStorageLevel(self, storageLevel):
450 """ 451 Returns a Java StorageLevel based on a pyspark.StorageLevel. 452 """ 453 if not isinstance(storageLevel, StorageLevel): 454 raise Exception("storageLevel must be of type pyspark.StorageLevel") 455 456 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel 457 return newStorageLevel(storageLevel.useDisk, 458 storageLevel.useMemory, 459 storageLevel.useOffHeap, 460 storageLevel.deserialized, 461 storageLevel.replication)
462
463 - def setJobGroup(self, groupId, description, interruptOnCancel=False):
464 """ 465 Assigns a group ID to all the jobs started by this thread until the group ID is set to a 466 different value or cleared. 467 468 Often, a unit of execution in an application consists of multiple Spark actions or jobs. 469 Application programmers can use this method to group all those jobs together and give a 470 group description. Once set, the Spark web UI will associate such jobs with this group. 471 472 The application can use L{SparkContext.cancelJobGroup} to cancel all 473 running jobs in this group. 474 475 >>> import thread, threading 476 >>> from time import sleep 477 >>> result = "Not Set" 478 >>> lock = threading.Lock() 479 >>> def map_func(x): 480 ... sleep(100) 481 ... raise Exception("Task should have been cancelled") 482 >>> def start_job(x): 483 ... global result 484 ... try: 485 ... sc.setJobGroup("job_to_cancel", "some description") 486 ... result = sc.parallelize(range(x)).map(map_func).collect() 487 ... except Exception as e: 488 ... result = "Cancelled" 489 ... lock.release() 490 >>> def stop_job(): 491 ... sleep(5) 492 ... sc.cancelJobGroup("job_to_cancel") 493 >>> supress = lock.acquire() 494 >>> supress = thread.start_new_thread(start_job, (10,)) 495 >>> supress = thread.start_new_thread(stop_job, tuple()) 496 >>> supress = lock.acquire() 497 >>> print result 498 Cancelled 499 500 If interruptOnCancel is set to true for the job group, then job cancellation will result 501 in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure 502 that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, 503 where HDFS may respond to Thread.interrupt() by marking nodes as dead. 504 """ 505 self._jsc.setJobGroup(groupId, description, interruptOnCancel)
506
507 - def setLocalProperty(self, key, value):
508 """ 509 Set a local property that affects jobs submitted from this thread, such as the 510 Spark fair scheduler pool. 511 """ 512 self._jsc.setLocalProperty(key, value)
513
514 - def getLocalProperty(self, key):
515 """ 516 Get a local property set in this thread, or null if it is missing. See 517 L{setLocalProperty} 518 """ 519 return self._jsc.getLocalProperty(key)
520
521 - def sparkUser(self):
522 """ 523 Get SPARK_USER for user who is running SparkContext. 524 """ 525 return self._jsc.sc().sparkUser()
526
527 - def cancelJobGroup(self, groupId):
528 """ 529 Cancel active jobs for the specified group. See L{SparkContext.setJobGroup} 530 for more information. 531 """ 532 self._jsc.sc().cancelJobGroup(groupId)
533
534 - def cancelAllJobs(self):
535 """ 536 Cancel all jobs that have been scheduled or are running. 537 """ 538 self._jsc.sc().cancelAllJobs()
539
540 -def _test():
541 import atexit 542 import doctest 543 import tempfile 544 globs = globals().copy() 545 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 546 globs['tempdir'] = tempfile.mkdtemp() 547 atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 548 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) 549 globs['sc'].stop() 550 if failure_count: 551 exit(-1)
552 553 554 if __name__ == "__main__": 555 _test() 556