Package pyspark :: Module serializers
[frames] | no frames]

Source Code for Module pyspark.serializers

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  """ 
 19  PySpark supports custom serializers for transferring data; this can improve 
 20  performance. 
 21   
 22  By default, PySpark uses L{PickleSerializer} to serialize objects using Python's 
 23  C{cPickle} serializer, which can serialize nearly any Python object. 
 24  Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be 
 25  faster. 
 26   
 27  The serializer is chosen when creating L{SparkContext}: 
 28   
 29  >>> from pyspark.context import SparkContext 
 30  >>> from pyspark.serializers import MarshalSerializer 
 31  >>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) 
 32  >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) 
 33  [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] 
 34  >>> sc.stop() 
 35   
 36  By default, PySpark serialize objects in batches; the batch size can be 
 37  controlled through SparkContext's C{batchSize} parameter 
 38  (the default size is 1024 objects): 
 39   
 40  >>> sc = SparkContext('local', 'test', batchSize=2) 
 41  >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) 
 42   
 43  Behind the scenes, this creates a JavaRDD with four partitions, each of 
 44  which contains two batches of two objects: 
 45   
 46  >>> rdd.glom().collect() 
 47  [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] 
 48  >>> rdd._jrdd.count() 
 49  8L 
 50  >>> sc.stop() 
 51   
 52  A batch size of -1 uses an unlimited batch size, and a size of 1 disables 
 53  batching: 
 54   
 55  >>> sc = SparkContext('local', 'test', batchSize=1) 
 56  >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) 
 57  >>> rdd.glom().collect() 
 58  [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] 
 59  >>> rdd._jrdd.count() 
 60  16L 
 61  """ 
 62   
 63  import cPickle 
 64  from itertools import chain, izip, product 
 65  import marshal 
 66  import struct 
 67  import sys 
 68  import types 
 69  import collections 
 70  import zlib 
 71   
 72  from pyspark import cloudpickle 
 73   
 74   
 75  __all__ = ["PickleSerializer", "MarshalSerializer"] 
 76   
 77   
