Spark2.0源码之2_TorrentBroadcast

来源:互联网 发布:手机如何发布淘宝宝贝 编辑:程序博客网 时间:2024/06/05 17:14
通过Spark源码中的注释信息理解Spark内核源码。
package org.apache.spark.broadcastimport java.io._import java.nio.ByteBufferimport java.util.zip.Adler32import scala.collection.JavaConverters._import scala.reflect.ClassTagimport scala.util.Randomimport org.apache.spark._import org.apache.spark.internal.Loggingimport org.apache.spark.io.CompressionCodecimport org.apache.spark.serializer.Serializerimport org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel}import org.apache.spark.util.{ByteBufferInputStream, Utils}import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}/**  * 这是org.apache.spark.broadcast.Broadcast的一个实现类。  * 它的机制如下:  * driver会将广播变量数据的序列化对象划分成小的数据块并使用BlockManager存储这些数据块。  * 在每个executor上,使用这些变量数据时,首先会尝试从自身节点的BlockManager上抓取数据对象,如果数据不存在,  * 就会使用远程抓取的方式从driver或者其他有效的executor抓取这些数据块。获得这些数据后,将他们存储到自己的  * BlockManager上,以便其他executor抓取。  *  * 这种方式可以预防driver多次拷贝,发送数据的性能消耗。  *  * TorrentBroadcast初始化时,会读取conf配置信息。  *  * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. * * The mechanism is as follows: * * The driver divides the serialized object into small chunks and * stores those chunks in the BlockManager of the driver. * * On each executor, the executor first attempts to fetch the object from its BlockManager. If * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or * other executors if available. Once it gets the chunks, it puts the chunks in its own * BlockManager, ready for other executors to fetch from. * * This prevents the driver from being the bottleneck in sending out multiple copies of the * broadcast data (one per executor). * * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. * * @param obj object to broadcast * @param id A unique identifier for the broadcast variable. */private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)  extends Broadcast[T](id) with Logging with Serializable {  /**   * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],   * which builds this value by reading blocks from the driver and/or other executors.   *   * On the driver, if the value is required, it is read lazily from the block manager.   */  @transient private lazy val _value: T = readBroadcastBlock()  /** The compression codec to use, or None if compression is disabled */  @transient private var compressionCodec: Option[CompressionCodec] = _  /** Size of each block. Default value is 4MB.  This value is only read by the broadcaster. */  @transient private var blockSize: Int = _  private def setConf(conf: SparkConf) {    compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {      Some(CompressionCodec.createCodec(conf))    } else {      None    }    // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided    blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024    checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)  }  setConf(SparkEnv.get.conf)  private val broadcastId = BroadcastBlockId(id)  /** Total number of blocks this broadcast variable contains. */  private val numBlocks: Int = writeBlocks(obj)  /** Whether to generate checksum for blocks or not. */  private var checksumEnabled: Boolean = false  /** The checksum for all the blocks. */  private var checksums: Array[Int] = _  override protected def getValue() = {    _value  }  private def calcChecksum(block: ByteBuffer): Int = {    val adler = new Adler32()    if (block.hasArray) {      adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position)    } else {      val bytes = new Array[Byte](block.remaining())      block.duplicate.get(bytes)      adler.update(bytes)    }    adler.getValue.toInt  }  /**   * Divide the object into multiple blocks and put those blocks in the block manager.   *   * @param value the object to divide   * @return number of blocks this broadcast variable is divided into   */  private def writeBlocks(value: T): Int = {    import StorageLevel._    // Store a copy of the broadcast variable in the driver so that tasks run on the driver    // do not create a duplicate copy of the broadcast variable's value.    val blockManager = SparkEnv.get.blockManager    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {      throw new SparkException(s"Failed to store $broadcastId in BlockManager")    }    val blocks =      TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)    if (checksumEnabled) {      checksums = new Array[Int](blocks.length)    }    blocks.zipWithIndex.foreach { case (block, i) =>      if (checksumEnabled) {        checksums(i) = calcChecksum(block)      }      val pieceId = BroadcastBlockId(id, "piece" + i)      val bytes = new ChunkedByteBuffer(block.duplicate())      if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {        throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")      }    }    blocks.length  }  /** Fetch torrent blocks from the driver and/or other executors. */  private def readBlocks(): Array[ChunkedByteBuffer] = {    // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported    // to the driver, so other executors can pull these chunks from this executor as well.    val blocks = new Array[ChunkedByteBuffer](numBlocks)    val bm = SparkEnv.get.blockManager    for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {      val pieceId = BroadcastBlockId(id, "piece" + pid)      logDebug(s"Reading piece $pieceId of $broadcastId")      // First try getLocalBytes because there is a chance that previous attempts to fetch the      // broadcast blocks have already fetched some of the blocks. In that case, some blocks      // would be available locally (on this executor).      bm.getLocalBytes(pieceId) match {        case Some(block) =>          blocks(pid) = block          releaseLock(pieceId)        case None =>          bm.getRemoteBytes(pieceId) match {            case Some(b) =>              if (checksumEnabled) {                val sum = calcChecksum(b.chunks(0))                if (sum != checksums(pid)) {                  throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +                    s" $sum != ${checksums(pid)}")                }              }              // We found the block from remote executors/driver's BlockManager, so put the block              // in this executor's BlockManager.              if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {                throw new SparkException(                  s"Failed to store $pieceId of $broadcastId in local BlockManager")              }              blocks(pid) = b            case None =>              throw new SparkException(s"Failed to get $pieceId of $broadcastId")          }      }    }    blocks  }  /**   * Remove all persisted state associated with this Torrent broadcast on the executors.   */  override protected def doUnpersist(blocking: Boolean) {    TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)  }  /**   * Remove all persisted state associated with this Torrent broadcast on the executors   * and driver.   */  override protected def doDestroy(blocking: Boolean) {    TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)  }  /** Used by the JVM when serializing this object. */  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {    assertValid()    out.defaultWriteObject()  }  private def readBroadcastBlock(): T = Utils.tryOrIOException {    TorrentBroadcast.synchronized {      setConf(SparkEnv.get.conf)      val blockManager = SparkEnv.get.blockManager      blockManager.getLocalValues(broadcastId).map(_.data.next()) match {        case Some(x) =>          releaseLock(broadcastId)          x.asInstanceOf[T]        case None =>          logInfo("Started reading broadcast variable " + id)          val startTimeMs = System.currentTimeMillis()          val blocks = readBlocks().flatMap(_.getChunks())          logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))          val obj = TorrentBroadcast.unBlockifyObject[T](            blocks, SparkEnv.get.serializer, compressionCodec)          // Store the merged copy in BlockManager so other tasks on this executor don't          // need to re-fetch it.          val storageLevel = StorageLevel.MEMORY_AND_DISK          if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {            throw new SparkException(s"Failed to store $broadcastId in BlockManager")          }          obj      }    }  }  /**   * If running in a task, register the given block's locks for release upon task completion.   * Otherwise, if not running in a task then immediately release the lock.   */  private def releaseLock(blockId: BlockId): Unit = {    val blockManager = SparkEnv.get.blockManager    Option(TaskContext.get()) match {      case Some(taskContext) =>        taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId))      case None =>        // This should only happen on the driver, where broadcast variables may be accessed        // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow        // broadcast variables to be garbage collected we need to free the reference here        // which is slightly unsafe but is technically okay because broadcast variables aren't        // stored off-heap.        blockManager.releaseLock(blockId)    }  }}private object TorrentBroadcast extends Logging {  def blockifyObject[T: ClassTag](      obj: T,      blockSize: Int,      serializer: Serializer,      compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {    val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)    val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)    val ser = serializer.newInstance()    val serOut = ser.serializeStream(out)    Utils.tryWithSafeFinally {      serOut.writeObject[T](obj)    } {      serOut.close()    }    cbbos.toChunkedByteBuffer.getChunks()  }  def unBlockifyObject[T: ClassTag](      blocks: Array[ByteBuffer],      serializer: Serializer,      compressionCodec: Option[CompressionCodec]): T = {    require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")    val is = new SequenceInputStream(      blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)    val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)    val ser = serializer.newInstance()    val serIn = ser.deserializeStream(in)    val obj = Utils.tryWithSafeFinally {      serIn.readObject[T]()    } {      serIn.close()    }    obj  }  /**   * Remove all persisted blocks associated with this torrent broadcast on the executors.   * If removeFromDriver is true, also remove these persisted blocks on the driver.   */  def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {    logDebug(s"Unpersisting TorrentBroadcast $id")    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)  }}
0 0
原创粉丝点击