ConcurrentHashMap原理详解

来源:互联网 发布:哈尔滨上牌数据 编辑:程序博客网 时间:2024/06/18 15:49

ConcurrentHashMap是既高效又线程安全的HashMap

HashTable和ConcurrentHashMap区别

HashTable容器使用synchronized来保证线程安全,但在线程竞争激烈的情况下HashTable的效率非常低下。因为当一个线程访问HashTable的同步方法,其他线程也访问HashTable的同步方法时,会进入阻塞或轮询状态,因为HashTable只有一把锁。
ConcurrentHashMap所使用的锁分段技术。首先将数据分成一段一段地存储,然后给每一段数据配一把锁,当一个线程占用锁访问其中一个段数据的时候,其他段的数据也能被其他线程访问

ConcurrentHashMap的结构

ConcurrentHashMap是由Segment数组结构和HashEntry数组结构组成。Segment是一种可重入锁(ReentrantLock),在ConcurrentHashMap里扮演锁的角色;HashEntry则用于存储键值对数据。一个ConcurrentHashMap里包含一个Segment数组。Segment的结构和HashMap类似,是一种数组和链表结构。一个Segment里包含一个HashEntry数组,每个HashEntry是一个链表结构的元素,每个Segment守护着一个HashEntry数组里的元素,当对HashEntry数组的数据进行修改时,必须首先获得与它对应的Segment锁,他的结构图如下所示
这里写图片描述
#ConcurrentHashMap的初始化
我们来看看ConcurrentHashMap的构造方法

    /**     * Creates a new, empty map with the specified initial     * capacity, load factor and concurrency level.     *     * @param initialCapacity the initial capacity. The implementation     * performs internal sizing to accommodate this many elements.     * @param loadFactor  the load factor threshold, used to control resizing.     * Resizing may be performed when the average number of elements per     * bin exceeds this threshold.     * @param concurrencyLevel the estimated number of concurrently     * updating threads. The implementation performs internal sizing     * to try to accommodate this many threads.     * @throws IllegalArgumentException if the initial capacity is     * negative or the load factor or concurrencyLevel are     * nonpositive.     *      * 实现原理:     *     ConcurrentHashMap使用分段锁技术,将数据分成一段一段的存储,然后给每一段数据配一把锁,当一个线程占用锁访问其中一个段数据的时候,     *      其他段的数据也能被其他线程访问,能够实现真正的并发访问。如下图是ConcurrentHashMap的内部结构图:     *      * 1.initialCapacity 表示新建的这个ConcurrentHashMap的初始容量,也就是上线结构图中的Entry数量。     * 默认值为static final int DEFAULT_INITIAL_CAPACITY = 16;     *      * 2.loadFactor表示负载因子,就是当ConcurrentHashMap中的元素个数大于loadFactor * 最大容量时候就需要rehash和扩容。     * 默认值为static final float DEFAULT_LOAD_FACTOR = 0.75f;     *      * 3.concurrencyLevel表示并发级别,这个值用来确定segment的个数,segment的个数大于等于concurrencyLevel的第一个2的n次方的数。     * 比如,如果concurrencyLevel为12,13,14,15,16,则Segment的数目为16(2的4次方)。     *      * 4.理想情况下ConcurrentHashMap真正的访问量能够达到concurrencyLevel,因为有concurrencyLevel个Segment,     * 假如有concurrencyLevel个线程要访问Map,并且需要访问的数据都恰好分别落在不同的segment中,则这些线程能够无竞     * 争的自由访问(因为不需要竞争同一把锁)达到同时访问的效果。这也是这个concurrencyLevel参数为什么起名为“并发级别”的原因。     *      *      */    @SuppressWarnings("unchecked")    public ConcurrentHashMapSourceCode(int initialCapacity,                             float loadFactor, int concurrencyLevel) {        //1.验证参数的合法性,如果不合法,直接抛出异常        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)            throw new IllegalArgumentException();        //2.concurrencyLevel也就是Segment的个数不能超过最大Segment的个数,最大个数MAX_SEGMENTS默认值为 2 << 16,如果超过这个值,设置这个值。        if (concurrencyLevel > MAX_SEGMENTS)            concurrencyLevel = MAX_SEGMENTS;        // Find power-of-two sizes best matching arguments        //比如concurrencyLevel=16默认值,则ssize也会等于16(2的4次方,sshift=4),如果concurrencyLevel=18,则ssize=32(也就是2的5次方,sshift=5),        //3.这段代码的使用循环找到>=concurrencyLevel的第一个2的n次方的数ssize,这个数ssize就是Segment数组的大小;并记录一共向左按位移动的次数sshift。        int sshift = 0;        int ssize = 1;        while (ssize < concurrencyLevel) {            //sshift记录ssize向左移动的次数            ++sshift;            //ssize就是Segment数组的大小            ssize <<= 1;        }        //segmentShift 默认的情况下为28        this.segmentShift = 32 - sshift;        //segmentMask 默认情况下为15,segmentMask的各个二进制位都为1,目的是之后可以通过key的hash值与这个值做&运算确定Segment的索引。        this.segmentMask = ssize - 1;        //4 检查给的容量值是否大于允许的最大容量,如果大于MAXIMUM_CAPACITY,就设置为该值。initialCapacity默认值也为16。static final int MAXIMUM_CAPACITY = 1 << 30;        if (initialCapacity > MAXIMUM_CAPACITY)            initialCapacity = MAXIMUM_CAPACITY;        //5 计算每个Segment平均应该放置多少元素,这个值c是向上取整的值。比如初始容量initialCapacity=15,Segment数组的大小为16,Segment的个数为4,则每个Segment平均需要放置4个元素。        int c = initialCapacity / ssize;        if (c * ssize < initialCapacity)            ++c;        int cap = MIN_SEGMENT_TABLE_CAPACITY;        while (cap < c)            cap <<= 1;        //6 创建一个Segment的实例,将其当做Segment数组的第一个元素。        // create segments and segments[0],cap * loadFactor = 1.5,cap=2        Segment<K,V> s0 =            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),                             (HashEntry<K,V>[])new HashEntry[cap]);        // ssize默认=16,表示Segment数组的大小        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]        this.segments = ss;    }

segments数组的长度ssize是通过concurrencyLevel计算得出的。为了能通过按位与的散列算法来定位segments数组的索引,必须保证segments数组的长度是2的N次方(power-of-two size),所以必须计算出一个大于或等于concurrencyLevel的最小的2的N次方值来作为segments数组的长度。假如concurrencyLevel等于14、15或16,ssize都会等于16,即容器里锁的个数也是16。concurrencyLevel的最大值是65535,这意味着segments数组的长度最大为65536,对应的二进制是16位。
segmentShift和segmentMask这两个全局变量需要在定位segment时的散列算法里使用,sshift等于ssize从1向左移位的次数,在默认情况下concurrencyLevel等于16,1需要向左移位移动4次,所以sshift等于4。segmentShift用于定位参与散列运算的位数,segmentShift等于32减sshift,所以等于28,这里之所以用32是因为ConcurrentHashMap里的hash()方法输出的最大数是32位的,后面的测试中我们可以看到这点。segmentMask是散列运算的掩码,等于ssize减1,即15,掩码的二进制各个位的值都是1。因为ssize的最大长度是65536,所以segmentShift最大值是16,segmentMask最大值是65535,对应的二进制是16位,每个位都是1。
上面代码中的变量cap就是segment里HashEntry数组的长度,它等于initialCapacity除以ssize的倍数c,如果c大于1,就会取大于等于c的2的N次方值,所以cap不是1,就是2的N次方。segment的容量threshold=(int)cap*loadFactor,默认情况下initialCapacity等于16,loadfactor等于
0.75,通过运算cap等于1,threshold等于零.

定位segment

既然ConcurrentHashMap使用分段锁Segment来保护不同段的数据,那么在插入和获取元素的时候,必须先通过散列算法定位到Segment。下面让我们来看看ConcurrentHashMap的hash算法

  /**     * Applies a supplemental hash function to a given hashCode, which     * defends against poor quality hash functions.  This is critical     * because ConcurrentHashMap uses power-of-two length hash tables,     * that otherwise encounter collisions for hashCodes that do not     * differ in lower or upper bits.     */    private int hash(Object k) {        int h = hashSeed;        if ((0 != h) && (k instanceof String)) {            return sun.misc.Hashing.stringHash32((String) k);        }        h ^= k.hashCode();        // Spread bits to regularize both segment and index locations,        // using variant of single-word Wang/Jenkins hash.        h += (h <<  15) ^ 0xffffcd7d;        h ^= (h >>> 10);        h += (h <<   3);        h ^= (h >>>  6);        h += (h <<   2) + (h << 14);        return h ^ (h >>> 16);    }

之所以进行再散列,目的是减少散列冲突,使元素能够均匀地分布在不同的Segment上,从而提高容器的存取效率。通过上的hash算法已经得到哈希值,那么我们进一步看下如何通过这个hash值来定位segment

  /**     * Get the segment for the given hash     */    @SuppressWarnings("unchecked")    private Segment<K,V> segmentForHash(int h) {        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;        return (Segment<K,V>) UNSAFE.getObjectVolatile(segments, u);    }

默认情况下segmentShift为28,segmentMask为15,再散列后的数最大是32位二进制数据,向右无符号移动28位,意思是让高4位参与到散列运算中,(hash>>>segmentShift)&segmentMask的运算结果分别是4、15、7和8,可以看到散列值没有发生冲突。

ConcurrentHashMap的操作

get

Segment的get操作实现非常简单和高效。先经过一次再散列,然后使用这个散列值通过散列运算定位到Segment,再通过散列算法定位到元素,代码如下

   /**     * Returns the value to which the specified key is mapped,     * or {@code null} if this map contains no mapping for the key.     *     * <p>More formally, if this map contains a mapping from a key     * {@code k} to a value {@code v} such that {@code key.equals(k)},     * then this method returns {@code v}; otherwise it returns     * {@code null}.  (There can be at most one such mapping.)     *     * @throws NullPointerException if the specified key is null     */    public V get(Object key) {        Segment<K,V> s; // manually integrate access methods to reduce overhead        HashEntry<K,V>[] tab;        //1 和put操作一样,先通过key进行两次hash确定取哪个segment中的数据        int h = hash(key);        //2 使用UNSAFE方法获取对应的Segment,然后再进行一次&运算得到HashEntry链表的位置,然后从链表头开始遍历整个链表。        //(由于hash会碰撞,所以用一个链表保存),如果找到对应的key,则返回对应的value值,如果链表遍历完都没有找到对应的key,        // 则说明map中不包含该key,返回null        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&            (tab = s.table) != null) {            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);                 e != null; e = e.next) {                K k;                // 如果发生哈希碰撞,判断key是否相同                if ((k = e.key) == key || (e.hash == h && key.equals(k)))                    return e.value;            }        }        return null;    }

get操作的高效之处在于整个get过程不需要加锁,除非读到的值是空才会加锁重读。我们知道HashTable容器的get方法是需要加锁的,那么ConcurrentHashMap的get操作是如何做到不加锁的呢?原因是它的get方法里将要使用的共享变量都定义成volatile类型,如用于统计当前Segement大小的count字段和用于存储值的HashEntry的value。定义成volatile的变量,能够在线程之间保持可见性,能够被多线程同时读,并且保证不会读到过期的值,但是只能被单线程写(有一种情况可以被多线程写,就是写入的值不依赖于原值),在get操作里只需要读不需要写共享变量count和value,所以可以不用加锁。之所以不会读到过期的值,是因为根据Java内存模型的happen before原则,对volatile字段的写入操作先于读操作,即使两个线程同时修改和获取volatile变量,get操作也能拿到最新的值,这是用volatile替换锁的经典应用场景。

put

由于put方法里需要对共享变量进行写入操作,所以为了线程安全,在操作共享变量时必须加锁。put方法首先定位到Segment,然后在Segment里进行插入操作。插入操作需要经历两个步骤,第一步判断是否需要对Segment里的HashEntry数组进行扩容,第二步定位添加元素的位置,然后将其放在HashEntry数组里。
(1)是否需要扩容
在插入元素前会先判断Segment里的HashEntry数组是否超过容量(threshold),如果超过阈值,则对数组进行扩容。值得一提的是,Segment的扩容判断比HashMap更恰当,因为HashMap是在插入元素后判断元素是否已经到达容量的,如果到达了就进行扩容,但是很有可能扩容之后没有新元素插入,这时HashMap就进行了一次无效的扩容。
(2)如何扩容
在扩容的时候,首先会创建一个容量是原来容量两倍的数组,然后将原数组里的元素进行再散列后插入到新的数组里。为了高效,ConcurrentHashMap不会对整个容器进行扩容,而只对某个segment进行扩容。
我们看下源码

  /**     * Maps the specified key to the specified value in this table.     * Neither the key nor the value can be null.     *     * <p> The value can be retrieved by calling the <tt>get</tt> method     * with a key that is equal to the original key.     *     * @param key key with which the specified value is to be associated     * @param value value to be associated with the specified key     * @return the previous value associated with <tt>key</tt>, or     *         <tt>null</tt> if there was no mapping for <tt>key</tt>     * @throws NullPointerException if the specified key or value is null     */    @SuppressWarnings("unchecked")    public V put(K key, V value) {        Segment<K,V> s;        //1.value值不能为空        if (value == null)            throw new NullPointerException();        //2.key通过一次hash运算得到一个hash值。(这个hash运算下文详说)        int hash = hash(key);        //3.将得到的hash值向右按位移动segmentShift位,然后再与segmentMask做&运算得到Segment的索引        //在初始化的时候,segmentShift的值等于32-sshift,例如concurrencyLevel等于16,则sshift等于4,那么segmentShift为28。        //hash值是一个32位的整数,将其向右移动28就变成这个样子:0000 0000 0000 0000 0000 0000 0000 XXXX,然后再用这个值与segmentMask        //做&运算,也就是说取最后四位的值。这个值确定Segment的索引。        int j = (hash >>> segmentShift) & segmentMask;        //4.使用UNSAFE的方式从Segment数组中获取该索引对应的Segment对象        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment            s = ensureSegment(j);        //5.向这个Segment对象中put值,这个put操作也是一样的步骤        return s.put(key, hash, value, false);    }

下面是扩容的代码

  /**     * Returns the segment for the given index, creating it and     * recording in segment table (via CAS) if not already present.     *     * @param k the index     * @return the segment     */    @SuppressWarnings("unchecked")    private Segment<K,V> ensureSegment(int k) {        final Segment<K,V>[] ss = this.segments;        long u = (k << SSHIFT) + SBASE; // raw offset        Segment<K,V> seg;        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {            Segment<K,V> proto = ss[0]; // use segment 0 as prototype            int cap = proto.table.length;            float lf = proto.loadFactor;            int threshold = (int)(cap * lf);            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))                == null) { // recheck                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))                       == null) {                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))                        break;                }            }        }        return seg;    }

size

如果要统计整个ConcurrentHashMap里元素的大小,就必须统计所有Segment里元素的大小后求和。Segment里的全局变量count是一个volatile变量,那么在多线程场景下,是不是直接把所有Segment的count相加就可以得到整个ConcurrentHashMap大小了呢?不是的,虽然相加可以获取每个Segment的count的最新值,但是可能累加前使用的count发生了变化,那么统计结果就不准了。所以,最安全的做法是在统计size的时候把所有Segment的put、remove和clean方法全部锁住,但是这种做法显然非常低效。因为在累加count操作过程中,之前累加过的count发生变化的几率非常小,所以ConcurrentHashMap的做法是先尝试2次通过不锁住Segment的方式来统计各个Segment大小,如果统计的过程中,容器的count发生了变化,则再采用加锁的方式来统计所有Segment的大小。那么ConcurrentHashMap是如何判断在统计的时候容器是否发生了变化呢?使用modCount变量,在put、remove和clean方法里操作元素前都会将变量modCount进行加1,那么在统计size前后比较modCount是否发生变化,从而得知容器的大小是否发生变化。
源码如下

 /**     * Returns the number of key-value mappings in this map.  If the     * map contains more than <tt>Integer.MAX_VALUE</tt> elements, returns     * <tt>Integer.MAX_VALUE</tt>.     *     * @return the number of key-value mappings in this map     *      *1. size 操作和put与get的区别在于,size操作需要遍历所有的segment才能算出整个map的大小,而put和get操作只需要关心一个segment;     *2. 假设我们当前遍历的Segment为SA,那么在遍历SA过程中,其他的Segment比如SB可能会被修改,那么这一次计算出来的size值并不是Map的当前真正大小。     *     所以一个比较简单的办法是就是计算Map大小的时候所有的segment都lock住,不能更新数据(put 和 remove,计算完之后unlock;     *     *3. 作者Doug Lea 想出一个更好的idea:先给3次机会(retries初始化为-1,一直重试到RETRIES_BEFORE_LOCK值为2 ,不锁定lock所有Segment;     *   遍历所有的segment,累加各个segment的大小得到整个Map的大小。     *        *4.如果某相邻的2次计算获取的所有Segment的所有更新次数(每个Segment都有一个变量modCount变量,这个变量在Segment的Entry被修改的时候会加1     * 通过这个值可以得到每个Segment的更新操作的次数)是一样的,说明在计算的过程中没有更新操作,直接结束循环,返回当前的size;     *      * 5. 如果重试3次计算的结果中,Map的更新次数和前一次不一致,则之后的计算先对所有的Segment加锁,遍历所有segment计算map的大小,最后当重试计算>3次后再解锁所有的     * segment。         *      * 6.例子:     *      * 假如一个Map有4个segment,标记S1,S2,S3,S4,现在我们要获取Map的Size;     * 计算过程是这样的:     *                 第一次计算不对segment S1,S2,S3,S4加锁,遍历所有的segment,假设这次每个segment的大小变成了1,2,3,4;更新次数分别为2,2,3,1;则这次计算可以得到Map的总大小为1+2+3+4=10,总更新次数modCount=2+2+3+1=8;     *                 第二次计算,不对S1,S2,S3,S4加锁,遍历所有的Segment,假设这次每个segment的大小变成了2,2,3,4;更新次数变为了3,2,3,1; 则Map的size=2+2+3+4=11;modCount=9     *                 那么第一次和第二次计算得到的更新次数不一致,第一次是8,第二次是9;则可以判定这段时间Map的数据被更新;因此必须进行第3次重试计算;     *                 第三次计算,不对S1,S2,S3,S4加锁,遍历所有的Segment,假设每个Segment的更新次数还是为3,2,3,1;则因为第2次计算和第3次计算的得到的Map的modCount次数是一致的,则说明这段时间内第2次和第3次这段时间内Map的数据没有被更新     *                 此时可以返回第3次计算的Map大小;最坏的情况:第3次计算得到的计算结果和第2次不一致,则只能先对所有的Segment加锁再计算,最后解锁。     */    public int size() {        // Try a few times to get accurate count. On failure due to        // continuous async changes in table, resort to locking.        final Segment<K,V>[] segments = this.segments;        int size;        boolean overflow; // true if size overflows 32 bits        long sum;         // sum of modCounts        long last = 0L;   // previous sum        int retries = -1; // first iteration isn't retry        try {            for (;;) {                // 如果重试次数为3次,锁定segment                if (retries++ == RETRIES_BEFORE_LOCK) {                    for (int j = 0; j < segments.length; ++j)                        ensureSegment(j).lock(); // force creation                }                sum = 0L;                size = 0;                overflow = false;                for (int j = 0; j < segments.length; ++j) {                    //遍历所有的Segment                    Segment<K,V> seg = segmentAt(segments, j);                    if (seg != null) {                        //累加修改的次数                        sum += seg.modCount;                        //c代表segment的                        int c = seg.count;                        if (c < 0 || (size += c) < 0)                            overflow = true;                    }                }                //如果和前一次计算的Map的size一致,结束循环,返回最终的size值                if (sum == last)                    break;                last = sum;            }        } finally {            // 如果重试次数>3次则,释放segment锁            if (retries > RETRIES_BEFORE_LOCK) {                for (int j = 0; j < segments.length; ++j)                    segmentAt(segments, j).unlock();            }        }        return overflow ? Integer.MAX_VALUE : size;    }

containsValue

 /**     * Returns <tt>true</tt> if this map maps one or more keys to the     * specified value. Note: This method requires a full internal     * traversal of the hash table, and so is much slower than     * method <tt>containsKey</tt>.     *     * @param value value whose presence in this map is to be tested     * @return <tt>true</tt> if this map maps one or more keys to the     *         specified value     * @throws NullPointerException if the specified value is null     */    public boolean containsValue(Object value) {        // Same idea as size()        if (value == null)            throw new NullPointerException();        final Segment<K,V>[] segments = this.segments;        boolean found = false;        long last = 0;        int retries = -1;        try {            outer: for (;;) {                //重试3次,计算size后才给所有segment加锁,计算Map的size                if (retries++ == RETRIES_BEFORE_LOCK) {                    for (int j = 0; j < segments.length; ++j)                        ensureSegment(j).lock(); // force creation                }                long hashSum = 0L;                int sum = 0;                for (int j = 0; j < segments.length; ++j) {                    HashEntry<K,V>[] tab;                    //遍历所有的Segment                    Segment<K,V> seg = segmentAt(segments, j);                    if (seg != null && (tab = seg.table) != null) {                        //遍历每个Segment里面的HashEntry                        for (int i = 0 ; i < tab.length; i++) {                            HashEntry<K,V> e;                            for (e = entryAt(tab, i); e != null; e = e.next) {                                //获取value值,并且与入参value进行比较                                V v = e.value;                                //相同返回,found=true,退出循环                                if (v != null && value.equals(v)) {                                    found = true;                                    break outer;                                }                            }                        }                        //累加各个segment的更新次数                        sum += seg.modCount;                    }                }                //前一次计算的更新次数modCount和当前计算的segment的更新次数进行比较,相同,退出循环,返回found = true                if (retries > 0 && sum == last)                    break;                last = sum;            }        } finally {            //重试计算次数>3次后,释放segment锁            if (retries > RETRIES_BEFORE_LOCK) {                for (int j = 0; j < segments.length; ++j)                    segmentAt(segments, j).unlock();            }        }        return found;    }

参考文章:
http://www.infoq.com/cn/articles/ConcurrentHashMap
http://www.importnew.com/21781.html
参考书籍:
《Java并发编程》、JDK7源码