C# ConcurrentDictionary实现

来源:互联网 发布:中国移动java游戏 编辑:程序博客网 时间:2024/06/17 02:34

ConcurrentDictionary的源码看了很多遍,今天抽点时间整理一下,它的实现比Dictionary要复杂很多,至于线程安全我觉得比较简单,用的是lock的思想。首先我们来看看它的源码。

 public class ConcurrentDictionary<TKey, TValue> : IDictionary<TKey, TValue>, IDictionary, IReadOnlyDictionary<TKey, TValue>    {        /// <summary>        /// Tables that hold the internal state of the ConcurrentDictionary        ///        /// Wrapping the three tables in a single object allows us to atomically        /// replace all tables at once.        /// </summary>        private class Tables        {            internal readonly Node[] m_buckets; // A singly-linked list for each bucket.            internal readonly object[] m_locks; // A set of locks, each guarding a section of the table.            internal volatile int[] m_countPerLock; // The number of elements guarded by each lock.            internal readonly IEqualityComparer<TKey> m_comparer; // Key equality comparer            internal Tables(Node[] buckets, object[] locks, int[] countPerLock, IEqualityComparer<TKey> comparer)            {                m_buckets = buckets;                m_locks = locks;                m_countPerLock = countPerLock;                m_comparer = comparer;            }        }                private const int DEFAULT_CONCURRENCY_MULTIPLIER = 4;        private const int DEFAULT_CAPACITY = 31;        private const int MAX_LOCK_NUMBER = 1024;          // Whether TValue is a type that can be written atomically (i.e., with no danger of torn reads)        private static readonly bool s_isValueWriteAtomic = IsValueWriteAtomic();        public ConcurrentDictionary() : this(DefaultConcurrencyLevel, DEFAULT_CAPACITY, true, EqualityComparer<TKey>.Default)        public ConcurrentDictionary(int concurrencyLevel, int capacity) : this(concurrencyLevel, capacity, false, EqualityComparer<TKey>.Default) { }        public ConcurrentDictionary(int concurrencyLevel, int capacity, IEqualityComparer<TKey> comparer) : this(concurrencyLevel, capacity, false, comparer){}                internal ConcurrentDictionary(int concurrencyLevel, int capacity, bool growLockArray, IEqualityComparer<TKey> comparer)        {            if (concurrencyLevel < 1)            {                throw new ArgumentOutOfRangeException("concurrencyLevel", GetResource("ConcurrentDictionary_ConcurrencyLevelMustBePositive"));            }            if (capacity < 0)            {                throw new ArgumentOutOfRangeException("capacity", GetResource("ConcurrentDictionary_CapacityMustNotBeNegative"));            }            if (comparer == null) throw new ArgumentNullException("comparer");            // The capacity should be at least as large as the concurrency level. Otherwise, we would have locks that don't guard            // any buckets.            if (capacity < concurrencyLevel)            {                capacity = concurrencyLevel;            }            object[] locks = new object[concurrencyLevel];            for (int i = 0; i < locks.Length; i++)            {                locks[i] = new object();            }            int[] countPerLock = new int[locks.Length];            Node[] buckets = new Node[capacity];            m_tables = new Tables(buckets, locks, countPerLock, comparer);            m_growLockArray = growLockArray;            m_budget = buckets.Length / locks.Length;        }                public TValue this[TKey key]        {            get            {                TValue value;                if (!TryGetValue(key, out value))                {                    throw new KeyNotFoundException();                }                return value;            }            set            {                if (key == null) throw new ArgumentNullException("key");                TValue dummy;                TryAddInternal(key, value, true, true, out dummy);            }        }                public bool TryGetValue(TKey key, out TValue value)        {            if (key == null) throw new ArgumentNullException("key");            int bucketNo, lockNoUnused;            // We must capture the m_buckets field in a local variable. It is set to a new table on each table resize.            Tables tables = m_tables;            IEqualityComparer<TKey> comparer = tables.m_comparer;            GetBucketAndLockNo(comparer.GetHashCode(key), out bucketNo, out lockNoUnused, tables.m_buckets.Length, tables.m_locks.Length);            Node n = Volatile.Read<Node>(ref tables.m_buckets[bucketNo]);            while (n != null)            {                if (comparer.Equals(n.m_key, key))                {                    value = n.m_value;                    return true;                }                n = n.m_next;            }            value = default(TValue);            return false;        }                private bool TryAddInternal(TKey key, TValue value, bool updateIfExists, bool acquireLock, out TValue resultingValue)        {            while (true)            {                int bucketNo, lockNo;                int hashcode;                Tables tables = m_tables;                IEqualityComparer<TKey> comparer = tables.m_comparer;                hashcode = comparer.GetHashCode(key);                GetBucketAndLockNo(hashcode, out bucketNo, out lockNo, tables.m_buckets.Length, tables.m_locks.Length);                bool resizeDesired = false;                bool lockTaken = false;                try                {                    if (acquireLock)                        Monitor.Enter(tables.m_locks[lockNo], ref lockTaken);                    // If the table just got resized, we may not be holding the right lock, and must retry.                    // This should be a rare occurence.                    if (tables != m_tables)                    {                        continue;                    }                    // Try to find this key in the bucket                    Node prev = null;                    for (Node node = tables.m_buckets[bucketNo]; node != null; node = node.m_next)                    {                        Assert((prev == null && node == tables.m_buckets[bucketNo]) || prev.m_next == node);                        if (comparer.Equals(node.m_key, key))                        {                            // The key was found in the dictionary. If updates are allowed, update the value for that key.                            // We need to create a new node for the update, in order to support TValue types that cannot                            // be written atomically, since lock-free reads may be happening concurrently.                            if (updateIfExists)                            {                                if (s_isValueWriteAtomic)                                {                                    node.m_value = value;                                }                                else                                {                                    Node newNode = new Node(node.m_key, value, hashcode, node.m_next);                                    if (prev == null)                                    {                                        tables.m_buckets[bucketNo] = newNode;                                    }                                    else                                    {                                        prev.m_next = newNode;                                    }                                }                                resultingValue = value;                            }                            else                            {                                resultingValue = node.m_value;                            }                            return false;                        }                        prev = node;                    }                    // The key was not found in the bucket. Insert the key-value pair.                    Volatile.Write<Node>(ref tables.m_buckets[bucketNo], new Node(key, value, hashcode, tables.m_buckets[bucketNo]));                    checked                    {                        tables.m_countPerLock[lockNo]++;                    }                    if (tables.m_countPerLock[lockNo] > m_budget)                    {                        resizeDesired = true;                    }                }                finally                {                    if (lockTaken)                        Monitor.Exit(tables.m_locks[lockNo]);                }                if (resizeDesired)                {                    GrowTable(tables, tables.m_comparer, false, m_keyRehashCount);                }                resultingValue = value;                return true;            }        }        public bool TryRemove(TKey key, out TValue value)        {            if (key == null) throw new ArgumentNullException("key");            return TryRemoveInternal(key, out value, false, default(TValue));        }                private bool TryRemoveInternal(TKey key, out TValue value, bool matchValue, TValue oldValue)        {            while (true)            {                Tables tables = m_tables;                IEqualityComparer<TKey> comparer = tables.m_comparer;                int bucketNo, lockNo;                GetBucketAndLockNo(comparer.GetHashCode(key), out bucketNo, out lockNo, tables.m_buckets.Length, tables.m_locks.Length);                lock (tables.m_locks[lockNo])                {                    // If the table just got resized, we may not be holding the right lock, and must retry.                    // This should be a rare occurence.                    if (tables != m_tables)                    {                        continue;                    }                    Node prev = null;                    for (Node curr = tables.m_buckets[bucketNo]; curr != null; curr = curr.m_next)                    {                        Assert((prev == null && curr == tables.m_buckets[bucketNo]) || prev.m_next == curr);                        if (comparer.Equals(curr.m_key, key))                        {                            if (matchValue)                            {                                bool valuesMatch = EqualityComparer<TValue>.Default.Equals(oldValue, curr.m_value);                                if (!valuesMatch)                                {                                    value = default(TValue);                                    return false;                                }                            }                            if (prev == null)                            {                                Volatile.Write<Node>(ref tables.m_buckets[bucketNo], curr.m_next);                            }                            else                            {                                prev.m_next = curr.m_next;                            }                            value = curr.m_value;                            tables.m_countPerLock[lockNo]--;                            return true;                        }                        prev = curr;                    }                }                value = default(TValue);                return false;            }        }        private void GrowTable(Tables tables, IEqualityComparer<TKey> newComparer, bool regenerateHashKeys, int rehashCount)        {            int locksAcquired = 0;            try            {                AcquireLocks(0, 1, ref locksAcquired);                if (regenerateHashKeys && rehashCount == m_keyRehashCount)                {                    tables = m_tables;                }                else                {                    if (tables != m_tables)                    {                        return;                    }                    long approxCount = 0;                    for (int i = 0; i < tables.m_countPerLock.Length; i++)                    {                        approxCount += tables.m_countPerLock[i];                    }                    if (approxCount < tables.m_buckets.Length / 4)                    {                        m_budget = 2 * m_budget;                        if (m_budget < 0)                        {                            m_budget = int.MaxValue;                        }                        return;                    }                }                int newLength = 0;                bool maximizeTableSize = false;                try                {                    checked                    {                        newLength = tables.m_buckets.Length * 2 + 1;                        while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0)                        {                            newLength += 2;                        }                        Assert(newLength % 2 != 0);                        if (newLength > Array.MaxArrayLength)                        {                            maximizeTableSize = true;                        }                    }                }                catch (OverflowException)                {                    maximizeTableSize = true;                }                if (maximizeTableSize)                {                    newLength = Array.MaxArrayLength;                    m_budget = int.MaxValue;                }                // Now acquire all other locks for the table                AcquireLocks(1, tables.m_locks.Length, ref locksAcquired);                object[] newLocks = tables.m_locks;                // Add more locks                if (m_growLockArray && tables.m_locks.Length < MAX_LOCK_NUMBER)                {                    newLocks = new object[tables.m_locks.Length * 2];                    Array.Copy(tables.m_locks, newLocks, tables.m_locks.Length);                    for (int i = tables.m_locks.Length; i < newLocks.Length; i++)                    {                        newLocks[i] = new object();                    }                }                Node[] newBuckets = new Node[newLength];                int[] newCountPerLock = new int[newLocks.Length];                for (int i = 0; i < tables.m_buckets.Length; i++)                {                    Node current = tables.m_buckets[i];                    while (current != null)                    {                        Node next = current.m_next;                        int newBucketNo, newLockNo;                        int nodeHashCode = current.m_hashcode;                        if (regenerateHashKeys)                        {                            // Recompute the hash from the key                            nodeHashCode = newComparer.GetHashCode(current.m_key);                        }                        GetBucketAndLockNo(nodeHashCode, out newBucketNo, out newLockNo, newBuckets.Length, newLocks.Length);                        newBuckets[newBucketNo] = new Node(current.m_key, current.m_value, nodeHashCode, newBuckets[newBucketNo]);                        checked                        {                            newCountPerLock[newLockNo]++;                        }                        current = next;                    }                }                // If this resize regenerated the hashkeys, increment the count                if (regenerateHashKeys)                {                    // We use unchecked here because we don't want to throw an exception if                     // an overflow happens                    unchecked                    {                        m_keyRehashCount++;                    }                }                // Adjust the budget                m_budget = Math.Max(1, newBuckets.Length / newLocks.Length);                // Replace tables with the new versions                m_tables = new Tables(newBuckets, newLocks, newCountPerLock, newComparer);            }            finally            {                // Release all locks that we took earlier                ReleaseLocks(0, locksAcquired);            }        }        private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired)        {            Assert(fromInclusive <= toExclusive);            object[] locks = m_tables.m_locks;            for (int i = fromInclusive; i < toExclusive; i++)            {                bool lockTaken = false;                try                {                   Monitor.Enter(locks[i], ref lockTaken);                }                finally                {                    if (lockTaken)                    {                        locksAcquired++;                    }                }            }        }        private void GetBucketAndLockNo(int hashcode, out int bucketNo, out int lockNo, int bucketCount, int lockCount)        {            bucketNo = (hashcode & 0x7fffffff) % bucketCount;            lockNo = bucketNo % lockCount;            Assert(bucketNo >= 0 && bucketNo < bucketCount);            Assert(lockNo >= 0 && lockNo < lockCount);        }        private static int DefaultConcurrencyLevel        {            get { return DEFAULT_CONCURRENCY_MULTIPLIER * PlatformHelper.ProcessorCount; }        }        private class Node        {            internal TKey m_key;            internal TValue m_value;            internal volatile Node m_next;            internal int m_hashcode;            internal Node(TKey key, TValue value, int hashcode, Node next)            {                m_key = key;                m_value = value;                m_next = next;                m_hashcode = hashcode;            }        }            }        public static class Volatile    {        [ResourceExposure(ResourceScope.None)]        [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]        [SecuritySafeCritical] //the intrinsic implementation of this method contains unverifiable code        public static T Read<T>(ref T location) where T : class        {            var value = location;            Thread.MemoryBarrier();            return value;        }                [ResourceExposure(ResourceScope.None)]        [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]        [SecuritySafeCritical] //the intrinsic implementation of this method contains unverifiable code        public static void Write<T>(ref T location, T value) where T : class        {             Thread.MemoryBarrier();            location = value;        }    }

 

ConcurrentDictionary的构造函数依然有int capacity参数,该参数是控制ConcurrentDictionary里面的初始节点数组的大小【Node[] buckets = new Node[capacity] 和m_tables = new Tables(buckets, locks, countPerLock, comparer);】,同时构造函数中多了一个int concurrencyLevel参数,控制并行度【object[] locks = new object[concurrencyLevel]; for (int i = 0; i < locks.Length; i++){  locks[i] = new object(); }】。如果指定了int capacity参数,很多时候参数bool growLockArray为false【m_growLockArray = growLockArray;】表示ConcurrentDictionary在扩容的时候,object[] locks 这个锁的对象数组不扩容,可以理解为锁的粒度变大了,先前4个key公用一个lock对象,现在可能8个key对应一个对象;m_budget = buckets.Length / locks.Length中的m_budget 可以理解为一个lock对象被多少个key共享

现在我们来看看TryGetValue获取值,这个方法非常简单,应为读取时不需要加锁的,所以首先根据key计算其哈希值,再找到对应的哈希桶,读取哈希桶的数据【Node n = Volatile.Read<Node>(ref tables.m_buckets[bucketNo])】;一个哈希桶的数据可能有多个【 while (n != null){if (comparer.Equals(n.m_key, key)){ value = n.m_value; return true; } n = n.m_next;}】,所以从这里可以看出来每个 哈希桶里面是一个Node链表数据结构

接下来我们看看比较复杂的TryAddInternal方法,优先需要根据key来确定哈希桶,无论是添加还是修改 都需要锁定对象,所以这里用的是Monitor.Enter(tables.m_locks[lockNo], ref lockTaken); 在最后在释放锁 Monitor.Exit(tables.m_locks[lockNo]);,如果是添加元素那么直接给里面的哈希桶赋值 Volatile.Write<Node>(ref tables.m_buckets[bucketNo], new Node(key, value, hashcode, tables.m_buckets[bucketNo]));注意Node的构造函数,tables.m_buckets[bucketNo])将是新节点的m_next值,也就是添加的新节点永远是哈希桶链表的第一个节点,这里,赋值后对应的lock对象的计数器需要加1【tables.m_countPerLock[lockNo]++;】,如果每个计数器达到预计达阀值就需要扩容了【if (tables.m_countPerLock[lockNo] > m_budget){ resizeDesired = true;}】,那么修改也是首先找到对应的node节点【如果添加的key所在哈希桶里面存在数据】,如果value是可以直接修改的话,那么我们直接修改【 if (s_isValueWriteAtomic) { node.m_value = value;}】,不是的话那我们就克隆一个节点 替换掉原先的节点【Node newNode = new Node(node.m_key, value, hashcode, node.m_next); if (prev == null){ tables.m_buckets[bucketNo] = newNode; } else{ prev.m_next = newNode;}】,如果是桶的第一个节点那么替换比较简单,否者就修改先前节点的m_next 属性

