Spark源码学习笔记6-RpcEnv(Rpc实现层)

来源:互联网 发布:yum配置文件 编辑:程序博客网 时间:2024/06/08 17:10

继5-RpcEnv(Rpc抽象层) 之后,我们再来了解下Rpc框架下的实现层。
上一节里RpcEnv里create函数调用的是NettyRpcEnvFactory的create函数。

  • NettyRpcEnvFactory

    NettyRpcEnvFactory类位于NettyRpcEnv.scala文件,其create函数实现如下:

private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {  def create(config: RpcEnvConfig): RpcEnv = {    val sparkConf = config.conf    // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support    // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance    val javaSerializerInstance =      new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]    val nettyEnv =      new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,        config.securityManager)    if (!config.clientMode) {      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>        nettyEnv.startServer(config.bindAddress, actualPort)        (nettyEnv, nettyEnv.address.port)      }      try {        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1      } catch {        case NonFatal(e) =>          nettyEnv.shutdown()          throw e      }    }    nettyEnv  }}

NettyRpcEnvFactory创建了NettyRpcEnv之后,如果clientMode为false,即服务端(Driver端Rpc通讯),则使用创建出的NettyRpcEnv的函数startServer定义一个函数变量startNettyRpcEnv((nettyEnv, nettyEnv.address.port)为函数的返回值),将该函数作为参数传递给函数Utils.startServiceOnPort,即在Driver端启动服务。
这里可以进入Utils.startServiceOnPort这个函数看看源代码,可以看出为什么不直接调用nettyEnv.startServer,而要把它封装起来传递给工具类来调用:在这个端口启动服务不一定一次就能成功,工具类里对失败的情况做最大次数的尝试,直到启动成功并返回启动成功后的端口。

  • NettyRpcEnv

我们再来研究下NettyRpcEnv类,该类继承RpcEnv,具有伴生对象。伴生对象仅维持两个对象currentEnv和currentClient(在NettyRpcEndpointRef反序列化时使用,暂时不太明白什么意思):

private[netty] object NettyRpcEnv extends Logging {  /**   * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].   * Use `currentEnv` to wrap the deserialization codes. E.g.,   *   * {{{   *   NettyRpcEnv.currentEnv.withValue(this) {   *     your deserialization codes   *   }   * }}}   */  private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)  /**   * Similar to `currentEnv`, this variable references the client instance associated with an   * RPC, in case it's needed to find out the remote address during deserialization.   */  private[netty] val currentClient = new DynamicVariable[TransportClient](null)}

继续看看伴生类NettyRpcEnv(明天继续…)
NettyRpcEnv的构造函数中创建了一些私有变量,如下:

package org.apache.spark.rpc.netty......private[netty] class NettyRpcEnv(    val conf: SparkConf,    javaSerializerInstance: JavaSerializerInstance,    host: String,    securityManager: SecurityManager) extends RpcEnv(conf) with Logging {  private[netty] val transportConf = SparkTransportConf.fromSparkConf(    conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),    "rpc",    conf.getInt("spark.rpc.io.threads", 0))  private val dispatcher: Dispatcher = new Dispatcher(this)  private val streamManager = new NettyStreamManager(this)  private val transportContext = new TransportContext(transportConf,    new NettyRpcHandler(dispatcher, this, streamManager))  private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {    if (securityManager.isAuthenticationEnabled()) {      java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,        securityManager.isSaslEncryptionEnabled()))    } else {      java.util.Collections.emptyList[TransportClientBootstrap]    }  }  private val clientFactory = transportContext.createClientFactory(createClientBootstraps())  /** * A separate client factory for file downloads. This avoids using the same RPC handler as * the main RPC context, so that events caused by these clients are kept isolated from the * main RPC traffic. *  * It also allows for different configuration of certain properties, such as the number of * connections per peer.   */  @volatile private var fileDownloadFactory: TransportClientFactory = _  val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")  // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool  // to implement non-blocking send/ask.  // TODO: a non-blocking TransportClientFactory.createClient in future  private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(    "netty-rpc-connection",    conf.getInt("spark.rpc.connect.threads", 64))  @volatile private var server: TransportServer = _  private val stopped = new AtomicBoolean(false)  /** * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]], * we just put messages to its [[Outbox]] to implement a non-blocking `send` method.   */  private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()  ......  ......}

我们先了解下这些成员变量dispatcher, streamManager, transportContext, clientFactory, fileDownloadFactory, clientConnectionExecutor, server。

  • dispatcher: Dispatcher
private val dispatcher: Dispatcher = new Dispatcher(this)

