Java实现临界区:经典并发控制回顾

来源:互联网 发布:网络表情用英文怎么说 编辑:程序博客网 时间:2024/06/07 22:46

只要有一定的内存order保证,不通过比较并交换(CAS)那些需要硬件支持的原子操作,能不能实现一个互斥的临界区?答案是:能。

计算机先驱 Edsger Wybe Dijkstra,50多年前的这篇经典paper中就提出了解决方案。并且自这以后开启了如何通过一般编程语言实现并发控制的 研究。


这里的假设我们有N个线程,代表序号分别为1-N,一个公共变量k用于辅助指示当前占有临界区的线程。临界区是critical section,并且内存模型是先执行的操作对后面可见,对同一个内存位置的访问是一个接着另一个。

初始数组b[N],c[N]完全都为true。k的初始值任意(1-N)。这里的i变量代表当前的执行逻辑单元(线程)。

对于每个线程i,b[i]和c[i]都代表了线程i的参与竞争临界区的意愿,b[i]==false代表线程i将要参与临界区的争夺,线程c[i]==false代表线程i正在争竞临界区。线程退出临界区时,会而将b[i]、c[i]都置为true。从而其他线程能够通过查看当前的b[k]和c[k]来判断线程是否仍然占据临界区,这里的判断是一个大概的判断,由于各个线程执行顺序的不确定。

存在多个线程查看b[k],从而将k设置为自身的id,从而都进入了临界区前的位置,但即使这样,由于进临界区前先要查看其他线程的c[j]值,所以这里至多只有一个线程进入临界区,其他线程都退回到Li1的逻辑。存在这种情况,这里一个线程都无法获取临界区,从而全部回到Li1,下一次继续竞争。

注意:paper中的Li2,c[i] := true这一句会导致许多重复的无意义操作(因为c[i]本来就是true),这里针对的情况仅仅是从Li4里面goto Li1的时候,所以我们将c[i]:=true放到goto Li1之前就能保持程序语义,并且减少了无用功。

我们用JAVA来实现一遍这个方案试试,并且用10个线程,每个进入临界区1千万次,每次+1来验证它,可执行代码如下:

package com.psly.testatomic;import sun.misc.Unsafe;public class TestVolatile {//用于内存保证:putXXVolatile/getXXVolatileprivate static final Unsafe _unsafe = UtilUnsafe.getUnsafe();private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);//N:线程数,TIMES每个线程需要进入临界区的次数。private final static int N = 10;private final static int TIMES = 10000000;private final static int[] B = new int[N+1];private final static int[] C = new int[N+1];//每个线程进入临界区++count,最终count == N * TIMESprivate static long count;//countObj:获取count字段所属于的对象(其实就是地址),private final static Object countObj;//countOffset:获取count字段处于所在对象地址的偏移量private final static long countOffset;//k与上面的count字段类似private static int k = 1;private final static Object kObj; private final static long kOffset;static{for(int i = 1; i <= N; ++i){B[i] = 1;C[i] = 1;}try {countObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));kObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("k"));kOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("k"));} catch (Exception e) {throw new Error(e);} }final static void dijkstrasConcurMethod(int pM){    int times = TIMES;        int i = pM;    L0: for(;;){    B[i] = 0;    L1: for(;;){    if( k != i ) {    //C[i] = 1;    if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)    _unsafe.putIntVolatile(kObj, kOffset, i);//k = i;    continue L1;    } else{           _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;    //这里必定会看到更新的C[i],从而根本上保证了互斥,临界区最多一个线程。    for(int j = 1; j <= N; ++j )     if(j != i &&  _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){    //将C[i]的值更新回去,写这里效率更高    _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);    continue L1;    }    }    break L1;        }        //临界区开始        long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);        //临界区结束                _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);        B[i]=1;        if( --times != 0){        continue L0; //goto L0;        }        return;        }}public static void main(String[] args) throws InterruptedException{//开始时间long start = System.currentTimeMillis();//打印累加器初始值    System.out.println( count + " initial\n");    Thread handle[] = new Thread[N+1];        //创建线程    for (int i = 1; i <= N; ++i){    int j = i;    handle[i] = new Thread(new Runnable(){    @Override    public void run(){    dijkstrasConcurMethod(j);    }    });    }    //线程开始执行    for (int i = 1; i <= N; ++i)        handle[i].start();    //主线程等待子线程结束    for (int i = 1; i <= N; ++i)        handle[i].join();    //打印累加值,== N * TIMES    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));  //打印程序执行时间    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); }}

执行一遍,输出为:

0 initial10000000012.936 seconds

10个线程,每个进入临界区1千万次,总共累加为1亿。费时12.936秒。所以这个示例,至少看起来是正确的

我们接着,

重点关注dijkstrasConcurMethod这个方法:

final static void dijkstrasConcurMethod(int pM){    int times = TIMES;        int i = pM;    L0: for(;;){    B[i] = 0;    L1: for(;;){    if( k != i ) {    //C[i] = 1;    if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)    _unsafe.putIntVolatile(kObj, kOffset, i);//k = i;    continue L1;    } else{           _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;    //这里必定会看到更新的C[i],从而根本上保证了互斥,临界区最多一个线程。    for(int j = 1; j <= N; ++j )     if(j != i &&  _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){    //将C[i]的值更新回去,写这里效率更高    _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);    continue L1;    }    }    break L1;        }        //临界区开始        long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);        //临界区结束                _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);        B[i]=1;        if( --times != 0){        continue L0; //goto L0;        }        return;        }}

我们将paper中的ture/false用1/0来代替。由于JAVA中没有goto语句,所以我们有了一个带表情的循环for(;;)来实现一样的功能。这里的 pM代表了线程本身的下标,TIMES为需要执行临界区的次数。

其实从严格意义上来说这里的程序并不完全等同于Dijkstra上面paper中的示例,paper中的共享内存要求是强一致的,也就是说任何的一个写入操作B[i],C[i],k立刻能够被其他线程看到。


paper发表时是1965年,那个时候对于内存模型以及硬件能力的设想可能是这样的。但是随着现代的计算机体系结构的发展,为了提高程序执行的熟读,尤其是多层缓存以及指令乱序执行的引入,使得大部分程序设计语言的模型已经不符合上面的假设了。

然而尽管如此,我们的JAVA程序加入volatile语义的操作之后,我们这个程序依然是对的。因为保证了两点


  1. _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;    //这里必定会看到更新的C[i],从而根本上保证了互斥,临界区最多一个线程。    for(int j = 1; j <= N; ++j )     if(j != i &&  _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){    //将C[i]的值更新回去,写这里效率更高    _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);    continue L1;    }
    保证C上面更新的值在开始探测整个C数组之前被看到。
  2.     //临界区开始        long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);        //临界区结束                _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
    保证离开临界区之后才将C[i]更新回1,从而防止这个1过早泄露出来,从而导致前面循环探测的失误。
我们接着来看第二篇paper,由于篇幅短,可以直接贴出来:

只是将原来的N个执行单元简化成了2个,从而更好理解。这篇paper的算法是错误的,可以自行推导下。

我们接着来看第三篇paper,也是出自另一位图灵奖得住、著名计算机科学家Donald Ervin Knuth。
算法如下:

他的想法是,只采用一个control(初始化为0)的环形线程id数组,一个k用于指示临界区id。思想是:
  • 首先从k开始遍历到自己的id(i),假如发现一个control(j)!=0,说明前面已经有线程在竞争了,所以我们goto返回。否则从k到前一个id的control都为0,那么我们就进入第二步。
  • 第二步首先将contrl值设置为2,说明已经进一步竞争了,此时依然可能有多个线程到达此处,所以接下来,我们采用与Dijkstra类似的探测排除方法,最多可以得到一个进入下一步的线程。
  • 第三步,将k的值设置为当前id,进入临界区。
  • 第四部,从临界区出来之后,将k值设置为当前id右边→_→的一个id,如此一来很可能形成环形的执行顺序。最后将control[i]设置为0。
  • 最后返回。 注意, 这里的k设置是没有竞争的 k:=if i = 1 then N else i -1;是为了尽量让右边一个线程执行,但是极端情况下依然可能被其他线程获取锁。所以还是得有L3: k := i; 这一行。
我们也一样采用JAVA来完成这个算法,
可执行代码如下:
package com.psly.testatomic;import java.util.Random;import com.psly.locksupprot.LockSupport;import sun.misc.Unsafe;public class TestVolatileKnuthMethod {private final static Random random = new Random();//用于内存保证:putXXVolatile/getXXVolatileprivate static final Unsafe _unsafe = UtilUnsafe.getUnsafe();private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);//N:线程数,TIMES每个线程需要进入临界区的次数。private final static int N = 5;private final static int TIMES = 1000;private final static int[] B = new int[N+1];private final static int[] C = new int[N+1];//knuth's methodprivate final static int[] control = new int[N+1];//每个线程进入临界区++count,最终count == N * TIMESprivate static long count;//countObj:获取count字段所属于的对象(其实就是地址),private final static Object countObj;//countOffset:获取count字段处于所在对象地址的偏移量private final static long countOffset;//k与上面的count字段类似private static int k = 1;private final static Object kObj; private final static long kOffset;static{for(int i = 1; i <= N; ++i){B[i] = 1;C[i] = 1;}try {countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));} catch (Exception e) {throw new Error(e);} }private static Object obj = new Object();final static void knuthConcurMethod(int pM){    int times = TIMES;        int i = pM;            L0: for(;;){    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);    L1:for(;;){    for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     continue L1;    }    }    for(int j = N; j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     continue L1;    }    }    }    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;    for(int j = N; j >= 1; --j){    if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){    continue L0;    }    }    _unsafe.putIntVolatile(kObj, kOffset, i);    long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);    _unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;                if( --times != 0)    continue L0;    return ;    }}private static Thread[] handle = new Thread[N+1];public static void main(String[] args) throws InterruptedException{//开始时间long start = System.currentTimeMillis();//打印累加器初始值    System.out.println( count + " initial\n");        //创建线程    for (int i = 1; i <= N; ++i){    int j = i;    handle[i] = new Thread(new Runnable(){    @Override    public void run(){    knuthConcurMethod(j);    }    });    }    //线程开始执行    for (int i = 1; i <= N; ++i)        handle[i].start();    //主线程等待子线程结束    for (int i = 1; i <= N; ++i)        handle[i].join();    //打印累加值,== N * TIMES    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));  //打印程序执行时间    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); }}

输出如下:
0 initial50007.464 seconds

5个线程,每个1000次,7.46秒。 可以看出,尽管公平性得到了保证,但是这样的效率较低,因为环形数组中多余的线程一直在占有CPU资源。knuth的paper中也说要采用queue之类的方式提升效率,


我们这里采用另外的办法,想个办法让它休眠,然后等到必要的时候唤醒他。刚好java本身提供了park/unpark接口,并且我们这里的线程数组是固定的。所以可以直接采用。

在上面的示例中,添加如下代码:
唤醒:
    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;                int j = (i == 1)? N : i -1;            for(int m = 0; m < N - 1; ++m){            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)            break;            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){            LockSupport.unpark(handle[j]);            break;            }            j = (j == 1)? N : j -1;            }                if( --times != 0)

休眠:
    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }

然后得到我的阻塞版本,可执行代码为:
package com.psly.testatomic;import com.psly.locksupprot.LockSupport;import sun.misc.Unsafe;public class TestVolatileKnuthMethod {//用于内存保证:putXXVolatile/getXXVolatileprivate static final Unsafe _unsafe = UtilUnsafe.getUnsafe();private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);//N:线程数,TIMES每个线程需要进入临界区的次数。private final static int N = 5;private final static int TIMES = 1000;private final static int[] B = new int[N+1];private final static int[] C = new int[N+1];//knuth's methodprivate final static int[] control = new int[N+1];//每个线程进入临界区++count,最终count == N * TIMESprivate static long count;//countObj:获取count字段所属于的对象(其实就是地址),private final static Object countObj;//countOffset:获取count字段处于所在对象地址的偏移量private final static long countOffset;//k与上面的count字段类似private static int k = 1;private final static Object kObj; private final static long kOffset;static{for(int i = 1; i <= N; ++i){B[i] = 1;C[i] = 1;}try {countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));} catch (Exception e) {throw new Error(e);} }private static Object obj = new Object();final static void knuthConcurMethod(int pM){    int times = TIMES;        int i = pM;            L0: for(;;){    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);    L1:for(;;){    for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    for(int j = N; j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    }    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;    for(int j = N; j >= 1; --j){    if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){    continue L0;    }    }                //临界区开始    _unsafe.putIntVolatile(kObj, kOffset, i);                   long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);    _unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1); //临界区结束    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;                int j = (i == 1)? N : i -1;            for(int m = 0; m < N - 1; ++m){            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)            break;            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){            LockSupport.unpark(handle[j]);            break;            }            j = (j == 1)? N : j -1;            }                if( --times != 0)    continue L0;    return ;    }}private static Thread[] handle = new Thread[N+1];public static void main(String[] args) throws InterruptedException{//开始时间long start = System.currentTimeMillis();//打印累加器初始值    System.out.println( count + " initial\n");        //创建线程    for (int i = 1; i <= N; ++i){    int j = i;    handle[i] = new Thread(new Runnable(){    @Override    public void run(){    knuthConcurMethod(j);    }    });    }    //线程开始执行    for (int i = 1; i <= N; ++i)        handle[i].start();    //主线程等待子线程结束    for (int i = 1; i <= N; ++i)        handle[i].join();    //打印累加值,== N * TIMES    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));  //打印程序执行时间    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " milliseconds"); }}
 
输出:
0 initial50000.043 milliseconds


5个线程,每个1000次,0.043秒。

我们再尝试下更多的操作,N=100,TIMES=5000.
输出:
0 initial5000002.938 seconds

100个线程,每个进入临界区5000次,总共2.938秒,这比轮询的版本好多啦。

再看下我们的Java代码主要逻辑:
final static void knuthConcurMethod(int pM){    int times = TIMES;        int i = pM;            L0: for(;;){    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);    L1:for(;;){    for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    for(int j = N; j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    }    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;    for(int j = N; j >= 1; --j){    if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){    continue L0;    }    }    _unsafe.putIntVolatile(kObj, kOffset, i);    long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);    _unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;                int j = (i == 1)? N : i -1;            for(int m = 0; m < N - 1; ++m){            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)            break;            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){            LockSupport.unpark(handle[j]);            break;            }            j = (j == 1)? N : j -1;            }                if( --times != 0)    continue L0;    return ;    }}