接下来我们来看看哈希桶的扩容GrowTable,这个方法比较复杂,我就没怎么仔细研读了,首先是多线程我们需要考虑线程安全,说白了就是加锁 AcquireLocks(0, 1, ref locksAcquired),哈希桶扩容基本是按照2倍来扩容的【 newLength = tables.m_buckets.Length * 2 + 1; while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0){  newLength += 2; }】,在正真扩容前我们需要锁定所有对象【AcquireLocks(1, tables.m_locks.Length, ref locksAcquired);】,扩容首先需要扩容锁的对象数组

 if (m_growLockArray && tables.m_locks.Length < MAX_LOCK_NUMBER)                {                    newLocks = new object[tables.m_locks.Length * 2];                    Array.Copy(tables.m_locks, newLocks, tables.m_locks.Length);                    for (int i = tables.m_locks.Length; i < newLocks.Length; i++)                    {                        newLocks[i] = new object();                    }                }

然后在是哈希桶扩容,这里扩容可以理解为克隆原先的节点到新的数组中 旧的位置上【newBuckets[newBucketNo] = new Node(current.m_key, current.m_value, nodeHashCode, newBuckets[newBucketNo]);】

 Node[] newBuckets = new Node[newLength];                int[] newCountPerLock = new int[newLocks.Length];                for (int i = 0; i < tables.m_buckets.Length; i++)                {                    Node current = tables.m_buckets[i];                    while (current != null)                    {                        Node next = current.m_next;                        int newBucketNo, newLockNo;                        int nodeHashCode = current.m_hashcode;                        if (regenerateHashKeys)                        {                            // Recompute the hash from the key                            nodeHashCode = newComparer.GetHashCode(current.m_key);                        }                        GetBucketAndLockNo(nodeHashCode, out newBucketNo, out newLockNo, newBuckets.Length, newLocks.Length);                        newBuckets[newBucketNo] = new Node(current.m_key, current.m_value, nodeHashCode, newBuckets[newBucketNo]);                        checked                        {                            newCountPerLock[newLockNo]++;                        }                        current = next;                    }                }

看来扩容,最后来看看移除元素,首先需要根据key来计算哈希桶的位置【GetBucketAndLockNo(comparer.GetHashCode(key), out bucketNo, out lockNo, tables.m_buckets.Length, tables.m_locks.Length)】,然后锁住对应的对象【  lock (tables.m_locks[lockNo])】,在哈希桶里面获取遍历链表查找对应的key,如果是桶的第一个节点则直接写 Volatile.Write<Node>(ref tables.m_buckets[bucketNo], curr.m_next),否者修改链表prev.m_next = curr.m_next,最后该lock对象的计数器需要减1【tables.m_countPerLock[lockNo]--】。

原创粉丝点击