Dispatcher类是一个消息分发器,负责将RPC消息发送到适当的端点。该类有一个内部类EndpointData,包含端点/端点引用/收件箱Inbox。类Dispatcher包含3个端点及引用相关的私有变量endpoints, endpointRefs, receivers。包含成员函数registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef, 该函数向Dispatcher注册端点(添加到上述3个变量里), 并返回创建的端点引用(注意这个地方返回的端点引用中监听地址与传递的参数没有关系,也就是说返回的端点引用都是同一个端点引用,Dispatcher所在的端点引用?看不懂这里,待后续了解,同时欢迎赐教。从后续的NettyRpcEnv中函数send和ask可以看出,dispatcher用于处理发往本地Endpoint的消息,发往远程端的消息是放入Outbox中,这大概是为什么吧!)。

Dispatcher还包括RpcEndpointRef的获取/移除/取消注册等函数。
私有函数postMessage向特定端点发送消息,该函数实现是把InboxMessage消息实例放入特定端点的Inbox,同时把这个特定端点的EndpointData放入receivers中,由receivers追踪。
私有函数postMessage被public函数postToAll, postRemoteMessage, postLocalMessage, postOneWayMessage调用。postToAll把消息发给所有注册过端点;postRemoteMessage将参数RequestMessage和RpcResponseCallback组装成RpcMessage放入receiver对应的inbox;postLocalMessage与postRemoteMessage类似,只是RpcCallContext略有差别;postOneWayMessage通过RequestMessage组装成没有RpcCallContext的OneWayMessage放入receiver对应的inbox。

Dispatcher还包含一个继承了Runnable(具有抽象函数run()的java interface,主要用于线程执行)的内部类MessageLoop,用于消息处理:循环地从receivers取出具有消息的EndpointData,调用inbox处理消息,直到取到PoisonPill(成员为null的EndpointData,是一种标记,表示需要跳出消息循环,Dispatcher的stop函数放入的),取到之后也要再放进去,以便其他消息循环退出。

Dispatcher内部维护着一个线程池threadpool: ThreadPoolExecutor, 通过孤立对象ThreadUtils的函数newDaemonFixedThreadPool创建包含指定数量线程的线程池,给每个线程new一个MessageLoop实例让其运行。

Dispatcher还有public的stop函数,函数中依次取消注册的端点,给receivers队列里放入PoisonPill,以便MessageLoop退出,调用线程池的shutdown函数。

package org.apache.spark.rpc.netty....../** * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). */private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {  private class EndpointData(      val name: String,      val endpoint: RpcEndpoint,      val ref: NettyRpcEndpointRef) {    val inbox = new Inbox(ref, endpoint)  }  private val endpoints: ConcurrentMap[String, EndpointData] =    new ConcurrentHashMap[String, EndpointData]  private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =    new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]  // Track the receivers whose inboxes may contain messages.  private val receivers = new LinkedBlockingQueue[EndpointData]  /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced * immediately.   */  @GuardedBy("this")  private var stopped = false  def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {    val addr = RpcEndpointAddress(nettyEnv.address, name)    val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)    synchronized {      if (stopped) {        throw new IllegalStateException("RpcEnv has been stopped")      }      if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {        throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")      }      val data = endpoints.get(name)      endpointRefs.put(data.endpoint, data.ref)      receivers.offer(data)  // for the OnStart message    }    endpointRef  }  def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)  def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)  // Should be idempotent  private def unregisterRpcEndpoint(name: String): Unit = {    val data = endpoints.remove(name)    if (data != null) {      data.inbox.stop()      receivers.offer(data)  // for the OnStop message    }    // Don't clean `endpointRefs` here because it's possible that some messages are being processed    // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via    // `removeRpcEndpointRef`.  }  def stop(rpcEndpointRef: RpcEndpointRef): Unit = {    synchronized {      if (stopped) {        // This endpoint will be stopped by Dispatcher.stop() method.        return      }      unregisterRpcEndpoint(rpcEndpointRef.name)    }  }  /** * Send a message to all registered [[RpcEndpoint]]s in this process. *  * This can be used to make network events known to all end points (e.g. "a new node connected").   */  def postToAll(message: InboxMessage): Unit = {    val iter = endpoints.keySet().iterator()    while (iter.hasNext) {      val name = iter.next      postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}"))    }  }  /** Posts a message sent by a remote endpoint. */  def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {    val rpcCallContext =      new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)    postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))  }  /** Posts a message sent by a local endpoint. */  def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {    val rpcCallContext =      new LocalNettyRpcCallContext(message.senderAddress, p)    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)    postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))  }  /** Posts a one-way message. */  def postOneWayMessage(message: RequestMessage): Unit = {    postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),      (e) => throw e)  }  /** * Posts a message to a specific endpoint. *  * @param endpointName name of the endpoint. * @param message the message to post * @param callbackIfStopped callback function if the endpoint is stopped.   */  private def postMessage(      endpointName: String,      message: InboxMessage,      callbackIfStopped: (Exception) => Unit): Unit = {    val error = synchronized {      val data = endpoints.get(endpointName)      if (stopped) {        Some(new RpcEnvStoppedException())      } else if (data == null) {        Some(new SparkException(s"Could not find $endpointName."))      } else {        data.inbox.post(message)        receivers.offer(data)        None      }    }    // We don't need to call `onStop` in the `synchronized` block    error.foreach(callbackIfStopped)  }  def stop(): Unit = {    synchronized {      if (stopped) {        return      }      stopped = true    }    // Stop all endpoints. This will queue all endpoints for processing by the message loops.    endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)    // Enqueue a message that tells the message loops to stop.    receivers.offer(PoisonPill)    threadpool.shutdown()  }  def awaitTermination(): Unit = {    threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)  }  /** * Return if the endpoint exists   */  def verify(name: String): Boolean = {    endpoints.containsKey(name)  }  /** Thread pool used for dispatching messages. */  private val threadpool: ThreadPoolExecutor = {    val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",      math.max(2, Runtime.getRuntime.availableProcessors()))    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")    for (i <- 0 until numThreads) {      pool.execute(new MessageLoop)    }    pool  }  /** Message loop used for dispatching messages. */  private class MessageLoop extends Runnable {    override def run(): Unit = {      try {        while (true) {          try {            val data = receivers.take()            if (data == PoisonPill) {              // Put PoisonPill back so that other MessageLoops can see it.              receivers.offer(PoisonPill)              return            }            data.inbox.process(Dispatcher.this)          } catch {            case NonFatal(e) => logError(e.getMessage, e)          }        }      } catch {        case ie: InterruptedException => // exit      }    }  }  /** A poison endpoint that indicates MessageLoop should exit its message loop. */  private val PoisonPill = new EndpointData(null, null, null)}
  • Inbox

