Spark Shuffle Read过程

1. ShuffledRDD的compute()方法

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)      .read()      .asInstanceOf[Iterator[(K, C)]]  }


2. HashShuffleReader的read()方法

override def read(): Iterator[Product2[K, C]] = {    val ser = Serializer.getSerializer(dep.serializer)    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {      if (dep.mapSideCombine) {        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))      } else {        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))      }    } else {      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")      // Convert the Product2s to pairs since this is what downstream RDDs currently expect      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))    }    // Sort the output if there is a sort ordering defined.    dep.keyOrdering match {      case Some(keyOrd: Ordering[K]) =>        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,        // the ExternalSorter won't spill to disk.        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))        sorter.insertAll(aggregatedIter)        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)        sorter.iterator      case None =>        aggregatedIter    }  }

读取数据的主要流程:1. 获取待拉取数据的iterator;2. 使用AppendOnlyMap/ExternalAppendOnlyMap 做combine,这个过程和shuffle write一样;3. 如果需要对key排序,则使用ExternalSorter。下面讲述主要如何得到iterator。

3. BlockStoreShuffleFetcher的fetch()方法

def fetch[T](      shuffleId: Int,      reduceId: Int,      context: TaskContext,      serializer: Serializer)    : Iterator[T] =  {    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))    val blockManager = SparkEnv.get.blockManager    val startTime = System.currentTimeMillis    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(      shuffleId, reduceId, System.currentTimeMillis - startTime))    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]    for (((address, size), index) <- statuses.zipWithIndex) {      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))    }    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {      case (address, splits) =>        (address, => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))    }    val blockFetcherItr = new ShuffleBlockFetcherIterator(      context,      SparkEnv.get.blockManager.shuffleClient,      blockManager,      blocksByAddress,      serializer,      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)    val itr = blockFetcherItr.flatMap(unpackBlock)    val completionIter = CompletionIterator[T, Iterator[T]](itr, {      context.taskMetrics.updateShuffleReadMetrics()    })  }


4. ShuffleBlockFetcherIterator的initialize()方法

private[this] def initialize(): Unit = {    // Add a task completion callback (called in both success case and failure case) to cleanup.    context.addTaskCompletionListener(_ => cleanup())    // Split local and remote blocks.    val remoteRequests = splitLocalRemoteBlocks()    // Add the remote requests into our queue in a random order    fetchRequests ++= Utils.randomize(remoteRequests)    // Send out initial requests for blocks, up to our maxBytesInFlight    while (fetchRequests.nonEmpty &&      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {      sendRequest(fetchRequests.dequeue())    }    val numFetches = remoteRequests.size - fetchRequests.size    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))    // Get Local Blocks    fetchLocalBlocks()    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))  }


5. 读取远程数据

private[this] def sendRequest(req: FetchRequest) {    logDebug("Sending request for %d blocks (%s) from %s".format(      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))    bytesInFlight += req.size    // so we can look up the size of each blockID    val sizeMap = { case (blockId, size) => (blockId.toString, size) }.toMap    val blockIds =    val address = req.address    shuffleClient.fetchBlocks(, address.port, address.executorId, blockIds.toArray,      new BlockFetchingListener {        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {          // Only add the buffer to results queue if the iterator is not zombie,          // i.e. cleanup() has not been called yet.          if (!isZombie) {            // Increment the ref count because we need to pass this to a different thread.            // This needs to be released after use.            buf.retain()            results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))            shuffleMetrics.incRemoteBytesRead(buf.size)            shuffleMetrics.incRemoteBlocksFetched(1)          }          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))        }        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {          logError(s"Failed to get block(s) from ${}:${req.address.port}", e)          results.put(new FailureFetchResult(BlockId(blockId), e))        }      }    )  }

shuffleClient的fetchBlocks()方法读取远程数据。ShuffleClient有两个子类,分别是ExternalShuffleClient和BlockTransferService,BlockTransferService也有两个子类,分别是NettyBlockTransferService和NioBlockTransferService。Spark 1.5.2中已经将NioBlockTransferService方式设置为deprecated,在后续版本中将被移除。

6. ExternalShuffleClient的fetchBlocks()方法

  public void fetchBlocks(      final String host,      final int port,      final String execId,      String[] blockIds,      BlockFetchingListener listener) {    assert appId != null : "Called before init()";    logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);    try {      RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =        new RetryingBlockFetcher.BlockFetchStarter() {          @Override          public void createAndStart(String[] blockIds, BlockFetchingListener listener)              throws IOException {            TransportClient client = clientFactory.createClient(host, port);            new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start();          }        };      int maxRetries = conf.maxIORetries();      if (maxRetries > 0) {        // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's        // a bug in this code. We should remove the if statement once we're sure of the stability.        new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();      } else {        blockFetchStarter.createAndStart(blockIds, listener);      }    } catch (Exception e) {      logger.error("Exception while beginning fetchBlocks", e);      for (String blockId : blockIds) {        listener.onBlockFetchFailure(blockId, e);      }    }  }

7. OneForOneBlockFetcher的start()方法

public void start() {    if (blockIds.length == 0) {      throw new IllegalArgumentException("Zero-sized blockIds array");    }    client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {      @Override      public void onSuccess(byte[] response) {        try {          streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);          logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);          // Immediately request all chunks -- we expect that the total size of the request is          // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].          for (int i = 0; i < streamHandle.numChunks; i++) {            client.fetchChunk(streamHandle.streamId, i, chunkCallback);          }        } catch (Exception e) {          logger.error("Failed while starting block fetches after success", e);          failRemainingBlocks(blockIds, e);        }      }      @Override      public void onFailure(Throwable e) {        logger.error("Failed while starting block fetches", e);        failRemainingBlocks(blockIds, e);      }    });  }

