您的位置:首页 > 其它

ThreadLocal源码分析

2014-05-11 17:00 381 查看
这一篇之所以讲ThreadLocal,是因为之前在读Handler,Looper的源码过程(见http://maosidiaoxian.iteye.com/blog/1927735)中,看到了这个类,引起了我的兴趣。而后来发现JAVA1.6中的TheadLocal类,和我在android源码看到的这个ThreadLocal类代码是不一样的。所以这篇先讲一下Java的ThreadLocal。

Java中ThreadLocal在Java1.2就已经提出了,后来重构过,所以我在想android中的这个类的实现是不是重构过之前的版本优化过的,由于未找到早前版本该类的代码,所以只能进行猜测。

ThreadLocal用于提供线程局部变量,即通过这里的get()和set()方法访问某个变量的线程都有自己的局部变量。但在这里需要注意的是,它是通过各个线程自己创建对象然后调用这里的set方法设置进去,而不是由ThreadLocal自己来创建变量的副本的。就像上面链接的那一篇文章里面提到的android的Looper类所用的(虽然Android里ThreadLocal实现代码不一样,但用法是一样的),代码如下:

Java代码


static final ThreadLocal<Looper> sThreadLocal = new ThreadLocal<Looper>();

private static void prepare(boolean quitAllowed) {

if (sThreadLocal.get() != null) {

throw new RuntimeException("Only one Looper may be created per thread");

}

sThreadLocal.set(new Looper(quitAllowed));

}

可以看到,线程中的局部变量,是在线程中创建,然后调用ThreadLocal.set()方法设置进去的。

下面详细读一下ThreadLocal这个类,看看它是怎么实现的(jdk1.6的源码)。



从上图可以看到ThreadLocal有三个变(常)量,代码如下:

Java代码


/**

* ThreadLocals rely on per-thread linear-probe hash maps attached

* to each thread (Thread.threadLocals and

* inheritableThreadLocals). The ThreadLocal objects act as keys,

* searched via threadLocalHashCode. This is a custom hash code

* (useful only within ThreadLocalMaps) that eliminates collisions

* in the common case where consecutively constructed ThreadLocals

* are used by the same threads, while remaining well-behaved in

* less common cases.

*/

private final int threadLocalHashCode = nextHashCode();

/**

* The next hash code to be given out. Updated atomically. Starts at

* zero.

*/

private static AtomicInteger nextHashCode =

new AtomicInteger();

/**

* The difference between successively generated hash codes - turns

* implicit sequential thread-local IDs into near-optimally spread

* multiplicative hash values for power-of-two-sized tables.

*/

private static final int HASH_INCREMENT = 0x61c88647;

其中HASH_INCREMENT 是一个常量,它表示连续的两个ThreadLocal实例的threadLocalHashCode值的增量。至于为什么用0x61c88647常量,我也没明白,只百度出一堆跟加密算法有关的内容。nextHashCode变量是AtomicInteger类型,AtomicInteger是一个提供原子操作的Integer类,简单而言就是不用Synchronized关键字也可以进行线程安全的加减操作。

当创建一个ThreadLocal对象时,会进行以下操作:

Java代码


public ThreadLocal() {

}

private final int threadLocalHashCode = nextHashCode();

我们再来看nextHashCode()这个方法,代码如下:

Java代码


/**

* Returns the next hash code.

*/

private static int nextHashCode() {

return nextHashCode.getAndAdd(HASH_INCREMENT);

}

它是返回一个hash code,同时对nextHashCode加上增量HASH_INCREMENT。也就是在初始化ThreadLocal的实例过程中,做的仅仅是将一个哈希值赋给实例的threadLocalHashCode,并生成下一个哈希值。而前面可以看到threadLocalHashCode是final的,在个类中,它用来区分不同的ThreadLocal对象。

在前面的图当中,我们可以看到在ThreadLocal里还有一个ThreadLocalMap的内部类,从类名来看应该是被设计来保存线程局部变量的Map,但是在ThreadLocal类或它的实例当中,并没有其他变量,那么通过ThreadLocal.set()放进去的值又是怎么保存的呢?

我们继续看ThreadLocal的set()和其他方法:

Java代码


/**

* Sets the current thread's copy of this thread-local variable

* to the specified value. Most subclasses will have no need to

* override this method, relying solely on the {@link #initialValue}

* method to set the values of thread-locals.

*

* @param value the value to be stored in the current thread's copy of

* this thread-local.

*/

public void set(T value) {

Thread t = Thread.currentThread();

ThreadLocalMap map = getMap(t);

if (map != null)

map.set(this, value);

else

createMap(t, value);

}

/**

* Get the map associated with a ThreadLocal. Overridden in

* InheritableThreadLocal.

*

* @param t the current thread

* @return the map

*/

ThreadLocalMap getMap(Thread t) {

return t.threadLocals;

}

/**

* Create the map associated with a ThreadLocal. Overridden in

* InheritableThreadLocal.

*

* @param t the current thread

* @param firstValue value for the initial entry of the map

* @param map the map to store.

*/

void createMap(Thread t, T firstValue) {

t.threadLocals = new ThreadLocalMap(this, firstValue);

}

从上面的代码可以看到,set()方法是调用ThreadLocalMap.set()方法保存到ThreadLocalMap实例当中,但是ThreadLocalMap实例却不是保存在Thread中,它通过getMap()方法去取,当取不到的时候就去创建,但是创建之后却是保存在Thread实例中。可以看到Thread.java的代码有如下声明:

Java代码


ThreadLocal.ThreadLocalMap threadLocals = null;

这种设计很巧妙,线程中有各自的map,而把ThreadLocal实例作为key,这样除了当线程销毁时相关的线程局部变量被销毁之外,还让性能提升了很多。

看完set()方法,再来看get()方法,它的代码如下:

Java代码


/**

* Returns the value in the current thread's copy of this

* thread-local variable. If the variable has no value for the

* current thread, it is first initialized to the value returned

* by an invocation of the {@link #initialValue} method.

*

* @return the current thread's value of this thread-local

*/

public T get() {

Thread t = Thread.currentThread();

ThreadLocalMap map = getMap(t);

if (map != null) {

ThreadLocalMap.Entry e = map.getEntry(this);

if (e != null)

return (T)e.value;

}

return setInitialValue();

}

/**

* Variant of set() to establish initialValue. Used instead

* of set() in case user has overridden the set() method.

*

* @return the initial value

*/

private T setInitialValue() {

T value = initialValue();

Thread t = Thread.currentThread();

ThreadLocalMap map = getMap(t);

if (map != null)

map.set(this, value);

else

createMap(t, value);

return value;

}

/**

* Returns the current thread's "initial value" for this

* thread-local variable. This method will be invoked the first

* time a thread accesses the variable with the {@link #get}

* method, unless the thread previously invoked the {@link #set}

* method, in which case the <tt>initialValue</tt> method will not

* be invoked for the thread. Normally, this method is invoked at

* most once per thread, but it may be invoked again in case of

* subsequent invocations of {@link #remove} followed by {@link #get}.

*

* <p>This implementation simply returns <tt>null</tt>; if the

* programmer desires thread-local variables to have an initial

* value other than <tt>null</tt>, <tt>ThreadLocal</tt> must be

* subclassed, and this method overridden. Typically, an

* anonymous inner class will be used.

*

* @return the initial value for this thread-local

*/

protected T initialValue() {

return null;

}

可以看到在get()方法中,是通过当前线程获取map,再从map当中取得对象。如果取不到(如map为null或取到的值为null),则调用setInitialValue()方法对其设置初始值。setInitialValue()方法会调用initialValue()方法得到一个初始值(默认为null),然后当Thread中的map不为空时,把初始值设置进去,否则为它创建一个ThreadLocalMap对象(但不会调用map的set方法保存这个初始值),最后返回这个初始值。

最后看remove()方法,它提供了移除此线程局部变量在当前进程的值。代码如下:

Java代码


/**

* Removes the current thread's value for this thread-local

* variable. If this thread-local variable is subsequently

* {@linkplain #get read} by the current thread, its value will be

* reinitialized by invoking its {@link #initialValue} method,

* unless its value is {@linkplain #set set} by the current thread

* in the interim. This may result in multiple invocations of the

* <tt>initialValue</tt> method in the current thread.

*

* @since 1.5

*/

public void remove() {

ThreadLocalMap m = getMap(Thread.currentThread());

if (m != null)

m.remove(this);

}

关于ThreadLocal就读到这里,有时间再写下它的内部类ThreadLocalMap。

上篇讲到了ThreadLocal类(http://maosidiaoxian.iteye.com/blog/1939142),这篇继续讲ThreadLocal中的ThreadLocalMap内部类。

下面先通过一张图,看一下这个内部类的结构:



可以看到在ThreadLocalMap类中,有一个常量,三个成员变量,代码如下:

Java代码


/**

* The initial capacity -- MUST be a power of two.

*/

private static final int INITIAL_CAPACITY = 16;

/**

* The table, resized as necessary.

* table.length MUST always be a power of two.

*/

private Entry[] table;

/**

* The number of entries in the table.

*/

private int size = 0;

/**

* The next size value at which to resize.

*/

private int threshold; // Default to 0

一个map对象是有一个容量的,INITIAL_CAPACITY 常量表示这个Map的默认的初始化容量。

数组table是一个实体表,保存设置进去的对象,长度必须为2的n次方的值。它是一个Entry类型,Entry类是ThreadLocalMap的一个内部类,继承自弱引用类型WeakReference,定义如下:

Java代码


static class Entry extends WeakReference<ThreadLocal> {

/** The value associated with this ThreadLocal. */

Object value;

Entry(ThreadLocal k, Object v) {

super(k);

value = v;

}

}

size变量表示实体表里实体的大小,初始值为0;threshold表示表的阈值,默认为0,当实体表的保存的实体大于这个阈值时,就需要对实体表table调整大小了。

现在来看一下ThreadLocalMap实例化的时候做了些什么。

ThreadLcoalMap共有两个构造方法,代码如下:

Java代码


ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {

table = new Entry[INITIAL_CAPACITY];

int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);

table[i] = new Entry(firstKey, firstValue);

size = 1;

setThreshold(INITIAL_CAPACITY);

}

/**

* Construct a new map including all Inheritable ThreadLocals

* from given parent map. Called only by createInheritedMap.

*

* @param parentMap the map associated with parent thread.

*/

private ThreadLocalMap(ThreadLocalMap parentMap) {

Entry[] parentTable = parentMap.table;

int len = parentTable.length;

setThreshold(len);

table = new Entry[len];

for (int j = 0; j < len; j++) {

Entry e = parentTable[j];

if (e != null) {

ThreadLocal key = e.get();

if (key != null) {

Object value = key.childValue(e.value);

Entry c = new Entry(key, value);

int h = key.threadLocalHashCode & (len - 1);

while (table[h] != null)

h = nextIndex(h, len);

table[h] = c;

size++;

}

}

}

}

这两个构造方法,一个为default,一个为private。在上一篇文章提到的,在ThreadLocal里的set方法获取不到这个map,需要创建的时候,它调用ThreadLcoal里的createMap()方法,而createMap()方法则调用default的这个构造方法。而另一个构造方法是在ThreadLocal里的createInheritedMap()中调用的,这个方法会由Thread里的构造方法来调用(在Thread类里的init()方法来调用,而init()方法只被Thread里的构造方法调用),我们在这里就不看它了。

分析一下ThreadLocalMap里带两个参数的这个构造方法,这两个参数为创建这个ThreadLocalMap对象之后的第一对键值对。在构造方法里,先对table进行初始化,默认大小为INITIAL_CAPACITY。然后计算第一对键值对要存放的位置,即在table中的下标,它的代码如下:

Java代码


int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);

