1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 import os
19 import shutil
20 import sys
21 from threading import Lock
22 from tempfile import NamedTemporaryFile
23
24 from pyspark import accumulators
25 from pyspark.accumulators import Accumulator
26 from pyspark.broadcast import Broadcast
27 from pyspark.files import SparkFiles
28 from pyspark.java_gateway import launch_gateway
29 from pyspark.serializers import dump_pickle, write_with_length, batched
30 from pyspark.storagelevel import StorageLevel
31 from pyspark.rdd import RDD
32
33 from py4j.java_collections import ListConverter
34
35
36 -class SparkContext(object):
37 """
38 Main entry point for Spark functionality. A SparkContext represents the
39 connection to a Spark cluster, and can be used to create L{RDD}s and
40 broadcast variables on that cluster.
41 """
42
43 _gateway = None
44 _jvm = None
45 _writeIteratorToPickleFile = None
46 _takePartition = None
47 _next_accum_id = 0
48 _active_spark_context = None
49 _lock = Lock()
50 _python_includes = None
51
52
53 - def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
54 environment=None, batchSize=1024):
55 """
56 Create a new SparkContext.
57
58 @param master: Cluster URL to connect to
59 (e.g. mesos://host:port, spark://host:port, local[4]).
60 @param jobName: A name for your job, to display on the cluster web UI
61 @param sparkHome: Location where Spark is installed on cluster nodes.
62 @param pyFiles: Collection of .zip or .py files to send to the cluster
63 and add to PYTHONPATH. These can be paths on the local file
64 system or HDFS, HTTP, HTTPS, or FTP URLs.
65 @param environment: A dictionary of environment variables to set on
66 worker nodes.
67 @param batchSize: The number of Python objects represented as a single
68 Java object. Set 1 to disable batching or -1 to use an
69 unlimited batch size.
70
71
72 >>> from pyspark.context import SparkContext
73 >>> sc = SparkContext('local', 'test')
74
75 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
76 Traceback (most recent call last):
77 ...
78 ValueError:...
79 """
80 SparkContext._ensure_initialized(self)
81
82 self.master = master
83 self.jobName = jobName
84 self.sparkHome = sparkHome or None
85 self.environment = environment or {}
86 self.batchSize = batchSize
87
88
89 empty_string_array = self._gateway.new_array(self._jvm.String, 0)
90 self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
91 empty_string_array)
92
93
94
95 self._accumulatorServer = accumulators._start_update_server()
96 (host, port) = self._accumulatorServer.server_address
97 self._javaAccumulator = self._jsc.accumulator(
98 self._jvm.java.util.ArrayList(),
99 self._jvm.PythonAccumulatorParam(host, port))
100
101 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
102
103
104
105
106 self._pickled_broadcast_vars = set()
107
108 SparkFiles._sc = self
109 root_dir = SparkFiles.getRootDirectory()
110 sys.path.append(root_dir)
111
112
113 self._python_includes = list()
114 for path in (pyFiles or []):
115 self.addPyFile(path)
116
117
118 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir()
119 self._temp_dir = \
120 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
121
122 @classmethod
123 - def _ensure_initialized(cls, instance=None):
124 with SparkContext._lock:
125 if not SparkContext._gateway:
126 SparkContext._gateway = launch_gateway()
127 SparkContext._jvm = SparkContext._gateway.jvm
128 SparkContext._writeIteratorToPickleFile = \
129 SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
130 SparkContext._takePartition = \
131 SparkContext._jvm.PythonRDD.takePartition
132
133 if instance:
134 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
135 raise ValueError("Cannot run multiple SparkContexts at once")
136 else:
137 SparkContext._active_spark_context = instance
138
139 @classmethod
140 - def setSystemProperty(cls, key, value):
141 """
142 Set a system property, such as spark.executor.memory. This must be
143 invoked before instantiating SparkContext.
144 """
145 SparkContext._ensure_initialized()
146 SparkContext._jvm.java.lang.System.setProperty(key, value)
147
148 @property
150 """
151 Default level of parallelism to use when not given by user (e.g. for
152 reduce tasks)
153 """
154 return self._jsc.sc().defaultParallelism()
155
158
160 """
161 Shut down the SparkContext.
162 """
163 if self._jsc:
164 self._jsc.stop()
165 self._jsc = None
166 if self._accumulatorServer:
167 self._accumulatorServer.shutdown()
168 self._accumulatorServer = None
169 with SparkContext._lock:
170 SparkContext._active_spark_context = None
171
172 - def parallelize(self, c, numSlices=None):
173 """
174 Distribute a local Python collection to form an RDD.
175
176 >>> sc.parallelize(range(5), 5).glom().collect()
177 [[0], [1], [2], [3], [4]]
178 """
179 numSlices = numSlices or self.defaultParallelism
180
181
182
183 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
184
185 if "__len__" not in dir(c):
186 c = list(c)
187 batchSize = min(len(c) // numSlices, self.batchSize)
188 if batchSize > 1:
189 c = batched(c, batchSize)
190 for x in c:
191 write_with_length(dump_pickle(x), tempFile)
192 tempFile.close()
193 readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
194 jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
195 return RDD(jrdd, self)
196
197 - def textFile(self, name, minSplits=None):
198 """
199 Read a text file from HDFS, a local file system (available on all
200 nodes), or any Hadoop-supported file system URI, and return it as an
201 RDD of Strings.
202 """
203 minSplits = minSplits or min(self.defaultParallelism, 2)
204 jrdd = self._jsc.textFile(name, minSplits)
205 return RDD(jrdd, self)
206
207 - def _checkpointFile(self, name):
208 jrdd = self._jsc.checkpointFile(name)
209 return RDD(jrdd, self)
210
211 - def union(self, rdds):
212 """
213 Build the union of a list of RDDs.
214 """
215 first = rdds[0]._jrdd
216 rest = [x._jrdd for x in rdds[1:]]
217 rest = ListConverter().convert(rest, self.gateway._gateway_client)
218 return RDD(self._jsc.union(first, rest), self)
219
220 - def broadcast(self, value):
221 """
222 Broadcast a read-only variable to the cluster, returning a C{Broadcast}
223 object for reading it in distributed functions. The variable will be
224 sent to each cluster only once.
225 """
226 jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
227 return Broadcast(jbroadcast.id(), value, jbroadcast,
228 self._pickled_broadcast_vars)
229
230 - def accumulator(self, value, accum_param=None):
231 """
232 Create an L{Accumulator} with the given initial value, using a given
233 L{AccumulatorParam} helper object to define how to add values of the
234 data type if provided. Default AccumulatorParams are used for integers
235 and floating-point numbers if you do not provide one. For other types,
236 a custom AccumulatorParam can be used.
237 """
238 if accum_param == None:
239 if isinstance(value, int):
240 accum_param = accumulators.INT_ACCUMULATOR_PARAM
241 elif isinstance(value, float):
242 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
243 elif isinstance(value, complex):
244 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
245 else:
246 raise Exception("No default accumulator param for type %s" % type(value))
247 SparkContext._next_accum_id += 1
248 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
249
250 - def addFile(self, path):
251 """
252 Add a file to be downloaded with this Spark job on every node.
253 The C{path} passed can be either a local file, a file in HDFS
254 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
255 FTP URI.
256
257 To access the file in Spark jobs, use
258 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
259 download location.
260
261 >>> from pyspark import SparkFiles
262 >>> path = os.path.join(tempdir, "test.txt")
263 >>> with open(path, "w") as testFile:
264 ... testFile.write("100")
265 >>> sc.addFile(path)
266 >>> def func(iterator):
267 ... with open(SparkFiles.get("test.txt")) as testFile:
268 ... fileVal = int(testFile.readline())
269 ... return [x * 100 for x in iterator]
270 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
271 [100, 200, 300, 400]
272 """
273 self._jsc.sc().addFile(path)
274
275 - def clearFiles(self):
276 """
277 Clear the job's list of files added by L{addFile} or L{addPyFile} so
278 that they do not get downloaded to any new nodes.
279 """
280
281 self._jsc.sc().clearFiles()
282
283 - def addPyFile(self, path):
284 """
285 Add a .py or .zip dependency for all tasks to be executed on this
286 SparkContext in the future. The C{path} passed can be either a local
287 file, a file in HDFS (or other Hadoop-supported filesystems), or an
288 HTTP, HTTPS or FTP URI.
289 """
290 self.addFile(path)
291 (dirname, filename) = os.path.split(path)
292
293 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
294 self._python_includes.append(filename)
295 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
296
297 - def setCheckpointDir(self, dirName, useExisting=False):
298 """
299 Set the directory under which RDDs are going to be checkpointed. The
300 directory must be a HDFS path if running on a cluster.
301
302 If the directory does not exist, it will be created. If the directory
303 exists and C{useExisting} is set to true, then the exisiting directory
304 will be used. Otherwise an exception will be thrown to prevent
305 accidental overriding of checkpoint files in the existing directory.
306 """
307 self._jsc.sc().setCheckpointDir(dirName, useExisting)
308
309 - def _getJavaStorageLevel(self, storageLevel):
310 """
311 Returns a Java StorageLevel based on a pyspark.StorageLevel.
312 """
313 if not isinstance(storageLevel, StorageLevel):
314 raise Exception("storageLevel must be of type pyspark.StorageLevel")
315
316 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
317 return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory,
318 storageLevel.deserialized, storageLevel.replication)
319
321 import atexit
322 import doctest
323 import tempfile
324 globs = globals().copy()
325 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
326 globs['tempdir'] = tempfile.mkdtemp()
327 atexit.register(lambda: shutil.rmtree(globs['tempdir']))
328 (failure_count, test_count) = doctest.testmod(globs=globs)
329 globs['sc'].stop()
330 if failure_count:
331 exit(-1)
332
333
334 if __name__ == "__main__":
335 _test()
336