    本篇文章以RDD.aggregateByKey引起的SortShuffleWriter为例说明Shuffle map端的原理和实现,为了便于说明问题这里的所有执行流程都是默认执行流程

    为了便于说明问题,本文中Stage1是shuffle map操作所在的Stage,Stage2是shuffle reduce操作所在的Stage,本文中spark.shuffle.blockTransferService为默认方式netty

   Shuffle map操作结束之后,Stage1结束,Spark的调度系统会启动Stage2,Stage2最首要的任务是根据ShuffledRDD和MapOutputTrackerMaster获取Stage2 partition的信息。具体流程可参见下面的时序图




/*  * 从stage1生成的disk shuffle文件读取内容,读取后的内容产生Stage2的partition,并且最终将这个partition转换成Iterator  * */  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)      .read()      .asInstanceOf[Iterator[(K, C)]]  }



<pre name="code" class="java"> override def read(): Iterator[Product2[K, C]] = {    /*    * 将Stage2 partition数据所在的每个块的数据转化成一个InputStream    * */    val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(      handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)    // Wrap the streams for compression based on configuration    val wrappedStreams = { case (blockId, inputStream) =>      blockManager.wrapForCompression(blockId, inputStream)    }    val ser = Serializer.getSerializer(dep.serializer)    val serializerInstance = ser.newInstance()    // Create a key/value iterator for each stream    /*    * 对wrappedStream中的数据进行deserialize处理,Stage1 map操作将数据写入disk的时候,数据是序列化了的    * */    val recordIter = wrappedStreams.flatMap { wrappedStream =>      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a      // NextIterator. The NextIterator makes sure that close() is called on the      // underlying InputStream when all records have been read.      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator    }    // Update the context task metrics for each record read.    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( => {        readMetrics.incRecordsRead(1)        record      }),      context.taskMetrics().updateShuffleReadMetrics())    // An interruptible iterator must be used here in order to support task cancellation    /*    * 将Stage2 partition的数据转化成Iterator    * */    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {      if (dep.mapSideCombine) {        // We are reading values that are already combined        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]        /*        * 如果在shuffle map操作已经在分区内部合并了相同Key的Value,则在这里合并不同分区间的Value        * */        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)      } else {        // We don't know the value type, but also don't care -- the dependency *should*        // have made sure its compatible w/ this aggregator, which will convert the value        // type to the combined type C        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]        /*       * 如果在shuffle map操作没有在分区内部合并相同Key的Value,则在这里合并Key相同的Value       * */        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)      }    } else {      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]    }    // Sort the output if there is a sort ordering defined.    dep.keyOrdering match {      case Some(keyOrd: Ordering[K]) =>        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,        // the ExternalSorter won't spill to disk.        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))        sorter.insertAll(aggregatedIter)        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)        sorter.iterator      case None =>        aggregatedIter    }  }

HashShuffleReader.read首先调用BlockStoreShuffleFetcher.fetchBlockStreams方法RPC创建InputStream,远程读取Shuffle Transfer Service的数据块,然后对InputStream中的数据解压缩和deserialize处理之后,将流中的数据转化成Iterator,再之后在这个Iterator中,对Stage1分区间的数据进行合并。

BlockStoreShuffleFetcher.fetchBlockStreams方法首先调用调用MapOutputTrackerWorker.getServerStatues方法获取Stage1 shuffle map操作产生的分区信息,但是这些信息在ShuffleMapTask执行完毕之后只是返回给了MapOutputTrackerMaster,MapOutputTrackerWorker中没有。需要调用MapOutputTrackerWorker.askTracker从MapOutputTrackerMaster远程读取Stage1 shuffle map操作产生的分区信息。MapOutputTrackerMaster返回这些信息后MapOutputTrackerWorker.convertMapStatues对进一步处理,得到Stage1 Shuffle map操作产生的数据所在的节点IP地址、这个节点上数据的大小、Shuffle id。经过这一步处理之后,Stage2知道了获取数据的目的节点,并且知道了在目的节点上需要读取多少数据,知道了怎么到Shuffle Transfer Service服务读取数据。

然后创建从Shuffle Transfer Service读取数据信息的对象ShuffleBlockFetcherIterator


def fetchBlockStreams(      shuffleId: Int,      reduceId: Int,//ShuffledRDD的partition index      context: TaskContext,      blockManager: BlockManager,      mapOutputTracker: MapOutputTracker)    : Iterator[(BlockId, InputStream)] =  {    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))    val startTime = System.currentTimeMillis    val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(      shuffleId, reduceId, System.currentTimeMillis - startTime))    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]    for (((address, size), index) <- statuses.zipWithIndex) {      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))//index是stage1的partition index    }    /*    * 获取shuffle文件的块信息,包括块所在的ip地址    * 块信息由shuffle id, stage1的partition index, stage2的partition index确定    * blocksByAddress一个元素为(块的地址, Seq(Shuffle块id, 块大小)) 这里的块大小是Stage1一个partition shuffle到Stage2这个partition的数据大小    * */    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {      case (address, splits) =>        (address, => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))    }    val blockFetcherItr = new ShuffleBlockFetcherIterator(      context,      blockManager.shuffleClient,      blockManager,      blocksByAddress,      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)    // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler { blockPair =>      val blockId = blockPair._1      val blockOption = blockPair._2      blockOption match {        case Success(inputStream) => {          (blockId, inputStream)        }        case Failure(e) => {          blockId match {            case ShuffleBlockId(shufId, mapId, _) =>              val address = statuses(mapId.toInt)._1              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)            case _ =>              throw new SparkException(                "Failed to get block " + blockId + ", which is not a shuffle block", e)          }        }      }    }  }

def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {    val statuses = mapStatuses.get(shuffleId).orNull    /*    * Shuffle reduce刚开始执行的时候,statues会为null,这个时候需要向Driver的MapOutputTrackerMaster发消息,获得Stage1    * 产生的shuffle信息    * */    if (statuses == null) {      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")      var fetchedStatuses: Array[MapStatus] = null      fetching.synchronized {        // Someone else is fetching it; wait for them to be done        while (fetching.contains(shuffleId)) {          try {            fetching.wait()          } catch {            case e: InterruptedException =>          }        }        // Either while we waited the fetch happened successfully, or        // someone fetched it in between the get and the fetching.synchronized.        fetchedStatuses = mapStatuses.get(shuffleId).orNull        if (fetchedStatuses == null) {          // We have to do the fetch, get others to wait for us.          fetching += shuffleId        }      }      if (fetchedStatuses == null) {        // We won the race to fetch the output locs; do so        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)        // This try-finally prevents hangs due to timeouts:        try {          /*          *向Driver的MapOutputTrackerMaster发消息,获得Stage1          * 产生的shuffle信息          * */          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)          logInfo("Got the output locations")          mapStatuses.put(shuffleId, fetchedStatuses)        } finally {          fetching.synchronized {            fetching -= shuffleId            fetching.notifyAll()          }        }      }      if (fetchedStatuses != null) {        fetchedStatuses.synchronized {          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)        }      } else {        logError("Missing all output locations for shuffle " + shuffleId)        throw new MetadataFetchFailedException(          shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)      }    } else {      statuses.synchronized {        return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)      }    }  }


private def convertMapStatuses(      shuffleId: Int,      reduceId: Int,      statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {    assert (statuses != null)    /*    * statuses是一个数组,status是这个数组中的一个元素,数组中的每个元素表示shuffle map阶段的一个partition shuffle到shuffle reduce阶段每个partition的数据量    * status.location是shuffle map 阶段的BlockManager.shuffleServerId    * status.getSizeForBlock(reduceId)表示partition index是redueceId在生成的Shuffle disk文件中得数据量    * */ {      status =>        if (status == null) {          logError("Missing an output location for shuffle " + shuffleId)          throw new MetadataFetchFailedException(            shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)        } else {          (status.location, status.getSizeForBlock(reduceId))        }    }  }


private[this] def initialize(): Unit = {    // Add a task completion callback (called in both success case and failure case) to cleanup.    context.addTaskCompletionListener(_ => cleanup())    // Split local and remote blocks.    val remoteRequests = splitLocalRemoteBlocks()    // Add the remote requests into our queue in a random order    fetchRequests ++= Utils.randomize(remoteRequests)    // Send out initial requests for blocks, up to our maxBytesInFlight    while (fetchRequests.nonEmpty &&      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {      sendRequest(fetchRequests.dequeue())//获取远程读取数据块InputStream    }    val numFetches = remoteRequests.size - fetchRequests.size    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))    // Get Local Blocks    fetchLocalBlocks()//获取本地读取数据块InputStream    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))  }


NettyBlockTransferService.fetchBlocks方法通过创建OneForOneBlockFetcher对象并且调用OneForOneBlockFetcher.start方法向远程Shuffle Transfer Service读取远程的块。OneForOneBlockFetcher对象的时候,这个对象的openMessage成员设置成OpenBlocks类型,这样远程Shuffle Transfer Service会接收到OpenBlocks消息。代码如下:

public class OneForOneBlockFetcher {  private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);  private final TransportClient client;  private final OpenBlocks openMessage;  private final String[] blockIds;  private final BlockFetchingListener listener;  private final ChunkReceivedCallback chunkCallback;  private StreamHandle streamHandle = null;  public OneForOneBlockFetcher(      TransportClient client,      String appId,      String execId,      String[] blockIds,      BlockFetchingListener listener) {    this.client = client;    //消息类型是OpenBlocks类型    this.openMessage = new OpenBlocks(appId, execId, blockIds);    this.blockIds = blockIds;    this.listener = listener;    this.chunkCallback = new ChunkCallback();  }  /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */  private class ChunkCallback implements ChunkReceivedCallback {    @Override    public void onSuccess(int chunkIndex, ManagedBuffer buffer) {      // On receipt of a chunk, pass it upwards as a block.      listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);    }    @Override    public void onFailure(int chunkIndex, Throwable e) {      // On receipt of a failure, fail every block from chunkIndex onwards.      String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);      failRemainingBlocks(remainingBlockIds, e);    }  }  /**   * Begins the fetching process, calling the listener with every block fetched.   * The given message will be serialized with the Java serializer, and the RPC must return a   * {@link StreamHandle}. We will send all fetch requests immediately, without throttling.   */  public void start() {    if (blockIds.length == 0) {      throw new IllegalArgumentException("Zero-sized blockIds array");    }    //发送OpenBlocks消息到远程Shuffle Transfer service    client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {      @Override      public void onSuccess(byte[] response) {        try {          streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);          logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);          // Immediately request all chunks -- we expect that the total size of the request is          // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].          for (int i = 0; i < streamHandle.numChunks; i++) {            client.fetchChunk(streamHandle.streamId, i, chunkCallback);          }        } catch (Exception e) {          logger.error("Failed while starting block fetches after success", e);          failRemainingBlocks(blockIds, e);        }      }      @Override      public void onFailure(Throwable e) {        logger.error("Failed while starting block fetches", e);        failRemainingBlocks(blockIds, e);      }    });  }  /** Invokes the "onBlockFetchFailure" callback for every listed block id. */  private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {    for (String blockId : failedBlockIds) {      try {        listener.onBlockFetchFailure(blockId, e);      } catch (Exception e2) {        logger.error("Error in block fetch failure callback", e2);      }    }  }}


 override def init(blockDataManager: BlockDataManager): Unit = {    val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)    var serverBootstrap: Option[TransportServerBootstrap] = None    var clientBootstrap: Option[TransportClientBootstrap] = None    if (authEnabled) {      serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))      clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,        securityManager.isSaslEncryptionEnabled()))    }    transportContext = new TransportContext(transportConf, rpcHandler)    clientFactory = transportContext.createClientFactory(clientBootstrap.toList)    /*    * 创建rpc请求服务,它会调用TransportContext.createServer方法,在这个方法会把前面    * 创建的rpcHandler(NettyBlockRpcServer对象)作为rpc服务的处理对象    * */    server = createServer(serverBootstrap.toList)    appId = conf.getAppId    logInfo("Server created on " + server.getPort)  }


override def receive(      client: TransportClient,      messageBytes: Array[Byte],      responseContext: RpcResponseCallback): Unit = {    val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)    logTrace(s"Received request: $message")    message match {      case openBlocks: OpenBlocks =>        val blocks: Seq[ManagedBuffer] =调用BlockManager.getBlockData读取块信息,返回InputStream        val streamId = streamManager.registerStream(blocks.iterator)        logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)      case uploadBlock: UploadBlock =>        // StorageLevel is serialized as bytes using our JavaSerializer.        val level: StorageLevel =          serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))        val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))        blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)        responseContext.onSuccess(new Array[Byte](0))    }  }
BlockManager.getBlockData方法读取Shuffle数据文件的时候,通过调用IndexShuffleBlockResolver.getBlockData首先根据Stage2 分区的id(reduceId)读取从index文件要读取的数据块在shuffle数据文件的起始地址,然后再根据Stage2 分区的id(reduceId)读取从index文件要读取的数据块在shuffle数据文件的结束地址,然后根据上面获得的起始地址和结束地址创建读取Shuffle数据文件的FileSegmentManagerBuffer对象。IndexShuffleBlockResolver.getBlockData代码如下:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {    // The block is actually going to be a range of a single map output file for this map, so    // find out the consolidated file, then the offset within that from our index    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)    val in = new DataInputStream(new FileInputStream(indexFile))    try {      ByteStreams.skipFully(in, blockId.reduceId * 8)      val offset = in.readLong()//根据Stage2 分区的id(rediceId)读取从index文件要读取的数据块在shuffle数据文件的起始地址      val nextOffset = in.readLong()//根据Stage2 分区的id(rediceId)读取从index文件要读取的数据块在shuffle数据文件的结束地址      new FileSegmentManagedBuffer(        transportConf,        getDataFile(blockId.shuffleId, blockId.mapId),        offset,        nextOffset - offset)    } finally {      in.close()    }  }

至此,Stage2 Shuffle reduce操作分区的生成分析完成。Stage2生成了它的分区之后,一个崭新的Stage开始执行。


