CountDownLatch使用、源码浅析

来源:互联网 发布:淘宝买东西怎么收货 编辑:程序博客网 时间:2024/06/15 09:36

CountDownLatch, 允许一组线程都完成后再执行某些事情,

例如A 线程需要等待 B C D线程 都执行完成后再执行

public static void main(String[] args) throws InterruptedException {        CountDownLatch latch = new CountDownLatch(2);        new Thread(new Manager(latch)).start();        Thread.sleep(2000);        new Thread(new Manager(latch)).start();        Thread.sleep(10000);        new Thread(new Work(latch)).start();        new Thread(new Work(latch)).start();        System.out.println("main thread");}


class Work implements Runnable {    private CountDownLatch latch;    public Work(CountDownLatch latch) {        this.latch = latch;    }    @Override    public void run() {        try {            Thread.sleep(5000);            System.out.println("我完成了");            latch.countDown();        } catch (InterruptedException e) {            e.printStackTrace();        }    }}

class Manager implements Runnable {    private CountDownLatch latch;    public Manager(CountDownLatch latch) {        this.latch = latch;    }    @Override    public void run() {        try {            latch.await();            System.out.println("你们完成了,我要开始工作了");        } catch (InterruptedException e) {            e.printStackTrace();        }    }}


运行结果:

main thread
我完成了
我完成了
你们完成了,我要开始工作了


下面是源码片段

public class CountDownLatch {    /**     * Synchronization control For CountDownLatch.     * Uses AQS state to represent count.     */    private static final class Sync extends AbstractQueuedSynchronizer {        private static final long serialVersionUID = 4982264981922014374L;        Sync(int count) {            setState(count);        }        int getCount() {            return getState();        }        protected int tryAcquireShared(int acquires) {            return (getState() == 0) ? 1 : -1;        }        protected boolean tryReleaseShared(int releases) {            // Decrement count; signal when transition to zero            for (;;) {                int c = getState();                if (c == 0)                    return false;                int nextc = c-1;                if (compareAndSetState(c, nextc))                    return nextc == 0;            }        }    }


注意这个Sync类,继承自 AbstractQueuedSynchronizer ,简称AQS,这个类有点像是java实现并发的一个基础类,像是 ReentrantLock内部也是用AQS实现的,

AQS提供一种基于 FIFO的队列,每个要获取锁的线程会依次放入AQS的队列之中,获取不到锁的线程会被 java提供的Unsafe类的park 方法阻塞,内部细节还是挺复杂的,

AQS有几个重要属性:

 private transient volatile Node head;  //第一个元素
 private transient volatile Node tail;   // 最后一个元素

这里能够看出,其实AQS的FIFO队列是由链表实现的

 private volatile int state;   //同步状态

刚刚讲到,等待获取锁的线程会放入队列中,AQS把线程包装成一个Node对象放入 队列中,所以head tail变量都是Node类型,

static final class Node {       static final Node SHARED = new Node();        static final Node EXCLUSIVE = null;        waitStatus 有几个值,        //代表线程已经被取消了,被取消的线程不需要获取锁        static final int CANCELLED =  1;          //代表线程需要被唤醒        static final int SIGNAL    = -1;        //代表线程正在等待某个condition        static final int CONDITION = -2;        //这个我也没理解        static final int PROPAGATE = -3;                   volatile int waitStatus;        volatile Node prev;        volatile Node next;        volatile Thread thread;        Node nextWaiter;}



这里可能一下子看的很迷糊,其实AQS是一个公共的类,并不是专门为CountDownLatch准备的,所以其实CountDownLatch的源码里面并不是都用到了每一个属性,

这里只是带大家过个概念,大家只要知道 有head tail元素,以及 waitStatus  status这几个就行了,因为下面的源码也就只涉及到了这些而已


首先看一下CountDownLatch的  await方法

 

public void await() throws InterruptedException {        sync.acquireSharedInterruptibly(1);    } public final void acquireSharedInterruptibly(int arg)            throws InterruptedException {        if (Thread.interrupted())            throw new InterruptedException();        if (tryAcquireShared(arg) < 0)              doAcquireSharedInterruptibly(arg);    } protected int tryAcquireShared(int acquires) {            return (getState() == 0) ? 1 : -1;  }

tryAcquireShared方法判断state状态是否为0, 是的话返回1 否则返回 -1,

 那么这个state是什么时候设置进去的呢? 是在构造函数里面

