Package pyspark :: Module rdd
[frames] | no frames]

Source Code for Module pyspark.rdd

  1  from base64 import standard_b64encode as b64enc 
  2  import copy 
  3  from collections import defaultdict 
  4  from itertools import chain, ifilter, imap, product 
  5  import operator 
  6  import os 
  7  import shlex 
  8  from subprocess import Popen, PIPE 
  9  from tempfile import NamedTemporaryFile 
 10  from threading import Thread 
 11   
 12  from pyspark import cloudpickle 
 13  from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ 
 14      read_from_pickle_file 
 15  from pyspark.join import python_join, python_left_outer_join, \ 
 16      python_right_outer_join, python_cogroup 
 17   
 18  from py4j.java_collections import ListConverter, MapConverter 
 19   
 20   
 21  __all__ = ["RDD"] 
22 23 24 -class RDD(object):
25 """ 26 A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. 27 Represents an immutable, partitioned collection of elements that can be 28 operated on in parallel. 29 """ 30
31 - def __init__(self, jrdd, ctx):
32 self._jrdd = jrdd 33 self.is_cached = False 34 self.is_checkpointed = False 35 self.ctx = ctx 36 self._partitionFunc = None
37 38 @property
39 - def context(self):
40 """ 41 The L{SparkContext} that this RDD was created on. 42 """ 43 return self.ctx
44
45 - def cache(self):
46 """ 47 Persist this RDD with the default storage level (C{MEMORY_ONLY}). 48 """ 49 self.is_cached = True 50 self._jrdd.cache() 51 return self
52
53 - def checkpoint(self):
54 """ 55 Mark this RDD for checkpointing. It will be saved to a file inside the 56 checkpoint directory set with L{SparkContext.setCheckpointDir()} and 57 all references to its parent RDDs will be removed. This function must 58 be called before any job has been executed on this RDD. It is strongly 59 recommended that this RDD is persisted in memory, otherwise saving it 60 on a file will require recomputation. 61 """ 62 self.is_checkpointed = True 63 self._jrdd.rdd().checkpoint()
64
65 - def isCheckpointed(self):
66 """ 67 Return whether this RDD has been checkpointed or not 68 """ 69 return self._jrdd.rdd().isCheckpointed()
70
71 - def getCheckpointFile(self):
72 """ 73 Gets the name of the file to which this RDD was checkpointed 74 """ 75 checkpointFile = self._jrdd.rdd().getCheckpointFile() 76 if checkpointFile.isDefined(): 77 return checkpointFile.get() 78 else: 79 return None
80 81 # TODO persist(self, storageLevel) 82
83 - def map(self, f, preservesPartitioning=False):
84 """ 85 Return a new RDD containing the distinct elements in this RDD. 86 """ 87 def func(split, iterator): return imap(f, iterator) 88 return PipelinedRDD(self, func, preservesPartitioning)
89
90 - def flatMap(self, f, preservesPartitioning=False):
91 """ 92 Return a new RDD by first applying a function to all elements of this 93 RDD, and then flattening the results. 94 95 >>> rdd = sc.parallelize([2, 3, 4]) 96 >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) 97 [1, 1, 1, 2, 2, 3] 98 >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) 99 [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] 100 """ 101 def func(s, iterator): return chain.from_iterable(imap(f, iterator)) 102 return self.mapPartitionsWithSplit(func, preservesPartitioning)
103
104 - def mapPartitions(self, f, preservesPartitioning=False):
105 """ 106 Return a new RDD by applying a function to each partition of this RDD. 107 108 >>> rdd = sc.parallelize([1, 2, 3, 4], 2) 109 >>> def f(iterator): yield sum(iterator) 110 >>> rdd.mapPartitions(f).collect() 111 [3, 7] 112 """ 113 def func(s, iterator): return f(iterator) 114 return self.mapPartitionsWithSplit(func)
115
116 - def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
117 """ 118 Return a new RDD by applying a function to each partition of this RDD, 119 while tracking the index of the original partition. 120 121 >>> rdd = sc.parallelize([1, 2, 3, 4], 4) 122 >>> def f(splitIndex, iterator): yield splitIndex 123 >>> rdd.mapPartitionsWithSplit(f).sum() 124 6 125 """ 126 return PipelinedRDD(self, f, preservesPartitioning)
127
128 - def filter(self, f):
129 """ 130 Return a new RDD containing only the elements that satisfy a predicate. 131 132 >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) 133 >>> rdd.filter(lambda x: x % 2 == 0).collect() 134 [2, 4] 135 """ 136 def func(iterator): return ifilter(f, iterator) 137 return self.mapPartitions(func)
138
139 - def distinct(self):
140 """ 141 Return a new RDD containing the distinct elements in this RDD. 142 143 >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) 144 [1, 2, 3] 145 """ 146 return self.map(lambda x: (x, "")) \ 147 .reduceByKey(lambda x, _: x) \ 148 .map(lambda (x, _): x)
149 150 # TODO: sampling needs to be re-implemented due to Batch 151 #def sample(self, withReplacement, fraction, seed): 152 # jrdd = self._jrdd.sample(withReplacement, fraction, seed) 153 # return RDD(jrdd, self.ctx) 154 155 #def takeSample(self, withReplacement, num, seed): 156 # vals = self._jrdd.takeSample(withReplacement, num, seed) 157 # return [load_pickle(bytes(x)) for x in vals] 158
159 - def union(self, other):
160 """ 161 Return the union of this RDD and another one. 162 163 >>> rdd = sc.parallelize([1, 1, 2, 3]) 164 >>> rdd.union(rdd).collect() 165 [1, 1, 2, 3, 1, 1, 2, 3] 166 """ 167 return RDD(self._jrdd.union(other._jrdd), self.ctx)
168
169 - def __add__(self, other):
170 """ 171 Return the union of this RDD and another one. 172 173 >>> rdd = sc.parallelize([1, 1, 2, 3]) 174 >>> (rdd + rdd).collect() 175 [1, 1, 2, 3, 1, 1, 2, 3] 176 """ 177 if not isinstance(other, RDD): 178 raise TypeError 179 return self.union(other)
180 181 # TODO: sort 182
183 - def glom(self):
184 """ 185 Return an RDD created by coalescing all elements within each partition 186 into a list. 187 188 >>> rdd = sc.parallelize([1, 2, 3, 4], 2) 189 >>> sorted(rdd.glom().collect()) 190 [[1, 2], [3, 4]] 191 """ 192 def func(iterator): yield list(iterator) 193 return self.mapPartitions(func)
194
195 - def cartesian(self, other):
196 """ 197 Return the Cartesian product of this RDD and another one, that is, the 198 RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and 199 C{b} is in C{other}. 200 201 >>> rdd = sc.parallelize([1, 2]) 202 >>> sorted(rdd.cartesian(rdd).collect()) 203 [(1, 1), (1, 2), (2, 1), (2, 2)] 204 """ 205 # Due to batching, we can't use the Java cartesian method. 206 java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) 207 def unpack_batches(pair): 208 (x, y) = pair 209 if type(x) == Batch or type(y) == Batch: 210 xs = x.items if type(x) == Batch else [x] 211 ys = y.items if type(y) == Batch else [y] 212 for pair in product(xs, ys): 213 yield pair 214 else: 215 yield pair
216 return java_cartesian.flatMap(unpack_batches)
217
218 - def groupBy(self, f, numPartitions=None):
219 """ 220 Return an RDD of grouped items. 221 222 >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) 223 >>> result = rdd.groupBy(lambda x: x % 2).collect() 224 >>> sorted([(x, sorted(y)) for (x, y) in result]) 225 [(0, [2, 8]), (1, [1, 1, 3, 5])] 226 """ 227 return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
228
229 - def pipe(self, command, env={}):
230 """ 231 Return an RDD created by piping elements to a forked external process. 232 233 >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() 234 ['1', '2', '3'] 235 """ 236 def func(iterator): 237 pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) 238 def pipe_objs(out): 239 for obj in iterator: 240 out.write(str(obj).rstrip('\n') + '\n') 241 out.close()
242 Thread(target=pipe_objs, args=[pipe.stdin]).start() 243 return (x.rstrip('\n') for x in pipe.stdout) 244 return self.mapPartitions(func) 245
246 - def foreach(self, f):
247 """ 248 Applies a function to all elements of this RDD. 249 250 >>> def f(x): print x 251 >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) 252 """ 253 self.map(f).collect() # Force evaluation
254
255 - def collect(self):
256 """ 257 Return a list that contains all of the elements in this RDD. 258 """ 259 picklesInJava = self._jrdd.collect().iterator() 260 return list(self._collect_iterator_through_file(picklesInJava))
261
262 - def _collect_iterator_through_file(self, iterator):
263 # Transferring lots of data through Py4J can be slow because 264 # socket.readline() is inefficient. Instead, we'll dump the data to a 265 # file and read it back. 266 tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) 267 tempFile.close() 268 self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) 269 # Read the data into Python and deserialize it: 270 with open(tempFile.name, 'rb') as tempFile: 271 for item in read_from_pickle_file(tempFile): 272 yield item 273 os.unlink(tempFile.name)
274
275 - def reduce(self, f):
276 """ 277 Reduces the elements of this RDD using the specified commutative and 278 associative binary operator. 279 280 >>> from operator import add 281 >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) 282 15 283 >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) 284 10 285 """ 286 def func(iterator): 287 acc = None 288 for obj in iterator: 289 if acc is None: 290 acc = obj 291 else: 292 acc = f(obj, acc) 293 if acc is not None: 294 yield acc
295 vals = self.mapPartitions(func).collect() 296 return reduce(f, vals) 297
298 - def fold(self, zeroValue, op):
299 """ 300 Aggregate the elements of each partition, and then the results for all 301 the partitions, using a given associative function and a neutral "zero 302 value." 303 304 The function C{op(t1, t2)} is allowed to modify C{t1} and return it 305 as its result value to avoid object allocation; however, it should not 306 modify C{t2}. 307 308 >>> from operator import add 309 >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 310 15 311 """ 312 def func(iterator): 313 acc = zeroValue 314 for obj in iterator: 315 acc = op(obj, acc) 316 yield acc
317 vals = self.mapPartitions(func).collect() 318 return reduce(op, vals, zeroValue) 319 320 # TODO: aggregate 321
322 - def sum(self):
323 """ 324 Add up the elements in this RDD. 325 326 >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 327 6.0 328 """ 329 return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
330
331 - def count(self):
332 """ 333 Return the number of elements in this RDD. 334 335 >>> sc.parallelize([2, 3, 4]).count() 336 3 337 """ 338 return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
339
340 - def countByValue(self):
341 """ 342 Return the count of each unique value in this RDD as a dictionary of 343 (value, count) pairs. 344 345 >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) 346 [(1, 2), (2, 3)] 347 """ 348 def countPartition(iterator): 349 counts = defaultdict(int) 350 for obj in iterator: 351 counts[obj] += 1 352 yield counts
353 def mergeMaps(m1, m2): 354 for (k, v) in m2.iteritems(): 355 m1[k] += v 356 return m1 357 return self.mapPartitions(countPartition).reduce(mergeMaps) 358
359 - def take(self, num):
360 """ 361 Take the first num elements of the RDD. 362 363 This currently scans the partitions *one by one*, so it will be slow if 364 a lot of partitions are required. In that case, use L{collect} to get 365 the whole RDD instead. 366 367 >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) 368 [2, 3] 369 >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) 370 [2, 3, 4, 5, 6] 371 """ 372 items = [] 373 for partition in range(self._jrdd.splits().size()): 374 iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) 375 # Each item in the iterator is a string, Python object, batch of 376 # Python objects. Regardless, it is sufficient to take `num` 377 # of these objects in order to collect `num` Python objects: 378 iterator = iterator.take(num) 379 items.extend(self._collect_iterator_through_file(iterator)) 380 if len(items) >= num: 381 break 382 return items[:num]
383
384 - def first(self):
385 """ 386 Return the first element in this RDD. 387 388 >>> sc.parallelize([2, 3, 4]).first() 389 2 390 """ 391 return self.take(1)[0]
392
393 - def saveAsTextFile(self, path):
394 """ 395 Save this RDD as a text file, using string representations of elements. 396 397 >>> tempFile = NamedTemporaryFile(delete=True) 398 >>> tempFile.close() 399 >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) 400 >>> from fileinput import input 401 >>> from glob import glob 402 >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*")))) 403 '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' 404 """ 405 def func(split, iterator): 406 return (str(x).encode("utf-8") for x in iterator)
407 keyed = PipelinedRDD(self, func) 408 keyed._bypass_serializer = True 409 keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) 410 411 # Pair functions 412
413 - def collectAsMap(self):
414 """ 415 Return the key-value pairs in this RDD to the master as a dictionary. 416 417 >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() 418 >>> m[1] 419 2 420 >>> m[3] 421 4 422 """ 423 return dict(self.collect())
424
425 - def reduceByKey(self, func, numPartitions=None):
426 """ 427 Merge the values for each key using an associative reduce function. 428 429 This will also perform the merging locally on each mapper before 430 sending results to a reducer, similarly to a "combiner" in MapReduce. 431 432 Output will be hash-partitioned with C{numPartitions} partitions, or 433 the default parallelism level if C{numPartitions} is not specified. 434 435 >>> from operator import add 436 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) 437 >>> sorted(rdd.reduceByKey(add).collect()) 438 [('a', 2), ('b', 1)] 439 """ 440 return self.combineByKey(lambda x: x, func, func, numPartitions)
441
442 - def reduceByKeyLocally(self, func):
443 """ 444 Merge the values for each key using an associative reduce function, but 445 return the results immediately to the master as a dictionary. 446 447 This will also perform the merging locally on each mapper before 448 sending results to a reducer, similarly to a "combiner" in MapReduce. 449 450 >>> from operator import add 451 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) 452 >>> sorted(rdd.reduceByKeyLocally(add).items()) 453 [('a', 2), ('b', 1)] 454 """ 455 def reducePartition(iterator): 456 m = {} 457 for (k, v) in iterator: 458 m[k] = v if k not in m else func(m[k], v) 459 yield m
460 def mergeMaps(m1, m2): 461 for (k, v) in m2.iteritems(): 462 m1[k] = v if k not in m1 else func(m1[k], v) 463 return m1 464 return self.mapPartitions(reducePartition).reduce(mergeMaps) 465
466 - def countByKey(self):
467 """ 468 Count the number of elements for each key, and return the result to the 469 master as a dictionary. 470 471 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) 472 >>> sorted(rdd.countByKey().items()) 473 [('a', 2), ('b', 1)] 474 """ 475 return self.map(lambda x: x[0]).countByValue()
476
477 - def join(self, other, numPartitions=None):
478 """ 479 Return an RDD containing all pairs of elements with matching keys in 480 C{self} and C{other}. 481 482 Each pair of elements will be returned as a (k, (v1, v2)) tuple, where 483 (k, v1) is in C{self} and (k, v2) is in C{other}. 484 485 Performs a hash join across the cluster. 486 487 >>> x = sc.parallelize([("a", 1), ("b", 4)]) 488 >>> y = sc.parallelize([("a", 2), ("a", 3)]) 489 >>> sorted(x.join(y).collect()) 490 [('a', (1, 2)), ('a', (1, 3))] 491 """ 492 return python_join(self, other, numPartitions)
493
494 - def leftOuterJoin(self, other, numPartitions=None):
495 """ 496 Perform a left outer join of C{self} and C{other}. 497 498 For each element (k, v) in C{self}, the resulting RDD will either 499 contain all pairs (k, (v, w)) for w in C{other}, or the pair 500 (k, (v, None)) if no elements in other have key k. 501 502 Hash-partitions the resulting RDD into the given number of partitions. 503 504 >>> x = sc.parallelize([("a", 1), ("b", 4)]) 505 >>> y = sc.parallelize([("a", 2)]) 506 >>> sorted(x.leftOuterJoin(y).collect()) 507 [('a', (1, 2)), ('b', (4, None))] 508 """ 509 return python_left_outer_join(self, other, numPartitions)
510
511 - def rightOuterJoin(self, other, numPartitions=None):
512 """ 513 Perform a right outer join of C{self} and C{other}. 514 515 For each element (k, w) in C{other}, the resulting RDD will either 516 contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w)) 517 if no elements in C{self} have key k. 518 519 Hash-partitions the resulting RDD into the given number of partitions. 520 521 >>> x = sc.parallelize([("a", 1), ("b", 4)]) 522 >>> y = sc.parallelize([("a", 2)]) 523 >>> sorted(y.rightOuterJoin(x).collect()) 524 [('a', (2, 1)), ('b', (None, 4))] 525 """ 526 return python_right_outer_join(self, other, numPartitions)
527 528 # TODO: add option to control map-side combining
529 - def partitionBy(self, numPartitions, partitionFunc=hash):
530 """ 531 Return a copy of the RDD partitioned using the specified partitioner. 532 533 >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) 534 >>> sets = pairs.partitionBy(2).glom().collect() 535 >>> set(sets[0]).intersection(set(sets[1])) 536 set([]) 537 """ 538 if numPartitions is None: 539 numPartitions = self.ctx.defaultParallelism 540 # Transferring O(n) objects to Java is too expensive. Instead, we'll 541 # form the hash buckets in Python, transferring O(numPartitions) objects 542 # to Java. Each object is a (splitNumber, [objects]) pair. 543 def add_shuffle_key(split, iterator): 544 buckets = defaultdict(list) 545 for (k, v) in iterator: 546 buckets[partitionFunc(k) % numPartitions].append((k, v)) 547 for (split, items) in buckets.iteritems(): 548 yield str(split) 549 yield dump_pickle(Batch(items))
550 keyed = PipelinedRDD(self, add_shuffle_key) 551 keyed._bypass_serializer = True 552 pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() 553 partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, 554 id(partitionFunc)) 555 jrdd = pairRDD.partitionBy(partitioner).values() 556 rdd = RDD(jrdd, self.ctx) 557 # This is required so that id(partitionFunc) remains unique, even if 558 # partitionFunc is a lambda: 559 rdd._partitionFunc = partitionFunc 560 return rdd 561 562 # TODO: add control over map-side aggregation
563 - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, 564 numPartitions=None):
565 """ 566 Generic function to combine the elements for each key using a custom 567 set of aggregation functions. 568 569 Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined 570 type" C. Note that V and C can be different -- for example, one might 571 group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). 572 573 Users provide three functions: 574 575 - C{createCombiner}, which turns a V into a C (e.g., creates 576 a one-element list) 577 - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of 578 a list) 579 - C{mergeCombiners}, to combine two C's into a single one. 580 581 In addition, users can control the partitioning of the output RDD. 582 583 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) 584 >>> def f(x): return x 585 >>> def add(a, b): return a + str(b) 586 >>> sorted(x.combineByKey(str, add, add).collect()) 587 [('a', '11'), ('b', '1')] 588 """ 589 if numPartitions is None: 590 numPartitions = self.ctx.defaultParallelism 591 def combineLocally(iterator): 592 combiners = {} 593 for (k, v) in iterator: 594 if k not in combiners: 595 combiners[k] = createCombiner(v) 596 else: 597 combiners[k] = mergeValue(combiners[k], v) 598 return combiners.iteritems()
599 locally_combined = self.mapPartitions(combineLocally) 600 shuffled = locally_combined.partitionBy(numPartitions) 601 def _mergeCombiners(iterator): 602 combiners = {} 603 for (k, v) in iterator: 604 if not k in combiners: 605 combiners[k] = v 606 else: 607 combiners[k] = mergeCombiners(combiners[k], v) 608 return combiners.iteritems() 609 return shuffled.mapPartitions(_mergeCombiners) 610 611 # TODO: support variant with custom partitioner
612 - def groupByKey(self, numPartitions=None):
613 """ 614 Group the values for each key in the RDD into a single sequence. 615 Hash-partitions the resulting RDD with into numPartitions partitions. 616 617 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) 618 >>> sorted(x.groupByKey().collect()) 619 [('a', [1, 1]), ('b', [1])] 620 """ 621 622 def createCombiner(x): 623 return [x]
624 625 def mergeValue(xs, x): 626 xs.append(x) 627 return xs 628 629 def mergeCombiners(a, b): 630 return a + b 631 632 return self.combineByKey(createCombiner, mergeValue, mergeCombiners, 633 numPartitions) 634 635 # TODO: add tests
636 - def flatMapValues(self, f):
637 """ 638 Pass each value in the key-value pair RDD through a flatMap function 639 without changing the keys; this also retains the original RDD's 640 partitioning. 641 """ 642 flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) 643 return self.flatMap(flat_map_fn, preservesPartitioning=True)
644
645 - def mapValues(self, f):
646 """ 647 Pass each value in the key-value pair RDD through a map function 648 without changing the keys; this also retains the original RDD's 649 partitioning. 650 """ 651 map_values_fn = lambda (k, v): (k, f(v)) 652 return self.map(map_values_fn, preservesPartitioning=True)
653 654 # TODO: support varargs cogroup of several RDDs.
655 - def groupWith(self, other):
656 """ 657 Alias for cogroup. 658 """ 659 return self.cogroup(other)
660 661 # TODO: add variant with custom parittioner
662 - def cogroup(self, other, numPartitions=None):
663 """ 664 For each key k in C{self} or C{other}, return a resulting RDD that 665 contains a tuple with the list of values for that key in C{self} as well 666 as C{other}. 667 668 >>> x = sc.parallelize([("a", 1), ("b", 4)]) 669 >>> y = sc.parallelize([("a", 2)]) 670 >>> sorted(x.cogroup(y).collect()) 671 [('a', ([1], [2])), ('b', ([4], []))] 672 """ 673 return python_cogroup(self, other, numPartitions)
674
675 # TODO: `lookup` is disabled because we can't make direct comparisons based 676 # on the key; we need to compare the hash of the key to the hash of the 677 # keys in the pairs. This could be an expensive operation, since those 678 # hashes aren't retained. 679 680 681 -class PipelinedRDD(RDD):
682 """ 683 Pipelined maps: 684 >>> rdd = sc.parallelize([1, 2, 3, 4]) 685 >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() 686 [4, 8, 12, 16] 687 >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() 688 [4, 8, 12, 16] 689 690 Pipelined reduces: 691 >>> from operator import add 692 >>> rdd.map(lambda x: 2 * x).reduce(add) 693 20 694 >>> rdd.flatMap(lambda x: [x, x]).reduce(add) 695 20 696 """
697 - def __init__(self, prev, func, preservesPartitioning=False):
698 if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): 699 prev_func = prev.func 700 def pipeline_func(split, iterator): 701 return func(split, prev_func(split, iterator))
702 self.func = pipeline_func 703 self.preservesPartitioning = \ 704 prev.preservesPartitioning and preservesPartitioning 705 self._prev_jrdd = prev._prev_jrdd 706 else: 707 self.func = func 708 self.preservesPartitioning = preservesPartitioning 709 self._prev_jrdd = prev._jrdd 710 self.is_cached = False 711 self.is_checkpointed = False 712 self.ctx = prev.ctx 713 self.prev = prev 714 self._jrdd_val = None 715 self._bypass_serializer = False
716 717 @property
718 - def _jrdd(self):
719 if self._jrdd_val: 720 return self._jrdd_val 721 func = self.func 722 if not self._bypass_serializer and self.ctx.batchSize != 1: 723 oldfunc = self.func 724 batchSize = self.ctx.batchSize 725 def batched_func(split, iterator): 726 return batched(oldfunc(split, iterator), batchSize)
727 func = batched_func 728 cmds = [func, self._bypass_serializer] 729 pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) 730 broadcast_vars = ListConverter().convert( 731 [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], 732 self.ctx._gateway._gateway_client) 733 self.ctx._pickled_broadcast_vars.clear() 734 class_manifest = self._prev_jrdd.classManifest() 735 env = copy.copy(self.ctx.environment) 736 env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") 737 env = MapConverter().convert(env, self.ctx._gateway._gateway_client) 738 python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), 739 pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, 740 broadcast_vars, self.ctx._javaAccumulator, class_manifest) 741 self._jrdd_val = python_rdd.asJavaRDD() 742 return self._jrdd_val 743
744 - def _is_pipelinable(self):
745 return not (self.is_cached or self.is_checkpointed)
746
747 748 -def _test():
749 import doctest 750 from pyspark.context import SparkContext 751 globs = globals().copy() 752 # The small batch size here ensures that we see multiple batches, 753 # even in these small test examples: 754 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 755 (failure_count, test_count) = doctest.testmod(globs=globs) 756 globs['sc'].stop() 757 if failure_count: 758 exit(-1)
759 760 761 if __name__ == "__main__": 762 _test() 763