Spark在资源管理和调度方式上采用了类似于Hadoop YARN的方式,最上层是资源调度器,它负责分配资源和调度注册到Spark中的所有应用,Spark选用Mesos或是YARN等作为其资源调度框架。在每一个应用内部,Spark又实现了任务调度器,负责任务的调度和协调,类似于MapReduce。本质上,外层的资源调度和内层的任务调度相互独立,各司其职。本文对于Spark的源码分析主要集中在内层的任务调度器上,分析Spark任务调度器的实现。



  1. TaskSchedulerListenerTaskSchedulerListener部分的主要功能是监听用户提交的job,将job分解为不同的类型的stage以及相应的task,并向TaskScheduler提交task。
  2. TaskSchedulerTaskScheduler接收用户提交的task并执行。而TaskScheduler根据部署的不同又分为三个子模块:
    • ClusterScheduler
    • LocalScheduler
    • MesosScheduler



DAGScheduler class chart

  • 用户所提交的job在得到DAGScheduler的调度后,会被包装成ActiveJob,同时会启动JobWaiter阻塞监听job的完成状况。
  • 于此同时依据job中RDD的dependency和dependency属性(NarrowDependencyShufflerDependecy),DAGScheduler会根据依赖关系的先后产生出不同的stage DAG(result stage, shuffle map stage)。
  • 在每一个stage内部,根据stage产生出相应的task,包括ResultTask或是ShuffleMapTask,这些task会根据RDD中partition的数量和分布,产生出一组相应的task,并将其包装为TaskSet提交到TaskScheduler上去。


在Spark中,每一个RDD是对于数据集在某一状态下的表现形式,而这个状态有可能是从前一状态转换而来的,因此换句话说这一个RDD有可能与之前的RDD(s)有依赖关系。根据依赖关系的不同,可以将RDD分成两种不同的类型:Narrow DependencyWide Dependency

  • Narrow Dependency指的是 child RDD只依赖于parent RDD(s)固定数量的partition。
  • Wide Dependency指的是child RDD的每一个partition都依赖于parent RDD(s)所有partition。


RDD dependecies

根据RDD依赖关系的不同,Spark也将每一个job分为不同的stage,而stage之间的依赖关系则形成了DAG。对于Narrow Dependency,Spark会尽量多地将RDD转换放在同一个stage中;而对于Wide Dependency,由于Wide Dependency通常意味着shuffle操作,因此Spark会将此stage定义为ShuffleMapStage,以便于向MapOutputTracker注册shuffle操作。对于stage的划分可参看下图,Spark通常将shuffle操作定义为stage的边界。

different stage boundary



  1. private var taskScheduler: TaskScheduler = {
  2. //...
  3. }
  4. taskScheduler.start()
  5. private var dagScheduler = new DAGScheduler(taskScheduler)
  6. dagScheduler.start()

DAGScheduler的启动会在内部创建daemon线程,daemon线程调用run()从block queue中取出event进行处理。

  1. private def run() {
  2. SparkEnv.set(env)
  3. while (true) {
  4. val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
  5. if (event != null) {
  6. logDebug("Got event of type " + event.getClass.getName)
  7. }
  8. if (event != null) {
  9. if (processEvent(event)) {
  10. return
  11. }
  12. }
  13. val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
  14. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
  15. resubmitFailedStages()
  16. } else {
  17. submitWaitingStages()
  18. }
  19. }
  20. }



  • JobSubmitted
  • CompletionEvent
  • ExecutorLost
  • TaskFailed
  • StopDAGScheduler


