Java7 Fork-Join 框架:任务切分,并行处理

来源:互联网 发布:网络广告发布软件 编辑:程序博客网 时间:2024/05/21 22:55

概要

现代的计算机已经向多CPU方向发展,即使是普通的PC,甚至现在的智能手机、多核处理器已被广泛应用。在未来,处理器的核心数将会发展的越来越多。
虽然硬件上的多核CPU已经十分成熟,但是很多应用程序并未这种多核CPU做好准备,因此并不能很好地利用多核CPU的性能优势。
为了充分利用多CPU、多核CPU的性能优势,级软基软件系统应该可以充分“挖掘”每个CPU的计算能力,决不能让某个CPU处于“空闲”状态。为此,可以考虑把一个任务拆分成多个“小任务”,把多个"小任务"放到多个处理器核心上并行执行。当多个“小任务”执行完成之后,再将这些执行结果合并起来即可。


Java在JDK7之后加入了并行计算的框架Fork/Join,可以解决我们系统中大数据计算的性能问题。Fork/Join采用的是分治法,Fork是将一个大任务拆分成若干个子任务,子任务分别去计算,而Join是获取到子任务的计算结果,然后合并,这个是递归的过程。子任务被分配到不同的核上执行时,效率最高。伪代码如下:

[java] view plain copy
  1. Result solve(Problem problem) {  
  2.     if (problem is small)  
  3.         directly solve problem  
  4.     else {  
  5.         split problem into independent parts  
  6.         fork new subtasks to solve each part  
  7.         join all subtasks  
  8.         compose result from subresults  
  9.     }  
  10. }  

Fork/Join框架的核心类是ForkJoinPool,它能够接收一个ForkJoinTask,并得到计算结果。ForkJoinTask有两个子类,RecursiveTask(有返回值)和RecursiveAction(无返回结果),我们自己定义任务时,只需选择这两个类继承即可

示例代码

[java] view plain copy
  1. package forkJoin;  
  2.   
  3. import java.util.concurrent.RecursiveTask;  
  4.   
  5. public class SumTask extends RecursiveTask<Integer> {  
  6.     private static final int THRESHOLD = 20;  
  7.   
  8.     private int[] array;  
  9.     private int low;  
  10.     private int high;  
  11.   
  12.     public SumTask(int[] array, int low, int high) {  
  13.         this.array = array;  
  14.         this.low = low;  
  15.         this.high = high;  
  16.     }  
  17.   
  18.     @Override  
  19.     protected Integer compute() {  
  20.         int sum = 0;  
  21.         if (high - low + 1 <= THRESHOLD) {  
  22.             System.out.println(low + " - " + high + "  计算");  
  23. //            测试并行的个数,统计输出过程中的文字,看看有多少线程停止在这里就知道有多少并行计算  
  24. //            参考 ForkJoinPool 初始化设置的并行数  
  25. //            try {  
  26. //                Thread.sleep(11111111);  
  27. //            } catch (InterruptedException e) {  
  28. //                e.printStackTrace();  
  29. //            }  
  30.             // 小于阈值则直接计算  
  31.             for (int i = low; i <= high; i++) {  
  32.                 sum += array[i];  
  33.             }  
  34.         } else {  
  35.             System.out.println(low + " - " + high + "  切分");  
  36.             // 1. 一个大任务分割成两个子任务  
  37.             int mid = (low + high) / 2;  
  38.             SumTask left = new SumTask(array, low, mid);  
  39.             SumTask right = new SumTask(array, mid + 1, high);  
  40.   
  41.             // 2. 分别并行计算  
  42.             invokeAll(left, right);  
  43.   
  44.             // 3. 合并结果  
  45.             sum = left.join() + right.join();  
  46.   
  47.             // 另一种方式  
  48.             try {  
  49.                 sum = left.get() + right.get();  
  50.             } catch (Throwable e) {  
  51.                 System.out.println(e.getMessage());  
  52.             }  
  53.         }  
  54.         return sum;  
  55.     }  
  56. }  
