Package pyspark :: Module context
[frames] | no frames]

Source Code for Module pyspark.context

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  import os 
 19  import shutil 
 20  import sys 
 21  from threading import Lock 
 22  from tempfile import NamedTemporaryFile 
 23   
 24  from pyspark import accumulators 
 25  from pyspark.accumulators import Accumulator 
 26  from pyspark.broadcast import Broadcast 
 27  from pyspark.files import SparkFiles 
 28  from pyspark.java_gateway import launch_gateway 
 29  from pyspark.serializers import dump_pickle, write_with_length, batched 
 30  from pyspark.storagelevel import StorageLevel 
 31  from pyspark.rdd import RDD 
 32   
 33  from py4j.java_collections import ListConverter 
34 35 36 -class SparkContext(object):
37 """ 38 Main entry point for Spark functionality. A SparkContext represents the 39 connection to a Spark cluster, and can be used to create L{RDD}s and 40 broadcast variables on that cluster. 41 """ 42 43 _gateway = None 44 _jvm = None 45 _writeIteratorToPickleFile = None 46 _takePartition = None 47 _next_accum_id = 0 48 _active_spark_context = None 49 _lock = Lock() 50 _python_includes = None # zip and egg files that need to be added to PYTHONPATH 51 52
53 - def __init__(self, master, jobName, sparkHome=None, pyFiles=None, 54 environment=None, batchSize=1024):
55 """ 56 Create a new SparkContext. 57 58 @param master: Cluster URL to connect to 59 (e.g. mesos://host:port, spark://host:port, local[4]). 60 @param jobName: A name for your job, to display on the cluster web UI 61 @param sparkHome: Location where Spark is installed on cluster nodes. 62 @param pyFiles: Collection of .zip or .py files to send to the cluster 63 and add to PYTHONPATH. These can be paths on the local file 64 system or HDFS, HTTP, HTTPS, or FTP URLs. 65 @param environment: A dictionary of environment variables to set on 66 worker nodes. 67 @param batchSize: The number of Python objects represented as a single 68 Java object. Set 1 to disable batching or -1 to use an 69 unlimited batch size. 70 71 72 >>> from pyspark.context import SparkContext 73 >>> sc = SparkContext('local', 'test') 74 75 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL 76 Traceback (most recent call last): 77 ... 78 ValueError:... 79 """ 80 SparkContext._ensure_initialized(self) 81 82 self.master = master 83 self.jobName = jobName 84 self.sparkHome = sparkHome or None # None becomes null in Py4J 85 self.environment = environment or {} 86 self.batchSize = batchSize # -1 represents a unlimited batch size 87 88 # Create the Java SparkContext through Py4J 89 empty_string_array = self._gateway.new_array(self._jvm.String, 0) 90 self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, 91 empty_string_array) 92 93 # Create a single Accumulator in Java that we'll send all our updates through; 94 # they will be passed back to us through a TCP server 95 self._accumulatorServer = accumulators._start_update_server() 96 (host, port) = self._accumulatorServer.server_address 97 self._javaAccumulator = self._jsc.accumulator( 98 self._jvm.java.util.ArrayList(), 99 self._jvm.PythonAccumulatorParam(host, port)) 100 101 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 102 # Broadcast's __reduce__ method stores Broadcast instances here. 103 # This allows other code to determine which Broadcast instances have 104 # been pickled, so it can determine which Java broadcast objects to 105 # send. 106 self._pickled_broadcast_vars = set() 107 108 SparkFiles._sc = self 109 root_dir = SparkFiles.getRootDirectory() 110 sys.path.append(root_dir) 111 112 # Deploy any code dependencies specified in the constructor 113 self._python_includes = list() 114 for path in (pyFiles or []): 115 self.addPyFile(path) 116 117 # Create a temporary directory inside spark.local.dir: 118 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir() 119 self._temp_dir = \ 120 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
121 122 @classmethod
123 - def _ensure_initialized(cls, instance=None):
124 with SparkContext._lock: 125 if not SparkContext._gateway: 126 SparkContext._gateway = launch_gateway() 127 SparkContext._jvm = SparkContext._gateway.jvm 128 SparkContext._writeIteratorToPickleFile = \ 129 SparkContext._jvm.PythonRDD.writeIteratorToPickleFile 130 SparkContext._takePartition = \ 131 SparkContext._jvm.PythonRDD.takePartition 132 133 if instance: 134 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: 135 raise ValueError("Cannot run multiple SparkContexts at once") 136 else: 137 SparkContext._active_spark_context = instance
138 139 @classmethod
140 - def setSystemProperty(cls, key, value):
141 """ 142 Set a system property, such as spark.executor.memory. This must be 143 invoked before instantiating SparkContext. 144 """ 145 SparkContext._ensure_initialized() 146 SparkContext._jvm.java.lang.System.setProperty(key, value)
147 148 @property
149 - def defaultParallelism(self):
150 """ 151 Default level of parallelism to use when not given by user (e.g. for 152 reduce tasks) 153 """ 154 return self._jsc.sc().defaultParallelism()
155
156 - def __del__(self):
157 self.stop()
158
159 - def stop(self):
160 """ 161 Shut down the SparkContext. 162 """ 163 if self._jsc: 164 self._jsc.stop() 165 self._jsc = None 166 if self._accumulatorServer: 167 self._accumulatorServer.shutdown() 168 self._accumulatorServer = None 169 with SparkContext._lock: 170 SparkContext._active_spark_context = None
171
172 - def parallelize(self, c, numSlices=None):
173 """ 174 Distribute a local Python collection to form an RDD. 175 176 >>> sc.parallelize(range(5), 5).glom().collect() 177 [[0], [1], [2], [3], [4]] 178 """ 179 numSlices = numSlices or self.defaultParallelism 180 # Calling the Java parallelize() method with an ArrayList is too slow, 181 # because it sends O(n) Py4J commands. As an alternative, serialized 182 # objects are written to a file and loaded through textFile(). 183 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 184 # Make sure we distribute data evenly if it's smaller than self.batchSize 185 if "__len__" not in dir(c): 186 c = list(c) # Make it a list so we can compute its length 187 batchSize = min(len(c) // numSlices, self.batchSize) 188 if batchSize > 1: 189 c = batched(c, batchSize) 190 for x in c: 191 write_with_length(dump_pickle(x), tempFile) 192 tempFile.close() 193 readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile 194 jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) 195 return RDD(jrdd, self)
196
197 - def textFile(self, name, minSplits=None):
198 """ 199 Read a text file from HDFS, a local file system (available on all 200 nodes), or any Hadoop-supported file system URI, and return it as an 201 RDD of Strings. 202 """ 203 minSplits = minSplits or min(self.defaultParallelism, 2) 204 jrdd = self._jsc.textFile(name, minSplits) 205 return RDD(jrdd, self)
206
207 - def _checkpointFile(self, name):
208 jrdd = self._jsc.checkpointFile(name) 209 return RDD(jrdd, self)
210
211 - def union(self, rdds):
212 """ 213 Build the union of a list of RDDs. 214 """ 215 first = rdds[0]._jrdd 216 rest = [x._jrdd for x in rdds[1:]] 217 rest = ListConverter().convert(rest, self.gateway._gateway_client) 218 return RDD(self._jsc.union(first, rest), self)
219
220 - def broadcast(self, value):
221 """ 222 Broadcast a read-only variable to the cluster, returning a C{Broadcast} 223 object for reading it in distributed functions. The variable will be 224 sent to each cluster only once. 225 """ 226 jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) 227 return Broadcast(jbroadcast.id(), value, jbroadcast, 228 self._pickled_broadcast_vars)
229
230 - def accumulator(self, value, accum_param=None):
231 """ 232 Create an L{Accumulator} with the given initial value, using a given 233 L{AccumulatorParam} helper object to define how to add values of the 234 data type if provided. Default AccumulatorParams are used for integers 235 and floating-point numbers if you do not provide one. For other types, 236 a custom AccumulatorParam can be used. 237 """ 238 if accum_param == None: 239 if isinstance(value, int): 240 accum_param = accumulators.INT_ACCUMULATOR_PARAM 241 elif isinstance(value, float): 242 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 243 elif isinstance(value, complex): 244 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 245 else: 246 raise Exception("No default accumulator param for type %s" % type(value)) 247 SparkContext._next_accum_id += 1 248 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
249
250 - def addFile(self, path):
251 """ 252 Add a file to be downloaded with this Spark job on every node. 253 The C{path} passed can be either a local file, a file in HDFS 254 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 255 FTP URI. 256 257 To access the file in Spark jobs, use 258 L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 259 download location. 260 261 >>> from pyspark import SparkFiles 262 >>> path = os.path.join(tempdir, "test.txt") 263 >>> with open(path, "w") as testFile: 264 ... testFile.write("100") 265 >>> sc.addFile(path) 266 >>> def func(iterator): 267 ... with open(SparkFiles.get("test.txt")) as testFile: 268 ... fileVal = int(testFile.readline()) 269 ... return [x * 100 for x in iterator] 270 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 271 [100, 200, 300, 400] 272 """ 273 self._jsc.sc().addFile(path)
274
275 - def clearFiles(self):
276 """ 277 Clear the job's list of files added by L{addFile} or L{addPyFile} so 278 that they do not get downloaded to any new nodes. 279 """ 280 # TODO: remove added .py or .zip files from the PYTHONPATH? 281 self._jsc.sc().clearFiles()
282
283 - def addPyFile(self, path):
284 """ 285 Add a .py or .zip dependency for all tasks to be executed on this 286 SparkContext in the future. The C{path} passed can be either a local 287 file, a file in HDFS (or other Hadoop-supported filesystems), or an 288 HTTP, HTTPS or FTP URI. 289 """ 290 self.addFile(path) 291 (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix 292 293 if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): 294 self._python_includes.append(filename) 295 sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
296
297 - def setCheckpointDir(self, dirName, useExisting=False):
298 """ 299 Set the directory under which RDDs are going to be checkpointed. The 300 directory must be a HDFS path if running on a cluster. 301 302 If the directory does not exist, it will be created. If the directory 303 exists and C{useExisting} is set to true, then the exisiting directory 304 will be used. Otherwise an exception will be thrown to prevent 305 accidental overriding of checkpoint files in the existing directory. 306 """ 307 self._jsc.sc().setCheckpointDir(dirName, useExisting)
308
309 - def _getJavaStorageLevel(self, storageLevel):
310 """ 311 Returns a Java StorageLevel based on a pyspark.StorageLevel. 312 """ 313 if not isinstance(storageLevel, StorageLevel): 314 raise Exception("storageLevel must be of type pyspark.StorageLevel") 315 316 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel 317 return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, 318 storageLevel.deserialized, storageLevel.replication)
319
320 -def _test():
321 import atexit 322 import doctest 323 import tempfile 324 globs = globals().copy() 325 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 326 globs['tempdir'] = tempfile.mkdtemp() 327 atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 328 (failure_count, test_count) = doctest.testmod(globs=globs) 329 globs['sc'].stop() 330 if failure_count: 331 exit(-1)
332 333 334 if __name__ == "__main__": 335 _test() 336