我们知道INITIAL_CAPACITY 是一个2的n次幂的值,在这里它为16,即2的4次方,那么INITIAL_CAPACITY - 1 用16进度表示也就是0xF。firstKey.threadLocalHashCode与它作位的与运算,也就是取了threadLocal的哈希值的低4位(当table扩大时,取的值的范围也会跟着增加,但肯定是不大于table的长度的,这一点在ThreadLocalMap里的set方法可以体现出来,而table的长度必须为2的n次也是这种计算方法的前提条件),并将其作为在表中的下标(table的长度也是)。然后为这个键值对在数组相应的下标下创建一个Entry对象,最后设置阈值。在这里它的阈值为table长度的2/3。

在计算下标的时候,从前面的计算方法中我们可以想到一种情况,即两个ThreadLocal对象虽然他们哈希值不同,但是通过与运算计算出的下标正好相同。对这种情况,它会计算下一个坐标,而下一个坐标则通过当前坐标+1的方式取得,在nextIndex()方法可以看到,代码如下(顺便贴一下prevIndex()方法):

Java代码


/**

* Increment i modulo len.

*/

private static int nextIndex(int i, int len) {

return ((i + 1 < len) ? i + 1 : 0);

}

/**

* Decrement i modulo len.

*/

private static int prevIndex(int i, int len) {

return ((i - 1 >= 0) ? i - 1 : len - 1);

}

