Package pyspark :: Module accumulators
[frames] | no frames]

Source Code for Module pyspark.accumulators

  1  """ 
  2  >>> from pyspark.context import SparkContext 
  3  >>> sc = SparkContext('local', 'test') 
  4  >>> a = sc.accumulator(1) 
  5  >>> a.value 
  6  1 
  7  >>> a.value = 2 
  8  >>> a.value 
  9  2 
 10  >>> a += 5 
 11  >>> a.value 
 12  7 
 13   
 14  >>> sc.accumulator(1.0).value 
 15  1.0 
 16   
 17  >>> sc.accumulator(1j).value 
 18  1j 
 19   
 20  >>> rdd = sc.parallelize([1,2,3]) 
 21  >>> def f(x): 
 22  ...     global a 
 23  ...     a += x 
 24  >>> rdd.foreach(f) 
 25  >>> a.value 
 26  13 
 27   
 28  >>> from pyspark.accumulators import AccumulatorParam 
 29  >>> class VectorAccumulatorParam(AccumulatorParam): 
 30  ...     def zero(self, value): 
 31  ...         return [0.0] * len(value) 
 32  ...     def addInPlace(self, val1, val2): 
 33  ...         for i in xrange(len(val1)): 
 34  ...              val1[i] += val2[i] 
 35  ...         return val1 
 36  >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) 
 37  >>> va.value 
 38  [1.0, 2.0, 3.0] 
 39  >>> def g(x): 
 40  ...     global va 
 41  ...     va += [x] * 3 
 42  >>> rdd.foreach(g) 
 43  >>> va.value 
 44  [7.0, 8.0, 9.0] 
 45   
 46  >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL 
 47  Traceback (most recent call last): 
 48      ... 
 49  Py4JJavaError:... 
 50   
 51  >>> def h(x): 
 52  ...     global a 
 53  ...     a.value = 7 
 54  >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL 
 55  Traceback (most recent call last): 
 56      ... 
 57  Py4JJavaError:... 
 58   
 59  >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL 
 60  Traceback (most recent call last): 
 61      ... 
 62  Exception:... 
 63  """ 
 64   
 65  import struct 
 66  import SocketServer 
 67  import threading 
 68  from pyspark.cloudpickle import CloudPickler 
 69  from pyspark.serializers import read_int, read_with_length, load_pickle 
 70   
 71   
 72  # Holds accumulators registered on the current machine, keyed by ID. This is then used to send 
 73  # the local accumulator updates back to the driver program at the end of a task. 
 74  _accumulatorRegistry = {} 
75 76 77 -def _deserialize_accumulator(aid, zero_value, accum_param):
78 from pyspark.accumulators import _accumulatorRegistry 79 accum = Accumulator(aid, zero_value, accum_param) 80 accum._deserialized = True 81 _accumulatorRegistry[aid] = accum 82 return accum
83
84 85 -class Accumulator(object):
86 """ 87 A shared variable that can be accumulated, i.e., has a commutative and associative "add" 88 operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} 89 operator, but only the driver program is allowed to access its value, using C{value}. 90 Updates from the workers get propagated automatically to the driver program. 91 92 While C{SparkContext} supports accumulators for primitive data types like C{int} and 93 C{float}, users can also define accumulators for custom types by providing a custom 94 L{AccumulatorParam} object. Refer to the doctest of this module for an example. 95 """ 96
97 - def __init__(self, aid, value, accum_param):
98 """Create a new Accumulator with a given initial value and AccumulatorParam object""" 99 from pyspark.accumulators import _accumulatorRegistry 100 self.aid = aid 101 self.accum_param = accum_param 102 self._value = value 103 self._deserialized = False 104 _accumulatorRegistry[aid] = self
105
106 - def __reduce__(self):
107 """Custom serialization; saves the zero value from our AccumulatorParam""" 108 param = self.accum_param 109 return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
110 111 @property
112 - def value(self):
113 """Get the accumulator's value; only usable in driver program""" 114 if self._deserialized: 115 raise Exception("Accumulator.value cannot be accessed inside tasks") 116 return self._value
117 118 @value.setter
119 - def value(self, value):
120 """Sets the accumulator's value; only usable in driver program""" 121 if self._deserialized: 122 raise Exception("Accumulator.value cannot be accessed inside tasks") 123 self._value = value
124
125 - def __iadd__(self, term):
126 """The += operator; adds a term to this accumulator's value""" 127 self._value = self.accum_param.addInPlace(self._value, term) 128 return self
129
130 - def __str__(self):
131 return str(self._value)
132
133 - def __repr__(self):
134 return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
135
136 137 -class AccumulatorParam(object):
138 """ 139 Helper object that defines how to accumulate values of a given type. 140 """ 141
142 - def zero(self, value):
143 """ 144 Provide a "zero value" for the type, compatible in dimensions with the 145 provided C{value} (e.g., a zero vector) 146 """ 147 raise NotImplementedError
148
149 - def addInPlace(self, value1, value2):
150 """ 151 Add two values of the accumulator's data type, returning a new value; 152 for efficiency, can also update C{value1} in place and return it. 153 """ 154 raise NotImplementedError
155
156 157 -class AddingAccumulatorParam(AccumulatorParam):
158 """ 159 An AccumulatorParam that uses the + operators to add values. Designed for simple types 160 such as integers, floats, and lists. Requires the zero value for the underlying type 161 as a parameter. 162 """ 163
164 - def __init__(self, zero_value):
165 self.zero_value = zero_value
166
167 - def zero(self, value):
168 return self.zero_value
169
170 - def addInPlace(self, value1, value2):
171 value1 += value2 172 return value1
173 174 175 # Singleton accumulator params for some standard types 176 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) 177 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) 178 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
179 180 181 -class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
182 - def handle(self):
183 from pyspark.accumulators import _accumulatorRegistry 184 num_updates = read_int(self.rfile) 185 for _ in range(num_updates): 186 (aid, update) = load_pickle(read_with_length(self.rfile)) 187 _accumulatorRegistry[aid] += update 188 # Write a byte in acknowledgement 189 self.wfile.write(struct.pack("!b", 1))
190
191 192 -def _start_update_server():
193 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" 194 server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) 195 thread = threading.Thread(target=server.serve_forever) 196 thread.daemon = True 197 thread.start() 198 return server
199