ThreadLocal源码分析

来源:互联网 发布:源码在线编辑 编辑:程序博客网 时间:2024/06/06 11:49

一、概念

ThreadLocal提供线程本地变量。这些变量与普通变量不同,每个线程有自己的一份拷贝。ThreadLocal实例典型用法是在类中作为私有的静态域,用与线程绑定状态(比如,用户ID或者事务ID)。
简而言之,每个线程第一次获取该值后,之后在线程内部就可以随意操作该变量,但是这并不影响其他线程,因为每个线程一旦拥有了该变量后,就会有一份拷贝。

public class ThreadId {     // Atomic integer containing the next thread ID to be assigned     private static final AtomicInteger nextId = new AtomicInteger(0);     // Thread local variable containing each thread's ID     private static final ThreadLocal<Integer> threadId =         new ThreadLocal<Integer>() {             @Override protected Integer initialValue() {                 return nextId.getAndIncrement();         }     };     // Returns the current thread's unique ID, assigning it if necessary     public static int get() {         return threadId.get();     } }

这是ThreadLocal文档中的一个例子。

二、 主要API

ThreadLocal有get、set、remove用于设置变量值,还有受保护的initialValue用于返回当前线程初值,如果开发者不覆写该方法,那么默认返回null。initialValue一般只会调用一次,在调用get或set方法时调用,但是如果在调用remove之后再调用get就会又触发initialValue。

三、 源码分析

get()方法

ThreadLocal的get()方法用于获取本线程存储在ThreadLocal中的变量,其实现如下:

public T get() {        Thread t = Thread.currentThread();        //得到当前线程的Map        ThreadLocalMap map = getMap(t);        //如果map不为null        if (map != null) {            //得到存储在map中的值            ThreadLocalMap.Entry e = map.getEntry(this);            //如果Entry不为null,那么返回Entry的值            if (e != null) {                @SuppressWarnings("unchecked")                T result = (T)e.value;                return result;            }        }        //调用setInitialValue()设置初始值        return setInitialValue();    }

从上面的get()方法可以看出大体流程:
1. 根据当前线程取得ThreadLocalMap,如果不存在map,那么调用setInitialValue()创建Map并设置初始值;
2. 如果map不为null,那么从map中得到Entry,如果Entry为null,调用setInitialValue()方法设置初始值;
3. 如果Entry不为null,那么就返回值即可。

当Map不存在或者Entry不存在时,将会调用setInitialValue()方法,下面是setInitialValue()方法的实现:

private T setInitialValue() {        T value = initialValue();        Thread t = Thread.currentThread();        ThreadLocalMap map = getMap(t);        //如果Map不为null,那么就设置初值        if (map != null)            map.set(this, value);        //如果Map为null,那么就创建表        else            createMap(t, value);        return value;    }

可以看到setInitialValue()方法在Map不存在的时候会创建表,而在Map存在的时候,就会设置初始值。
那么下面先看一下是如何创建Map的,并且这个Map到底是什么样子的。

void createMap(Thread t, T firstValue) {        t.threadLocals = new ThreadLocalMap(this, firstValue);    }

可以看到createMap()方法创建了一个ThreadLocalMap对象,并将ThreadLocalMap对象赋给了Thread的threadLocals变量,所以可以得出每一个线程都有一个ThreadLocalMap对象,用于存储每一个线程的本地变量,Map只有一个,但是却可以存储多个本地变量。
既然createMap()就是给Thread的threadLocals赋值,那么可以猜测从线程中得到Map就是获取这个变量,下面是getMap()的实现,

ThreadLocalMap getMap(Thread t) {        return t.threadLocals;    }

可以看到,该方法果然如我们所想,就是返回Thread的threadLocals变量。
至此,我们分析完了ThreadLocal的get()方法,其中的关键就是ThreadLocalMap,这个类后面会具体分析,看它是如何实现的。

set(T value)方法

set()方法用于设置当前线程的本地变量值,其实现如下:

 public void set(T value) {        Thread t = Thread.currentThread();        ThreadLocalMap map = getMap(t);        //如果map不为null,那么直接设置值        if (map != null)            map.set(this, value);        //如果Map为null,那么需要创建表        else            createMap(t, value);    }

经过了上面的get()方法,可以看出set()方法的流程是如出一辙的:
1. 如果Thread的threadLocals变量不为null,那么就直接将值设置;
2. 如果Thread的threadLocals变量为null,那么创建ThreadLocalMap并赋值给threadLocals变量。

remove()

remove()方法用于删除当前线程的本地变量,其实现如下:

 public void remove() {         ThreadLocalMap m = getMap(Thread.currentThread());         if (m != null)             m.remove(this);     }

可以看到也是调用了ThreadLocalMap的remove()方法。
ThreadLocal的get()、set()和remove()方法最终都是委托给了ThreadLocalMap的相应方法,下面我们就着重分析一下ThreadLocalMap是如何实现的。

ThreadLocalMap源码分析

ThreadLocalMap是一个自定义的HashMap,用于存储线程本地变量。

构造方法

createMap()中调用ThreadLocalMap的构造方法,其实现如下:

 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {            //初始化一个Entry的数组,默认容量16            table = new Entry[INITIAL_CAPACITY];            //计算桶的索引            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);            //创建Entry并赋值            table[i] = new Entry(firstKey, firstValue);            size = 1;            //设置阈值            setThreshold(INITIAL_CAPACITY);        }

可以看到ThreadLocalMap的初始化完成的是Entry数组的建立,默认初始容量16,但是阈值的设置与HashMap有些不同,HashMap采用的加载因子默认是0.75,而setThreshold()方法如下:

private void setThreshold(int len) {            threshold = len * 2 / 3;        }

所以可以看到ThreadLocalMap的加载因子是2/3。
下面看一下Entry对象的定义:

static class Entry extends WeakReference<ThreadLocal<?>> {            /** The value associated with this ThreadLocal. */            Object value;            Entry(ThreadLocal<?> k, Object v) {                super(k);                value = v;            }        }

可以看到Entry继承自WeakReference,其key为ThreadLocal,值为传入的值。
既然键是ThreadLocal,那么键的hash值是什么呢?在构造方法中可以看到threadLocalHashCode变量,其定义如下:

private final int threadLocalHashCode = nextHashCode();private static AtomicInteger nextHashCode =        new AtomicInteger();private static final int HASH_INCREMENT = 0x61c88647;private static int nextHashCode() {        return nextHashCode.getAndAdd(HASH_INCREMENT);    }

可以看到,一旦一个ThreadLocal创建了,那么其threadLocalHashCode就是确定的,而由于nextHashCode是静态的,所以这会导致每一个ThreadLocal的threadLocalHashCode是不相同的,所以可以得出结论:ThreadLocalMap中的键不是根据Thread的ID进行hash的,而是根据其关联的ThreadLocal的threadLocalHashCode值确定的。
因此,如果有多个ThreadLocal对象,一个线程同时在这几个ThreadLocal对象中存储本地变量,那么因为ThreadLocal的threadLocalHashCode不同,将会被放进不同的Entry桶中。
根据Entry的定义,可以发现其并不是一个HashMap中常见的链表节点,所以可以得出结论:ThreadLocalMap中的Entry数组每一个桶中最多只会存放一个Entry。

ThreadLocalMap#getEntry()方法

 private Entry getEntry(ThreadLocal<?> key) {            //计算桶处的索引            int i = key.threadLocalHashCode & (table.length - 1);            //得到桶处的Entry            Entry e = table[i];            //如果Entry不为null并且键值相同,则返回            if (e != null && e.get() == key)                return e;            //否则            else                return getEntryAfterMiss(key, i, e);        }

从上面可以看到,根据ThreadLocal的threadLocalHashCode计算桶的索引,然后尝试得到Entry,而一旦Entry不为null并且保存的键值也相等,那么返回;否则调用getEntryAfterMiss()方法。
因为Entry继承自WeakReference,所以就存在其get()方法返回null的情况,所以需要处理为null的情况,下面看一下getEntryAfterMiss()是如何实现的:

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {            Entry[] tab = table;            int len = tab.length;            //当Entry不为null            while (e != null) {                //得到与之关联的ThreadLocal                ThreadLocal<?> k = e.get();                //如果键值相等,直接返回                if (k == key)                    return e;                //如果键为null                if (k == null)                    expungeStaleEntry(i);                //不为null但是不相等                else                    i = nextIndex(i, len);                e = tab[i];            }            return null;        }

从上面的代码可以得出,如果出现了Hash冲突,那么会使用线性探查法查找下一个索引,nextIndex()方法的实现如下:

  private static int nextIndex(int i, int len) {            return ((i + 1 < len) ? i + 1 : 0);        }

如果遇到了Entry不为null,但是其保存的ThreadLocal已经为null了,那么会调用expungeStaleEntry()方法,该方法用于删除一个过时的Entry,其实现如下:

private int expungeStaleEntry(int staleSlot) {            Entry[] tab = table;            int len = tab.length;            // 回收Entry            tab[staleSlot].value = null;            tab[staleSlot] = null;            size--;            // rehash知道遇到一个null            Entry e;            int i;            //从下一个索引开始,只要该索引出Entry不为null,那么继续循环            for (i = nextIndex(staleSlot, len);                 (e = tab[i]) != null;                 i = nextIndex(i, len)) {                ThreadLocal<?> k = e.get();                //如果键值不存在,那么回收Entry                if (k == null) {                    e.value = null;                    tab[i] = null;                    size--;                }                //如果键值存在                else {                    //重新计算索引                    int h = k.threadLocalHashCode & (len - 1);                    //如果不匹配之前的索引,那么需要将该处的Entry向前移                    if (h != i) {                        tab[i] = null;                        //从最初的索引开始往后循环                        while (tab[h] != null)                            h = nextIndex(h, len);                        tab[h] = e;                    }                }            }            return i;        }

经过上面的注释和线性探查法坚决hash冲突的问题,那么需要首先回收过时的Entry;接下来需要继续往下回收,一旦遇到了一个不需要回收的Entry,由于前面的Entry被回收了,而该Entry可能之前由于hash冲突被移到了后面,那么现在要做的就是将该Entry往前移,往前移的原则就是从计算得到的索引出往后查找一个空位置,然后插入。
经过上面的分析,可以得出结论:
1. ThreadLocalMap的hash规则是使用的ThreadLocal的threadLocalHashCode变量的;
2. ThreadLocalMap中的数组每一个槽只能存放一个Entry,解决hash冲突使用的线性探查法;
3. Entry继承自WeakReference,所以存在引用被GC回收的情况,需要清理过时Entry;
4. 清理过时Entrt时需要考虑将原先由于线性探测法放到后面的移到前面的索引。

分析完了get()方法后,基本知道了ThreadLocalMap的原理,下面再分析set()方法,可以看到rehash规则。

ThreadLocalMap#set()

private void set(ThreadLocal<?> key, Object value) {            Entry[] tab = table;            int len = tab.length;            //得到索引            int i = key.threadLocalHashCode & (len-1);            //循环查找一个空位置            for (Entry e = tab[i];                 e != null;                 e = tab[i = nextIndex(i, len)]) {                ThreadLocal<?> k = e.get();                //如果键值相同,那么直接更新                if (k == key) {                    e.value = value;                    return;                }                //如果键值为null,需要替换过时的值,直接返回                if (k == null) {                    replaceStaleEntry(key, value, i);                    return;                }            }            //查找到了一个空槽位,放置Entry            tab[i] = new Entry(key, value);            //尺寸+1            int sz = ++size;            //如果超过了阈值,那么rehash            if (!cleanSomeSlots(i, sz) && sz >= threshold)                rehash();        }

从上面可以基本得出set()方法的流程,主要就是要找到一个空槽放置Entry,所以使用了for循环。
1. for循环中,如果键值相同,那么替换旧值并返回;
2. for循环中,如果键值为null,那么替换过时的值并返回;
3. 如果for循环中没有退出方法,那么意味着找到了一个空槽位,放置Entry然后判断是否需要rehash。

下面主要关心一下rehash()方法,其实现如下:

ThreadLocalMap#rehash()

private void rehash() {            //清除过时Entry            expungeStaleEntries();            // 如果尺寸仍然大于长度的一般,那么进行resize()            if (size >= threshold - threshold / 4)                resize();        }

可以看到rehash()的流程有两步:
1. 清除过时的Entry
2. 如果清除完,size仍然超过了容量的一般,那么进行resize()方法。

至于这个阈值是如何计算的,其公式如下:

threshold-threshold/4=len*2/3-len*2/3/4=len/2

下面看一下resize()是如何扩容以及rehash的,其实现如下:

ThreadLocalMap#resize()

 private void resize() {            Entry[] oldTab = table;            int oldLen = oldTab.length;            //扩容策略为扩大2倍            int newLen = oldLen * 2;            Entry[] newTab = new Entry[newLen];            int count = 0;            //移除值            for (int j = 0; j < oldLen; ++j) {                Entry e = oldTab[j];                if (e != null) {                    ThreadLocal<?> k = e.get();                    if (k == null) {                        e.value = null; // Help the GC                    } else {                        //与新的容量做与                        int h = k.threadLocalHashCode & (newLen - 1);                        while (newTab[h] != null)                            h = nextIndex(h, newLen);                        newTab[h] = e;                        count++;                    }                }            }            setThreshold(newLen);            size = count;            table = newTab;        }

可以看到resize()中扩容是扩大两倍,计算新的索引值是与新的容量做与。
ThreadLocalMap中主要还有一个remove()方法,这里就不再看了。

总结

ThreadLocal用于保存线程的本地变量,其实现原理是将变量保存在每一个Thread的threadLocals变量中,而该变量是一个ThreadLocalMap用于保存多个ThreadLocal存在Thread中的值。
ThreadLocalMap使用的hash规则是ThreadLocal的threadLocalHashCode,每一个ThreadLocal对象均不同;使用的hash冲突解决方法是线性探查法。基于这两点就可以很容易地分析ThreadLocalMap的各个方法的实现。

0 0