我们再看看Dispatcher里面用到的EndpointData中包含类Inbox,Inbox存放消息,并且提供处理消息的函数。Inbox所在源文件Inbox.scala。

Inbox中首先包含一些InboxMessage的定义,继承sealed trait InboxMessage。样例类OneWayMessage, RpcMessage, 和一些特殊的InboxMessage:OnStart, OnStop, RemoteProcessConnected, RemoteProcessDisconnected, RemoteProcessConnectionError。

类Inbox内部存放messages的是一个LinkedList[InboxMessage], 维持的另外几个变量: stopped表示Inbox是否停止; enableConcurrent表示是否并发执行(Dispatcher里面是多个线程处理消息,那么同一个RpcEndpointData的Inbox就可以被多个线程调用process),Inbox启动时会置为true, Inbox停止时会置为false; numActiveThreads表示在处理该RpcEndpointData的Inbox里的消息的线程数。

主要函数process(dispatcher: Dispatcher): Unit处理消息。
处理OnStart消息 该消息在Inbox的构造函数中放入消息LinkedList;调用RpcEndpoint的onStart(),另外如果不是ThreadSafeRpcEndpoint,则把并行标记置为true;
处理OnStop消息 该消息在Inbox停止时放入LinkedList;Dispatcher停止时,OnStop是每个RpcEndpointData最后一个放入的消息,放入之前把Dispatcher的字段stopped置为true, postMessage就放入不了消息;故该条消息应该是该Inbox最后一条消息,程序中有assert(activeThreads == 1,…),表示处理该消息的线程是调用该Inbox的最后一个线程;将该RpcEndpointRef从Dispatcher中移除,调用该RpcEndpoint的onStop()。
处理RemoteProcessConnected, RemoteProcessDisconnected, RemoteProcessConnectionError消息时都是直接调用RpcEndpoint的相应函数。
处理RpcMessage 调用RpcEndpoint的receiveAndReply函数,取决于RpcEndpoint的具体实现。
处理OneWayMessage 调用RpcEndpoint的receive函数,取决于RpcEndpoint的具体实现。

