【Spark Core】任务执行机制和Task源码浅析2

1. Executor的launchTasks函数


    case LaunchTask(data) =>      if (executor == null) {        logError("Received LaunchTask command but executor was null")        System.exit(1)      } else {        val ser = env.closureSerializer.newInstance()        val taskDesc = ser.deserialize[TaskDescription](data.value)        logInfo("Got assigned task " + taskDesc.taskId)        executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,          taskDesc.name, taskDesc.serializedTask)      }


  def launchTask(      context: ExecutorBackend,      taskId: Long,      attemptNumber: Int,      taskName: String,      serializedTask: ByteBuffer) {    val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,      serializedTask)    runningTasks.put(taskId, tr)    threadPool.execute(tr)  }


2. TaskRunner的run方法

run方法中val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)是真正执行task中的任务。


      try {        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)        updateDependencies(taskFiles, taskJars)        // 反序列化Task        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)        // If this task has been killed before we deserialized it, let's quit now. Otherwise,        // continue executing the task.        if (killed) {          // Throw an exception rather than returning, because returning within a try{} block          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure          // for the task.          throw new TaskKilledException        }        attemptedTask = Some(task)        logDebug("Task " + taskId + "'s epoch is " + task.epoch)        env.mapOutputTracker.updateEpoch(task.epoch)        // Run the actual task and measure its runtime.        // 运行Task, 具体可以去看ResultTask和ShuffleMapTask        taskStart = System.currentTimeMillis()        val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)        val taskFinish = System.currentTimeMillis()        // If the task has been killed, let's fail it.        if (task.killed) {          throw new TaskKilledException        }        // 对结果进行序列化        val resultSer = env.serializer.newInstance()        val beforeSerialization = System.currentTimeMillis()        val valueBytes = resultSer.serialize(value)        val afterSerialization = System.currentTimeMillis()        // 更新任务的相关监控信息,会反映到监控页面上的        for (m <- task.metrics) {          m.setExecutorDeserializeTime(taskStart - deserializeStartTime)          m.setExecutorRunTime(taskFinish - taskStart)          m.setJvmGCTime(gcTime - startGCTime)          m.setResultSerializationTime(afterSerialization - beforeSerialization)        }        val accumUpdates = Accumulators.values        // 对结果进行再包装,包装完再进行序列化        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)        val serializedDirectResult = ser.serialize(directResult)        val resultSize = serializedDirectResult.limit        // directSend = sending directly back to the driver        val serializedResult = {          if (maxResultSize > 0 && resultSize > maxResultSize) {            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +              s"dropping it.")            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))          } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {            // 如果中间结果的大小超过了spark.akka.frameSize(默认是10M)的大小,就要提升序列化级别了,超过内存的部分要保存到硬盘的            val blockId = TaskResultBlockId(taskId)            env.blockManager.putBytes(              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)            logInfo(              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))          } else {            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")            serializedDirectResult          }        }        // 将任务完成和taskresult,通过statusUpdate报告给driver        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)      } catch {        //异常处理代码,略去...      } finally {        // 清理为ResultTask注册的shuffle内存,最后把task从正在运行的列表当中删除        // Release memory used by this thread for shuffles        env.shuffleMemoryManager.releaseMemoryForThisThread()        // Release memory used by this thread for unrolling blocks        env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()        // Release memory used by this thread for accumulators        Accumulators.clear()        runningTasks.remove(taskId)      }    }

3. Task执行过程

TaskRunner.run –> Task.run –> Task.runTask –> RDD.iterator –> RDD.computeOrReadCheckpoint –> RDD.compute


  /**   * Called by Executor to run this task.   *   * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.   * @param attemptNumber how many times this task has been attempted (0 for the first attempt)   * @return the result of the task   */  final def run(taskAttemptId: Long, attemptNumber: Int): T = {    context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,      taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)    TaskContextHelper.setTaskContext(context)    context.taskMetrics.setHostname(Utils.localHostName())    taskThread = Thread.currentThread()    if (_killed) {      kill(interruptThread = false)    }    try {      runTask(context)    } finally {      context.markTaskCompleted()      TaskContextHelper.unset()    }  }



  override def runTask(context: TaskContext): MapStatus = {    // Deserialize the RDD using the broadcast variable.    val ser = SparkEnv.get.closureSerializer.newInstance()    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)    //此处的taskBinary即为在org.apache.spark.scheduler.DAGScheduler#submitMissingTasks序列化的task的广播变量取得的      metrics = Some(context.taskMetrics)    var writer: ShuffleWriter[Any, Any] = null    try {      val manager = SparkEnv.get.shuffleManager      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])      // 将rdd计算的结果写入memory或者disk        return writer.stop(success = true).get    } catch {      case e: Exception =>        try {          if (writer != null) {            writer.stop(success = false)          }        } catch {          case e: Exception =>            log.debug("Could not stop writer", e)        }        throw e    }  }


  override def runTask(context: TaskContext): U = {    // Deserialize the RDD and the func using the broadcast variables.    val ser = SparkEnv.get.closureSerializer.newInstance()    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)    metrics = Some(context.taskMetrics)    func(context, rdd.iterator(partition, context))  }

4. Task状态更新


  1. Task运行之前,告诉Driver当前Task的状态为TaskState.RUNNING。
  2. Task运行之后,告诉Driver当前Task的状态为TaskState.FINISHED,并返回计算结果。
  3. 如果Task运行过程中发生错误,告诉Driver当前Task的状态为TaskState.FAILED,并返回错误原因。
  4. 如果Task在中途被Kill掉了,告诉Driver当前Task的状态为TaskState.FAILED。

5. Task执行完毕



      case StatusUpdate(executorId, taskId, state, data) =>        //statusUpdate函数处理处理从taskset删除已完成的task等工作        scheduler.statusUpdate(taskId, state, data.value)        if (TaskState.isFinished(state)) {          executorDataMap.get(executorId) match {            case Some(executorInfo) =>              executorInfo.freeCores += scheduler.CPUS_PER_TASK              makeOffers(executorId)            case None =>              // Ignoring the update since we don't know about the executor.              logWarning(s"Ignored task status update ($taskId state $state) " +                "from unknown executor $sender with ID $executorId")          }        }


  1. TaskScheduler通过TaskId找到管理这个Task的TaskSetManager(负责管理一批Task的类),从TaskSetManager里面删掉这个Task,并把Task插入到TaskResultGetter(负责获取Task结果的类)的成功队列里;
  2. TaskResultGetter获取到结果之后,调用TaskScheduler的handleSuccessfulTask方法把结果返回;
  3. TaskScheduler调用TaskSetManager的handleSuccessfulTask方法,处理成功的Task;
  4. TaskSetManager调用DAGScheduler的taskEnded方法,告诉DAGScheduler这个Task运行结束了,如果这个时候Task全部成功了,就会结束TaskSetManager。


  1. 调用Stage的addOutputLoc方法,把结果添加到Stage的outputLocs列表里
  2. 如果该Stage没有等待的Task了,就标记该Stage为结束
  3. 把Stage的outputLocs注册到MapOutputTracker里面,留个下一个Stage用
  4. 如果Stage的outputLocs为空,表示它的计算失败,重新提交Stage
  5. 找出下一个在等待并且没有父亲的Stage提交

转载请注明作者Jason Ding及其出处

