1 import os
2 import shutil
3 import sys
4 from threading import Lock
5 from tempfile import NamedTemporaryFile
6
7 from pyspark import accumulators
8 from pyspark.accumulators import Accumulator
9 from pyspark.broadcast import Broadcast
10 from pyspark.files import SparkFiles
11 from pyspark.java_gateway import launch_gateway
12 from pyspark.serializers import dump_pickle, write_with_length, batched
13 from pyspark.rdd import RDD
14
15 from py4j.java_collections import ListConverter
16
17
18 -class SparkContext(object):
19 """
20 Main entry point for Spark functionality. A SparkContext represents the
21 connection to a Spark cluster, and can be used to create L{RDD}s and
22 broadcast variables on that cluster.
23 """
24
25 _gateway = None
26 _jvm = None
27 _writeIteratorToPickleFile = None
28 _takePartition = None
29 _next_accum_id = 0
30 _active_spark_context = None
31 _lock = Lock()
32
33 - def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
34 environment=None, batchSize=1024):
35 """
36 Create a new SparkContext.
37
38 @param master: Cluster URL to connect to
39 (e.g. mesos://host:port, spark://host:port, local[4]).
40 @param jobName: A name for your job, to display on the cluster web UI
41 @param sparkHome: Location where Spark is installed on cluster nodes.
42 @param pyFiles: Collection of .zip or .py files to send to the cluster
43 and add to PYTHONPATH. These can be paths on the local file
44 system or HDFS, HTTP, HTTPS, or FTP URLs.
45 @param environment: A dictionary of environment variables to set on
46 worker nodes.
47 @param batchSize: The number of Python objects represented as a single
48 Java object. Set 1 to disable batching or -1 to use an
49 unlimited batch size.
50 """
51 with SparkContext._lock:
52 if SparkContext._active_spark_context:
53 raise ValueError("Cannot run multiple SparkContexts at once")
54 else:
55 SparkContext._active_spark_context = self
56 if not SparkContext._gateway:
57 SparkContext._gateway = launch_gateway()
58 SparkContext._jvm = SparkContext._gateway.jvm
59 SparkContext._writeIteratorToPickleFile = \
60 SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
61 SparkContext._takePartition = \
62 SparkContext._jvm.PythonRDD.takePartition
63 self.master = master
64 self.jobName = jobName
65 self.sparkHome = sparkHome or None
66 self.environment = environment or {}
67 self.batchSize = batchSize
68
69
70 empty_string_array = self._gateway.new_array(self._jvm.String, 0)
71 self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
72 empty_string_array)
73
74
75
76 self._accumulatorServer = accumulators._start_update_server()
77 (host, port) = self._accumulatorServer.server_address
78 self._javaAccumulator = self._jsc.accumulator(
79 self._jvm.java.util.ArrayList(),
80 self._jvm.PythonAccumulatorParam(host, port))
81
82 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
83
84
85
86
87 self._pickled_broadcast_vars = set()
88
89
90 for path in (pyFiles or []):
91 self.addPyFile(path)
92 SparkFiles._sc = self
93 sys.path.append(SparkFiles.getRootDirectory())
94
95
96 local_dir = self._jvm.spark.Utils.getLocalDir()
97 self._temp_dir = \
98 self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
99
100 @property
102 """
103 Default level of parallelism to use when not given by user (e.g. for
104 reduce tasks)
105 """
106 return self._jsc.sc().defaultParallelism()
107
110
112 """
113 Shut down the SparkContext.
114 """
115 if self._jsc:
116 self._jsc.stop()
117 self._jsc = None
118 if self._accumulatorServer:
119 self._accumulatorServer.shutdown()
120 self._accumulatorServer = None
121 with SparkContext._lock:
122 SparkContext._active_spark_context = None
123
124 - def parallelize(self, c, numSlices=None):
125 """
126 Distribute a local Python collection to form an RDD.
127 """
128 numSlices = numSlices or self.defaultParallelism
129
130
131
132 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
133 if self.batchSize != 1:
134 c = batched(c, self.batchSize)
135 for x in c:
136 write_with_length(dump_pickle(x), tempFile)
137 tempFile.close()
138 readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
139 jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
140 return RDD(jrdd, self)
141
142 - def textFile(self, name, minSplits=None):
143 """
144 Read a text file from HDFS, a local file system (available on all
145 nodes), or any Hadoop-supported file system URI, and return it as an
146 RDD of Strings.
147 """
148 minSplits = minSplits or min(self.defaultParallelism, 2)
149 jrdd = self._jsc.textFile(name, minSplits)
150 return RDD(jrdd, self)
151
152 - def _checkpointFile(self, name):
153 jrdd = self._jsc.checkpointFile(name)
154 return RDD(jrdd, self)
155
156 - def union(self, rdds):
157 """
158 Build the union of a list of RDDs.
159 """
160 first = rdds[0]._jrdd
161 rest = [x._jrdd for x in rdds[1:]]
162 rest = ListConverter().convert(rest, self.gateway._gateway_client)
163 return RDD(self._jsc.union(first, rest), self)
164
165 - def broadcast(self, value):
166 """
167 Broadcast a read-only variable to the cluster, returning a C{Broadcast}
168 object for reading it in distributed functions. The variable will be
169 sent to each cluster only once.
170 """
171 jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
172 return Broadcast(jbroadcast.id(), value, jbroadcast,
173 self._pickled_broadcast_vars)
174
175 - def accumulator(self, value, accum_param=None):
176 """
177 Create an L{Accumulator} with the given initial value, using a given
178 L{AccumulatorParam} helper object to define how to add values of the
179 data type if provided. Default AccumulatorParams are used for integers
180 and floating-point numbers if you do not provide one. For other types,
181 a custom AccumulatorParam can be used.
182 """
183 if accum_param == None:
184 if isinstance(value, int):
185 accum_param = accumulators.INT_ACCUMULATOR_PARAM
186 elif isinstance(value, float):
187 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
188 elif isinstance(value, complex):
189 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
190 else:
191 raise Exception("No default accumulator param for type %s" % type(value))
192 SparkContext._next_accum_id += 1
193 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
194
195 - def addFile(self, path):
196 """
197 Add a file to be downloaded with this Spark job on every node.
198 The C{path} passed can be either a local file, a file in HDFS
199 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
200 FTP URI.
201
202 To access the file in Spark jobs, use
203 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
204 download location.
205
206 >>> from pyspark import SparkFiles
207 >>> path = os.path.join(tempdir, "test.txt")
208 >>> with open(path, "w") as testFile:
209 ... testFile.write("100")
210 >>> sc.addFile(path)
211 >>> def func(iterator):
212 ... with open(SparkFiles.get("test.txt")) as testFile:
213 ... fileVal = int(testFile.readline())
214 ... return [x * 100 for x in iterator]
215 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
216 [100, 200, 300, 400]
217 """
218 self._jsc.sc().addFile(path)
219
220 - def clearFiles(self):
221 """
222 Clear the job's list of files added by L{addFile} or L{addPyFile} so
223 that they do not get downloaded to any new nodes.
224 """
225
226 self._jsc.sc().clearFiles()
227
228 - def addPyFile(self, path):
229 """
230 Add a .py or .zip dependency for all tasks to be executed on this
231 SparkContext in the future. The C{path} passed can be either a local
232 file, a file in HDFS (or other Hadoop-supported filesystems), or an
233 HTTP, HTTPS or FTP URI.
234 """
235 self.addFile(path)
236 filename = path.split("/")[-1]
237
238 - def setCheckpointDir(self, dirName, useExisting=False):
239 """
240 Set the directory under which RDDs are going to be checkpointed. The
241 directory must be a HDFS path if running on a cluster.
242
243 If the directory does not exist, it will be created. If the directory
244 exists and C{useExisting} is set to true, then the exisiting directory
245 will be used. Otherwise an exception will be thrown to prevent
246 accidental overriding of checkpoint files in the existing directory.
247 """
248 self._jsc.sc().setCheckpointDir(dirName, useExisting)
249
252 import atexit
253 import doctest
254 import tempfile
255 globs = globals().copy()
256 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
257 globs['tempdir'] = tempfile.mkdtemp()
258 atexit.register(lambda: shutil.rmtree(globs['tempdir']))
259 (failure_count, test_count) = doctest.testmod(globs=globs)
260 globs['sc'].stop()
261 if failure_count:
262 exit(-1)
263
264
265 if __name__ == "__main__":
266 _test()
267