Driver端如何正确取消Spark中的job

来源:互联网 发布:二战欧洲知乎 编辑:程序博客网 时间:2024/05/20 06:27

1.      SparkContext提供了一个取消job的api

class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationClient {/** Cancel a given job if it's scheduled or running */private[spark] def cancelJob(jobId: Int) {  dagScheduler.cancelJob(jobId)}}

2.      那么如何获取jobId呢?

Spark提供了一个叫SparkListener的对象,它提供了对spark事件的监听功能

trait SparkListener {  /**   * Called when a job starts   */  def onJobStart(jobStart: SparkListenerJobStart) { }  /**   * Called when a job ends   */  def onJobEnd(jobEnd: SparkListenerJobEnd) { }}

因此需要自定义一个类,继承自SparkListener,即:

public class DHSparkListener implements SparkListener {private static Logger logger = Logger.getLogger(DHSparkListener.class);//存储了提交job的线程局部变量和job的映射关系    private static ConcurrentHashMap<String, Integer> jobInfoMap;    public DHSparkListener() {        jobInfoMap = new ConcurrentHashMap<String, Integer>();    }    @Override    public void onJobEnd(SparkListenerJobEnd jobEnd) {        logger.info("DHSparkListener Job End:" + jobEnd.jobResult().getClass() + ",Id:" + jobEnd.jobId());        for (String key : jobInfoMap.keySet()) {            if (jobInfoMap.get(key) == jobEnd.jobId()) {                jobInfoMap.remove(key);                logger.info(key+" request has been returned. because "+jobEnd.jobResult().getClass());            }        }    }    @Override    public void onJobStart(SparkListenerJobStart jobStart) {        logger.info("DHSparkListener Job Start: JobId->" + jobStart.jobId());//根据线程变量属性找到该job是哪个线程提交的        logger.info("DHSparkListener Job Start: Thread->" + jobStart.properties().getProperty("thread", "default"));        jobInfoMap.put(jobStart.properties().getProperty("thread", "default"), jobStart.jobId());    }……}

那么用户如何知道该job是哪个线程提交的呢?需要在提交job的时候设置线程局部变量属性,即

SparkConf conf = new SparkConf().setAppName("SparkListenerTest application in Java");        String sparkMaster = Configure.instance.get("SparkMaster");        String sparkExecutorMemory = "16g";        String sparkCoresMax = "4";        String sparkJarAddress = "/tmp/cuckoo-core-1.0-SNAPSHOT-allinone.jar";        conf.setMaster(sparkMaster);        conf.set("spark.executor.memory", sparkExecutorMemory);        conf.set("spark.cores.max", sparkCoresMax);        JavaSparkContext jsc = new JavaSparkContext(conf);        jsc.addJar(sparkJarAddress);        DHSparkListener dHSparkListener = new DHSparkListener();        jsc.sc().addSparkListener(dHSparkListener);        List<Integer> listData = new ArrayList<Integer>();        listData = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9);        JavaRDD<Integer> rdd1 = jsc.parallelize(listData, 1);JavaRDD<Integer> rdd2 = rdd1.map(new Function<Integer, Integer>() {            public Integer call(Integer v1) throws Exception {              //do something then return            }        });<pre name="code" class="plain">       //在触发action提交job之前设置提交线程的局部属性,供SparkListener获取       jsc.setLocalProperty("thread", "client");       rdd2.count();

这样在jobInfoMap中记录了job和job提交者的映射关系,当发现某个job迟迟没有结束的时候,可以调用SparkContext的cancelJob取消,但是仅仅到这里就够了吗?接着往下看,excutor取消job最终调用的是:

def kill(interruptThread: Boolean) {  _killed = true  if (context != null) {    context.markInterrupted()  }  if (interruptThread && taskThread != null) {    taskThread.interrupt()  }}

最终调用到Thread.interrupt函数,给启动task的线程设置interrupt标记位,因此在长时间允许的task中,需要针对Thread的interrupt标记位进行判断,当被置位的时候,需要退出,并且做一些清理,即存在类似的代码段:

if(Thread.interrupted()){    //……线程被中断,清理资源}或者调用sleep,wait函数时会抛出InterruptedException异常,需要进行捕获,然后做对应的处理


3.      最后一步,配置job kill的动作

除了以上操作之外,还需要再配置针对每个job调用kill的动作,即spark.job.interruptOnCancel属性为true 

  //在触发action提交job之前设置提交线程的局部属性,供SparkListener获取       jsc.setLocalProperty("thread", "client");   //配置该job接受到kill之后的动作,即task线程收到interrupt信号   jsc.setLocalProperty("spark.job.interruptOnCancel", "true");       rdd2.count();
0 0