pyspark rdd def partitionBy自定义partitionFunc

来源:互联网 发布:windows域服务器管理 编辑:程序博客网 时间:2024/05/16 14:52

partitionBy(self, numPartitions, partitionFunc=portable_hash): 函数里主要有两个参数,一个是numPartitions ,这个是分区的数量,大家都知道。

另一个是partitionFunc,这个分区的函数,默认是哈希函数。当然我们也可以来自定义:

data = sc.parallelize(['1', '2', '3', ]).map(lambda x: (x,x)).collect()wp = data.partitionBy(data.count(),lambda k: int(k))print wp.map(lambda t: t[0]).glom().collect()

这里的自定义函数是最简单的 lambda k: int(k),即根据自身的int值来分区。我们还可以根据需要定义其他更多的分区函数。

下面给出partitionBy的源码:
def partitionBy(self, numPartitions, partitionFunc=portable_hash):
“””
Return a copy of the RDD partitioned using the specified partitioner.

      >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))       >>> sets = pairs.partitionBy(2).glom().collect()      >>> set(sets[0]).intersection(set(sets[1]))       set([])       """       if numPartitions is None:           numPartitions = self._defaultReducePartitions()       # Transferring O(n) objects to Java is too expensive.       # Instead, we'll form the hash buckets in Python,       # transferring O(numPartitions) objects to Java.       # Each object is a (splitNumber, [objects]) pair.       # In order to avoid too huge objects, the objects are       # grouped into chunks.       outputSerializer = self.ctx._unbatched_serializer       limit = (_parse_memory(self.ctx._conf.get(           "spark.python.worker.memory", "512m")) / 2)       def add_shuffle_key(split, iterator):           buckets = defaultdict(list)           c, batch = 0, min(10 * numPartitions, 1000)           for (k, v) in iterator:               buckets[partitionFunc(k) % numPartitions].append((k, v))               c += 1               # check used memory and avg size of chunk of objects               if (c % 1000 == 0 and get_used_memory() > limit                       or c > batch):                   n, size = len(buckets), 0                   for split in buckets.keys():                       yield pack_long(split)                       d = outputSerializer.dumps(buckets[split])                       del buckets[split]                       yield d                       size += len(d)                   avg = (size / n) >> 20                   # let 1M < avg < 10M                   if avg < 1:                       batch *= 1.5                   elif avg > 10:                     batch = max(batch / 1.5, 1)                 c = 0           for (split, items) in buckets.iteritems():               yield pack_long(split)               yield outputSerializer.dumps(items)       keyed = self.mapPartitionsWithIndex(add_shuffle_key)       keyed._bypass_serializer = True       with _JavaStackTrace(self.context) as st:           pairRDD = self.ctx._jvm.PairwiseRDD(               keyed._jrdd.rdd()).asJavaPairRDD()           partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,                                                         id(partitionFunc))       jrdd = pairRDD.partitionBy(partitioner).values()       rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))       # This is required so that id(partitionFunc) remains unique,       # even if partitionFunc is a lambda:       rdd._partitionFunc = partitionFunc       return rdd