Package pyspark :: Module context
[frames] | no frames]

Source Code for Module pyspark.context

  1  import os 
  2  import shutil 
  3  import sys 
  4  from threading import Lock 
  5  from tempfile import NamedTemporaryFile 
  6   
  7  from pyspark import accumulators 
  8  from pyspark.accumulators import Accumulator 
  9  from pyspark.broadcast import Broadcast 
 10  from pyspark.files import SparkFiles 
 11  from pyspark.java_gateway import launch_gateway 
 12  from pyspark.serializers import dump_pickle, write_with_length, batched 
 13  from pyspark.rdd import RDD 
 14   
 15  from py4j.java_collections import ListConverter 
16 17 18 -class SparkContext(object):
19 """ 20 Main entry point for Spark functionality. A SparkContext represents the 21 connection to a Spark cluster, and can be used to create L{RDD}s and 22 broadcast variables on that cluster. 23 """ 24 25 _gateway = None 26 _jvm = None 27 _writeIteratorToPickleFile = None 28 _takePartition = None 29 _next_accum_id = 0 30 _active_spark_context = None 31 _lock = Lock() 32
33 - def __init__(self, master, jobName, sparkHome=None, pyFiles=None, 34 environment=None, batchSize=1024):
35 """ 36 Create a new SparkContext. 37 38 @param master: Cluster URL to connect to 39 (e.g. mesos://host:port, spark://host:port, local[4]). 40 @param jobName: A name for your job, to display on the cluster web UI 41 @param sparkHome: Location where Spark is installed on cluster nodes. 42 @param pyFiles: Collection of .zip or .py files to send to the cluster 43 and add to PYTHONPATH. These can be paths on the local file 44 system or HDFS, HTTP, HTTPS, or FTP URLs. 45 @param environment: A dictionary of environment variables to set on 46 worker nodes. 47 @param batchSize: The number of Python objects represented as a single 48 Java object. Set 1 to disable batching or -1 to use an 49 unlimited batch size. 50 """ 51 with SparkContext._lock: 52 if SparkContext._active_spark_context: 53 raise ValueError("Cannot run multiple SparkContexts at once") 54 else: 55 SparkContext._active_spark_context = self 56 if not SparkContext._gateway: 57 SparkContext._gateway = launch_gateway() 58 SparkContext._jvm = SparkContext._gateway.jvm 59 SparkContext._writeIteratorToPickleFile = \ 60 SparkContext._jvm.PythonRDD.writeIteratorToPickleFile 61 SparkContext._takePartition = \ 62 SparkContext._jvm.PythonRDD.takePartition 63 self.master = master 64 self.jobName = jobName 65 self.sparkHome = sparkHome or None # None becomes null in Py4J 66 self.environment = environment or {} 67 self.batchSize = batchSize # -1 represents a unlimited batch size 68 69 # Create the Java SparkContext through Py4J 70 empty_string_array = self._gateway.new_array(self._jvm.String, 0) 71 self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, 72 empty_string_array) 73 74 # Create a single Accumulator in Java that we'll send all our updates through; 75 # they will be passed back to us through a TCP server 76 self._accumulatorServer = accumulators._start_update_server() 77 (host, port) = self._accumulatorServer.server_address 78 self._javaAccumulator = self._jsc.accumulator( 79 self._jvm.java.util.ArrayList(), 80 self._jvm.PythonAccumulatorParam(host, port)) 81 82 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 83 # Broadcast's __reduce__ method stores Broadcast instances here. 84 # This allows other code to determine which Broadcast instances have 85 # been pickled, so it can determine which Java broadcast objects to 86 # send. 87 self._pickled_broadcast_vars = set() 88 89 # Deploy any code dependencies specified in the constructor 90 for path in (pyFiles or []): 91 self.addPyFile(path) 92 SparkFiles._sc = self 93 sys.path.append(SparkFiles.getRootDirectory()) 94 95 # Create a temporary directory inside spark.local.dir: 96 local_dir = self._jvm.spark.Utils.getLocalDir() 97 self._temp_dir = \ 98 self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
99 100 @property
101 - def defaultParallelism(self):
102 """ 103 Default level of parallelism to use when not given by user (e.g. for 104 reduce tasks) 105 """ 106 return self._jsc.sc().defaultParallelism()
107
108 - def __del__(self):
109 self.stop()
110
111 - def stop(self):
112 """ 113 Shut down the SparkContext. 114 """ 115 if self._jsc: 116 self._jsc.stop() 117 self._jsc = None 118 if self._accumulatorServer: 119 self._accumulatorServer.shutdown() 120 self._accumulatorServer = None 121 with SparkContext._lock: 122 SparkContext._active_spark_context = None
123
124 - def parallelize(self, c, numSlices=None):
125 """ 126 Distribute a local Python collection to form an RDD. 127 """ 128 numSlices = numSlices or self.defaultParallelism 129 # Calling the Java parallelize() method with an ArrayList is too slow, 130 # because it sends O(n) Py4J commands. As an alternative, serialized 131 # objects are written to a file and loaded through textFile(). 132 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 133 if self.batchSize != 1: 134 c = batched(c, self.batchSize) 135 for x in c: 136 write_with_length(dump_pickle(x), tempFile) 137 tempFile.close() 138 readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile 139 jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) 140 return RDD(jrdd, self)
141
142 - def textFile(self, name, minSplits=None):
143 """ 144 Read a text file from HDFS, a local file system (available on all 145 nodes), or any Hadoop-supported file system URI, and return it as an 146 RDD of Strings. 147 """ 148 minSplits = minSplits or min(self.defaultParallelism, 2) 149 jrdd = self._jsc.textFile(name, minSplits) 150 return RDD(jrdd, self)
151
152 - def _checkpointFile(self, name):
153 jrdd = self._jsc.checkpointFile(name) 154 return RDD(jrdd, self)
155
156 - def union(self, rdds):
157 """ 158 Build the union of a list of RDDs. 159 """ 160 first = rdds[0]._jrdd 161 rest = [x._jrdd for x in rdds[1:]] 162 rest = ListConverter().convert(rest, self.gateway._gateway_client) 163 return RDD(self._jsc.union(first, rest), self)
164
165 - def broadcast(self, value):
166 """ 167 Broadcast a read-only variable to the cluster, returning a C{Broadcast} 168 object for reading it in distributed functions. The variable will be 169 sent to each cluster only once. 170 """ 171 jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) 172 return Broadcast(jbroadcast.id(), value, jbroadcast, 173 self._pickled_broadcast_vars)
174
175 - def accumulator(self, value, accum_param=None):
176 """ 177 Create an L{Accumulator} with the given initial value, using a given 178 L{AccumulatorParam} helper object to define how to add values of the 179 data type if provided. Default AccumulatorParams are used for integers 180 and floating-point numbers if you do not provide one. For other types, 181 a custom AccumulatorParam can be used. 182 """ 183 if accum_param == None: 184 if isinstance(value, int): 185 accum_param = accumulators.INT_ACCUMULATOR_PARAM 186 elif isinstance(value, float): 187 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 188 elif isinstance(value, complex): 189 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 190 else: 191 raise Exception("No default accumulator param for type %s" % type(value)) 192 SparkContext._next_accum_id += 1 193 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
194
195 - def addFile(self, path):
196 """ 197 Add a file to be downloaded with this Spark job on every node. 198 The C{path} passed can be either a local file, a file in HDFS 199 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 200 FTP URI. 201 202 To access the file in Spark jobs, use 203 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 204 download location. 205 206 >>> from pyspark import SparkFiles 207 >>> path = os.path.join(tempdir, "test.txt") 208 >>> with open(path, "w") as testFile: 209 ... testFile.write("100") 210 >>> sc.addFile(path) 211 >>> def func(iterator): 212 ... with open(SparkFiles.get("test.txt")) as testFile: 213 ... fileVal = int(testFile.readline()) 214 ... return [x * 100 for x in iterator] 215 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 216 [100, 200, 300, 400] 217 """ 218 self._jsc.sc().addFile(path)
219
220 - def clearFiles(self):
221 """ 222 Clear the job's list of files added by L{addFile} or L{addPyFile} so 223 that they do not get downloaded to any new nodes. 224 """ 225 # TODO: remove added .py or .zip files from the PYTHONPATH? 226 self._jsc.sc().clearFiles()
227
228 - def addPyFile(self, path):
229 """ 230 Add a .py or .zip dependency for all tasks to be executed on this 231 SparkContext in the future. The C{path} passed can be either a local 232 file, a file in HDFS (or other Hadoop-supported filesystems), or an 233 HTTP, HTTPS or FTP URI. 234 """ 235 self.addFile(path) 236 filename = path.split("/")[-1]
237
238 - def setCheckpointDir(self, dirName, useExisting=False):
239 """ 240 Set the directory under which RDDs are going to be checkpointed. The 241 directory must be a HDFS path if running on a cluster. 242 243 If the directory does not exist, it will be created. If the directory 244 exists and C{useExisting} is set to true, then the exisiting directory 245 will be used. Otherwise an exception will be thrown to prevent 246 accidental overriding of checkpoint files in the existing directory. 247 """ 248 self._jsc.sc().setCheckpointDir(dirName, useExisting)
249
250 251 -def _test():
252 import atexit 253 import doctest 254 import tempfile 255 globs = globals().copy() 256 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 257 globs['tempdir'] = tempfile.mkdtemp() 258 atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 259 (failure_count, test_count) = doctest.testmod(globs=globs) 260 globs['sc'].stop() 261 if failure_count: 262 exit(-1)
263 264 265 if __name__ == "__main__": 266 _test() 267