C# CancellationTokenSource和CancellationToken的实现

来源:互联网 发布:sql unique key 编辑:程序博客网 时间:2024/06/05 19:11

微软关于CancellationTokenSource的介绍很简单,其实CancellationTokenSource的使用也很简单,但是实现就不是那么简单了,我们首先来看看CancellationTokenSource的实现:

public class CancellationTokenSource : IDisposable{    private const int CANNOT_BE_CANCELED = 0;    private const int NOT_CANCELED = 1;    private const int NOTIFYING = 2;    private const int NOTIFYINGCOMPLETE = 3;        private volatile int m_state;    private static readonly Action<object> s_LinkedTokenCancelDelegate = new Action<object>(LinkedTokenCancelDelegate);        private static readonly int s_nLists = (PlatformHelper.ProcessorCount > 24) ? 24 : PlatformHelper.ProcessorCount;     private volatile CancellationCallbackInfo m_executingCallback;    private volatile SparselyPopulatedArray<CancellationCallbackInfo>[] m_registeredCallbacksLists;    private static readonly TimerCallback s_timerCallback = new TimerCallback(TimerCallbackLogic);    private volatile Timer m_timer;        public CancellationTokenSource()    {        m_state = NOT_CANCELED;    }        //Constructs a CancellationTokenSource that will be canceled after a specified time span.    public CancellationTokenSource(Int32 millisecondsDelay)    {        if (millisecondsDelay < -1)        {            throw new ArgumentOutOfRangeException("millisecondsDelay");        }        InitializeWithTimer(millisecondsDelay);    }        private void InitializeWithTimer(Int32 millisecondsDelay)    {        m_state = NOT_CANCELED;        m_timer = new Timer(s_timerCallback, this, millisecondsDelay, -1);    }        private static void TimerCallbackLogic(object obj)    {        CancellationTokenSource cts = (CancellationTokenSource)obj;        if (!cts.IsDisposed)        {                        try            {                cts.Cancel(); // will take care of disposing of m_timer            }            catch (ObjectDisposedException)            {                if (!cts.IsDisposed) throw;            }        }    }        public void Cancel()    {        Cancel(false);    }    public void Cancel(bool throwOnFirstException)    {        ThrowIfDisposed();        NotifyCancellation(throwOnFirstException);                }        public void CancelAfter(Int32 millisecondsDelay)    {        ThrowIfDisposed();        if (millisecondsDelay < -1)        {            throw new ArgumentOutOfRangeException("millisecondsDelay");        }        if (IsCancellationRequested) return;        if (m_timer == null)        {            Timer newTimer = new Timer(s_timerCallback, this, -1, -1);            if (Interlocked.CompareExchange(ref m_timer, newTimer, null) != null)            {                newTimer.Dispose();            }        }                // It is possible that m_timer has already been disposed, so we must do        // the following in a try/catch block.        try        {            m_timer.Change(millisecondsDelay, -1);        }        catch (ObjectDisposedException)        {                }    }           private void NotifyCancellation(bool throwOnFirstException)    {        if (IsCancellationRequested)            return;        // If we're the first to signal cancellation, do the main extra work.        if (Interlocked.CompareExchange(ref m_state, NOTIFYING, NOT_CANCELED) == NOT_CANCELED)        {            Timer timer = m_timer;            if(timer != null) timer.Dispose();            //record the threadID being used for running the callbacks.            ThreadIDExecutingCallbacks = Thread.CurrentThread.ManagedThreadId;                        //If the kernel event is null at this point, it will be set during lazy construction.            if (m_kernelEvent != null)                m_kernelEvent.Set(); // update the MRE value.            ExecuteCallbackHandlers(throwOnFirstException);            Contract.Assert(IsCancellationCompleted, "Expected cancellation to have finished");        }    }        /// Invoke the Canceled event. The handlers are invoked synchronously in LIFO order.    private void ExecuteCallbackHandlers(bool throwOnFirstException)    {        Contract.Assert(IsCancellationRequested, "ExecuteCallbackHandlers should only be called after setting IsCancellationRequested->true");        Contract.Assert(ThreadIDExecutingCallbacks != -1, "ThreadIDExecutingCallbacks should have been set.");        List<Exception> exceptionList = null;        SparselyPopulatedArray<CancellationCallbackInfo>[] callbackLists = m_registeredCallbacksLists;        if (callbackLists == null)        {            Interlocked.Exchange(ref m_state, NOTIFYINGCOMPLETE);            return;        }                try        {            for (int index = 0; index < callbackLists.Length; index++)            {                SparselyPopulatedArray<CancellationCallbackInfo> list = Volatile.Read<SparselyPopulatedArray<CancellationCallbackInfo>>(ref callbackLists[index]);                if (list != null)                {                    SparselyPopulatedArrayFragment<CancellationCallbackInfo> currArrayFragment = list.Tail;                    while (currArrayFragment != null)                    {                        for (int i = currArrayFragment.Length - 1; i >= 0; i--)                        {                            m_executingCallback = currArrayFragment[i];                            if (m_executingCallback != null)                            {                                CancellationCallbackCoreWorkArguments args = new CancellationCallbackCoreWorkArguments(currArrayFragment, i);                                try                                {                                    if (m_executingCallback.TargetSyncContext != null)                                    {                                        m_executingCallback.TargetSyncContext.Send(CancellationCallbackCoreWork_OnSyncContext, args);                                                                        ThreadIDExecutingCallbacks = Thread.CurrentThread.ManagedThreadId;                                    }                                    else                                    {                                        CancellationCallbackCoreWork(args);                                    }                                }                                catch(Exception ex)                                {                                    if (throwOnFirstException)                                        throw;                                    if(exceptionList == null)                                        exceptionList = new List<Exception>();                                    exceptionList.Add(ex);                                }                            }                        }                        currArrayFragment = currArrayFragment.Prev;                    }                }            }        }        finally        {            m_state = NOTIFYINGCOMPLETE;            m_executingCallback = null;            Thread.MemoryBarrier(); // for safety, prevent reorderings crossing this point and seeing inconsistent state.        }        if (exceptionList != null)        {            Contract.Assert(exceptionList.Count > 0, "Expected exception count > 0");            throw new AggregateException(exceptionList);        }    }        private void CancellationCallbackCoreWork_OnSyncContext(object obj)    {        CancellationCallbackCoreWork((CancellationCallbackCoreWorkArguments)obj);    }    private void CancellationCallbackCoreWork(CancellationCallbackCoreWorkArguments args)    {        CancellationCallbackInfo callback = args.m_currArrayFragment.SafeAtomicRemove(args.m_currArrayIndex, m_executingCallback);        if (callback == m_executingCallback)        {            if (callback.TargetExecutionContext != null)            {                callback.CancellationTokenSource.ThreadIDExecutingCallbacks = Thread.CurrentThread.ManagedThreadId;            }            callback.ExecuteCallback();        }    }        public static CancellationTokenSource CreateLinkedTokenSource(CancellationToken token1, CancellationToken token2)    {        CancellationTokenSource linkedTokenSource = new CancellationTokenSource();        bool token2CanBeCanceled = token2.CanBeCanceled;        if( token1.CanBeCanceled )        {            linkedTokenSource.m_linkingRegistrations = new CancellationTokenRegistration[token2CanBeCanceled ? 2 : 1]; // there will be at least 1 and at most 2 linkings            linkedTokenSource.m_linkingRegistrations[0] = token1.InternalRegisterWithoutEC(s_LinkedTokenCancelDelegate, linkedTokenSource);        }                if( token2CanBeCanceled )        {            int index = 1;            if( linkedTokenSource.m_linkingRegistrations == null )            {                linkedTokenSource.m_linkingRegistrations = new CancellationTokenRegistration[1]; // this will be the only linking                index = 0;            }            linkedTokenSource.m_linkingRegistrations[index] = token2.InternalRegisterWithoutEC(s_LinkedTokenCancelDelegate, linkedTokenSource);        }                return linkedTokenSource;    }        public static CancellationTokenSource CreateLinkedTokenSource(params CancellationToken[] tokens)    {        if (tokens == null)            throw new ArgumentNullException("tokens");        if (tokens.Length == 0)            throw new ArgumentException(Environment.GetResourceString("CancellationToken_CreateLinkedToken_TokensIsEmpty"));                    Contract.EndContractBlock();        CancellationTokenSource linkedTokenSource = new CancellationTokenSource();        linkedTokenSource.m_linkingRegistrations = new CancellationTokenRegistration[tokens.Length];        for (int i = 0; i < tokens.Length; i++)        {            if (tokens[i].CanBeCanceled)            {                linkedTokenSource.m_linkingRegistrations[i] = tokens[i].InternalRegisterWithoutEC(s_LinkedTokenCancelDelegate, linkedTokenSource);            }                }        return linkedTokenSource;    }        internal CancellationTokenRegistration InternalRegister(Action<object> callback, object stateForCallback, SynchronizationContext targetSyncContext, ExecutionContext executionContext)    {        if (AppContextSwitches.ThrowExceptionIfDisposedCancellationTokenSource)        {            ThrowIfDisposed();        }        Contract.Assert(CanBeCanceled, "Cannot register for uncancelable token src");        if (!IsCancellationRequested)        {            if (m_disposed && !AppContextSwitches.ThrowExceptionIfDisposedCancellationTokenSource)                return new CancellationTokenRegistration();            int myIndex = Thread.CurrentThread.ManagedThreadId % s_nLists;            CancellationCallbackInfo callbackInfo = new CancellationCallbackInfo(callback, stateForCallback, targetSyncContext, executionContext, this);            //allocate the callback list array            var registeredCallbacksLists = m_registeredCallbacksLists;            if (registeredCallbacksLists == null)            {                SparselyPopulatedArray<CancellationCallbackInfo>[] list = new SparselyPopulatedArray<CancellationCallbackInfo>[s_nLists];                registeredCallbacksLists = Interlocked.CompareExchange(ref m_registeredCallbacksLists, list, null);                if (registeredCallbacksLists == null) registeredCallbacksLists = list;            }            //allocate the actual lists on-demand to save mem in low-use situations, and to avoid false-sharing.            var callbacks = Volatile.Read<SparselyPopulatedArray<CancellationCallbackInfo>>(ref registeredCallbacksLists[myIndex]);            if (callbacks == null)            {                SparselyPopulatedArray<CancellationCallbackInfo> callBackArray = new SparselyPopulatedArray<CancellationCallbackInfo>(4);                Interlocked.CompareExchange(ref (registeredCallbacksLists[myIndex]), callBackArray, null);                callbacks = registeredCallbacksLists[myIndex];            }            // Now add the registration to the list.            SparselyPopulatedArrayAddInfo<CancellationCallbackInfo> addInfo = callbacks.Add(callbackInfo);            CancellationTokenRegistration registration = new CancellationTokenRegistration(callbackInfo, addInfo);            if (!IsCancellationRequested)                return registration;            bool deregisterOccurred = registration.TryDeregister();            if (!deregisterOccurred)            {                return registration;            }        }        // If cancellation already occurred, we run the callback on this thread and return an empty registration.        callback(stateForCallback);        return new CancellationTokenRegistration();    }                public bool IsCancellationRequested    {        get { return m_state >= NOTIFYING; }    }        internal bool IsCancellationCompleted    {        get { return m_state == NOTIFYINGCOMPLETE; }    }        public CancellationToken Token    {        get        {            ThrowIfDisposed();            return new CancellationToken(this);        }    }    internal CancellationCallbackInfo ExecutingCallback    {        get { return m_executingCallback; }    }   private static void LinkedTokenCancelDelegate(object source)    {        CancellationTokenSource cts = source as CancellationTokenSource;        Contract.Assert(source != null);        cts.Cancel();    }}

CancellationTokenSource的实现相对比较复杂,我们首先看看CancellationTokenSource的构造函数,默认构造函数将会设置【m_state = NOT_CANCELED】,我们也可以构造一个特定时间后就自动Cancel的CancellationTokenSource,自动Cancel是依赖一个Timer实例,在Timer到指定时间后调用CancellationTokenSource的Cancel方法【这里是在TimerCallbackLogic里面调用Cancel方法】,CancelAfter方法的实现也是依赖这个Timer实例和TimerCallbackLogic方法

现在我们来看看CancellationTokenSource最主要的一个方法Cancel,Cancel方法调用NotifyCancellation方法,NotifyCancellation方法主要调用ExecuteCallbackHandlers【从这个方法的名称可以猜测到主要是调用回调方法】,在ExecuteCallbackHandlers方法里面用到一个变量m_registeredCallbacksLists,它是SparselyPopulatedArray<CancellationCallbackInfo>[]结构,【可以理解为是一个链表的数组,数组每个元素时一个链表,链表里面的每个节点都可以访问下一个节点】,我们遍历这个链表数组的每一个节点,检查节点是否有值,即m_executingCallback != null,然后调用回调方法,如果回调方法的TargetSyncContext不为空,调用CancellationCallbackCoreWork_OnSyncContext方法,否者调用CancellationCallbackCoreWork方法【CancellationCallbackCoreWork_OnSyncContext里面也是调用它】,CancellationCallbackCoreWork方法是调用CancellationCallbackInfo的ExecuteCallback。

CancellationTokenSource有两个CreateLinkedTokenSource方法【可以理解为创建于当前的CreateLinkedTokenSource相关联的CreateLinkedTokenSource】,期主要实现是CancellationToken的Register方法。

public struct CancellationToken{    private CancellationTokenSource m_source;    internal CancellationToken(CancellationTokenSource source)    {        m_source = source;    }    public CancellationToken(bool canceled) :this()    {        if(canceled)            m_source = CancellationTokenSource.InternalGetStaticSource(canceled);    }        public CancellationTokenRegistration Register(Action callback)    {        if (callback == null)            throw new ArgumentNullException("callback");                return Register(s_ActionToActionObjShunt,callback,false,true);    }        public CancellationTokenRegistration Register(Action callback, bool useSynchronizationContext)    {        if (callback == null)            throw new ArgumentNullException("callback");                return Register(s_ActionToActionObjShunt,callback,useSynchronizationContext,true);    }       public CancellationTokenRegistration Register(Action<Object> callback, Object state)    {        if (callback == null)            throw new ArgumentNullException("callback");        return Register(callback,state,false,true);    }       /// Registers a delegate that will be called when this CancellationToken is canceled.    public CancellationTokenRegistration Register(Action<Object> callback, Object state, bool useSynchronizationContext)    {        return Register(callback,state,useSynchronizationContext,true);    }       private CancellationTokenRegistration Register(Action<Object> callback, Object state, bool useSynchronizationContext, bool useExecutionContext)    {        StackCrawlMark stackMark = StackCrawlMark.LookForMyCaller;        if (callback == null)            throw new ArgumentNullException("callback");        if (CanBeCanceled == false)        {            return new CancellationTokenRegistration(); // nothing to do for tokens than can never reach the canceled state. Give them a dummy registration.        }        SynchronizationContext capturedSyncContext = null;        ExecutionContext capturedExecutionContext = null;        if (!IsCancellationRequested)        {            if (useSynchronizationContext)                capturedSyncContext = SynchronizationContext.Current;            if (useExecutionContext)                capturedExecutionContext = ExecutionContext.Capture(ref stackMark, ExecutionContext.CaptureOptions.OptimizeDefaultCase);         }        // Register the callback with the source.        return m_source.InternalRegister(callback, state, capturedSyncContext, capturedExecutionContext);    }        private readonly static Action<Object> s_ActionToActionObjShunt = new Action<Object>(ActionToActionObjShunt);    private static void ActionToActionObjShunt(object obj)    {        Action action = obj as Action;        Contract.Assert(action != null, "Expected an Action here");        action();    }            public static CancellationToken None    {        get { return default(CancellationToken); }    }    public bool IsCancellationRequested     {        get        {            return m_source != null && m_source.IsCancellationRequested;        }    }        public bool CanBeCanceled    {        get        {            return m_source != null && m_source.CanBeCanceled;        }    }    public void ThrowIfCancellationRequested()    {        if (IsCancellationRequested)             ThrowOperationCanceledException();    }    private void ThrowOperationCanceledException()    {        throw new OperationCanceledException(Environment.GetResourceString("OperationCanceled"), this);    }}

CancellationToken的很多属性都是来源于CancellationTokenSource的属性,CancellationToken的主要方法 Register 也是嗲用CancellationTokenSource的InternalRegister方法。InternalRegister方法检查当前是否发起了Cancel【IsCancellationRequested】,如果是直接调用回调方法callback(stateForCallback);,否者把回调方法包装成CancellationCallbackInfo实例,然后添加到m_registeredCallbacksLists对象中,然后在返回CancellationTokenRegistration实例。