本质上DAGScheduler是一个生产者-消费者模型,用户和TaskSchduler产生event将其放入block queue,daemon线程消费event并处理相应事件。



  1. 没有shuffle和reduce的job

    1. val textFile = sc.textFile("")
    2. textFile.filter(line => line.contains("Spark")).count()
  2. 有shuffle和reduce的job

    1. val textFile = sc.textFile("")
    2. textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b)


  1. def runJob[T, U: ClassManifest](
  2. finalRdd: RDD[T],
  3. func: (TaskContext, Iterator[T]) => U,
  4. partitions: Seq[Int],
  5. callSite: String,
  6. allowLocal: Boolean,
  7. resultHandler: (Int, U) => Unit)
  8. {
  9. if (partitions.size == 0) {
  10. return
  11. }
  12. val (toSubmit, waiter) = prepareJob(
  13. finalRdd, func, partitions, callSite, allowLocal, resultHandler)
  14. eventQueue.put(toSubmit)
  15. waiter.awaitResult() match {
  16. case JobSucceeded => {}
  17. case JobFailed(exception: Exception) =>
  18. logInfo("Failed to run " + callSite)
  19. throw exception
  20. }
  21. }



  1. case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
  2. val runId = nextRunId.getAndIncrement()
  3. val finalStage = newStage(finalRDD, None, runId)
  4. val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
  5. clearCacheLocs()
  6. if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
  7. runLocally(job)
  8. } else {
  9. activeJobs += job
  10. resultStageToJob(finalStage) = job
  11. submitStage(finalStage)
  12. }

首先,对于任何的job都会产生出一个finalStage来产生和提交task。其次对于某些简单的job,它没有依赖关系,并且只有一个partition,这样的job会使用local thread处理而并非提交到TaskScheduler上处理。

接下来产生finalStage后,需要调用submitStage(),它根据stage之间的依赖关系得出stage DAG,并以依赖关系进行处理:

  1. private def submitStage(stage: Stage) {
  2. if (!waiting(stage) && !running(stage) && !failed(stage)) {
  3. val missing = getMissingParentStages(stage).sortBy(
  4. if (missing == Nil) {
  5. submitMissingTasks(stage)
  6. running += stage
  7. } else {
  8. for (parent <- missing) {
  9. submitStage(parent)
  10. }
  11. waiting += stage
  12. }
  13. }
  14. }

对于新提交的job,finalStage的parent stage还未获得,因此submitStage会调用getMissingParentStages()来获得依赖关系:

  1. private def getMissingParentStages(stage: Stage): List[Stage] = {
  2. val missing = new HashSet[Stage]
  3. val visited = new HashSet[RDD[_]]
  4. def visit(rdd: RDD[_]) {
  5. if (!visited(rdd)) {
  6. visited += rdd
  7. if (getCacheLocs(rdd).contains(Nil)) {
  8. for (dep <- rdd.dependencies) {
  9. dep match {
  10. case shufDep: ShuffleDependency[_,_] =>
  11. val mapStage = getShuffleMapStage(shufDep, stage.priority)
  12. if (!mapStage.isAvailable) {
  13. missing += mapStage
  14. }
  15. case narrowDep: NarrowDependency[_] =>
  16. visit(narrowDep.rdd)
  17. }
  18. }
  19. }
  20. }
  21. }
  22. visit(stage.rdd)
  23. missing.toList
  24. }

这里parent stage是通过RDD的依赖关系递归遍历获得。对于Wide Dependecy也就是Shuffle Dependecy,Spark会产生新的mapStage作为finalStage的parent,而对于Narrow Dependecy Spark则不会产生新的stage。这里对stage的划分是按照上面提到的作为划分依据的,因此对于本段开头提到的两种job,第一种job只会产生一个finalStage,而第二种job会产生finalStagemapStage

当stage DAG产生以后,针对每个stage需要产生task去执行,故在这会调用submitMissingTasks()

  1. private def submitMissingTasks(stage: Stage) {
  2. val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
  3. myPending.clear()
  4. var tasks = ArrayBuffer[Task[_]]()
  5. if (stage.isShuffleMap) {
  6. for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
  7. val locs = getPreferredLocs(stage.rdd, p)
  8. tasks += new ShuffleMapTask(, stage.rdd, stage.shuffleDep.get, p, locs)
  9. }
  10. } else {
  11. val job = resultStageToJob(stage)
  12. for (id <- 0 until job.numPartitions if (!job.finished(id))) {
  13. val partition = job.partitions(id)
  14. val locs = getPreferredLocs(stage.rdd, partition)
  15. tasks += new ResultTask(, stage.rdd, job.func, partition, locs, id)
  16. }
  17. }
  18. if (tasks.size > 0) {
  19. myPending ++= tasks
  20. taskSched.submitTasks(
  21. new TaskSet(tasks.toArray,, stage.newAttemptId(), stage.priority))
  22. if (!stage.submissionTime.isDefined) {
  23. stage.submissionTime = Some(System.currentTimeMillis())
  24. }
  25. } else {
  26. running -= stage
  27. }
  28. }



  1. private def handleTaskCompletion(event: CompletionEvent) {
  2. val task = event.task
  3. val stage = idToStage(task.stageId)
  4. def markStageAsFinished(stage: Stage) = {
  5. val serviceTime = stage.submissionTime match {
  6. case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
  7. case _ => "Unkown"
  8. }
  9. logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
  10. running -= stage
  11. }
  12. event.reason match {
  13. case Success =>
  14. ...
  15. task match {
  16. case rt: ResultTask[_, _] =>
  17. ...
  18. case smt: ShuffleMapTask =>
  19. ...
  20. }
  21. case Resubmitted =>
  22. ...
  23. case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
  24. ...
  25. case other =>
  26. abortStage(idToStage(task.stageId), task + " failed: " + other)
  27. }
  28. }




  • ResultTask

    1. override def run(attemptId: Long): U = {
    2. val context = new TaskContext(stageId, partition, attemptId)
    3. try {
    4. func(context, rdd.iterator(split, context))
    5. } finally {
    6. context.executeOnCompleteCallbacks()
    7. }
    8. }
  • ShuffleMapTask

    1. override def run(attemptId: Long): MapStatus = {
    2. val numOutputSplits = dep.partitioner.numPartitions
    3. val taskContext = new TaskContext(stageId, partition, attemptId)
    4. try {
    5. val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
    6. for (elem <- rdd.iterator(split, taskContext)) {
    7. val pair = elem.asInstanceOf[(Any, Any)]
    8. val bucketId = dep.partitioner.getPartition(pair._1)
    9. buckets(bucketId) += pair
    10. }
    11. val compressedSizes = new Array[Byte](numOutputSplits)
    12. val blockManager = SparkEnv.get.blockManager
    13. for (i <- 0 until numOutputSplits) {
    14. val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
    15. val iter: Iterator[(Any, Any)] = buckets(i).iterator
    16. val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
    17. compressedSizes(i) = MapOutputTracker.compressSize(size)
    18. }
    19. return new MapStatus(blockManager.blockManagerId, compressedSizes)
    20. } finally {
    21. taskContext.executeOnCompleteCallbacks()
    22. }
    23. }



  1. final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
  2. if (storageLevel != StorageLevel.NONE) {
  3. SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
  4. } else {
  5. computeOrReadCheckpoint(split, context)
  6. }
  7. }
  8. private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = {
  9. if (isCheckpointed) {
  10. firstParent[T].iterator(split, context)
  11. } else {
  12. compute(split, context)
  13. }
  14. }