[java] view plain copy
  1. package forkJoin;  
  2.   
  3. import java.util.Random;  
  4. import java.util.concurrent.ExecutionException;  
  5. import java.util.concurrent.ForkJoinPool;  
  6. import java.util.concurrent.RecursiveAction;  
  7.   
  8. public class Main {  
  9.   
  10.     /*static class MyTaskTest extends RecursiveTask<Integer> { 
  11.         final int n; 
  12.  
  13.         MyTaskTest(int n) { 
  14.             this.n = n; 
  15.         } 
  16.  
  17.         @Override 
  18.         protected Integer compute() { 
  19.             if (n <= 1) return n; 
  20.             MyTaskTest f1 = new MyTaskTest(n - 1); 
  21.             f1.fork(); 
  22.             MyTaskTest f2 = new MyTaskTest(n - 2); 
  23.             return f2.compute() + f1.join(); 
  24.         } 
  25.     }*/  
  26.   
  27.     /*class SortTask extends RecursiveAction { 
  28.         static final int THRESHOLD = 2; 
  29.         final long[] array; 
  30.         final int lo; 
  31.         final int hi; 
  32.  
  33.         SortTask(long[] array, int lo, int hi) { 
  34.             this.array = array; 
  35.             this.lo = lo; 
  36.             this.hi = hi; 
  37.         } 
  38.  
  39.         protected void compute() { 
  40.             if (hi - lo < THRESHOLD) 
  41.                 sequentiallySort(array, lo, hi); 
  42.             else { 
  43.                 int mid = (lo + hi) >>> 1; 
  44.                 invokeAll(new SortTask(array, lo, mid), 
  45.                         new SortTask(array, mid, hi)); 
  46.                 merge(array, lo, hi); 
  47.             } 
  48.         } 
  49.     }*/  
  50.   
  51.     private static int[] genArray() {  
  52.         int[] array = new int[100];  
  53.         for (int i = 0; i < array.length; i++) {  
  54.             array[i] = new Random().nextInt(500);  
  55.         }  
  56.         return array;  
  57.     }  
  58.   
  59.     public static void main(String[] args) throws ExecutionException, InterruptedException {  
  60.         /** 
  61.          * 下面以一个有返回值的大任务为例,介绍一下RecursiveTask的用法。 
  62.          大任务是:计算随机的100个数字的和。 
  63.          小任务是:每次只能20个数值的和。 
  64.          */  
  65.         int[] array = genArray();  
  66.   
  67. //        System.out.println(Arrays.toString(array));  
  68.         int total = 0;  
  69.         for (int i = 0; i < array.length; i++) {  
  70.             total += array[i];  
  71.         }  
  72.         System.out.println("目标和:" + total);  
  73.   
  74.         // 1. 创建任务  
  75.         SumTask sumTask = new SumTask(array, 0, array.length - 1);  
  76.   
  77.         // 2. 创建线程池  
  78.         // 设置并行计算的个数  
  79.         int processors = Runtime.getRuntime().availableProcessors();  
  80.         ForkJoinPool forkJoinPool = new ForkJoinPool(processors * 2);  
  81.   
  82.         // 3. 提交任务到线程池  
  83.         forkJoinPool.submit(sumTask);  
  84. //        forkJoinPool.shutdown();  
  85.   
  86.         long begin = System.currentTimeMillis();  
  87.         // 4. 获取结果  
  88.         Integer result = sumTask.get();// wait for  
  89.         long end = System.currentTimeMillis();  
  90.         System.out.println(String.format("结果 %s ,耗时 %sms", result, end - begin));  
  91.   
  92.         if (result == total) {  
  93.             System.out.println("测试成功");  
  94.         } else {  
  95.             System.out.println("fork join 使用失败!!!!");  
  96.         }  
  97.     }  
  98. }  

上面的代码是一个100个整数累加的任务,切分到小于20个数的时候直接进行累加,不再切分。
我们通过调整阈值(THRESHOLD),可以发现耗时是不一样的。实际应用中,如果需要分割的任务大小是固定的,可以经过测试,得到最佳阈值;如果大小不是固定的,就需要设计一个可伸缩的算法,来动态计算出阈值。如果子任务很多,效率并不一定会高。 
PS:类似的这种“分而治之”的需求场景,往往带有递归性,实际中,我们可以考虑任务是否具有“递归性”来决定是否使用“Fork-Join”框架。
原创粉丝点击