    internal class CancellationCallbackInfo    {        internal readonly Action<object> Callback;        internal readonly object StateForCallback;        internal readonly SynchronizationContext TargetSyncContext;        internal readonly ExecutionContext TargetExecutionContext;        internal readonly CancellationTokenSource CancellationTokenSource;        internal CancellationCallbackInfo(Action<object> callback, object stateForCallback, SynchronizationContext targetSyncContext, ExecutionContext targetExecutionContext,CancellationTokenSource cancellationTokenSource)        {            Callback = callback;            StateForCallback = stateForCallback;            TargetSyncContext = targetSyncContext;            TargetExecutionContext = targetExecutionContext;            CancellationTokenSource = cancellationTokenSource;        }        private static ContextCallback s_executionContextCallback;        internal void ExecuteCallback()        {            if (TargetExecutionContext != null)            {                var callback = s_executionContextCallback;                if (callback == null) s_executionContextCallback = callback = new ContextCallback(ExecutionContextCallback);                                ExecutionContext.Run(TargetExecutionContext, callback, this);            }            else            {                ExecutionContextCallback(this);            }        }        private static void ExecutionContextCallback(object obj)        {            CancellationCallbackInfo callbackInfo = obj as CancellationCallbackInfo;            Contract.Assert(callbackInfo != null);            callbackInfo.Callback(callbackInfo.StateForCallback);        }    }        internal class SparselyPopulatedArray<T> where T : class    {        private readonly SparselyPopulatedArrayFragment<T> m_head;        private volatile SparselyPopulatedArrayFragment<T> m_tail;        internal SparselyPopulatedArray(int initialSize)        {            m_head = m_tail = new SparselyPopulatedArrayFragment<T>(initialSize);        }        internal SparselyPopulatedArrayFragment<T> Tail        {            get { return m_tail; }        }        internal SparselyPopulatedArrayAddInfo<T> Add(T element)        {            while (true)            {                // Get the tail, and ensure it's up to date.                SparselyPopulatedArrayFragment<T> tail = m_tail;                while (tail.m_next != null)                    m_tail = (tail = tail.m_next);                // Search for a free index, starting from the tail.                SparselyPopulatedArrayFragment<T> curr = tail;                while (curr != null)                {                    const int RE_SEARCH_THRESHOLD = -10; // Every 10 skips, force a search.                    if (curr.m_freeCount < 1)                        --curr.m_freeCount;                    if (curr.m_freeCount > 0 || curr.m_freeCount < RE_SEARCH_THRESHOLD)                    {                        int c = curr.Length;                        int start = ((c - curr.m_freeCount) % c);                        if (start < 0)                        {                            start = 0;                            curr.m_freeCount--; // Too many free elements; fix up.                        }                        Contract.Assert(start >= 0 && start < c, "start is outside of bounds");                        // Now walk the array until we find a free slot (or reach the end).                        for (int i = 0; i < c; i++)                        {                            // If the slot is null, try to CAS our element into it.                            int tryIndex = (start + i) % c;                            Contract.Assert(tryIndex >= 0 && tryIndex < curr.m_elements.Length, "tryIndex is outside of bounds");                                                        if (curr.m_elements[tryIndex] == null && Interlocked.CompareExchange(ref curr.m_elements[tryIndex], element, null) == null)                            {                                int newFreeCount = curr.m_freeCount - 1;                                curr.m_freeCount = newFreeCount > 0 ? newFreeCount : 0;                                return new SparselyPopulatedArrayAddInfo<T>(curr, tryIndex);                            }                        }                    }                    curr = curr.m_prev;                }                // If we got here, we need to add a new chunk to the tail and try again.                SparselyPopulatedArrayFragment<T> newTail = new SparselyPopulatedArrayFragment<T>(                    tail.m_elements.Length == 4096 ? 4096 : tail.m_elements.Length * 2, tail);                if (Interlocked.CompareExchange(ref tail.m_next, newTail, null) == null)                {                    m_tail = newTail;                }            }        }    }        internal struct SparselyPopulatedArrayAddInfo<T> where T : class    {        private SparselyPopulatedArrayFragment<T> m_source;        private int m_index;        internal SparselyPopulatedArrayAddInfo(SparselyPopulatedArrayFragment<T> source, int index)        {            Contract.Assert(source != null);            Contract.Assert(index >= 0 && index < source.Length);            m_source = source;            m_index = index;        }        internal SparselyPopulatedArrayFragment<T> Source        {            get { return m_source; }        }        internal int Index        {            get { return m_index; }        }    }        internal class SparselyPopulatedArrayFragment<T> where T : class    {        internal readonly T[] m_elements; // The contents, sparsely populated (with nulls).        internal volatile int m_freeCount; // A hint of the number of free elements.        internal volatile SparselyPopulatedArrayFragment<T> m_next; // The next fragment in the chain.        internal volatile SparselyPopulatedArrayFragment<T> m_prev; // The previous fragment in the chain.        internal SparselyPopulatedArrayFragment(int size) : this(size, null)        {        }        internal SparselyPopulatedArrayFragment(int size, SparselyPopulatedArrayFragment<T> prev)        {            m_elements = new T[size];            m_freeCount = size;            m_prev = prev;        }        internal T this[int index]        {            get { return Volatile.Read<T>(ref m_elements[index]); }        }        internal int Length        {            get { return m_elements.Length; }        }        internal SparselyPopulatedArrayFragment<T> Prev        {            get { return m_prev; }        }        internal T SafeAtomicRemove(int index, T expectedElement)        {            T prevailingValue = Interlocked.CompareExchange(ref m_elements[index], null, expectedElement);            if (prevailingValue != null)                 ++m_freeCount;            return prevailingValue;        }    }

回头看CancellationCallbackInfo的实现也很简单。

原创粉丝点击