下面我们再来看一下set()方法和get()方法。

先看set()方法,相关代码如下:

Java代码


/**

* Set the value associated with key.

*

* @param key the thread local object

* @param value the value to be set

*/

private void set(ThreadLocal key, Object value) {

// We don't use a fast path as with get() because it is at

// least as common to use set() to create new entries as

// it is to replace existing ones, in which case, a fast

// path would fail more often than not.

Entry[] tab = table;

int len = tab.length;

int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];

e != null;

e = tab[i = nextIndex(i, len)]) {

ThreadLocal k = e.get();

if (k == key) {

e.value = value;

return;

}

if (k == null) {

replaceStaleEntry(key, value, i);

return;

}

}

tab[i] = new Entry(key, value);

int sz = ++size;

if (!cleanSomeSlots(i, sz) && sz >= threshold)

rehash();

}

/**

* Replace a stale entry encountered during a set operation

* with an entry for the specified key. The value passed in

* the value parameter is stored in the entry, whether or not

* an entry already exists for the specified key.

*

* As a side effect, this method expunges all stale entries in the

* "run" containing the stale entry. (A run is a sequence of entries

* between two null slots.)

*

* @param key the key

* @param value the value to be associated with key

* @param staleSlot index of the first stale entry encountered while

* searching for key.

*/

