多核编程(二)-- java并行机制与fork/join框架

来源:互联网 发布:延边大米 知乎 编辑:程序博客网 时间:2024/05/21 09:51

并行思想

考虑如下问题:在一个数据量很大的一个数组中,如何快速有效的对所有数据进行累加?

最原始的想法:利用多线程思想,将数组等分成若干份(例如4份),对于每一份数据中创建一个多线程进行累加,最后将4份数据的和相加求得结果。

我们可以写出如下原始代码:

class SumThread extends java.lang.Thread {    int lo, hi;    int[] arr; // arguments    int ans = 0; // result        SumThread(int[] a, int l, int h) {         lo=l; hi=h; arr=a;    }        public void run() { //override must have this type        for(int i=lo; i < hi; i++)            ans += arr[i];    }}class MainThread{    public static void main(String[] args){        int[] arr = new int[100];                for(int i = 0; i<100;i++)            arr[i] = i;        System.out.println(sum(arr));    }    static int sum(int[] arr){// can be a static method        int len = arr.length;        int ans = 0;        SumThread[] ts = new SumThread[4];        for(int i=0; i < 4; i++){// do parallel computations            ts[i] = new SumThread(arr,i*len/4,(i+1)*len/4);            ts[i].start();          }        for(int i=0; i < 4; i++) { // combine results            try{            ts[i].join(); // wait for helper to finish!            ans += ts[i].ans;            }            catch(InterruptedException e)            {                System.out.println("11");            }        }        return ans;    }}
在fork/join模型中,我们不太关注到线程间的内存共享,但是在java这类编程语言中,确实存在着内存共享,例如,

lo, hi, arr在主线程中创建,可以在helper线程中读取;ans在helper线程中被创建,也可以在主线程中被读取。


对上述代码稍作改进,按每一个处理器运行一个线程,允许根据可用处理器个数设置线程个数:(以后代码中省略try-catch,但是在实际运行中必须加上,否则会报InterruptedException错误)

int sum(int[] arr, int numTs){    int ans = 0;    SumThread[] ts = new SumThread[numTs];    for(int i=0; i < numTs; i++){        ts[i] = new SumThread(arr,(i*arr.length)/numTs, ((i+1)*arr.length)/numTs);        ts[i].start();    }    for(int i=0; i < numTs; i++) {         ts[i].join();         ans += ts[i].ans;    }    return ans;}

如果可以分配有足够多的线程,是不是有更好的方法呢?考虑使用分治法,我们可以按二叉树的形式来分配线程

class SumThread extends java.lang.Thread {    int lo; int hi; int[] arr; // arguments    int ans = 0; // result    SumThread(int[] a, int l, int h) { … }        public void run(){ // override        if(hi – lo < SEQUENTIAL_CUTOFF)            for(int i=lo; i < hi; i++)            ans += arr[i];        else {            SumThread left = new SumThread(arr,lo,(hi+lo)/2);            SumThread right= new SumThread(arr,(hi+lo)/2,hi);            left.start();            right.start();            left.join(); // don’t move this up a line – why?            right.join();            ans = left.ans + right.ans;        }    }}int sum(int[] arr){     SumThread t = new SumThread(arr,0,arr.length);    t.run();    return t.ans;}
如果拥有足够多的处理器,那么时间复杂度将是O(logn),但考虑实际的处理器将不会有那么多,时间复杂度会是O(n/numProcessors  + log n)。

那么有没有更好的方法来利用有限的处理器呢?我们观察到,很多进程仅仅是用来分配两个子进程,之后在等待两个子进程完成后将两个结果相加,因此我们可以将该进程充当其中一个子进程继续向下执行,那么就会大大减少进程的等待时间。

// wasteful: don’tSumThread left  = …SumThread right = …left.start();right.start();left.join(); right.join();ans=left.ans+right.ans;
// better: doSumThread left  = …SumThread right = …left.start();right.run();left.join(); ans=left.ans+right.ans;

对比前两种的方法我们可以看到,之前一种方法需要15个处理器,而改进之后仅需要8个处理器,减少了近一半的处理器需求。


接下来,我们引入fork/join框架的内容对代码再一步优化,该框架正是针对这类分治问题而设计的。使用该框架注意一下原则:

不要继承自Thread继承自RecursiveTask<V>不要重载run方法重载compute方法不需要返回ans从compute中返回一个V不用调用start调用fork不要仅仅只是调用join使用join的返回值不用调用run来使用上面的优化调用compute来优化不要在最上层直接调用run创建一个pool然后使用invoke

运用fork/Join框架并优化后的最终代码如下所示:

class SumArray extends RecursiveTask<Integer> {    int lo; int hi; int[] arr; // arguments    SumArray(int[] a, int l, int h) { … }        protected Integer compute(){// return answer        if(hi – lo < SEQUENTIAL_CUTOFF) {            int ans = 0;            for(int i=lo; i < hi; i++)                ans += arr[i];            return ans;        } else {            SumArray left = new SumArray(arr,lo,(hi+lo)/2);            SumArray right= new SumArray(arr,(hi+lo)/2,hi);            left.fork();            int rightAns = right.compute();            int leftAns  = left.join();             return leftAns + rightAns;        }    }}static final ForkJoinPool fjPool = new ForkJoinPool();int sum(int[] arr){    return fjPool.invoke(new SumArray(arr,0,arr.length));}

0 0
原创粉丝点击