spark join shuffle 数据读取的过程

来源:互联网 发布:甲骨文云计算大会 编辑:程序博客网 时间:2024/06/05 03:52

spark join shuffle 数据读取的过程

在spark中,当数据要shuffle时,这个拉取过程RDD是怎么和ShuffleMapTask 关联起来的。
在CoGroupedRDD通过调用如下函数去读取指定分区的数据

 SparkEnv.get.shuffleManager      .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)      .read()

通过上面的方法,就可以知道调用那个依赖的RDD,读取那个分片数据。
然后创建BlockStoreShuffleReader读取对象。在该类中执行下面的方法

// 下面就是对这个shuffler中的分片数据进行读取并进行相关的aggregate操作了val blockFetcherItr = new ShuffleBlockFetcherIterator(  context,  blockManager.shuffleClient,  blockManager,  mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),  // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility  SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

可以看到首先要通过mapOutputTracker去拉取该分区的地址信息

 def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)  : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")// 拉取这些状态数据回来了val statuses = getStatuses(shuffleId)// Synchronize on the returned array because, on the driver, it gets mutated in placestatuses.synchronized {  return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)}}

然后在 getStatuses函数中,发起远程调用,读取这个shuffle的结果地址数据

 try {      // 拉取这个shuffle的状态数据      val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))      // 这个status是那些数据分片的地址来的      fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)      logInfo("Got the output locations")      mapStatuses.put(shuffleId, fetchedStatuses)    } 

在MapOutputTrackerMaster中的MapOutputTrackerMasterEndpoint 接收线程中,接收到相关的消息

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {case GetMapOutputStatuses(shuffleId: Int) =>  // 问这个shuffler的地址信息  val hostPort = context.senderAddress.hostPort  logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)  // 去这个tracker里面去拉取了  val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)  val serializedSize = mapOutputStatuses.length  if (serializedSize > maxAkkaFrameSize) {    val msg = s"Map output statuses were $serializedSize bytes which " +      s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."    /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.     * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */    val exception = new SparkException(msg)    logError(msg, exception)    context.sendFailure(exception)  } else {    context.reply(mapOutputStatuses)  }

在tracker 保存着shuffle的执行结果。这些数据是通过DAGScheduler 在调用ShuffleMapTask 的时候,运行的结果存放的

 case smt: ShuffleMapTask =>        val shuffleStage = stage.asInstanceOf[ShuffleMapStage]        updateAccumulators(event)        val status = event.result.asInstanceOf[MapStatus]        val execId = status.location.executorId        logDebug("ShuffleMapTask finished on " + execId)        if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {          // 这是一个失败的任务          logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")        } else {          // 记录这个分区的运行结果          shuffleStage.addOutputLoc(smt.partitionId, status)        }        if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {          markStageAsFinished(shuffleStage)          logInfo("looking for newly runnable stages")          logInfo("running: " + runningStages)          logInfo("waiting: " + waitingStages)          logInfo("failed: " + failedStages)          // We supply true to increment the epoch number here in case this is a          // recomputation of the map outputs. In that case, some nodes may have cached          // locations with holes (from when we detected the error) and will need the          // epoch incremented to refetch them.          // TODO: Only increment the epoch number if this is not the first time          //       we registered these map outputs.          // 把当前shuffler的执行结果存放在这里了          mapOutputTracker.registerMapOutputs(            shuffleStage.shuffleDep.shuffleId,            shuffleStage.outputLocInMapOutputTrackerFormat(),            changeEpoch = true)

详情可查看 《SPARK TASK 任务状态管理》 ,在DAGScheduler 中当ShuffleMapTask 完成任务时,把对应的shuffleid
的计算结果路径写到mapOutputTracker中,然后在其它地方就可以请求到这个数据了。

private def convertMapStatuses(  shuffleId: Int,  startPartition: Int,  endPartition: Int,  statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {assert (statuses != null)val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]//  statuses是包含着所有的地址信息for ((status, mapId) <- statuses.zipWithIndex) {  if (status == null) {    val errorMessage = s"Missing an output location for shuffle $shuffleId"    logError(errorMessage)    throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)  } else {    for (part <- startPartition until endPartition) {      // 就是拿这个地址中这个分片里面的数据      splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=        ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))    }  }}splitsByAddress.toSeq}

然后就可以拉取指定分片里面的数据了,通过ShuffleBlockFetcherIterator 类的功能,

private[this] def sendRequest(req: FetchRequest) {logDebug("Sending request for %d blocks (%s) from %s".format(  req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))bytesInFlight += req.size// so we can look up the size of each blockIDval sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMapval blockIds = req.blocks.map(_._1.toString)// 去这些地址拉取数据了,同时注意block对象是 ShuffleBlockId 里面包含着当前请求的是那个分片数据//  在拉取的时候,还要对block块数据进行分片val address = req.addressshuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,  new BlockFetchingListener {    override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {    }    override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {      logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)      results.put(new FailureFetchResult(BlockId(blockId), address, e))    }  })}

然后就可以通过shuffleClient (NettyBlockTransferService )进行远程的拉取了。

override def fetchBlocks(  host: String,  port: Int,  execId: String,  blockIds: Array[String],  listener: BlockFetchingListener): Unit = {logTrace(s"Fetch blocks from $host:$port (executor id $execId)")// 通过netty的方式去拉取block文件try {  val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {    override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {      // 这里就是创建netty客户端进行拉取数据了      val client = clientFactory.createClient(host, port)      new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()    }  }  // 如果有重试器,则就创建一个包装对象进行重试  val maxRetries = transportConf.maxIORetries()  if (maxRetries > 0) {    // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's    // a bug in this code. We should remove the if statement once we're sure of the stability.    new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()  } else {    blockFetchStarter.createAndStart(blockIds, listener)  }} catch {  case e: Exception =>    logError("Exception while beginning fetchBlocks", e)    blockIds.foreach(listener.onBlockFetchFailure(_, e))}}

当拉取回来后,就可以对这个iterator数据进行后缀的处理了。然后回到BlockStoreShuffleReader类中

override def read(): Iterator[Product2[K, C]] = {// 读取这个分片的数据了,生成iteracor对象// 下面就是对这个shuffler中的分片数据进行读取并进行相关的aggregate操作了val blockFetcherItr = new ShuffleBlockFetcherIterator(  context,  blockManager.shuffleClient,  blockManager,  mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),  // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility  SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)// Wrap the streams for compression based on configurationval wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>  blockManager.wrapForCompression(blockId, inputStream)}val ser = Serializer.getSerializer(dep.serializer)val serializerInstance = ser.newInstance()// Create a key/value iterator for each streamval recordIter = wrappedStreams.flatMap { wrappedStream =>  // Note: the asKeyValueIterator below wraps a key/value iterator inside of a  // NextIterator. The NextIterator makes sure that close() is called on the  // underlying InputStream when all records have been read.  serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator}// Update the context task metrics for each record read.// metrics记录数据量val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](  recordIter.map(record => {    readMetrics.incRecordsRead(1)    record  }),  context.taskMetrics().updateShuffleReadMetrics())// An interruptible iterator must be used here in order to support task cancellationval interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {  if (dep.mapSideCombine) {    // We are reading values that are already combined    // 在这执行聚合方法了,在map端已经进行合并的了,这个数据先分组再count操作    val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]    dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)  } else {    // We don't know the value type, but also don't care -- the dependency *should*    // have made sure its compatible w/ this aggregator, which will convert the value    // type to the combined type C    // 在shullfer阶段已经执行这些方法的了,这个在最后所有数据进行count操作    val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]    dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)  }} else {  require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")  interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]}// Sort the output if there is a sort ordering defined.dep.keyOrdering match {  case Some(keyOrd: Ordering[K]) =>    // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,    // the ExternalSorter won't spill to disk.    // 对这数据进行排序    val sorter =      new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))    sorter.insertAll(aggregatedIter)    context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)    context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)    context.internalMetricsToAccumulators(      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)    CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())  case None =>    aggregatedIter}}

当数据拉取回来后生成iteraror ,然后判断是否有aggregator聚合函数,如果有就执行,
所以这样就可以在shuffle分片数据的过程就可以提前执行聚合函数,减少传输到后面的数据量。
所以像一些count 或者 sum的操作,其实可以直接在这里进行执行,这样在这个过程中就只要把这个
结果传到后面就可以了,数据量就大大减少了。

同时如果要对Shuffle的分片数据进行排序的需求keyOrdering,就直接在这里创建一个ExternalSorter
对象,对上面的数据进行排序返回,所以就在这个shffle的分片阶段中可以实现aggregate聚合函数和keyordering
对字段排序的功能。

总结

  1. 请求读取指定的分片数据split
  2. 去MapOutputTrackerMaster拉取该shuffleid的分片地址信息
  3. 通过netty到相关的地址拉取指定Partition的数据
  4. 去拉取回来的数据执行聚合函数操作
  5. 去执行后的iterator数据执行 keyorder排序数据,然后最后返回
原创粉丝点击