private void replaceStaleEntry(ThreadLocal key, Object value,

int staleSlot) {

Entry[] tab = table;

int len = tab.length;

Entry e;

// Back up to check for prior stale entry in current run.

// We clean out whole runs at a time to avoid continual

// incremental rehashing due to garbage collector freeing

// up refs in bunches (i.e., whenever the collector runs).

int slotToExpunge = staleSlot;

for (int i = prevIndex(staleSlot, len);

(e = tab[i]) != null;

i = prevIndex(i, len))

if (e.get() == null)

slotToExpunge = i;

// Find either the key or trailing null slot of run, whichever

// occurs first

for (int i = nextIndex(staleSlot, len);

(e = tab[i]) != null;

i = nextIndex(i, len)) {

ThreadLocal k = e.get();

// If we find key, then we need to swap it

// with the stale entry to maintain hash table order.

// The newly stale slot, or any other stale slot

// encountered above it, can then be sent to expungeStaleEntry

// to remove or rehash all of the other entries in run.

if (k == key) {

e.value = value;

tab[i] = tab[staleSlot];

tab[staleSlot] = e;

// Start expunge at preceding stale entry if it exists

if (slotToExpunge == staleSlot)

slotToExpunge = i;

cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);

return;

}

// If we didn't find stale entry on backward scan, the

// first stale entry seen while scanning for key is the

// first still present in the run.

if (k == null && slotToExpunge == staleSlot)

slotToExpunge = i;

}