package org.apache.spark.rpc.netty......private[netty] sealed trait InboxMessageprivate[netty] case class OneWayMessage(    senderAddress: RpcAddress,    content: Any) extends InboxMessageprivate[netty] case class RpcMessage(    senderAddress: RpcAddress,    content: Any,    context: NettyRpcCallContext) extends InboxMessageprivate[netty] case object OnStart extends InboxMessageprivate[netty] case object OnStop extends InboxMessage/** A message to tell all endpoints that a remote process has connected. */private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage/** A message to tell all endpoints that a remote process has disconnected. */private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage/** A message to tell all endpoints that a network error has happened. */private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress)  extends InboxMessage/** * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. */private[netty] class Inbox(    val endpointRef: NettyRpcEndpointRef,    val endpoint: RpcEndpoint)  extends Logging {  inbox =>  // Give this an alias so we can use it more clearly in closures.  @GuardedBy("this")  protected val messages = new java.util.LinkedList[InboxMessage]()  /** True if the inbox (and its associated endpoint) is stopped. */  @GuardedBy("this")  private var stopped = false  /** Allow multiple threads to process messages at the same time. */  @GuardedBy("this")  private var enableConcurrent = false  /** The number of threads processing messages for this inbox. */  @GuardedBy("this")  private var numActiveThreads = 0  // OnStart should be the first message to process  inbox.synchronized {    messages.add(OnStart)  }  /**   * Process stored messages.   */  def process(dispatcher: Dispatcher): Unit = {    var message: InboxMessage = null    inbox.synchronized {      if (!enableConcurrent && numActiveThreads != 0) {        return      }      message = messages.poll()      if (message != null) {        numActiveThreads += 1      } else {        return      }    }    while (true) {      safelyCall(endpoint) {        message match {          case RpcMessage(_sender, content, context) =>            try {              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>                throw new SparkException(s"Unsupported message $message from ${_sender}")              })            } catch {              case NonFatal(e) =>                context.sendFailure(e)                // Throw the exception -- this exception will be caught by the safelyCall function.                // The endpoint's onError function will be called.                throw e            }          case OneWayMessage(_sender, content) =>            endpoint.receive.applyOrElse[Any, Unit](content, { msg =>              throw new SparkException(s"Unsupported message $message from ${_sender}")            })          case OnStart =>            endpoint.onStart()            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {              inbox.synchronized {                if (!stopped) {                  enableConcurrent = true                }              }            }          case OnStop =>            val activeThreads = inbox.synchronized { inbox.numActiveThreads }            assert(activeThreads == 1,              s"There should be only a single active thread but found $activeThreads threads.")            dispatcher.removeRpcEndpointRef(endpoint)            endpoint.onStop()            assert(isEmpty, "OnStop should be the last message")          case RemoteProcessConnected(remoteAddress) =>            endpoint.onConnected(remoteAddress)          case RemoteProcessDisconnected(remoteAddress) =>            endpoint.onDisconnected(remoteAddress)          case RemoteProcessConnectionError(cause, remoteAddress) =>            endpoint.onNetworkError(cause, remoteAddress)        }      }      inbox.synchronized {        // "enableConcurrent" will be set to false after `onStop` is called, so we should check it        // every time.        if (!enableConcurrent && numActiveThreads != 1) {          // If we are not the only one worker, exit          numActiveThreads -= 1          return        }        message = messages.poll()        if (message == null) {          numActiveThreads -= 1          return        }      }    }  }  def post(message: InboxMessage): Unit = inbox.synchronized {    if (stopped) {      // We already put "OnStop" into "messages", so we should drop further messages      onDrop(message)    } else {      messages.add(message)      false    }  }  def stop(): Unit = inbox.synchronized {    // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last    // message    if (!stopped) {      // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only      // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources      // safely.      enableConcurrent = false      stopped = true      messages.add(OnStop)      // Note: The concurrent events in messages will be processed one by one.    }  }  def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }  /**   * Called when we are dropping a message. Test cases override this to test message dropping.   * Exposed for testing.   */  protected def onDrop(message: InboxMessage): Unit = {    logWarning(s"Drop $message because $endpointRef is stopped")  }  /**   * Calls action closure, and calls the endpoint's onError function in the case of exceptions.   */  private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {    try action catch {      case NonFatal(e) =>        try endpoint.onError(e) catch {          case NonFatal(ee) => logError(s"Ignoring error", ee)        }    }  }}
  • NettyStreamManager
    在NettyRpcEnv的构造函数中创建了私有变量streamManager:
private val streamManager = new NettyStreamManager(this)

NettyStreamManager继承trait RpcEnvFileServer(功能方面),主要用于NettyRpcEnv环境下的文件管理和服务,同时继承abstract StreamManager(实现方面),源文件为NettyStreamManager.scala

package org.apache.spark.rpc.netty....../** * StreamManager implementation for serving files from a NettyRpcEnv. * * Three kinds of resources can be registered in this manager, all backed by actual files: * * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]]. * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]]. * - arbitrary directories; all files under the directory become available through the manager, *   respecting the directory's hierarchy. * * Only streaming (openStream) is supported. */ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)  extends StreamManager with RpcEnvFileServer {  private val files = new ConcurrentHashMap[String, File]()  private val jars = new ConcurrentHashMap[String, File]()  private val dirs = new ConcurrentHashMap[String, File]()  override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {    throw new UnsupportedOperationException()  }  override def openStream(streamId: String): ManagedBuffer = {    val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)    val file = ftype match {      case "files" => files.get(fname)      case "jars" => jars.get(fname)      case other =>        val dir = dirs.get(ftype)        require(dir != null, s"Invalid stream URI: $ftype not found.")        new File(dir, fname)    }    if (file != null && file.isFile()) {      new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())    } else {      null    }  }  override def addFile(file: File): String = {    val existingPath = files.putIfAbsent(file.getName, file)    ......  }  override def addJar(file: File): String = {    val existingPath = jars.putIfAbsent(file.getName, file)    ......  }  override def addDirectory(baseUri: String, path: File): String = {    val fixedBaseUri = validateDirectoryUri(baseUri)    require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null,      ......  }}

(明天继续…)

  • TransportContext

TransportContext所在源文件TransportContext.java,该类负责Rpc消息的传输,涉及Netty通讯方式的具体实现(主要为server和client创建消息的传输通道?!!)。主要函数有createServer, createClientFactory, initializePipeline。其中又涉及类TransportClientFactory, TransportClientBootstrap, TransportServer, TransportServerBootstrap, TransportChannelHandler,这些类已经都是Netty通讯的具体实现,是用java实现的,待后续展开研究。

package org.apache.spark.network;......import io.netty.channel.Channel;import io.netty.channel.socket.SocketChannel;......import org.apache.spark.network.client.TransportResponseHandler;......import org.apache.spark.network.server.TransportChannelHandler;import org.apache.spark.network.server.TransportRequestHandler;....../** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to * setup Netty Channel pipelines with a * {@link org.apache.spark.network.server.TransportChannelHandler}. *  * There are two communication protocols that the TransportClient provides, control-plane RPCs and * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams * which can be streamed through the data plane in chunks using zero-copy IO. *  * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each * channel. As each TransportChannelHandler contains a TransportClient, this enables server * processes to send messages back to the client on an existing channel. */public class TransportContext {  private static final Logger logger = LoggerFactory.getLogger(TransportContext.class);  private final TransportConf conf;  private final RpcHandler rpcHandler;  private final boolean closeIdleConnections;  private final MessageEncoder encoder;  private final MessageDecoder decoder;  public TransportContext(TransportConf conf, RpcHandler rpcHandler) {    this(conf, rpcHandler, false);  }  public TransportContext(      TransportConf conf,      RpcHandler rpcHandler,      boolean closeIdleConnections) {    this.conf = conf;    this.rpcHandler = rpcHandler;    this.encoder = new MessageEncoder();    this.decoder = new MessageDecoder();    this.closeIdleConnections = closeIdleConnections;  }  /** * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning * a new Client. Bootstraps will be executed synchronously, and must run successfully in order * to create a Client.   */  public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {    return new TransportClientFactory(this, bootstraps);  }  public TransportClientFactory createClientFactory() {    return createClientFactory(Lists.<TransportClientBootstrap>newArrayList());  }  /** Create a server which will attempt to bind to a specific port. */  public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {    return new TransportServer(this, null, port, rpcHandler, bootstraps);  }  /** Create a server which will attempt to bind to a specific host and port. */  public TransportServer createServer(      String host, int port, List<TransportServerBootstrap> bootstraps) {    return new TransportServer(this, host, port, rpcHandler, bootstraps);  }  /** Creates a new server, binding to any available ephemeral port. */  public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {    return createServer(0, bootstraps);  }  public TransportServer createServer() {    return createServer(0, Lists.<TransportServerBootstrap>newArrayList());  }  public TransportChannelHandler initializePipeline(SocketChannel channel) {    return initializePipeline(channel, rpcHandler);  }  /** * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or * response messages. *  * @param channel The channel to initialize. * @param channelRpcHandler The RPC handler to use for the channel. *  * @return Returns the created TransportChannelHandler, which includes a TransportClient that can * be used to communicate on this channel. The TransportClient is directly associated with a * ChannelHandler to ensure all users of the same channel get the same TransportClient object.   */  public TransportChannelHandler initializePipeline(      SocketChannel channel,      RpcHandler channelRpcHandler) {    try {      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);      channel.pipeline()        .addLast("encoder", encoder)        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())        .addLast("decoder", decoder)        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))        // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this        // would require more logic to guarantee if this were not part of the same event loop.        .addLast("handler", channelHandler);      return channelHandler;    } catch (RuntimeException e) {      logger.error("Error while initializing Netty pipeline", e);      throw e;    }  }  /** * Creates the server- and client-side handler which is used to handle both RequestMessages and * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet.   */  private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {    TransportResponseHandler responseHandler = new TransportResponseHandler(channel);    TransportClient client = new TransportClient(channel, responseHandler);    TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,      rpcHandler);    return new TransportChannelHandler(client, responseHandler, requestHandler,      conf.connectionTimeoutMs(), closeIdleConnections);  }  public TransportConf getConf() { return conf; }}
  • NettyRpcHandler

NettyRpcHandler在NettyRpcEnv中new出来作为TransportContext构造函数参数传入:

private val transportContext = new TransportContext(transportConf,    new NettyRpcHandler(dispatcher, this, streamManager))

NettyRpcHandler在文件NettyRpcEnv.scala中,NettyRpcHandler继承抽象类RpcHandler。RpcHandler在文件RpcHandler.java中,所在的package是在network的server命名空间里,故应为server处理Rpc消息的类。RpcHandler用于处理TransportClient发送的Rpc消息,在其receive函数中处理Rpc消息,也有channelActive和channelInactive函数,处理与客户端的通讯channel的连接状态。

package org.apache.spark.network.server;....../** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */public abstract class RpcHandler {  private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();  /**   * Receive a single RPC message. Any exception thrown while in this method will be sent back to   * the client in string form as a standard RPC failure.   *   * This method will not be called in parallel for a single TransportClient (i.e., channel).   *   * @param client A channel client which enables the handler to make requests back to the sender   *               of this RPC. This will always be the exact same object for a particular channel.   * @param message The serialized bytes of the RPC.   * @param callback Callback which should be invoked exactly once upon success or failure of the   *                 RPC.   */  public abstract void receive(      TransportClient client,      ByteBuffer message,      RpcResponseCallback callback);  /**   * Returns the StreamManager which contains the state about which streams are currently being   * fetched by a TransportClient.   */  public abstract StreamManager getStreamManager();  /**   * Receives an RPC message that does not expect a reply.    ......   */  public void receive(TransportClient client, ByteBuffer message) {    receive(client, message, ONE_WAY_CALLBACK);  }  /**   * Invoked when the channel associated with the given client is active.   */  public void channelActive(TransportClient client) { }  /**   * Invoked when the channel associated with the given client is inactive.   * No further requests will come from this client.   */  public void channelInactive(TransportClient client) { }  public void exceptionCaught(Throwable cause, TransportClient client) { }  private static class OneWayRpcCallback implements RpcResponseCallback {    ......  }}

NettyRpcHandler实现RpcHandler的接口,因为给server发送Rpc消息的client不止一个,故NettyRpcHandler内部维护了一个remoteAddresses: ConcurrentHashMap[RpcAddress, RpcAddress]跟踪给它发过消息的client。

receive函数中,转换成RemoteMessage或OneWayMessage放入dispatcher;如果client是初次发送消息给该server,则把client的socket地址添加到remoteAddresses中,并且给dispatcher中所有的Endpoint发送RemoteProcessConnected(remoteEnvAddress)消息。

channelActive函数中把RemoteProcessConnected(clientAddr)消息发送给dispatcher中所有的Endpoint;channelInactive函数中把该client的outbox移除了,把remoteAddresses对该client的跟踪移除了,把RemoteProcessConnected(clientAddr)消息发送给dispatcher中所有的Endpoint,如果remoteAddresses不为null,则也把RemoteProcessDisconnected(remoteEnvAddress)消息发送给dispatcher中所有的Endpoint。clientAddr和remoteEnvAddress作为参数的连接状态消息有什么不同,暂时不是很明了,待后续了解

/** * Dispatches incoming RPCs to registered endpoints. *  * The handler keeps track of all client instances that communicate with it, so that the RpcEnv * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e., * one that is not listening for incoming connections, but rather needs to be contacted via the * client socket). *  * Events are sent on a per-connection basis, so if a client opens multiple connections to the * RpcEnv, multiple connection / disconnection events will be created for that client (albeit * with different `RpcAddress` information). */private[netty] class NettyRpcHandler(    dispatcher: Dispatcher,    nettyEnv: NettyRpcEnv,    streamManager: StreamManager) extends RpcHandler with Logging {  // A variable to track the remote RpcEnv addresses of all clients  private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()  override def receive(      client: TransportClient,      message: ByteBuffer,      callback: RpcResponseCallback): Unit = {    val messageToDispatch = internalReceive(client, message)    dispatcher.postRemoteMessage(messageToDispatch, callback)  }  override def receive(      client: TransportClient,      message: ByteBuffer): Unit = {    val messageToDispatch = internalReceive(client, message)    dispatcher.postOneWayMessage(messageToDispatch)  }  private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = {    val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]    assert(addr != null)    val clientAddr = RpcAddress(addr.getHostString, addr.getPort)    val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)    if (requestMessage.senderAddress == null) {      // Create a new message with the socket address of the client as the sender.      RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)    } else {      // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for      // the listening address      val remoteEnvAddress = requestMessage.senderAddress      if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) {        dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))      }      requestMessage    }  }  override def getStreamManager: StreamManager = streamManager  override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {    val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]    if (addr != null) {      val clientAddr = RpcAddress(addr.getHostString, addr.getPort)      dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))      // If the remove RpcEnv listens to some address, we should also fire a      // RemoteProcessConnectionError for the remote RpcEnv listening address      val remoteEnvAddress = remoteAddresses.get(clientAddr)      if (remoteEnvAddress != null) {        dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress))      }    } else {      // If the channel is closed before connecting, its remoteAddress will be null.      // See java.net.Socket.getRemoteSocketAddress      // Because we cannot get a RpcAddress, just log it      logError("Exception before connecting to the client", cause)    }  }  override def channelActive(client: TransportClient): Unit = {    val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]    assert(addr != null)    val clientAddr = RpcAddress(addr.getHostString, addr.getPort)    dispatcher.postToAll(RemoteProcessConnected(clientAddr))  }  override def channelInactive(client: TransportClient): Unit = {    val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]    if (addr != null) {      val clientAddr = RpcAddress(addr.getHostString, addr.getPort)      nettyEnv.removeOutbox(clientAddr)      dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))      val remoteEnvAddress = remoteAddresses.remove(clientAddr)      // If the remove RpcEnv listens to some address, we should also  fire a      // RemoteProcessDisconnected for the remote RpcEnv listening address      if (remoteEnvAddress != null) {        dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress))      }    } else {      // If the channel is closed before connecting, its remoteAddress will be null. In this case,      // we can ignore it since we don't fire "Associated".      // See java.net.Socket.getRemoteSocketAddress    }  }}
  • NettyRpcEnv中还有个线程池clientConnectionExecutor,暂时不了解具体干什么,注释和创建代码如下。
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool  // to implement non-blocking send/ask.  // TODO: a non-blocking TransportClientFactory.createClient in future  private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(    "netty-rpc-connection",    conf.getInt("spark.rpc.connect.threads", 64))
  • NettyRpcEnv中还包含对象outboxes: ConcurrentHashMap[RpcAddress, Outbox],为每一个远程通讯对象维持一个Outbox,实现非阻塞通讯,代码如下:
/** * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]], * we just put messages to its [[Outbox]] to implement a non-blocking `send` method.   */  private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
  • NettyRpcEnv的函数send和ask
    从源码中我们可以看到NettyRpcEnv的函数send和ask根据其remoteAdd判断,如果与本地RpcEnv的地址一样,即为local消息,发往本地RpcEnv里Endpoints(存在多个Endpoint)的消息,则放入dispatcher中,由dispatcher分发给具体的Endpoint;如果为远程端的消息,则放入具体的Outbox。
private[netty] def send(message: RequestMessage): Unit = {    val remoteAddr = message.receiver.address    if (remoteAddr == address) {      // Message to a local RPC endpoint.      try {        dispatcher.postOneWayMessage(message)      } catch {        case e: RpcEnvStoppedException => logWarning(e.getMessage)      }    } else {      // Message to a remote RPC endpoint.      postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))    }}private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {    val promise = Promise[Any]()    val remoteAddr = message.receiver.address    def onFailure(e: Throwable): Unit = {      if (!promise.tryFailure(e)) {        logWarning(s"Ignored failure: $e")      }    }    def onSuccess(reply: Any): Unit = reply match {      case RpcFailure(e) => onFailure(e)      case rpcReply =>        if (!promise.trySuccess(rpcReply)) {          logWarning(s"Ignored message: $reply")        }    }    try {      if (remoteAddr == address) {        val p = Promise[Any]()        p.future.onComplete {          case Success(response) => onSuccess(response)          case Failure(e) => onFailure(e)        }(ThreadUtils.sameThread)        dispatcher.postLocalMessage(message, p)      } else {        val rpcMessage = RpcOutboxMessage(serialize(message),          onFailure,          (client, response) => onSuccess(deserialize[Any](client, response)))        postToOutbox(message.receiver, rpcMessage)        promise.future.onFailure {          case _: TimeoutException => rpcMessage.onTimeout()          case _ =>        }(ThreadUtils.sameThread)      }      val timeoutCancelable = timeoutScheduler.schedule(new Runnable {        override def run(): Unit = {          onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}"))        }      }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)      promise.future.onComplete { v =>        timeoutCancelable.cancel(true)      }(ThreadUtils.sameThread)    } catch {      case NonFatal(e) =>        onFailure(e)    }    promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)  }
  • Outbox

Outbox与Inbox结构上大体相似,但是消息发送方式不太一样。Outbox同样维持一个消息列表:LinkedList[OutboxMessage]。Inbox中消息post进来后不负责发送,由dispatcher中的线程池循环取消息发送;Outbox中send和ask把消息放入消息列表后,需要主动调用函数drainOutbox(),循环读取所有消息并发送。所以Outbox的send和ask是同步函数,send函数是NettyRpcEnv中给远程端发送消息的postToOutbox函数调用的,并且只有在函数postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit的参数receiver.address为空的情况下,不为空则不经过Outbox,直接发送。

在Outbox 的消息处理函数drainOutbox()中,如果初次给该远端发送消息,则需要调用NettyRpcEnv的线程池clientConnectionExecutor来建立连接。

package org.apache.spark.rpc.netty......private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {  outbox => // Give this an alias so we can use it more clearly in closures.  @GuardedBy("this")  private val messages = new java.util.LinkedList[OutboxMessage]  @GuardedBy("this")  private var client: TransportClient = null  /**   * connectFuture points to the connect task. If there is no connect task, connectFuture will be   * null.   */  @GuardedBy("this")  private var connectFuture: java.util.concurrent.Future[Unit] = null  @GuardedBy("this")  private var stopped = false  /**   * If there is any thread draining the message queue   */  @GuardedBy("this")  private var draining = false  /**   * Send a message. If there is no active connection, cache it and launch a new connection. If   * [[Outbox]] is stopped, the sender will be notified with a [[SparkException]].   */  def send(message: OutboxMessage): Unit = {    val dropped = synchronized {      if (stopped) {        true      } else {        messages.add(message)        false      }    }    if (dropped) {      message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))    } else {      drainOutbox()    }  }  /**   * Drain the message queue. If there is other draining thread, just exit. If the connection has   * not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the   * connection.   */  private def drainOutbox(): Unit = {    var message: OutboxMessage = null    synchronized {      if (stopped) {        return      }      if (connectFuture != null) {        // We are connecting to the remote address, so just exit        return      }      if (client == null) {        // There is no connect task but client is null, so we need to launch the connect task.        launchConnectTask()        return      }      if (draining) {        // There is some thread draining, so just exit        return      }      message = messages.poll()      if (message == null) {        return      }      draining = true    }    while (true) {      try {        val _client = synchronized { client }        if (_client != null) {          message.sendWith(_client)        } else {          assert(stopped == true)        }      } catch {        case NonFatal(e) =>          handleNetworkFailure(e)          return      }      synchronized {        if (stopped) {          return        }        message = messages.poll()        if (message == null) {          draining = false          return        }      }    }  }  private def launchConnectTask(): Unit = {    connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] {      override def call(): Unit = {        try {          val _client = nettyEnv.createClient(address)          outbox.synchronized {            client = _client            if (stopped) {              closeClient()            }          }        } catch {          case ie: InterruptedException =>            // exit            return          case NonFatal(e) =>            outbox.synchronized { connectFuture = null }            handleNetworkFailure(e)            return        }        outbox.synchronized { connectFuture = null }        // It's possible that no thread is draining now. If we don't drain here, we cannot send the        // messages until the next message arrives.        drainOutbox()      }    })  }  /**   * Stop [[Inbox]] and notify the waiting messages with the cause.   */  private def handleNetworkFailure(e: Throwable): Unit = {    synchronized {      assert(connectFuture == null)      if (stopped) {        return      }      stopped = true      closeClient()    }    // Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along    // with a new connection    nettyEnv.removeOutbox(address)    // Notify the connection failure for the remaining messages    //    // We always check `stopped` before updating messages, so here we can make sure no thread will    // update messages and it's safe to just drain the queue.    var message = messages.poll()    while (message != null) {      message.onFailure(e)      message = messages.poll()    }    assert(messages.isEmpty)  }  private def closeClient(): Unit = synchronized {    // Just set client to null. Don't close it in order to reuse the connection.    client = null  }  /**   * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a   * [[SparkException]].   */  def stop(): Unit = {    synchronized {      if (stopped) {        return      }      stopped = true      if (connectFuture != null) {        connectFuture.cancel(true)      }      closeClient()    }    // We always check `stopped` before updating messages, so here we can make sure no thread will    // update messages and it's safe to just drain the queue.    var message = messages.poll()    while (message != null) {      message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))      message = messages.poll()    }  }}

至此,RpcEnv的实现类NettyRpcEnv的主要结构大致了解了一遍,里面各种类及其职责简单了解了下,后续如果有时间,想画一个类的关系图,能更加直观清晰的知道各类的关系。NettyRpcEnv中涉及的Netty通讯框架的内容就更加深入了,有时间可以继续学习下。

0 0