这里重点是唤醒的逻辑:
向右遍历过程中,唤醒必要的一个就行,甚至于不需要唤醒。我们尽量少做事情。
       int j = (i == 1)? N : i -1;             for(int m = 0; m < N - 1; ++m){              if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)                  break;              if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){                  LockSupport.unpark(handle[j]);                  break;              }              j = (j == 1)? N : j -1;             }  


好了,我们最后来看最篇与之相关的paper:



一个叫N. G. de Bruijn再次稍微改了Knuth的方法。

优势是能够更清楚得看清执行的逻辑,一点细微的改变是k的值并不随着线程进入临界区而设置的。据说理论上一个线程究竟需要多少次才能轮到执行,这个次数的上界减少了,只不过没看懂他假设的前提是什么。
我们对应的java改变如下:




这种方法据说能够给每个线程提供一个下次执行临界区前的最大上限数量turn。

可执行代码如下(可阻塞版本):
package com.psly.testatomic;import java.text.SimpleDateFormat;import java.util.Date;import com.psly.locksupprot.LockSupport;import sun.misc.Unsafe;public class TestVolatileBruijnMethod {//用于内存保证:putXXVolatile/getXXVolatileprivate static final Unsafe _unsafe = UtilUnsafe.getUnsafe();private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);//N:线程数,TIMES每个线程需要进入临界区的次数。private final static int N = 100;private final static int TIMES = 5000;private final static int[] B = new int[N+1];private final static int[] C = new int[N+1];//knuth's methodprivate final static int[] control = new int[N+1];//每个线程进入临界区++count,最终count == N * TIMESprivate static long count;//countObj:获取count字段所属于的对象(其实就是地址),private final static Object countObj;//countOffset:获取count字段处于所在对象地址的偏移量private final static long countOffset;//k与上面的count字段类似private static int k = 1;private final static Object kObj; private final static long kOffset;static{for(int i = 1; i <= N; ++i){B[i] = 1;C[i] = 1;}try {countObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("count"));countOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("count"));kObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("k"));kOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("k"));} catch (Exception e) {throw new Error(e);} }private static Object obj = new Object();final static void knuthConcurMethod(int pM){    int times = TIMES;        int i = pM;            L0: for(;;){    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);    L1:for(;;){    for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    for(int j = N; j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    }    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;    for(int j = N; j >= 1; --j){    if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){    continue L0;    }    }//    _unsafe.putIntVolatile(kObj, kOffset, i);    int kLocal = _unsafe.getIntVolatile(kObj, kOffset);    int kNew = kLocal;    long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);        if(_unsafe.getIntVolatile(control, _Obase + kLocal * _Oscale) == 0 || kLocal == i)    _unsafe.putIntVolatile(kObj, kOffset, kNew = ((kLocal == 1)? N: kLocal - 1));    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;                int j = kNew;            for(int m = 0; m < N; ++m){            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)            break;            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){            LockSupport.unpark(handle[j]);            break;            }            j = (j == 1)? N : j -1;            }                if( --times != 0)    continue L0;    return ;    }}private static Thread[] handle = new Thread[N+1];public static void main(String[] args) throws InterruptedException{System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()));//开始时间long start = System.currentTimeMillis();//打印累加器初始值    System.out.println( count + " initial\n");        //创建线程    for (int i = 1; i <= N; ++i){    int j = i;    handle[i] = new Thread(new Runnable(){    @Override    public void run(){    knuthConcurMethod(j);    }    });    }    //线程开始执行    for (int i = 1; i <= N; ++i)        handle[i].start();    //主线程等待子线程结束    for (int i = 1; i <= N; ++i)        handle[i].join();    //打印累加值,== N * TIMES    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));  //打印程序执行时间    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); }}



接着再后来的1972年,The MITRE Corporation的Murray A. Eisenberg and Michael R. McGuire又提出了据说协调性更好的算法,如下:


你会看到这个算法中的goto语句更多了,也更复杂了。

我们同样也给出可执行的JAVA代码:
package com.psly.testatomic;import java.text.SimpleDateFormat;import java.util.Date;import com.psly.locksupprot.LockSupport;import sun.misc.Unsafe;public class TestVolatileEisenbergMethod {//用于内存保证:putXXVolatile/getXXVolatileprivate static final Unsafe _unsafe = UtilUnsafe.getUnsafe();private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);//N:线程数,TIMES每个线程需要进入临界区的次数。private final static int N = 100;private final static int TIMES = 5000;private final static int[] B = new int[N+1];private final static int[] C = new int[N+1];//knuth's methodprivate final static int[] control = new int[N+1];//每个线程进入临界区++count,最终count == N * TIMESprivate static long count;//countObj:获取count字段所属于的对象(其实就是地址),private final static Object countObj;//countOffset:获取count字段处于所在对象地址的偏移量private final static long countOffset;//k与上面的count字段类似private static int k = 1;private final static Object kObj; private final static long kOffset;static{for(int i = 1; i <= N; ++i){B[i] = 1;C[i] = 1;}try {countObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("count"));countOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("count"));kObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("k"));kOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("k"));} catch (Exception e) {throw new Error(e);} }private static Object obj = new Object();final static void EisenbergConcurMethod(int pM){    int times = TIMES;        int i = pM;            L0: for(;;){    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);    L1:for(;;){    int kLocal;    for(int j = (kLocal = _unsafe.getIntVolatile(kObj, kOffset)); j <= N; ++j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    for(int j = 1; j <= kLocal - 1; ++j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){     LockSupport.park(obj);    continue L1;    }    }    }    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;    for(int j = 1; j <= N; ++j){    if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){    continue L0;    }    }    int kLocal;    if(_unsafe.getIntVolatile(control, _Obase + (kLocal = _unsafe.getIntVolatile(kObj, kOffset)) *_Oscale ) != 0     && kLocal != i)    continue L0;        _unsafe.putIntVolatile(kObj, kOffset, i);    long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1); //       System.out.println(Thread.currentThread().getName());                int kNew = i;     L2:  for(;;){        for(int j = i + 1; j <= N; ++j){        if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){        _unsafe.putIntVolatile(kObj, kOffset, j);        //LockSupport.unpark(handle[j]);        kNew = j;        break L2;        }        }    for(int j = 1; j <= i - 1; ++j){        if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){        _unsafe.putIntVolatile(kObj, kOffset, j);        //LockSupport.unpark(handle[j]);        kNew = j;        break L2;        }    }    break;        }        _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;                int j = kNew;            for(int m = 0; m < N; ++m){            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)            break;            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){            LockSupport.unpark(handle[j]);            break;            }            j = (j == N)? 1 : j + 1;            }                if( --times != 0)    continue L0;    return ;    }}private static Thread[] handle = new Thread[N+1];public static void main(String[] args) throws InterruptedException{System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()));//开始时间long start = System.currentTimeMillis();//打印累加器初始值    System.out.println( count + " initial\n");        //创建线程    for (int i = 1; i <= N; ++i){    int j = i;    handle[i] = new Thread(new Runnable(){    @Override    public void run(){    EisenbergConcurMethod(j);    }    });    }    //线程开始执行    for (int i = 1; i <= N; ++i)        handle[i].start();    //主线程等待子线程结束    for (int i = 1; i <= N; ++i)        handle[i].join();    //打印累加值,== N * TIMES    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));  //打印程序执行时间    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); }}
同样,他们的执行效率并没有很大的差别。
但是采用的想法却都各不相同。

最后我们来看一遍, 为什么后面三个算法能够实现临界区:

    L1:for(;;){ //以下两个循环的代码判断当前线程是否适合竞争临界区    for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){    // LockSupport.park(obj);    continue L1;    }    }    for(int j = N; j >= 1; --j){    if(j == i)    break L1;    if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){    // LockSupport.park(obj);    continue L1;    }    } //以上两个循环的代码判断当前线程是否适合竞争临界区    }    //以下代码保证最多一个线程进去临界区    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;    for(int j = N; j >= 1; --j){    if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){    continue L0;    }    }        //以上代码保证最多一个线程进入临界区    _unsafe.putIntVolatile(kObj, kOffset, i);    //临界区start    long val = _unsafe.getLongVolatile(countObj, countOffset);        _unsafe.putLongVolatile(countObj, countOffset, val + 1);    _unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i - 1);    //临界区end    _unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
  • 先通过两个循环来判断当前线程是否适合竞争锁,适合跳出L1,否则继续循环
  • 接着第二个循环通过探测其他线程的control值,假如发现都不为0则结束循环,获得锁,否则跳回L0,继续前面的循环判断。注意这里的语义确保最多只有一个线程进入临界区,存在全部线程都无法获得锁,跳回L0的极端情况。
  • 临界区结尾处将0给control[i],替换掉了它的2值,从而之后,让其他线程有机会获得锁(根据竞争判断的语义,假如一个线程看到其他的某个为2是无法获取锁的)。
Over

附上:

package com.psly.testatomic;import java.lang.reflect.Field;import sun.misc.Unsafe;public class UtilUnsafe {  private UtilUnsafe() { } // dummy private constructor  /** Fetch the Unsafe.  Use With Caution. */  public static Unsafe getUnsafe() {    // Not on bootclasspath    if( UtilUnsafe.class.getClassLoader() == null )      return Unsafe.getUnsafe();    try {      final Field fld = Unsafe.class.getDeclaredField("theUnsafe");      fld.setAccessible(true);      return (Unsafe) fld.get(UtilUnsafe.class);    } catch (Exception e) {      throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);    }  }}
package com.psly.locksupprot;import com.psly.testatomic.UtilUnsafe;public class LockSupport {    private LockSupport() {} // Cannot be instantiated.    private static void setBlocker(Thread t, Object arg) {        // Even though volatile, hotspot doesn't need a write barrier here.        UNSAFE.putObject(t, parkBlockerOffset, arg);    }    /**     * Makes available the permit for the given thread, if it     * was not already available.  If the thread was blocked on     * {@code park} then it will unblock.  Otherwise, its next call     * to {@code park} is guaranteed not to block. This operation     * is not guaranteed to have any effect at all if the given     * thread has not been started.     *     * @param thread the thread to unpark, or {@code null}, in which case     *        this operation has no effect     */    public static void unpark(Thread thread) {        if (thread != null)            UNSAFE.unpark(thread);    }    /**     * Disables the current thread for thread scheduling purposes unless the     * permit is available.     *     * <p>If the permit is available then it is consumed and the call returns     * immediately; otherwise     * the current thread becomes disabled for thread scheduling     * purposes and lies dormant until one of three things happens:     *     * <ul>     * <li>Some other thread invokes {@link #unpark unpark} with the     * current thread as the target; or     *     * <li>Some other thread {@linkplain Thread#interrupt interrupts}     * the current thread; or     *     * <li>The call spuriously (that is, for no reason) returns.     * </ul>     *     * <p>This method does <em>not</em> report which of these caused the     * method to return. Callers should re-check the conditions which caused     * the thread to park in the first place. Callers may also determine,     * for example, the interrupt status of the thread upon return.     *     * @param blocker the synchronization object responsible for this     *        thread parking     * @since 1.6     */    public static void park(Object blocker) {        Thread t = Thread.currentThread();        setBlocker(t, blocker);        UNSAFE.park(false, 0L);        setBlocker(t, null);    }    /**     * Disables the current thread for thread scheduling purposes, for up to     * the specified waiting time, unless the permit is available.     *     * <p>If the permit is available then it is consumed and the call     * returns immediately; otherwise the current thread becomes disabled     * for thread scheduling purposes and lies dormant until one of four     * things happens:     *     * <ul>     * <li>Some other thread invokes {@link #unpark unpark} with the     * current thread as the target; or     *     * <li>Some other thread {@linkplain Thread#interrupt interrupts}     * the current thread; or     *     * <li>The specified waiting time elapses; or     *     * <li>The call spuriously (that is, for no reason) returns.     * </ul>     *     * <p>This method does <em>not</em> report which of these caused the     * method to return. Callers should re-check the conditions which caused     * the thread to park in the first place. Callers may also determine,     * for example, the interrupt status of the thread, or the elapsed time     * upon return.     *     * @param blocker the synchronization object responsible for this     *        thread parking     * @param nanos the maximum number of nanoseconds to wait     * @since 1.6     */    public static void parkNanos(Object blocker, long nanos) {        if (nanos > 0) {            Thread t = Thread.currentThread();            setBlocker(t, blocker);            UNSAFE.park(false, nanos);            setBlocker(t, null);        }    }    /**     * Disables the current thread for thread scheduling purposes, until     * the specified deadline, unless the permit is available.     *     * <p>If the permit is available then it is consumed and the call     * returns immediately; otherwise the current thread becomes disabled     * for thread scheduling purposes and lies dormant until one of four     * things happens:     *     * <ul>     * <li>Some other thread invokes {@link #unpark unpark} with the     * current thread as the target; or     *     * <li>Some other thread {@linkplain Thread#interrupt interrupts} the     * current thread; or     *     * <li>The specified deadline passes; or     *     * <li>The call spuriously (that is, for no reason) returns.     * </ul>     *     * <p>This method does <em>not</em> report which of these caused the     * method to return. Callers should re-check the conditions which caused     * the thread to park in the first place. Callers may also determine,     * for example, the interrupt status of the thread, or the current time     * upon return.     *     * @param blocker the synchronization object responsible for this     *        thread parking     * @param deadline the absolute time, in milliseconds from the Epoch,     *        to wait until     * @since 1.6     */    public static void parkUntil(Object blocker, long deadline) {        Thread t = Thread.currentThread();        setBlocker(t, blocker);        UNSAFE.park(true, deadline);        setBlocker(t, null);    }    /**     * Returns the blocker object supplied to the most recent     * invocation of a park method that has not yet unblocked, or null     * if not blocked.  The value returned is just a momentary     * snapshot -- the thread may have since unblocked or blocked on a     * different blocker object.     *     * @param t the thread     * @return the blocker     * @throws NullPointerException if argument is null     * @since 1.6     */    public static Object getBlocker(Thread t) {        if (t == null)            throw new NullPointerException();        return UNSAFE.getObjectVolatile(t, parkBlockerOffset);    }    /**     * Disables the current thread for thread scheduling purposes unless the     * permit is available.     *     * <p>If the permit is available then it is consumed and the call     * returns immediately; otherwise the current thread becomes disabled     * for thread scheduling purposes and lies dormant until one of three     * things happens:     *     * <ul>     *     * <li>Some other thread invokes {@link #unpark unpark} with the     * current thread as the target; or     *     * <li>Some other thread {@linkplain Thread#interrupt interrupts}     * the current thread; or     *     * <li>The call spuriously (that is, for no reason) returns.     * </ul>     *     * <p>This method does <em>not</em> report which of these caused the     * method to return. Callers should re-check the conditions which caused     * the thread to park in the first place. Callers may also determine,     * for example, the interrupt status of the thread upon return.     */    public static void park() {        UNSAFE.park(false, 0L);    }    /**     * Disables the current thread for thread scheduling purposes, for up to     * the specified waiting time, unless the permit is available.     *     * <p>If the permit is available then it is consumed and the call     * returns immediately; otherwise the current thread becomes disabled     * for thread scheduling purposes and lies dormant until one of four     * things happens:     *     * <ul>     * <li>Some other thread invokes {@link #unpark unpark} with the     * current thread as the target; or     *     * <li>Some other thread {@linkplain Thread#interrupt interrupts}     * the current thread; or     *     * <li>The specified waiting time elapses; or     *     * <li>The call spuriously (that is, for no reason) returns.     * </ul>     *     * <p>This method does <em>not</em> report which of these caused the     * method to return. Callers should re-check the conditions which caused     * the thread to park in the first place. Callers may also determine,     * for example, the interrupt status of the thread, or the elapsed time     * upon return.     *     * @param nanos the maximum number of nanoseconds to wait     */    public static void parkNanos(long nanos) {        if (nanos > 0)            UNSAFE.park(false, nanos);    }    /**     * Disables the current thread for thread scheduling purposes, until     * the specified deadline, unless the permit is available.     *     * <p>If the permit is available then it is consumed and the call     * returns immediately; otherwise the current thread becomes disabled     * for thread scheduling purposes and lies dormant until one of four     * things happens:     *     * <ul>     * <li>Some other thread invokes {@link #unpark unpark} with the     * current thread as the target; or     *     * <li>Some other thread {@linkplain Thread#interrupt interrupts}     * the current thread; or     *     * <li>The specified deadline passes; or     *     * <li>The call spuriously (that is, for no reason) returns.     * </ul>     *     * <p>This method does <em>not</em> report which of these caused the     * method to return. Callers should re-check the conditions which caused     * the thread to park in the first place. Callers may also determine,     * for example, the interrupt status of the thread, or the current time     * upon return.     *     * @param deadline the absolute time, in milliseconds from the Epoch,     *        to wait until     */    public static void parkUntil(long deadline) {        UNSAFE.park(true, deadline);    }    /**     * Returns the pseudo-randomly initialized or updated secondary seed.     * Copied from ThreadLocalRandom due to package access restrictions.     */    static final int nextSecondarySeed() {        int r;        Thread t = Thread.currentThread();        if ((r = UNSAFE.getInt(t, SECONDARY)) != 0) {            r ^= r << 13;   // xorshift            r ^= r >>> 17;            r ^= r << 5;        }        else if ((r = java.util.concurrent.ThreadLocalRandom.current().nextInt()) == 0)            r = 1; // avoid zero        UNSAFE.putInt(t, SECONDARY, r);        return r;    }    // Hotspot implementation via intrinsics API    private static final sun.misc.Unsafe UNSAFE;    private static final long parkBlockerOffset;    private static final long SEED;    private static final long PROBE;    private static final long SECONDARY;    static {        try {            UNSAFE = UtilUnsafe.getUnsafe();            Class<?> tk = Thread.class;            parkBlockerOffset = UNSAFE.objectFieldOffset                (tk.getDeclaredField("parkBlocker"));            SEED = UNSAFE.objectFieldOffset                (tk.getDeclaredField("threadLocalRandomSeed"));            PROBE = UNSAFE.objectFieldOffset                (tk.getDeclaredField("threadLocalRandomProbe"));            SECONDARY = UNSAFE.objectFieldOffset                (tk.getDeclaredField("threadLocalRandomSecondarySeed"));        } catch (Exception ex) { throw new Error(ex); }    }}


最近发现,针对这个独占互斥区的并发控制,2013年图灵奖得主Leslie Lamport在1974年也提出过另一种算法,paper截图如下:


证明过程:



这个算法的特点是,没有中心控制。

我们用JAVA代码实现下:

package com.psly.testatomic;import sun.misc.Unsafe;public class TestVolatile {//用于内存保证:putXXVolatile/getXXVolatileprivate static final Unsafe _unsafe = UtilUnsafe.getUnsafe();private static final int _Obase  = _unsafe.arrayBaseOffset(long[].class);private static final int _Oscale = _unsafe.arrayIndexScale(long[].class);//N:线程数,TIMES每个线程需要进入临界区的次数。private final static int N = 2000;private final static int TIMES = 1000;private final static long[] choosing = new long[N+1];private final static long[] number = new long[N+1];//每个线程进入临界区++count,最终count == N * TIMESprivate static long count;//countObj:获取count字段所属于的对象(其实就是地址),private final static Object mainObj;//countOffset:获取count字段处于所在对象地址的偏移量private final static long countOffset;private static Object obj = new Object(); //private static Queue<Thread> queues = new ConcurrentLinkedQueue();static{for(int i = 1; i <= N; ++i){choosing[i] = 0;number[i] = 0;}try {mainObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));    //        waitersOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("waiters"));} catch (Exception e) {throw new Error(e);} }    final static void dijkstrasConcurMethod(int pM){    int times = TIMES;        int i = pM;    L0: for(;;){    _unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 1);    //获取最大的number并+1。    long maxNum = _unsafe.getLongVolatile(number, _Obase + _Oscale), midNum;    for(int j = 2; j <= N; ++j)    if(maxNum < (midNum = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)))    maxNum = midNum;    _unsafe.putLongVolatile(number, _Obase + i * _Oscale, 1 + maxNum);    _unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 0);   /* for(int j = 1; j <i; ++j)    LockSupport.unpark(handle[j]);    for(int j = i+1; j <= N; ++j)    LockSupport.unpark(handle[j]);*/    long jNumber, iNumber;    for(int j = 1; j <= N; ++j){    L1:for(;;){    for(int k = 0 ; k < 100; ++k)    if(!(_unsafe.getLongVolatile(choosing, _Obase + j * _Oscale) != 0))     break L1;    //LockSupport.park(obj);    }    L2:for(;;){    for(int k = 0; k < 1000; ++k)    if(!(_unsafe.getLongVolatile(number, _Obase + j * _Oscale) != 0     && ((jNumber=_unsafe.getLongVolatile(number, _Obase + j * _Oscale))     < (iNumber=_unsafe.getLongVolatile(number, _Obase + i * _Oscale))     || (jNumber == iNumber && j < i))))    break L2;    LockSupport.park(obj);    }    }    //critical section     //临界区开始              long val = _unsafe.getLongVolatile(mainObj, countOffset);              _unsafe.putLongVolatile(mainObj, countOffset, val + 1);             //临界区结束                        //设置标识            _unsafe.putLongVolatile(number, _Obase + i * _Oscale, 0);            //唤醒需要的线程            Thread target = handle[i];            long numMax = Long.MAX_VALUE, arg;                  for(int j = 1; j <i; ++j)      if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)      { target = handle[j]; numMax = arg;}    for(int j = i+1; j <= N; ++j)    if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)  { target = handle[j]; numMax = arg;}    LockSupport.unpark(target);            /*for(int j = 1; j <= N; ++j)            LockSupport.unpark(handle[j]);*/    //计算次数        if( --times != 0){        continue L0; //goto L0;        }        return;        }}private static Thread[] handle = new Thread[N+1]; public static void main(String[] args) throws InterruptedException{//开始时间long start = System.currentTimeMillis();//打印累加器初始值    System.out.println( count + " initial\n");  //  Thread handle[] = new Thread[N+1];        //创建线程    for (int i = 1; i <= N; ++i){    int j = i;    handle[i] = new Thread(new Runnable(){    @Override    public void run(){    dijkstrasConcurMethod(j);    }    });    }    //线程开始执行    for (int i = 1; i <= N; ++i)        handle[i].start();    //主线程等待子线程结束    for (int i = 1; i <= N; ++i)        handle[i].join();    //打印累加值,== N * TIMES    System.out.println(_unsafe.getLongVolatile(mainObj, countOffset));  //打印程序执行时间    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); }}


原创粉丝点击