// If key not found, put new entry in stale slot

tab[staleSlot].value = null;

tab[staleSlot] = new Entry(key, value);

// If there are any other stale entries in run, expunge them

if (slotToExpunge != staleSlot)

cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);

}

在set()方法中,首先通过key(threadLocal)中的hashc值与table的总长度取模,然后根据取模后的值作为下标,找到table当中下标为此值的entry,判断该entry是否存在。如果存在,判断entry里的key与value如果与当前要保存的key与value相同的话就不保存直接返回。如果entry里的key为null的话,就替换为当前要保存的key与value。否则就是碰撞的情况了,这时就调用nextIndex()方法计算下一个坐标。

如果计算后的坐标获取到的entry为null,就new一个Entry对象并保存进去,然后调用cleanSomeSlots()对table进行清理,如果没有任何Entry被清理,并且表的size超过了阈值,就会调用rehash()方法。

cleanSomeSlots()会调用expungeStaleEntry清理陈旧过时的Entry。rehash则会调用expungeStaleEntries()方法清理所有的陈旧的Entry,然后在size大于阈值的3/4时调用resize()方法进行扩容。代码如下:

Java代码


/**

* Expunge a stale entry by rehashing any possibly colliding entries

* lying between staleSlot and the next null slot. This also expunges

* any other stale entries encountered before the trailing null. See

* Knuth, Section 6.4

*

* @param staleSlot index of slot known to have null key

* @return the index of the next null slot after staleSlot

* (all between staleSlot and this slot will have been checked

* for expunging).

*/

private int expungeStaleEntry(int staleSlot) {

Entry[] tab = table;

int len = tab.length;

// expunge entry at staleSlot

tab[staleSlot].value = null;

tab[staleSlot] = null;

size--;

// Rehash until we encounter null

Entry e;

int i;

for (i = nextIndex(staleSlot, len);

(e = tab[i]) != null;

i = nextIndex(i, len)) {

ThreadLocal k = e.get();

if (k == null) {

e.value = null;

tab[i] = null;

size--;

} else {

int h = k.threadLocalHashCode & (len - 1);

if (h != i) {

tab[i] = null;

// Unlike Knuth 6.4 Algorithm R, we must scan until

// null because multiple entries could have been stale.

while (tab[h] != null)

h = nextIndex(h, len);

tab[h] = e;

}

}

}

return i;

}

/**

* Heuristically scan some cells looking for stale entries.

* This is invoked when either a new element is added, or

* another stale one has been expunged. It performs a

* logarithmic number of scans, as a balance between no

* scanning (fast but retains garbage) and a number of scans

* proportional to number of elements, that would find all

* garbage but would cause some insertions to take O(n) time.

*

* @param i a position known NOT to hold a stale entry. The

* scan starts at the element after i.

*

* @param n scan control: <tt>log2(n)</tt> cells are scanned,

* unless a stale entry is found, in which case

* <tt>log2(table.length)-1</tt> additional cells are scanned.

* When called from insertions, this parameter is the number

* of elements, but when from replaceStaleEntry, it is the

* table length. (Note: all this could be changed to be either

* more or less aggressive by weighting n instead of just

* using straight log n. But this version is simple, fast, and

* seems to work well.)

*

* @return true if any stale entries have been removed.

*/

