Spark源码阅读笔记之Broadcast(二)

来源:互联网 发布:网络流行词汇 编辑:程序博客网 时间:2024/05/22 11:49

Broadcast的Http传输机制是通过HttpBroadcastFactoryHttpBroadcast来实现的。

HttpBroadcastFactory代码:

class HttpBroadcastFactory extends BroadcastFactory {  override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {    HttpBroadcast.initialize(isDriver, conf, securityMgr)  }  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =    new HttpBroadcast[T](value_, isLocal, id)  override def stop() { HttpBroadcast.stop() }  /**   * Remove all persisted state associated with the HTTP broadcast with the given ID.   * @param removeFromDriver Whether to remove state from the driver   * @param blocking Whether to block until unbroadcasted   */  override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {    HttpBroadcast.unpersist(id, removeFromDriver, blocking)  }}

HttpBroadcastFactoryinitialize函数调用HttpBroadcast.initialize函数,主要完成的工作是根据配置获取根目录,并根据根目录启动Http服务,HttpBroadcast.initialize函数代码:

def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {      synchronized {        if (!initialized) {          bufferSize = conf.getInt("spark.buffer.size", 65536)          compress = conf.getBoolean("spark.broadcast.compress", true)          securityManager = securityMgr          if (isDriver) {            createServer(conf)            conf.set("spark.httpBroadcast.uri",  serverUri)          }          serverUri = conf.get("spark.httpBroadcast.uri")          cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)          compressionCodec = CompressionCodec.createCodec(conf)          initialized = true        }      }  }

HttpBroadcast.initialize调用的HttpBroadcast.createServer函数代码:

private def createServer(conf: SparkConf) {    broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")    val broadcastPort = conf.getInt("spark.broadcast.port", 0)    server =      new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")    server.start()    serverUri = server.uri    logInfo("Broadcast server started at " + serverUri)  }

其中HttpServer类是对Jetty server的封装。

HttpBroadcastFactoryunbroadcast函数调用HttpBroadcast.unpersist函数,主要完成的逻辑是删除各个节点存储的Broadcast,并根据removeFromDriver参数判断是否要删除Http服务根目录下存储的对应的文件,HttpBroadcast.unpersist函数代码:

/**   * Remove all persisted blocks associated with this HTTP broadcast on the executors.   * If removeFromDriver is true, also remove these persisted blocks on the driver   * and delete the associated broadcast file.   */  def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)    if (removeFromDriver) {      val file = getFile(id)      files.remove(file)      deleteBroadcastFile(file)    }  }

调用的HttpBroadcast.deleteBroadcastFile函数代码:

private def deleteBroadcastFile(file: File) {    try {      if (file.exists) {        if (file.delete()) {          logInfo("Deleted broadcast file: %s".format(file))        } else {          logWarning("Could not delete broadcast file: %s".format(file))        }      }    } catch {      case e: Exception =>        logError("Exception while deleting broadcast file: %s".format(file), e)    }  }

分析HttpBroadcast时需要注意两点:1、缓存机制;2、序列化和反序列化机制。先来看HttpBroadcast的代码:

private[spark] class HttpBroadcast[T: ClassTag](    @transient var value_ : T, isLocal: Boolean, id: Long)  extends Broadcast[T](id) with Logging with Serializable {  override protected def getValue() = value_  private val blockId = BroadcastBlockId(id)  /*   * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster   * does not need to be told about this block as not only need to know about this data block.   */  HttpBroadcast.synchronized {    SparkEnv.get.blockManager.putSingle(      blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)  }  if (!isLocal) {    HttpBroadcast.write(id, value_)  }  /**   * Remove all persisted state associated with this HTTP broadcast on the executors.   */  override protected def doUnpersist(blocking: Boolean) {    HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)  }  /**   * Remove all persisted state associated with this HTTP broadcast on the executors and driver.   */  override protected def doDestroy(blocking: Boolean) {    HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)  }  /** Used by the JVM when serializing this object. */  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {    assertValid()    out.defaultWriteObject()  }  /** Used by the JVM when deserializing this object. */  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {    in.defaultReadObject()    HttpBroadcast.synchronized {      SparkEnv.get.blockManager.getSingle(blockId) match {        case Some(x) => value_ = x.asInstanceOf[T]        case None => {          logInfo("Started reading broadcast variable " + id)          val start = System.nanoTime          value_ = HttpBroadcast.read[T](id)          /*           * We cache broadcast data in the BlockManager so that subsequent tasks using it           * do not need to re-fetch. This data is only used locally and no other node           * needs to fetch this block, so we don't notify the master.           */          SparkEnv.get.blockManager.putSingle(            blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)          val time = (System.nanoTime - start) / 1e9          logInfo("Reading broadcast variable " + id + " took " + time + " s")        }      }    }  }}

HttpBroadcast在Driver中初始化时(调用SparkContext的broadcast函数),调用HttpBroadcast.write函数将Broadcast中的数据写入到Http服务根目录下,供其他的Executor下载。HttpBroadcast.write函数代码

private def write(id: Long, value: Any) {    val file = getFile(id)    val fileOutputStream = new FileOutputStream(file)    try {      val out: OutputStream = {        if (compress) {          compressionCodec.compressedOutputStream(fileOutputStream)        } else {          new BufferedOutputStream(fileOutputStream, bufferSize)        }      }      val ser = SparkEnv.get.serializer.newInstance()      val serOut = ser.serializeStream(out)      serOut.writeObject(value)      serOut.close()      files += file    } finally {      fileOutputStream.close()    }  }def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

HttpBroadcast序列化时不会序列化需要传输的value,而是序列化该Broadcast的Id,然后在反序列化时,readObject根据Id,先调用BlockManager的getSingle函数,在BlockManager中读取该Broadcast的值,若没有则通过Http服务下载该Broadcast对应的文件,然后读取到内存中。每次读取到value后调用BlockManager的putSingle函数将该Broadcast缓存到BlockManager中,注意缓存时会设置tellMaster参数为false,即不通知Master,这样Master不知道该节点存储了该Broadcast,从而其他的Executor无法通过BlockManager来获取Broadcast的值,只有通过Http服务来获取,从而通过BlockManager实现了本地的缓存和用Http服务来远程传输Broadcast的机制。

readObject函数调用HttpBroadcast.read函数,HttpBroadcast.read函数根据Broadcast的Id和Http服务的uri生成该Broadcast对应的url,然后下载文件并读取,代码如下:

private def read[T: ClassTag](id: Long): T = {    logDebug("broadcast read server: " +  serverUri + " id: broadcast-" + id)    val url = serverUri + "/" + BroadcastBlockId(id).name    var uc: URLConnection = null    if (securityManager.isAuthenticationEnabled()) {      logDebug("broadcast security enabled")      val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)      uc = newuri.toURL.openConnection()      uc.setConnectTimeout(httpReadTimeout)      uc.setAllowUserInteraction(false)    } else {      logDebug("broadcast not using security")      uc = new URL(url).openConnection()      uc.setConnectTimeout(httpReadTimeout)    }    Utils.setupSecureURLConnection(uc, securityManager)    val in = {      uc.setReadTimeout(httpReadTimeout)      val inputStream = uc.getInputStream      if (compress) {        compressionCodec.compressedInputStream(inputStream)      } else {        new BufferedInputStream(inputStream, bufferSize)      }    }    val ser = SparkEnv.get.serializer.newInstance()    val serIn = ser.deserializeStream(in)    val obj = serIn.readObject[T]()    serIn.close()    obj  }
0 0
原创粉丝点击