78 -class SpecialLengths(object):
79 END_OF_DATA_SECTION = -1 80 PYTHON_EXCEPTION_THROWN = -2 81 TIMING_DATA = -3
82 83
84 -class Serializer(object):
85
86 - def dump_stream(self, iterator, stream):
87 """ 88 Serialize an iterator of objects to the output stream. 89 """ 90 raise NotImplementedError
91
92 - def load_stream(self, stream):
93 """ 94 Return an iterator of deserialized objects from the input stream. 95 """ 96 raise NotImplementedError
97
98 - def _load_stream_without_unbatching(self, stream):
99 return self.load_stream(stream)
100 101 # Note: our notion of "equality" is that output generated by 102 # equal serializers can be deserialized using the same serializer. 103 104 # This default implementation handles the simple cases; 105 # subclasses should override __eq__ as appropriate. 106
107 - def __eq__(self, other):
108 return isinstance(other, self.__class__)
109
110 - def __ne__(self, other):
111 return not self.__eq__(other)
112 113
114 -class FramedSerializer(Serializer):
115 116 """ 117 Serializer that writes objects as a stream of (length, data) pairs, 118 where C{length} is a 32-bit integer and data is C{length} bytes. 119 """ 120
121 - def __init__(self):
122 # On Python 2.6, we can't write bytearrays to streams, so we need to convert them 123 # to strings first. Check if the version number is that old. 124 self._only_write_strings = sys.version_info[0:2] <= (2, 6)
125
126 - def dump_stream(self, iterator, stream):
127 for obj in iterator: 128 self._write_with_length(obj, stream)
129
130 - def load_stream(self, stream):
131 while True: 132 try: 133 yield self._read_with_length(stream) 134 except EOFError: 135 return
136
137 - def _write_with_length(self, obj, stream):
138 serialized = self.dumps(obj) 139 write_int(len(serialized), stream) 140 if self._only_write_strings: 141 stream.write(str(serialized)) 142 else: 143 stream.write(serialized)
144
145 - def _read_with_length(self, stream):
146 length = read_int(stream) 147 obj = stream.read(length) 148 if obj == "": 149 raise EOFError 150 return self.loads(obj)
151
152 - def dumps(self, obj):
153 """ 154 Serialize an object into a byte array. 155 When batching is used, this will be called with an array of objects. 156 """ 157 raise NotImplementedError
158
159 - def loads(self, obj):
160 """ 161 Deserialize an object from a byte array. 162 """ 163 raise NotImplementedError
164 165
166 -class BatchedSerializer(Serializer):
167 168 """ 169 Serializes a stream of objects in batches by calling its wrapped 170 Serializer with streams of objects. 171 """ 172 173 UNLIMITED_BATCH_SIZE = -1 174
175 - def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
176 self.serializer = serializer 177 self.batchSize = batchSize
178
179 - def _batched(self, iterator):
180 if self.batchSize == self.UNLIMITED_BATCH_SIZE: 181 yield list(iterator) 182 else: 183 items = [] 184 count = 0 185 for item in iterator: 186 items.append(item) 187 count += 1 188 if count == self.batchSize: 189 yield items 190 items = [] 191 count = 0 192 if items: 193 yield items
194
195 - def dump_stream(self, iterator, stream):
196 self.serializer.dump_stream(self._batched(iterator), stream)
197
198 - def load_stream(self, stream):
199 return chain.from_iterable(self._load_stream_without_unbatching(stream))
200
201 - def _load_stream_without_unbatching(self, stream):
202 return self.serializer.load_stream(stream)
203
204 - def __eq__(self, other):
205 return (isinstance(other, BatchedSerializer) and 206 other.serializer == self.serializer)
207
208 - def __str__(self):
209 return "BatchedSerializer<%s>" % str(self.serializer)
210 211
212 -class CartesianDeserializer(FramedSerializer):
213 214 """ 215 Deserializes the JavaRDD cartesian() of two PythonRDDs. 216 """ 217
218 - def __init__(self, key_ser, val_ser):
219 self.key_ser = key_ser 220 self.val_ser = val_ser
221
222 - def prepare_keys_values(self, stream):
223 key_stream = self.key_ser._load_stream_without_unbatching(stream) 224 val_stream = self.val_ser._load_stream_without_unbatching(stream) 225 key_is_batched = isinstance(self.key_ser, BatchedSerializer) 226 val_is_batched = isinstance(self.val_ser, BatchedSerializer) 227 for (keys, vals) in izip(key_stream, val_stream): 228 keys = keys if key_is_batched else [keys] 229 vals = vals if val_is_batched else [vals] 230 yield (keys, vals)
231
232 - def load_stream(self, stream):
233 for (keys, vals) in self.prepare_keys_values(stream): 234 for pair in product(keys, vals): 235 yield pair
236
237 - def __eq__(self, other):
238 return (isinstance(other, CartesianDeserializer) and 239 self.key_ser == other.key_ser and self.val_ser == other.val_ser)
240
241 - def __str__(self):
242 return "CartesianDeserializer<%s, %s>" % \ 243 (str(self.key_ser), str(self.val_ser))
244 245
246 -class PairDeserializer(CartesianDeserializer):
247 248 """ 249 Deserializes the JavaRDD zip() of two PythonRDDs. 250 """ 251
252 - def __init__(self, key_ser, val_ser):
253 self.key_ser = key_ser 254 self.val_ser = val_ser
255
256 - def load_stream(self, stream):
257 for (keys, vals) in self.prepare_keys_values(stream): 258 if len(keys) != len(vals): 259 raise ValueError("Can not deserialize RDD with different number of items" 260 " in pair: (%d, %d)" % (len(keys), len(vals))) 261 for pair in izip(keys, vals): 262 yield pair
263
264 - def __eq__(self, other):
265 return (isinstance(other, PairDeserializer) and 266 self.key_ser == other.key_ser and self.val_ser == other.val_ser)
267
268 - def __str__(self):
269 return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser))
270 271
272 -class NoOpSerializer(FramedSerializer):
273
274 - def loads(self, obj):
275 return obj
276
277 - def dumps(self, obj):
278 return obj
279 280 281 # Hook namedtuple, make it picklable 282 283 __cls = {} 284 285
286 -def _restore(name, fields, value):
287 """ Restore an object of namedtuple""" 288 k = (name, fields) 289 cls = __cls.get(k) 290 if cls is None: 291 cls = collections.namedtuple(name, fields) 292 __cls[k] = cls 293 return cls(*value)
294 295
296 -def _hack_namedtuple(cls):
297 """ Make class generated by namedtuple picklable """ 298 name = cls.__name__ 299 fields = cls._fields 300 301 def __reduce__(self): 302 return (_restore, (name, fields, tuple(self)))
303 cls.__reduce__ = __reduce__ 304 return cls 305 306
307 -def _hijack_namedtuple():
308 """ Hack namedtuple() to make it picklable """ 309 # hijack only one time 310 if hasattr(collections.namedtuple, "__hijack"): 311 return 312 313 global _old_namedtuple # or it will put in closure 314 315 def _copy_func(f): 316 return types.FunctionType(f.func_code, f.func_globals, f.func_name, 317 f.func_defaults, f.func_closure)
318 319 _old_namedtuple = _copy_func(collections.namedtuple) 320 321 def namedtuple(*args, **kwargs): 322 cls = _old_namedtuple(*args, **kwargs) 323 return _hack_namedtuple(cls) 324 325 # replace namedtuple with new one 326 collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple 327 collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple 328 collections.namedtuple.func_code = namedtuple.func_code 329 collections.namedtuple.__hijack = 1 330 331 # hack the cls already generated by namedtuple 332 # those created in other module can be pickled as normal, 333 # so only hack those in __main__ module 334 for n, o in sys.modules["__main__"].__dict__.iteritems(): 335 if (type(o) is type and o.__base__ is tuple 336 and hasattr(o, "_fields") 337 and "__reduce__" not in o.__dict__): 338 _hack_namedtuple(o) # hack inplace 339 340 341 _hijack_namedtuple() 342 343
344 -class PickleSerializer(FramedSerializer):
345 346 """ 347 Serializes objects using Python's cPickle serializer: 348 349 http://docs.python.org/2/library/pickle.html 350 351 This serializer supports nearly any Python object, but may 352 not be as fast as more specialized serializers. 353 """ 354
355 - def dumps(self, obj):
356 return cPickle.dumps(obj, 2)
357 358 loads = cPickle.loads
359 360
361 -class CloudPickleSerializer(PickleSerializer):
362
363 - def dumps(self, obj):
364 return cloudpickle.dumps(obj, 2)
365 366
367 -class MarshalSerializer(FramedSerializer):
368 369 """ 370 Serializes objects using Python's Marshal serializer: 371 372 http://docs.python.org/2/library/marshal.html 373 374 This serializer is faster than PickleSerializer but supports fewer datatypes. 375 """ 376 377 dumps = marshal.dumps 378 loads = marshal.loads
379 380
381 -class AutoSerializer(FramedSerializer):
382 383 """ 384 Choose marshal or cPickle as serialization protocol autumatically 385 """ 386
387 - def __init__(self):
388 FramedSerializer.__init__(self) 389 self._type = None
390
391 - def dumps(self, obj):
392 if self._type is not None: 393 return 'P' + cPickle.dumps(obj, -1) 394 try: 395 return 'M' + marshal.dumps(obj) 396 except Exception: 397 self._type = 'P' 398 return 'P' + cPickle.dumps(obj, -1)
399
400 - def loads(self, obj):
401 _type = obj[0] 402 if _type == 'M': 403 return marshal.loads(obj[1:]) 404 elif _type == 'P': 405 return cPickle.loads(obj[1:]) 406 else: 407 raise ValueError("invalid sevialization type: %s" % _type)
408 409
410 -class CompressedSerializer(FramedSerializer):
411 """ 412 compress the serialized data 413 """ 414
415 - def __init__(self, serializer):
416 FramedSerializer.__init__(self) 417 self.serializer = serializer
418
419 - def dumps(self, obj):
420 return zlib.compress(self.serializer.dumps(obj), 1)
421
422 - def loads(self, obj):
423 return self.serializer.loads(zlib.decompress(obj))
424 425
426 -class UTF8Deserializer(Serializer):
427 428 """ 429 Deserializes streams written by String.getBytes. 430 """ 431
432 - def loads(self, stream):
433 length = read_int(stream) 434 return stream.read(length).decode('utf8')
435
436 - def load_stream(self, stream):
437 while True: 438 try: 439 yield self.loads(stream) 440 except struct.error: 441 return 442 except EOFError: 443 return
444 445
446 -def read_long(stream):
447 length = stream.read(8) 448 if length == "": 449 raise EOFError 450 return struct.unpack("!q", length)[0]
451 452
453 -def write_long(value, stream):
454 stream.write(struct.pack("!q", value))
455 456
457 -def pack_long(value):
458 return struct.pack("!q", value)
459 460
461 -def read_int(stream):
462 length = stream.read(4) 463 if length == "": 464 raise EOFError 465 return struct.unpack("!i", length)[0]
466 467
468 -def write_int(value, stream):
469 stream.write(struct.pack("!i", value))
470 471
472 -def write_with_length(obj, stream):
473 write_int(len(obj), stream) 474 stream.write(obj)
475