ClusterScheduler的启动会伴随SparkDeploySchedulerBackend的启动,而backend会将自己分为两个角色:首先是driver,driver是一个local运行的actor,负责与remote的executor进行通行,提交任务,控制executor;其次是StandaloneExecutorBackend,Spark会在每一个slave node上启动一个StandaloneExecutorBackend进程,负责执行任务,返回执行结果。



  1. master match {
  2. ...
  3. case SPARK_REGEX(sparkUrl) =>
  4. val scheduler = new ClusterScheduler(this)
  5. val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
  6. scheduler.initialize(backend)
  7. scheduler
  8. case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
  9. ...
  10. case _ =>
  11. ...
  12. }
  13. }
  14. taskScheduler.start()

ClusterScheduler的启动会启动SparkDeploySchedulerBackend,同时启动daemon进程来检查speculative task:

  1. override def start() {
  2. backend.start()
  3. if (System.getProperty("spark.speculation", "false") == "true") {
  4. new Thread("ClusterScheduler speculation check") {
  5. setDaemon(true)
  6. override def run() {
  7. while (true) {
  8. try {
  10. } catch {
  11. case e: InterruptedException => {}
  12. }
  13. checkSpeculatableTasks()
  14. }
  15. }
  16. }.start()
  17. }
  18. }


  1. override def start() {
  2. super.start()
  3. val driverUrl = "akka://spark@%s:%s/user/%s".format(
  4. System.getProperty(""), System.getProperty("spark.driver.port"),
  5. StandaloneSchedulerBackend.ACTOR_NAME)
  6. val args = Seq(driverUrl, "", "", "")
  7. val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
  8. val sparkHome = sc.getSparkHome().getOrElse(
  9. throw new IllegalArgumentException("must supply spark home for spark standalone"))
  10. val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome)
  11. client = new Client(sc.env.actorSystem, master, appDesc, this)
  12. client.start()
  13. }


  1. override def start() {
  2. val properties = new ArrayBuffer[(String, String)]
  3. val iterator = System.getProperties.entrySet.iterator
  4. while (iterator.hasNext) {
  5. val entry =
  6. val (key, value) = (entry.getKey.toString, entry.getValue.toString)
  7. if (key.startsWith("spark.")) {
  8. properties += ((key, value))
  9. }
  10. }
  11. driverActor = actorSystem.actorOf(
  12. Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
  13. }


至此ClusterScheduler的启动,local driver的创建,remote executor环境的启动所有过程都已结束,ClusterScheduler等待DAGScheduler提交任务。



  1. override def submitTasks(taskSet: TaskSet) {
  2. val tasks = taskSet.tasks
  3. logInfo("Adding task set " + + " with " + tasks.length + " tasks")
  4. this.synchronized {
  5. val manager = new TaskSetManager(this, taskSet)
  6. activeTaskSets( = manager
  7. activeTaskSetsQueue += manager
  8. taskSetTaskIds( = new HashSet[Long]()
  9. if (hasReceivedTask == false) {
  10. starvationTimer.scheduleAtFixedRate(new TimerTask() {
  11. override def run() {
  12. if (!hasLaunchedTask) {
  13. logWarning("Initial job has not accepted any resources; " +
  14. "check your cluster UI to ensure that workers are registered")
  15. } else {
  16. this.cancel()
  17. }
  18. }
  20. }
  21. hasReceivedTask = true;
  22. }
  23. backend.reviveOffers()
  24. }


  1. // Make fake resource offers on just one executor
  2. def makeOffers(executorId: String) {
  3. launchTasks(scheduler.resourceOffers(
  4. Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
  5. }
  6. // Launch tasks returned by a set of resource offers
  7. def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
  8. for (task <- tasks.flatten) {
  9. freeCores(task.executorId) -= 1
  10. executorActor(task.executorId) ! LaunchTask(task)
  11. }
  12. }



  1. override def receive = {
  2. case RegisteredExecutor(sparkProperties) =>
  3. ...
  4. case RegisterExecutorFailed(message) =>
  5. ...
  6. case LaunchTask(taskDesc) =>
  7. logInfo("Got assigned task " + taskDesc.taskId)
  8. executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
  9. case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
  10. ...
  11. }
  12. def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
  13. threadPool.execute(new TaskRunner(context, taskId, serializedTask))
  14. }


  1. class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
  2. extends Runnable {
  3. override def run() {
  4. SparkEnv.set(env)
  5. Thread.currentThread.setContextClassLoader(urlClassLoader)
  6. val ser = SparkEnv.get.closureSerializer.newInstance()
  7. logInfo("Running task ID " + taskId)
  8. context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
  9. try {
  10. SparkEnv.set(env)
  11. Accumulators.clear()
  12. val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
  13. updateDependencies(taskFiles, taskJars)
  14. val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
  15. logInfo("Its generation is " + task.generation)
  16. env.mapOutputTracker.updateGeneration(task.generation)
  17. val value =
  18. val accumUpdates = Accumulators.values
  19. val result = new TaskResult(value, accumUpdates)
  20. val serializedResult = ser.serialize(result)
  21. logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
  22. context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
  23. logInfo("Finished task ID " + taskId)
  24. } catch {
  25. case ffe: FetchFailedException => {
  26. val reason = ffe.toTaskEndReason
  27. context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
  28. }
  29. case t: Throwable => {
  30. val reason = ExceptionFailure(t)
  31. context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
  32. // TODO: Should we exit the whole executor here? On the one hand, the failed task may
  33. // have left some weird state around depending on when the exception was thrown, but on
  34. // the other hand, maybe we could detect that when future tasks fail and exit then.
  35. logError("Exception in task ID " + taskId, t)
  36. //System.exit(1)
  37. }
  38. }
  39. }
  40. }




至此对Spark的Scheduler模块的主线做了一个顺藤摸瓜式的介绍,Scheduler模块作为Spark最核心的模块之一,充分体现了Spark与MapReduce的不同之处,体现了Spark DAG思想的精巧和设计的优雅。