 public CountDownLatch(int count) {        if (count < 0) throw new IllegalArgumentException("count < 0");        this.sync = new Sync(count); }
Sync(int count) {            setState(count);}


也就是说创建CountDownLatch对象的时候,比如传入个2,state就是2了,

所以初始的时候 tryAcquireShared()方法的返回值肯定是小于0的,因为state是你构建CountDownLatch传进去的,传进去的值肯定大于0咯


再看看 doAcquireSharedInterruptibly 做了些什么

private void doAcquireSharedInterruptibly(int arg)        throws InterruptedException {        final Node node = addWaiter(Node.SHARED);        boolean failed = true;        try {            for (;;) {                //获取 node节点的prev节点                final Node p = node.predecessor();                if (p == head) {                    int r = tryAcquireShared(arg);                     //r若是等于0了,说明不需要等待了,若是大于0,可能是由于线程已经超时或者中断了,也就不需要等待了                    if (r >= 0) {                        setHeadAndPropagate(node, r);                        p.next = null; // help GC                        failed = false;                        return;                    }                }//判断是否需要阻塞,如果需要,  parkAndCheckInterrupt()方法会阻塞线程                if (shouldParkAfterFailedAcquire(p, node) &&                    parkAndCheckInterrupt())                    throw new InterruptedException();            }        } finally {            if (failed)                cancelAcquire(node);        }    }


先一步步来分析,看 addWaiter方法(添加一个等待者),传入了一个Node.SHARED,这个是啥呢,其实就是 new Node()

private Node addWaiter(Node mode) {        Node node = new Node(Thread.currentThread(), mode);        Node pred = tail;        if (pred != null) {            node.prev = pred;            if (compareAndSetTail(pred, node)) {                pred.next = node;                return node;            }        }        enq(node);        return node;}


首先呢,构建了一个Node对象,传入了当前的线程 ,判断tail元素是不是空,其实第一次进来的时候肯定是空,因为这时候的head和 tail都没有被赋值,

所以方法会进入 enq(node)

private Node enq(final Node node) {        for (;;) {            Node t = tail;            if (t == null) {                 if (compareAndSetHead(new Node()))                    tail = head;            } else {                node.prev = t;                if (compareAndSetTail(t, node)) {                    t.next = node;                    return t;                }            }        } }


调试后发现的,for循环里的代码会执行2次,第一次呢,因为tail为null,所以new Node()赋值给 head ,head再赋值给 tail元素

第二次就会进入 else代码块了,把node节点替换为 tail, 之前的tail设置成了node的 prev, 之前的tail的next设置为node,典型的链表写法啦



addWaiter方法返回的是 最后一个元素,队列中就两个元素,,一个head,是通过new Node()构造的, tail就是 调用await方法的当前线程的node了,

 node.predecessor()返回  node上上一个元素, 之前说过,addWaiter方法里面其实开始就 new Node()赋值给了 head 以及 tail,

node.predecessor的pre是之前的tail,也就是 head,

然后方法再次调用 tryAcquireShared 查询 state的状态,这时候如果其他线程没有调用countDown的话,

state肯定还是大于0,tryAcquireShared 就会返回 -1,

方法就会走到下面这里

if (shouldParkAfterFailedAcquire(p, node) &&                    parkAndCheckInterrupt())                    throw new InterruptedException();

//shouldParkAfterFailedAcquire 先判断当前节点的 上一个节点是不是出于 等待唤醒中,如果是,当前节点也要等待,

还记得 之前讲过的吗,>0 代表 线程状态是cancelled,被取消了,可能是由于超时或者 中断等原因,

如果当前节点的上一个节点状态>0  就递归

直到找到不是大于0的,也就是说不是cancelled状态的,将那个节点的next设置为node

如果上一个节点waitStatus不大于0,就把 当前节点状态改为 SIGNAL(等待唤醒)

private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {        int ws = pred.waitStatus;        if (ws == Node.SIGNAL)                  return true;        if (ws > 0) {                    do {                node.prev = pred = pred.prev;            } while (pred.waitStatus > 0);            pred.next = node;        } else {            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);        }        return false;}



调试 shouldParkAfterFailedAcquire 会发现方法进去了两次,第一次将waitStatus设置为了 SIGNAL(等待唤醒)  但是由于整个方法返回false,又会再次进入循环

第二次调用shouldParkAfterFailedAcquire 就会返回true,然后线程被 park阻塞,当countDown方法被调用且 state变为0的时候await线程会被唤醒,

然后再次判断state是否为0,这时候state肯定是为0咯,线程就不会阻塞,就能够继续执行了,而且由于其他线程已经全部调用完了countDown()方法,

也就是说  await的确是等到了 其他所有线程都执行完了再执行的


看下countDown方法

 public void countDown() {        sync.releaseShared(1); }


//releaseShared尝试去释放锁

 public final boolean releaseShared(int arg) {        if (tryReleaseShared(arg)) {            doReleaseShared();            return true;        }        return false;    } protected boolean tryReleaseShared(int releases) {            for (;;) {                int c = getState();                if (c == 0)                    return false;                int nextc = c-1;                if (compareAndSetState(c, nextc))                    return nextc == 0;            } }



tryReleaseShared方法很容易理解,state减1, 再判断是否能够释放锁(不是真的有锁,只是判断是否需要唤醒其他线程),假如new countDownLatch(2),

默认 state是2,第一个线程调用countDown方法只是把 state 减1, state还等于1,就会返回false,所以就不会

释放锁,只有等到最后一个线程调用countDown方法,使 state等于0的时候 才会释放锁,

然后唤醒 head的下一个线程,也就是await线程了,

private void unparkSuccessor(Node node) {          int ws = node.waitStatus;        if (ws < 0)            compareAndSetWaitStatus(node, ws, 0);               Node s = node.next;        if (s == null || s.waitStatus > 0) {            s = null;            for (Node t = tail; t != null && t != node; t = t.prev)                if (t.waitStatus <= 0)                    s = t;        }        if (s != null)            LockSupport.unpark(s.thread); }


总结下来:其实CountDownLatch就是 开始设置一个计数器一样的东西(state),比如需要等待A B C D E 5个线程都执行完 , F线程才能执行,那么CountDownLatch构造函数就传入5, F线程由于 计数器没有变成0会被阻塞, 其他线程每个调用完成后会将 计数器 减1,当最后一个线程调用CountDown,因为state变为0了,会去唤醒await线程

ps: Unsafe类中提供了很多  原子操作的 方法(CAS  compare andd set),以及 阻塞 线程的方法, 都是native的,