private boolean cleanSomeSlots(int i, int n) {

boolean removed = false;

Entry[] tab = table;

int len = tab.length;

do {

i = nextIndex(i, len);

Entry e = tab[i];

if (e != null && e.get() == null) {

n = len;

removed = true;

i = expungeStaleEntry(i);

}

} while ( (n >>>= 1) != 0);

return removed;

}

/**

* Re-pack and/or re-size the table. First scan the entire

* table removing stale entries. If this doesn't sufficiently

* shrink the size of the table, double the table size.

*/

private void rehash() {

expungeStaleEntries();

// Use lower threshold for doubling to avoid hysteresis

if (size >= threshold - threshold / 4)

resize();

}

/**

* Double the capacity of the table.

*/

private void resize() {

Entry[] oldTab = table;

int oldLen = oldTab.length;

int newLen = oldLen * 2;

Entry[] newTab = new Entry[newLen];

int count = 0;

for (int j = 0; j < oldLen; ++j) {

Entry e = oldTab[j];

if (e != null) {

ThreadLocal k = e.get();

if (k == null) {

e.value = null; // Help the GC

} else {

int h = k.threadLocalHashCode & (newLen - 1);

while (newTab[h] != null)

h = nextIndex(h, newLen);

newTab[h] = e;

count++;

}

}

}

setThreshold(newLen);

size = count;

table = newTab;

}

/**

* Expunge all stale entries in the table.

*/

private void expungeStaleEntries() {

Entry[] tab = table;

int len = tab.length;

for (int j = 0; j < len; j++) {

Entry e = tab[j];

if (e != null && e.get() == null)

expungeStaleEntry(j);

}

}

再下来是get()方法。

Java代码


/**

* Get the entry associated with key. This method

* itself handles only the fast path: a direct hit of existing

* key. It otherwise relays to getEntryAfterMiss. This is

* designed to maximize performance for direct hits, in part

* by making this method readily inlinable.

*

* @param key the thread local object

* @return the entry associated with key, or null if no such

*/

private Entry getEntry(ThreadLocal key) {

int i = key.threadLocalHashCode & (table.length - 1);

Entry e = table[i];

if (e != null && e.get() == key)

return e;

else

return getEntryAfterMiss(key, i, e);

}

/**

* Version of getEntry method for use when key is not found in

* its direct hash slot.

*

* @param key the thread local object

* @param i the table index for key's hash code

* @param e the entry at table[i]

* @return the entry associated with key, or null if no such

*/

private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {

Entry[] tab = table;

int len = tab.length;

while (e != null) {

ThreadLocal k = e.get();

if (k == key)

return e;

if (k == null)

expungeStaleEntry(i);

else

i = nextIndex(i, len);

e = tab[i];

}

return null;

}

getEntry()方法通过计算出的下标从table中取出entry,如果取得的entry为null或它的key值不相等,就调用getEntryAfterMiss()方法,否则返回。

而在getEntryAfterMiss()是当通过key与table的长度取模得到的下标取得entry后,entry里没有该key时所调用的。这时,如果获取的entry为null,即没有保存,就直接返回null,否则进入循环不,计算下一个坐标并获取对应的entry,并且当key相等时(表明找到了之前保存的值)返回entry,或是entry为null时退出循环,并返回null。

最后是remove()方法,这个就比较简单了,计算到下标后,如果取得的entry的key与ThreadLocal相同,就调用Entry的clear方法把弱引用设置为null,然后调用expungeStaleEntry对table进行清理。代码如下:

Java代码


/**

* Remove the entry for key.

*/

private void remove(ThreadLocal key) {

Entry[] tab = table;

int len = tab.length;

int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];

e != null;

e = tab[i = nextIndex(i, len)]) {

if (e.get() == key) {

e.clear();

expungeStaleEntry(i);

return;

}

}

}

总结 :



本文转自:http://maosidiaoxian.iteye.com/blog/1940093
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: