Spark修炼之道(高级篇)——Spark源码阅读:第八节 Task执行

来源:互联网 发布:php文件管理系统代码 编辑:程序博客网 时间:2024/05/21 12:49

Task执行

在上一节中,我们提到在Driver端CoarseGrainedSchedulerBackend中的launchTasks方法向Worker节点中的Executor发送启动任务命令,该命令的接收者是CoarseGrainedExecutorBackend(Standalone模式),类定义源码如下:

private[spark] class CoarseGrainedExecutorBackend(    override val rpcEnv: RpcEnv,    driverUrl: String,    executorId: String,    hostPort: String,    cores: Int,    userClassPath: Seq[URL],    env: SparkEnv)  extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {

可以看到它继承ThreadSafeRpcEndpoint,它ThreadSafeRpcEndpoint中的receive方法进行了实现,具体源代码如下:

override def receive: PartialFunction[Any, Unit] = {    case RegisteredExecutor =>      logInfo("Successfully registered with driver")      val (hostname, _) = Utils.parseHostPort(hostPort)      executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)    case RegisterExecutorFailed(message) =>      logError("Slave registration failed: " + message)      System.exit(1)    //处理Driver端发送过来的LaunchTask命令    case LaunchTask(data) =>      if (executor == null) {        logError("Received LaunchTask command but executor was null")        System.exit(1)      } else {        //对任务进行反序列化        val taskDesc = ser.deserialize[TaskDescription](data.value)        logInfo("Got assigned task " + taskDesc.taskId)        //Executor启动任务的运行        executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,          taskDesc.name, taskDesc.serializedTask)      }    case KillTask(taskId, _, interruptThread) =>      if (executor == null) {        logError("Received KillTask command but executor was null")        System.exit(1)      } else {        executor.killTask(taskId, interruptThread)      }    case StopExecutor =>      logInfo("Driver commanded a shutdown")      executor.stop()      stop()      rpcEnv.shutdown()  }

从前面的代码可以看到,通过 executor.launchTask方法启动Worker节点上Task的运行,其源码如下:

//Executor类中的launchTask方法def launchTask(      context: ExecutorBackend,      taskId: Long,      attemptNumber: Int,      taskName: String,      serializedTask: ByteBuffer): Unit = {    //创建TaskRunner    val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,      serializedTask)    runningTasks.put(taskId, tr)    //线程池执行TaskRunner线程,该线程中有一个run方法,完成Task的执行    threadPool.execute(tr)  }

TaskRunner是一个线程,它是一个内部类,被定义在org.apache.spark.executor.Executor类当中,具体源码如下:

 class TaskRunner(      execBackend: ExecutorBackend,      val taskId: Long,      val attemptNumber: Int,      taskName: String,      serializedTask: ByteBuffer)    extends Runnable {    /** Whether this task has been killed. */    @volatile private var killed = false    /** How much the JVM process has spent in GC when the task starts to run. */    @volatile var startGCTime: Long = _    /**     * The task to run. This will be set in run() by deserializing the task binary coming     * from the driver. Once it is set, it will never be changed.     */    @volatile var task: Task[Any] = _    def kill(interruptThread: Boolean): Unit = {      logInfo(s"Executor is trying to kill $taskName (TID $taskId)")      killed = true      if (task != null) {        task.kill(interruptThread)      }    }    override def run(): Unit = {      val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)      val deserializeStartTime = System.currentTimeMillis()      Thread.currentThread.setContextClassLoader(replClassLoader)      val ser = env.closureSerializer.newInstance()      logInfo(s"Running $taskName (TID $taskId)")      //向Driver端发状态更新      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)      var taskStart: Long = 0      startGCTime = computeTotalGcTime()      try {        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)        updateDependencies(taskFiles, taskJars)        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)        task.setTaskMemoryManager(taskMemoryManager)        // If this task has been killed before we deserialized it, let's quit now. Otherwise,        // continue executing the task.        if (killed) {          // Throw an exception rather than returning, because returning within a try{} block          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure          // for the task.          throw new TaskKilledException        }        logDebug("Task " + taskId + "'s epoch is " + task.epoch)        env.mapOutputTracker.updateEpoch(task.epoch)        // Run the actual task and measure its runtime.        taskStart = System.currentTimeMillis()        var threwException = true        val (value, accumUpdates) = try {        //执行Task的run方法,不同的Task有不同的实现,例如ShuffleMapTask及ResultTask有各自的实现          val res = task.run(            taskAttemptId = taskId,            attemptNumber = attemptNumber,            metricsSystem = env.metricsSystem)          threwException = false          res        } finally {          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()          if (freedMemory > 0) {            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {              throw new SparkException(errMsg)            } else {              logError(errMsg)            }          }        }        val taskFinish = System.currentTimeMillis()        // If the task has been killed, let's fail it.        if (task.killed) {          throw new TaskKilledException        }        val resultSer = env.serializer.newInstance()        val beforeSerialization = System.currentTimeMillis()        val valueBytes = resultSer.serialize(value)        val afterSerialization = System.currentTimeMillis()        for (m <- task.metrics) {          // Deserialization happens in two parts: first, we deserialize a Task object, which          // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.          m.setExecutorDeserializeTime(            (taskStart - deserializeStartTime) + task.executorDeserializeTime)          // We need to subtract Task.run()'s deserialization time to avoid double-counting          m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)          m.setJvmGCTime(computeTotalGcTime() - startGCTime)          m.setResultSerializationTime(afterSerialization - beforeSerialization)          m.updateAccumulators()        }        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)        val serializedDirectResult = ser.serialize(directResult)        val resultSize = serializedDirectResult.limit        // directSend = sending directly back to the driver        val serializedResult: ByteBuffer = {          if (maxResultSize > 0 && resultSize > maxResultSize) {            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +              s"dropping it.")            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))          } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {            val blockId = TaskResultBlockId(taskId)            env.blockManager.putBytes(              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)            logInfo(              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))          } else {            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")            serializedDirectResult          }        }        //执行完成后,通知Driver端进行状态更新        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)      } catch {        case ffe: FetchFailedException =>          val reason = ffe.toTaskEndReason          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))        case _: TaskKilledException | _: InterruptedException if task.killed =>          logInfo(s"Executor killed $taskName (TID $taskId)")          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))        case cDE: CommitDeniedException =>          val reason = cDE.toTaskEndReason          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))        case t: Throwable =>          // Attempt to exit cleanly by informing the driver of our failure.          // If anything goes wrong (or this was a fatal exception), we will delegate to          // the default uncaught exception handler, which will terminate the Executor.          logError(s"Exception in $taskName (TID $taskId)", t)          val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>            task.metrics.map { m =>              m.setExecutorRunTime(System.currentTimeMillis() - taskStart)              m.setJvmGCTime(computeTotalGcTime() - startGCTime)              m.updateAccumulators()              m            }          }          val serializedTaskEndReason = {            try {              ser.serialize(new ExceptionFailure(t, metrics))            } catch {              case _: NotSerializableException =>                // t is not serializable so just send the stacktrace                ser.serialize(new ExceptionFailure(t, metrics, false))            }          }          //任务失败时,同样进行状态更新,方便后期任务重运行          execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)          // Don't forcibly exit unless the exception was inherently fatal, to avoid          // stopping other tasks unnecessarily.          if (Utils.isFatalError(t)) {            SparkUncaughtExceptionHandler.uncaughtException(t)          }      } finally {        //从运行任务列表中删除        runningTasks.remove(taskId)      }    }  }

Task run方法负责Task的执行,其源码如下:

 /**   * Called by [[Executor]] to run this task.   *   * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.   * @param attemptNumber how many times this task has been attempted (0 for the first attempt)   * @return the result of the task along with updates of Accumulators.   */  final def run(    taskAttemptId: Long,    attemptNumber: Int,    metricsSystem: MetricsSystem)  : (T, AccumulatorUpdates) = {    //任务运行环境信息    context = new TaskContextImpl(      stageId,      partitionId,      taskAttemptId,      attemptNumber,      taskMemoryManager,      metricsSystem,      internalAccumulators,      runningLocally = false)    TaskContext.setTaskContext(context)    context.taskMetrics.setHostname(Utils.localHostName())    context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)    taskThread = Thread.currentThread()    if (_killed) {      kill(interruptThread = false)    }    try {     //调用runTask方法执行,不同的任务其实现不同,例如ShuffleMapTask和ResultTask其runTask方法逻辑不同      (runTask(context), context.collectAccumulators())    } finally {      context.markTaskCompleted()      try {        Utils.tryLogNonFatalError {          // Release memory used by this thread for shuffles          SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()        }        Utils.tryLogNonFatalError {          // Release memory used by this thread for unrolling blocks          SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()        }      } finally {        TaskContext.unset()      }    }  }

以ResultTask为例,其runTask方法源码如下:

//ResultTask中的runTask方法  override def runTask(context: TaskContext): U = {    // Deserialize the RDD and the func using the broadcast variables.    val deserializeStartTime = System.currentTimeMillis()    val ser = SparkEnv.get.closureSerializer.newInstance()    //反序列化rdd及执行函数    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime    metrics = Some(context.taskMetrics)    //执行rdd.iterator方法,完成任务的计算    func(context, rdd.iterator(partition, context))  }

总结一下Task的执行过程:
1 调用Driver端org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend中的launchTasks
2 调用Worker端的org.apache.spark.executor.CoarseGrainedExecutorBackend.launchTask
3 执行org.apache.spark.executor.TaskRunner线程中的run方法
4 调用org.apache.spark.scheduler.Task.run方法
5 调用org.apache.spark.scheduler.ResultTask.runTask方法
6 调用org.apache.spark.rdd.RDD.iterator方法

1 0
原创粉丝点击