Spark Transformation —— mapPartitions

来源:互联网 发布:尼古丁的好处 知乎 编辑:程序博客网 时间:2024/05/21 11:10

原理

def mapPartitions[U](f: (Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false)(implicit arg0: ClassTag[U]): RDD[U]

该函数和map函数类似,只不过映射函数的参数由RDD中的每一个元素变成了RDD中每一个分区的迭代器。如果在映射的过程中需要频繁创建额外的对象,使用mapPartitions要比map高效的过。

比如,将RDD中的所有数据通过JDBC连接写入数据库,如果使用map函数,可能要为每一个元素都创建一个connection,这样开销很大,如果使用mapPartitions,那么只需要针对每一个分区建立一个connection。

参数preservesPartitioning表示是否保留父RDD的partitioner分区信息。

var rdd1 = sc.makeRDD(1 to 5,2)//rdd1有两个分区scala> var rdd3 = rdd1.mapPartitions{ x => {     | var result = List[Int]()     | var i = 0     | while(x.hasNext){     |   i += x.next()     |  }     | result.::(i).iterator     | }}rdd3: org.apache.spark.rdd.RDD[Int] = MapPartitionsRDD[84] at mapPartitions at :23//rdd3将rdd1中每个分区中的数值累加scala> rdd3.collectres65: Array[Int] = Array(3, 12)//1+2 = 3,3+4+5 =12 函数对每个分区的迭代器使用scala> rdd3.partitions.sizeres66: Int = 2

原理图

这里写图片描述

mapPartitions函数获取到每个分区的迭代器,在函数中通过这个分区整体的迭代器对整个分区的元素进行操作。 内部实现是生成MapPartitionsRDD。

图中,用户通过函数f(iter) => iter.filter(_>=3)对分区中的所有数据进行过滤,>=3的数据保留。一个方块代表一个RDD分区,含有1、 2、 3的分区过滤只剩下元素3。

源码实现

/** * Return a new RDD by applying a function to each partition of this RDD. * * `preservesPartitioning` indicates whether the input function preserves the partitioner, which * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */def mapPartitions[U: ClassTag](    f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {  val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)  new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)}S
0 0
原创粉丝点击