Package pyspark :: Module accumulators
[frames] | no frames]

Source Code for Module pyspark.accumulators

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  """ 
 19  >>> from pyspark.context import SparkContext 
 20  >>> sc = SparkContext('local', 'test') 
 21  >>> a = sc.accumulator(1) 
 22  >>> a.value 
 23  1 
 24  >>> a.value = 2 
 25  >>> a.value 
 26  2 
 27  >>> a += 5 
 28  >>> a.value 
 29  7 
 30   
 31  >>> sc.accumulator(1.0).value 
 32  1.0 
 33   
 34  >>> sc.accumulator(1j).value 
 35  1j 
 36   
 37  >>> rdd = sc.parallelize([1,2,3]) 
 38  >>> def f(x): 
 39  ...     global a 
 40  ...     a += x 
 41  >>> rdd.foreach(f) 
 42  >>> a.value 
 43  13 
 44   
 45  >>> b = sc.accumulator(0) 
 46  >>> def g(x): 
 47  ...     b.add(x) 
 48  >>> rdd.foreach(g) 
 49  >>> b.value 
 50  6 
 51   
 52  >>> from pyspark.accumulators import AccumulatorParam 
 53  >>> class VectorAccumulatorParam(AccumulatorParam): 
 54  ...     def zero(self, value): 
 55  ...         return [0.0] * len(value) 
 56  ...     def addInPlace(self, val1, val2): 
 57  ...         for i in xrange(len(val1)): 
 58  ...              val1[i] += val2[i] 
 59  ...         return val1 
 60  >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) 
 61  >>> va.value 
 62  [1.0, 2.0, 3.0] 
 63  >>> def g(x): 
 64  ...     global va 
 65  ...     va += [x] * 3 
 66  >>> rdd.foreach(g) 
 67  >>> va.value 
 68  [7.0, 8.0, 9.0] 
 69   
 70  >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL 
 71  Traceback (most recent call last): 
 72      ... 
 73  Py4JJavaError:... 
 74   
 75  >>> def h(x): 
 76  ...     global a 
 77  ...     a.value = 7 
 78  >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL 
 79  Traceback (most recent call last): 
 80      ... 
 81  Py4JJavaError:... 
 82   
 83  >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL 
 84  Traceback (most recent call last): 
 85      ... 
 86  Exception:... 
 87  """ 
 88   
 89  import select 
 90  import struct 
 91  import SocketServer 
 92  import threading 
 93  from pyspark.cloudpickle import CloudPickler 
 94  from pyspark.serializers import read_int, PickleSerializer 
 95   
 96   
 97  pickleSer = PickleSerializer() 
 98   
 99  # Holds accumulators registered on the current machine, keyed by ID. This is then used to send 
100  # the local accumulator updates back to the driver program at the end of a task. 
101  _accumulatorRegistry = {} 
102 103 104 -def _deserialize_accumulator(aid, zero_value, accum_param):
105 from pyspark.accumulators import _accumulatorRegistry 106 accum = Accumulator(aid, zero_value, accum_param) 107 accum._deserialized = True 108 _accumulatorRegistry[aid] = accum 109 return accum
110
111 112 -class Accumulator(object):
113 114 """ 115 A shared variable that can be accumulated, i.e., has a commutative and associative "add" 116 operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} 117 operator, but only the driver program is allowed to access its value, using C{value}. 118 Updates from the workers get propagated automatically to the driver program. 119 120 While C{SparkContext} supports accumulators for primitive data types like C{int} and 121 C{float}, users can also define accumulators for custom types by providing a custom 122 L{AccumulatorParam} object. Refer to the doctest of this module for an example. 123 """ 124
125 - def __init__(self, aid, value, accum_param):
126 """Create a new Accumulator with a given initial value and AccumulatorParam object""" 127 from pyspark.accumulators import _accumulatorRegistry 128 self.aid = aid 129 self.accum_param = accum_param 130 self._value = value 131 self._deserialized = False 132 _accumulatorRegistry[aid] = self
133
134 - def __reduce__(self):
135 """Custom serialization; saves the zero value from our AccumulatorParam""" 136 param = self.accum_param 137 return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
138 139 @property
140 - def value(self):
141 """Get the accumulator's value; only usable in driver program""" 142 if self._deserialized: 143 raise Exception("Accumulator.value cannot be accessed inside tasks") 144 return self._value
145 146 @value.setter
147 - def value(self, value):
148 """Sets the accumulator's value; only usable in driver program""" 149 if self._deserialized: 150 raise Exception("Accumulator.value cannot be accessed inside tasks") 151 self._value = value
152
153 - def add(self, term):
154 """Adds a term to this accumulator's value""" 155 self._value = self.accum_param.addInPlace(self._value, term)
156
157 - def __iadd__(self, term):
158 """The += operator; adds a term to this accumulator's value""" 159 self.add(term) 160 return self
161
162 - def __str__(self):
163 return str(self._value)
164
165 - def __repr__(self):
166 return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
167
168 169 -class AccumulatorParam(object):
170 171 """ 172 Helper object that defines how to accumulate values of a given type. 173 """ 174
175 - def zero(self, value):
176 """ 177 Provide a "zero value" for the type, compatible in dimensions with the 178 provided C{value} (e.g., a zero vector) 179 """ 180 raise NotImplementedError
181
182 - def addInPlace(self, value1, value2):
183 """ 184 Add two values of the accumulator's data type, returning a new value; 185 for efficiency, can also update C{value1} in place and return it. 186 """ 187 raise NotImplementedError
188
189 190 -class AddingAccumulatorParam(AccumulatorParam):
191 192 """ 193 An AccumulatorParam that uses the + operators to add values. Designed for simple types 194 such as integers, floats, and lists. Requires the zero value for the underlying type 195 as a parameter. 196 """ 197
198 - def __init__(self, zero_value):
199 self.zero_value = zero_value
200
201 - def zero(self, value):
202 return self.zero_value
203
204 - def addInPlace(self, value1, value2):
205 value1 += value2 206 return value1
207 208 209 # Singleton accumulator params for some standard types 210 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) 211 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) 212 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
213 214 215 -class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
216 217 """ 218 This handler will keep polling updates from the same socket until the 219 server is shutdown. 220 """ 221
222 - def handle(self):
223 from pyspark.accumulators import _accumulatorRegistry 224 while not self.server.server_shutdown: 225 # Poll every 1 second for new data -- don't block in case of shutdown. 226 r, _, _ = select.select([self.rfile], [], [], 1) 227 if self.rfile in r: 228 num_updates = read_int(self.rfile) 229 for _ in range(num_updates): 230 (aid, update) = pickleSer._read_with_length(self.rfile) 231 _accumulatorRegistry[aid] += update 232 # Write a byte in acknowledgement 233 self.wfile.write(struct.pack("!b", 1))
234
235 236 -class AccumulatorServer(SocketServer.TCPServer):
237 238 """ 239 A simple TCP server that intercepts shutdown() in order to interrupt 240 our continuous polling on the handler. 241 """ 242 server_shutdown = False 243
244 - def shutdown(self):
245 self.server_shutdown = True 246 SocketServer.TCPServer.shutdown(self)
247
248 249 -def _start_update_server():
250 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" 251 server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) 252 thread = threading.Thread(target=server.serve_forever) 253 thread.daemon = True 254 thread.start() 255 return server
256