1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 """
19 >>> from pyspark.context import SparkContext
20 >>> sc = SparkContext('local', 'test')
21 >>> a = sc.accumulator(1)
22 >>> a.value
23 1
24 >>> a.value = 2
25 >>> a.value
26 2
27 >>> a += 5
28 >>> a.value
29 7
30
31 >>> sc.accumulator(1.0).value
32 1.0
33
34 >>> sc.accumulator(1j).value
35 1j
36
37 >>> rdd = sc.parallelize([1,2,3])
38 >>> def f(x):
39 ... global a
40 ... a += x
41 >>> rdd.foreach(f)
42 >>> a.value
43 13
44
45 >>> from pyspark.accumulators import AccumulatorParam
46 >>> class VectorAccumulatorParam(AccumulatorParam):
47 ... def zero(self, value):
48 ... return [0.0] * len(value)
49 ... def addInPlace(self, val1, val2):
50 ... for i in xrange(len(val1)):
51 ... val1[i] += val2[i]
52 ... return val1
53 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
54 >>> va.value
55 [1.0, 2.0, 3.0]
56 >>> def g(x):
57 ... global va
58 ... va += [x] * 3
59 >>> rdd.foreach(g)
60 >>> va.value
61 [7.0, 8.0, 9.0]
62
63 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
64 Traceback (most recent call last):
65 ...
66 Py4JJavaError:...
67
68 >>> def h(x):
69 ... global a
70 ... a.value = 7
71 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
72 Traceback (most recent call last):
73 ...
74 Py4JJavaError:...
75
76 >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
77 Traceback (most recent call last):
78 ...
79 Exception:...
80 """
81
82 import struct
83 import SocketServer
84 import threading
85 from pyspark.cloudpickle import CloudPickler
86 from pyspark.serializers import read_int, read_with_length, load_pickle
87
88
89
90
91 _accumulatorRegistry = {}
100
103 """
104 A shared variable that can be accumulated, i.e., has a commutative and associative "add"
105 operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
106 operator, but only the driver program is allowed to access its value, using C{value}.
107 Updates from the workers get propagated automatically to the driver program.
108
109 While C{SparkContext} supports accumulators for primitive data types like C{int} and
110 C{float}, users can also define accumulators for custom types by providing a custom
111 L{AccumulatorParam} object. Refer to the doctest of this module for an example.
112 """
113
114 - def __init__(self, aid, value, accum_param):
122
124 """Custom serialization; saves the zero value from our AccumulatorParam"""
125 param = self.accum_param
126 return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
127
128 @property
130 """Get the accumulator's value; only usable in driver program"""
131 if self._deserialized:
132 raise Exception("Accumulator.value cannot be accessed inside tasks")
133 return self._value
134
135 @value.setter
137 """Sets the accumulator's value; only usable in driver program"""
138 if self._deserialized:
139 raise Exception("Accumulator.value cannot be accessed inside tasks")
140 self._value = value
141
143 """The += operator; adds a term to this accumulator's value"""
144 self._value = self.accum_param.addInPlace(self._value, term)
145 return self
146
148 return str(self._value)
149
151 return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
152
155 """
156 Helper object that defines how to accumulate values of a given type.
157 """
158
159 - def zero(self, value):
160 """
161 Provide a "zero value" for the type, compatible in dimensions with the
162 provided C{value} (e.g., a zero vector)
163 """
164 raise NotImplementedError
165
167 """
168 Add two values of the accumulator's data type, returning a new value;
169 for efficiency, can also update C{value1} in place and return it.
170 """
171 raise NotImplementedError
172
175 """
176 An AccumulatorParam that uses the + operators to add values. Designed for simple types
177 such as integers, floats, and lists. Requires the zero value for the underlying type
178 as a parameter.
179 """
180
182 self.zero_value = zero_value
183
184 - def zero(self, value):
185 return self.zero_value
186
188 value1 += value2
189 return value1
190
191
192
193 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
194 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
195 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
207
210 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
211 server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
212 thread = threading.Thread(target=server.serve_forever)
213 thread.daemon = True
214 thread.start()
215 return server
216