1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from base64 import standard_b64encode as b64enc
19 import copy
20 from collections import defaultdict
21 from itertools import chain, ifilter, imap, product
22 import operator
23 import os
24 import sys
25 import shlex
26 from subprocess import Popen, PIPE
27 from tempfile import NamedTemporaryFile
28 from threading import Thread
29
30 from pyspark import cloudpickle
31 from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
32 read_from_pickle_file, pack_long
33 from pyspark.join import python_join, python_left_outer_join, \
34 python_right_outer_join, python_cogroup
35 from pyspark.statcounter import StatCounter
36 from pyspark.rddsampler import RDDSampler
37
38 from py4j.java_collections import ListConverter, MapConverter
39
40
41 __all__ = ["RDD"]
42
43
44 -class RDD(object):
45 """
46 A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
47 Represents an immutable, partitioned collection of elements that can be
48 operated on in parallel.
49 """
50
52 self._jrdd = jrdd
53 self.is_cached = False
54 self.is_checkpointed = False
55 self.ctx = ctx
56 self._partitionFunc = None
57
58 @property
60 """
61 The L{SparkContext} that this RDD was created on.
62 """
63 return self.ctx
64
66 """
67 Persist this RDD with the default storage level (C{MEMORY_ONLY}).
68 """
69 self.is_cached = True
70 self._jrdd.cache()
71 return self
72
74 """
75 Set this RDD's storage level to persist its values across operations after the first time
76 it is computed. This can only be used to assign a new storage level if the RDD does not
77 have a storage level set yet.
78 """
79 self.is_cached = True
80 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
81 self._jrdd.persist(javaStorageLevel)
82 return self
83
85 """
86 Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
87 """
88 self.is_cached = False
89 self._jrdd.unpersist()
90 return self
91
93 """
94 Mark this RDD for checkpointing. It will be saved to a file inside the
95 checkpoint directory set with L{SparkContext.setCheckpointDir()} and
96 all references to its parent RDDs will be removed. This function must
97 be called before any job has been executed on this RDD. It is strongly
98 recommended that this RDD is persisted in memory, otherwise saving it
99 on a file will require recomputation.
100 """
101 self.is_checkpointed = True
102 self._jrdd.rdd().checkpoint()
103
105 """
106 Return whether this RDD has been checkpointed or not
107 """
108 return self._jrdd.rdd().isCheckpointed()
109
111 """
112 Gets the name of the file to which this RDD was checkpointed
113 """
114 checkpointFile = self._jrdd.rdd().getCheckpointFile()
115 if checkpointFile.isDefined():
116 return checkpointFile.get()
117 else:
118 return None
119
120 - def map(self, f, preservesPartitioning=False):
121 """
122 Return a new RDD containing the distinct elements in this RDD.
123 """
124 def func(split, iterator): return imap(f, iterator)
125 return PipelinedRDD(self, func, preservesPartitioning)
126
127 - def flatMap(self, f, preservesPartitioning=False):
128 """
129 Return a new RDD by first applying a function to all elements of this
130 RDD, and then flattening the results.
131
132 >>> rdd = sc.parallelize([2, 3, 4])
133 >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
134 [1, 1, 1, 2, 2, 3]
135 >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
136 [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
137 """
138 def func(s, iterator): return chain.from_iterable(imap(f, iterator))
139 return self.mapPartitionsWithSplit(func, preservesPartitioning)
140
142 """
143 Return a new RDD by applying a function to each partition of this RDD.
144
145 >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
146 >>> def f(iterator): yield sum(iterator)
147 >>> rdd.mapPartitions(f).collect()
148 [3, 7]
149 """
150 def func(s, iterator): return f(iterator)
151 return self.mapPartitionsWithSplit(func)
152
154 """
155 Return a new RDD by applying a function to each partition of this RDD,
156 while tracking the index of the original partition.
157
158 >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
159 >>> def f(splitIndex, iterator): yield splitIndex
160 >>> rdd.mapPartitionsWithSplit(f).sum()
161 6
162 """
163 return PipelinedRDD(self, f, preservesPartitioning)
164
166 """
167 Return a new RDD containing only the elements that satisfy a predicate.
168
169 >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
170 >>> rdd.filter(lambda x: x % 2 == 0).collect()
171 [2, 4]
172 """
173 def func(iterator): return ifilter(f, iterator)
174 return self.mapPartitions(func)
175
177 """
178 Return a new RDD containing the distinct elements in this RDD.
179
180 >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
181 [1, 2, 3]
182 """
183 return self.map(lambda x: (x, None)) \
184 .reduceByKey(lambda x, _: x) \
185 .map(lambda (x, _): x)
186
187 - def sample(self, withReplacement, fraction, seed):
188 """
189 Return a sampled subset of this RDD (relies on numpy and falls back
190 on default random generator if numpy is unavailable).
191
192 >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
193 [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
194 """
195 return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True)
196
197
198 - def takeSample(self, withReplacement, num, seed):
199 """
200 Return a fixed-size sampled subset of this RDD (currently requires numpy).
201
202 >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
203 [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
204 """
205
206 fraction = 0.0
207 total = 0
208 multiplier = 3.0
209 initialCount = self.count()
210 maxSelected = 0
211
212 if (num < 0):
213 raise ValueError
214
215 if initialCount > sys.maxint - 1:
216 maxSelected = sys.maxint - 1
217 else:
218 maxSelected = initialCount
219
220 if num > initialCount and not withReplacement:
221 total = maxSelected
222 fraction = multiplier * (maxSelected + 1) / initialCount
223 else:
224 fraction = multiplier * (num + 1) / initialCount
225 total = num
226
227 samples = self.sample(withReplacement, fraction, seed).collect()
228
229
230
231
232 while len(samples) < total:
233 if seed > sys.maxint - 2:
234 seed = -1
235 seed += 1
236 samples = self.sample(withReplacement, fraction, seed).collect()
237
238 sampler = RDDSampler(withReplacement, fraction, seed+1)
239 sampler.shuffle(samples)
240 return samples[0:total]
241
243 """
244 Return the union of this RDD and another one.
245
246 >>> rdd = sc.parallelize([1, 1, 2, 3])
247 >>> rdd.union(rdd).collect()
248 [1, 1, 2, 3, 1, 1, 2, 3]
249 """
250 return RDD(self._jrdd.union(other._jrdd), self.ctx)
251
253 """
254 Return the union of this RDD and another one.
255
256 >>> rdd = sc.parallelize([1, 1, 2, 3])
257 >>> (rdd + rdd).collect()
258 [1, 1, 2, 3, 1, 1, 2, 3]
259 """
260 if not isinstance(other, RDD):
261 raise TypeError
262 return self.union(other)
263
264 - def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
265 """
266 Sorts this RDD, which is assumed to consist of (key, value) pairs.
267
268 >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
269 >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
270 [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
271 >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
272 >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
273 >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
274 [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
275 """
276 if numPartitions is None:
277 numPartitions = self.ctx.defaultParallelism
278
279 bounds = list()
280
281
282
283
284 if numPartitions > 1:
285 rddSize = self.count()
286 maxSampleSize = numPartitions * 20.0
287 fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
288
289 samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
290 samples = sorted(samples, reverse=(not ascending), key=keyfunc)
291
292
293
294 for i in range(0, numPartitions - 1):
295 index = (len(samples) - 1) * (i + 1) / numPartitions
296 bounds.append(samples[index])
297
298 def rangePartitionFunc(k):
299 p = 0
300 while p < len(bounds) and keyfunc(k) > bounds[p]:
301 p += 1
302 if ascending:
303 return p
304 else:
305 return numPartitions-1-p
306
307 def mapFunc(iterator):
308 yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
309
310 return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
311 .mapPartitions(mapFunc,preservesPartitioning=True)
312 .flatMap(lambda x: x, preservesPartitioning=True))
313
315 """
316 Return an RDD created by coalescing all elements within each partition
317 into a list.
318
319 >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
320 >>> sorted(rdd.glom().collect())
321 [[1, 2], [3, 4]]
322 """
323 def func(iterator): yield list(iterator)
324 return self.mapPartitions(func)
325
327 """
328 Return the Cartesian product of this RDD and another one, that is, the
329 RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and
330 C{b} is in C{other}.
331
332 >>> rdd = sc.parallelize([1, 2])
333 >>> sorted(rdd.cartesian(rdd).collect())
334 [(1, 1), (1, 2), (2, 1), (2, 2)]
335 """
336
337 java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
338 def unpack_batches(pair):
339 (x, y) = pair
340 if type(x) == Batch or type(y) == Batch:
341 xs = x.items if type(x) == Batch else [x]
342 ys = y.items if type(y) == Batch else [y]
343 for pair in product(xs, ys):
344 yield pair
345 else:
346 yield pair
347 return java_cartesian.flatMap(unpack_batches)
348
349 - def groupBy(self, f, numPartitions=None):
350 """
351 Return an RDD of grouped items.
352
353 >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
354 >>> result = rdd.groupBy(lambda x: x % 2).collect()
355 >>> sorted([(x, sorted(y)) for (x, y) in result])
356 [(0, [2, 8]), (1, [1, 1, 3, 5])]
357 """
358 return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
359
360 - def pipe(self, command, env={}):
361 """
362 Return an RDD created by piping elements to a forked external process.
363
364 >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
365 ['1', '2', '3']
366 """
367 def func(iterator):
368 pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
369 def pipe_objs(out):
370 for obj in iterator:
371 out.write(str(obj).rstrip('\n') + '\n')
372 out.close()
373 Thread(target=pipe_objs, args=[pipe.stdin]).start()
374 return (x.rstrip('\n') for x in pipe.stdout)
375 return self.mapPartitions(func)
376
378 """
379 Applies a function to all elements of this RDD.
380
381 >>> def f(x): print x
382 >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
383 """
384 def processPartition(iterator):
385 for x in iterator:
386 f(x)
387 yield None
388 self.mapPartitions(processPartition).collect()
389
391 """
392 Return a list that contains all of the elements in this RDD.
393 """
394 picklesInJava = self._jrdd.collect().iterator()
395 return list(self._collect_iterator_through_file(picklesInJava))
396
398
399
400
401 tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
402 tempFile.close()
403 self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
404
405 with open(tempFile.name, 'rb') as tempFile:
406 for item in read_from_pickle_file(tempFile):
407 yield item
408 os.unlink(tempFile.name)
409
411 """
412 Reduces the elements of this RDD using the specified commutative and
413 associative binary operator.
414
415 >>> from operator import add
416 >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
417 15
418 >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
419 10
420 """
421 def func(iterator):
422 acc = None
423 for obj in iterator:
424 if acc is None:
425 acc = obj
426 else:
427 acc = f(obj, acc)
428 if acc is not None:
429 yield acc
430 vals = self.mapPartitions(func).collect()
431 return reduce(f, vals)
432
433 - def fold(self, zeroValue, op):
434 """
435 Aggregate the elements of each partition, and then the results for all
436 the partitions, using a given associative function and a neutral "zero
437 value."
438
439 The function C{op(t1, t2)} is allowed to modify C{t1} and return it
440 as its result value to avoid object allocation; however, it should not
441 modify C{t2}.
442
443 >>> from operator import add
444 >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
445 15
446 """
447 def func(iterator):
448 acc = zeroValue
449 for obj in iterator:
450 acc = op(obj, acc)
451 yield acc
452 vals = self.mapPartitions(func).collect()
453 return reduce(op, vals, zeroValue)
454
455
456
458 """
459 Add up the elements in this RDD.
460
461 >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
462 6.0
463 """
464 return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
465
467 """
468 Return the number of elements in this RDD.
469
470 >>> sc.parallelize([2, 3, 4]).count()
471 3
472 """
473 return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
474
476 """
477 Return a L{StatCounter} object that captures the mean, variance
478 and count of the RDD's elements in one operation.
479 """
480 def redFunc(left_counter, right_counter):
481 return left_counter.mergeStats(right_counter)
482
483 return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
484
486 """
487 Compute the mean of this RDD's elements.
488
489 >>> sc.parallelize([1, 2, 3]).mean()
490 2.0
491 """
492 return self.stats().mean()
493
495 """
496 Compute the variance of this RDD's elements.
497
498 >>> sc.parallelize([1, 2, 3]).variance()
499 0.666...
500 """
501 return self.stats().variance()
502
504 """
505 Compute the standard deviation of this RDD's elements.
506
507 >>> sc.parallelize([1, 2, 3]).stdev()
508 0.816...
509 """
510 return self.stats().stdev()
511
513 """
514 Compute the sample standard deviation of this RDD's elements (which corrects for bias in
515 estimating the standard deviation by dividing by N-1 instead of N).
516
517 >>> sc.parallelize([1, 2, 3]).sampleStdev()
518 1.0
519 """
520 return self.stats().sampleStdev()
521
523 """
524 Compute the sample variance of this RDD's elements (which corrects for bias in
525 estimating the variance by dividing by N-1 instead of N).
526
527 >>> sc.parallelize([1, 2, 3]).sampleVariance()
528 1.0
529 """
530 return self.stats().sampleVariance()
531
533 """
534 Return the count of each unique value in this RDD as a dictionary of
535 (value, count) pairs.
536
537 >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
538 [(1, 2), (2, 3)]
539 """
540 def countPartition(iterator):
541 counts = defaultdict(int)
542 for obj in iterator:
543 counts[obj] += 1
544 yield counts
545 def mergeMaps(m1, m2):
546 for (k, v) in m2.iteritems():
547 m1[k] += v
548 return m1
549 return self.mapPartitions(countPartition).reduce(mergeMaps)
550
551 - def take(self, num):
552 """
553 Take the first num elements of the RDD.
554
555 This currently scans the partitions *one by one*, so it will be slow if
556 a lot of partitions are required. In that case, use L{collect} to get
557 the whole RDD instead.
558
559 >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
560 [2, 3]
561 >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
562 [2, 3, 4, 5, 6]
563 """
564 def takeUpToNum(iterator):
565 taken = 0
566 while taken < num:
567 yield next(iterator)
568 taken += 1
569
570 mapped = self.mapPartitions(takeUpToNum)
571 items = []
572 for partition in range(mapped._jrdd.splits().size()):
573 iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
574 items.extend(self._collect_iterator_through_file(iterator))
575 if len(items) >= num:
576 break
577 return items[:num]
578
580 """
581 Return the first element in this RDD.
582
583 >>> sc.parallelize([2, 3, 4]).first()
584 2
585 """
586 return self.take(1)[0]
587
588 - def saveAsTextFile(self, path):
589 """
590 Save this RDD as a text file, using string representations of elements.
591
592 >>> tempFile = NamedTemporaryFile(delete=True)
593 >>> tempFile.close()
594 >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
595 >>> from fileinput import input
596 >>> from glob import glob
597 >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*"))))
598 '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
599 """
600 def func(split, iterator):
601 for x in iterator:
602 if not isinstance(x, basestring):
603 x = unicode(x)
604 yield x.encode("utf-8")
605 keyed = PipelinedRDD(self, func)
606 keyed._bypass_serializer = True
607 keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
608
609
610
612 """
613 Return the key-value pairs in this RDD to the master as a dictionary.
614
615 >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
616 >>> m[1]
617 2
618 >>> m[3]
619 4
620 """
621 return dict(self.collect())
622
624 """
625 Merge the values for each key using an associative reduce function.
626
627 This will also perform the merging locally on each mapper before
628 sending results to a reducer, similarly to a "combiner" in MapReduce.
629
630 Output will be hash-partitioned with C{numPartitions} partitions, or
631 the default parallelism level if C{numPartitions} is not specified.
632
633 >>> from operator import add
634 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
635 >>> sorted(rdd.reduceByKey(add).collect())
636 [('a', 2), ('b', 1)]
637 """
638 return self.combineByKey(lambda x: x, func, func, numPartitions)
639
641 """
642 Merge the values for each key using an associative reduce function, but
643 return the results immediately to the master as a dictionary.
644
645 This will also perform the merging locally on each mapper before
646 sending results to a reducer, similarly to a "combiner" in MapReduce.
647
648 >>> from operator import add
649 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
650 >>> sorted(rdd.reduceByKeyLocally(add).items())
651 [('a', 2), ('b', 1)]
652 """
653 def reducePartition(iterator):
654 m = {}
655 for (k, v) in iterator:
656 m[k] = v if k not in m else func(m[k], v)
657 yield m
658 def mergeMaps(m1, m2):
659 for (k, v) in m2.iteritems():
660 m1[k] = v if k not in m1 else func(m1[k], v)
661 return m1
662 return self.mapPartitions(reducePartition).reduce(mergeMaps)
663
665 """
666 Count the number of elements for each key, and return the result to the
667 master as a dictionary.
668
669 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
670 >>> sorted(rdd.countByKey().items())
671 [('a', 2), ('b', 1)]
672 """
673 return self.map(lambda x: x[0]).countByValue()
674
675 - def join(self, other, numPartitions=None):
676 """
677 Return an RDD containing all pairs of elements with matching keys in
678 C{self} and C{other}.
679
680 Each pair of elements will be returned as a (k, (v1, v2)) tuple, where
681 (k, v1) is in C{self} and (k, v2) is in C{other}.
682
683 Performs a hash join across the cluster.
684
685 >>> x = sc.parallelize([("a", 1), ("b", 4)])
686 >>> y = sc.parallelize([("a", 2), ("a", 3)])
687 >>> sorted(x.join(y).collect())
688 [('a', (1, 2)), ('a', (1, 3))]
689 """
690 return python_join(self, other, numPartitions)
691
693 """
694 Perform a left outer join of C{self} and C{other}.
695
696 For each element (k, v) in C{self}, the resulting RDD will either
697 contain all pairs (k, (v, w)) for w in C{other}, or the pair
698 (k, (v, None)) if no elements in other have key k.
699
700 Hash-partitions the resulting RDD into the given number of partitions.
701
702 >>> x = sc.parallelize([("a", 1), ("b", 4)])
703 >>> y = sc.parallelize([("a", 2)])
704 >>> sorted(x.leftOuterJoin(y).collect())
705 [('a', (1, 2)), ('b', (4, None))]
706 """
707 return python_left_outer_join(self, other, numPartitions)
708
710 """
711 Perform a right outer join of C{self} and C{other}.
712
713 For each element (k, w) in C{other}, the resulting RDD will either
714 contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w))
715 if no elements in C{self} have key k.
716
717 Hash-partitions the resulting RDD into the given number of partitions.
718
719 >>> x = sc.parallelize([("a", 1), ("b", 4)])
720 >>> y = sc.parallelize([("a", 2)])
721 >>> sorted(y.rightOuterJoin(x).collect())
722 [('a', (2, 1)), ('b', (None, 4))]
723 """
724 return python_right_outer_join(self, other, numPartitions)
725
726
727 - def partitionBy(self, numPartitions, partitionFunc=hash):
728 """
729 Return a copy of the RDD partitioned using the specified partitioner.
730
731 >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
732 >>> sets = pairs.partitionBy(2).glom().collect()
733 >>> set(sets[0]).intersection(set(sets[1]))
734 set([])
735 """
736 if numPartitions is None:
737 numPartitions = self.ctx.defaultParallelism
738
739
740
741 def add_shuffle_key(split, iterator):
742
743 buckets = defaultdict(list)
744
745 for (k, v) in iterator:
746 buckets[partitionFunc(k) % numPartitions].append((k, v))
747 for (split, items) in buckets.iteritems():
748 yield pack_long(split)
749 yield dump_pickle(Batch(items))
750 keyed = PipelinedRDD(self, add_shuffle_key)
751 keyed._bypass_serializer = True
752 pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
753 partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
754 id(partitionFunc))
755 jrdd = pairRDD.partitionBy(partitioner).values()
756 rdd = RDD(jrdd, self.ctx)
757
758
759 rdd._partitionFunc = partitionFunc
760 return rdd
761
762
763 - def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
764 numPartitions=None):
765 """
766 Generic function to combine the elements for each key using a custom
767 set of aggregation functions.
768
769 Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined
770 type" C. Note that V and C can be different -- for example, one might
771 group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]).
772
773 Users provide three functions:
774
775 - C{createCombiner}, which turns a V into a C (e.g., creates
776 a one-element list)
777 - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of
778 a list)
779 - C{mergeCombiners}, to combine two C's into a single one.
780
781 In addition, users can control the partitioning of the output RDD.
782
783 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
784 >>> def f(x): return x
785 >>> def add(a, b): return a + str(b)
786 >>> sorted(x.combineByKey(str, add, add).collect())
787 [('a', '11'), ('b', '1')]
788 """
789 if numPartitions is None:
790 numPartitions = self.ctx.defaultParallelism
791 def combineLocally(iterator):
792 combiners = {}
793 for (k, v) in iterator:
794 if k not in combiners:
795 combiners[k] = createCombiner(v)
796 else:
797 combiners[k] = mergeValue(combiners[k], v)
798 return combiners.iteritems()
799 locally_combined = self.mapPartitions(combineLocally)
800 shuffled = locally_combined.partitionBy(numPartitions)
801 def _mergeCombiners(iterator):
802 combiners = {}
803 for (k, v) in iterator:
804 if not k in combiners:
805 combiners[k] = v
806 else:
807 combiners[k] = mergeCombiners(combiners[k], v)
808 return combiners.iteritems()
809 return shuffled.mapPartitions(_mergeCombiners)
810
811
813 """
814 Group the values for each key in the RDD into a single sequence.
815 Hash-partitions the resulting RDD with into numPartitions partitions.
816
817 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
818 >>> sorted(x.groupByKey().collect())
819 [('a', [1, 1]), ('b', [1])]
820 """
821
822 def createCombiner(x):
823 return [x]
824
825 def mergeValue(xs, x):
826 xs.append(x)
827 return xs
828
829 def mergeCombiners(a, b):
830 return a + b
831
832 return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
833 numPartitions)
834
835
837 """
838 Pass each value in the key-value pair RDD through a flatMap function
839 without changing the keys; this also retains the original RDD's
840 partitioning.
841 """
842 flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
843 return self.flatMap(flat_map_fn, preservesPartitioning=True)
844
846 """
847 Pass each value in the key-value pair RDD through a map function
848 without changing the keys; this also retains the original RDD's
849 partitioning.
850 """
851 map_values_fn = lambda (k, v): (k, f(v))
852 return self.map(map_values_fn, preservesPartitioning=True)
853
854
856 """
857 Alias for cogroup.
858 """
859 return self.cogroup(other)
860
861
862 - def cogroup(self, other, numPartitions=None):
863 """
864 For each key k in C{self} or C{other}, return a resulting RDD that
865 contains a tuple with the list of values for that key in C{self} as well
866 as C{other}.
867
868 >>> x = sc.parallelize([("a", 1), ("b", 4)])
869 >>> y = sc.parallelize([("a", 2)])
870 >>> sorted(x.cogroup(y).collect())
871 [('a', ([1], [2])), ('b', ([4], []))]
872 """
873 return python_cogroup(self, other, numPartitions)
874
876 """
877 Return each (key, value) pair in C{self} that has no pair with matching key
878 in C{other}.
879
880 >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)])
881 >>> y = sc.parallelize([("a", 3), ("c", None)])
882 >>> sorted(x.subtractByKey(y).collect())
883 [('b', 4), ('b', 5)]
884 """
885 filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
886 map_func = lambda (key, vals): [(key, val) for val in vals[0]]
887 return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
888
889 - def subtract(self, other, numPartitions=None):
890 """
891 Return each value in C{self} that is not contained in C{other}.
892
893 >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 3)])
894 >>> y = sc.parallelize([("a", 3), ("c", None)])
895 >>> sorted(x.subtract(y).collect())
896 [('a', 1), ('b', 4), ('b', 5)]
897 """
898 rdd = other.map(lambda x: (x, True))
899 return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0])
900
902 """
903 Creates tuples of the elements in this RDD by applying C{f}.
904
905 >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
906 >>> y = sc.parallelize(zip(range(0,5), range(0,5)))
907 >>> sorted(x.cogroup(y).collect())
908 [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
909 """
910 return self.map(lambda x: (f(x), x))
911
919 """
920 Pipelined maps:
921 >>> rdd = sc.parallelize([1, 2, 3, 4])
922 >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
923 [4, 8, 12, 16]
924 >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
925 [4, 8, 12, 16]
926
927 Pipelined reduces:
928 >>> from operator import add
929 >>> rdd.map(lambda x: 2 * x).reduce(add)
930 20
931 >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
932 20
933 """
934 - def __init__(self, prev, func, preservesPartitioning=False):
935 if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
936 prev_func = prev.func
937 def pipeline_func(split, iterator):
938 return func(split, prev_func(split, iterator))
939 self.func = pipeline_func
940 self.preservesPartitioning = \
941 prev.preservesPartitioning and preservesPartitioning
942 self._prev_jrdd = prev._prev_jrdd
943 else:
944 self.func = func
945 self.preservesPartitioning = preservesPartitioning
946 self._prev_jrdd = prev._jrdd
947 self.is_cached = False
948 self.is_checkpointed = False
949 self.ctx = prev.ctx
950 self.prev = prev
951 self._jrdd_val = None
952 self._bypass_serializer = False
953
954 @property
956 if self._jrdd_val:
957 return self._jrdd_val
958 func = self.func
959 if not self._bypass_serializer and self.ctx.batchSize != 1:
960 oldfunc = self.func
961 batchSize = self.ctx.batchSize
962 def batched_func(split, iterator):
963 return batched(oldfunc(split, iterator), batchSize)
964 func = batched_func
965 cmds = [func, self._bypass_serializer]
966 pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
967 broadcast_vars = ListConverter().convert(
968 [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
969 self.ctx._gateway._gateway_client)
970 self.ctx._pickled_broadcast_vars.clear()
971 class_manifest = self._prev_jrdd.classManifest()
972 env = MapConverter().convert(self.ctx.environment,
973 self.ctx._gateway._gateway_client)
974 includes = ListConverter().convert(self.ctx._python_includes,
975 self.ctx._gateway._gateway_client)
976 python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
977 pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
978 broadcast_vars, self.ctx._javaAccumulator, class_manifest)
979 self._jrdd_val = python_rdd.asJavaRDD()
980 return self._jrdd_val
981
983 return not (self.is_cached or self.is_checkpointed)
984
987 import doctest
988 from pyspark.context import SparkContext
989 globs = globals().copy()
990
991
992 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
993 (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
994 globs['sc'].stop()
995 if failure_count:
996 exit(-1)
997
998
999 if __name__ == "__main__":
1000 _test()
1001