跳至主要內容

ThreadLocal原理

xw大约 9 分钟JAVAJAVA

概念

ThreadLocal即为线程变量,不同线程间相互隔离。

基本使用

示例:

public class ThreadLocalTest  
{  
    @Data  
    static class Foo  
    {  
        //实例总数  
        static final AtomicInteger AMOUNT = new AtomicInteger(0);  
        //对象的编号  
        int index = 0;  
        //对象的内容  
        int bar = 10;  
        //构造器  
        public Foo()  
        {  
            index = AMOUNT.incrementAndGet(); //总数增加,并且给对象编号  
        }  
        @Override  
        public String toString()  
        {  
            return index + "@Foo{bar=" + bar + '}';  
        }  
    }  
  
    //定义线程本地变量  
    private static final ThreadLocal<Foo> LOCAL_FOO =  new ThreadLocal<Foo>();  
  
    public static void main(String[] args) throws InterruptedException  
    {  
        //获取自定义的混合型线程池  
        ExecutorService threadPool =  
                Executors.newFixedThreadPool(10);  
  
        //提交5个任务,将会用到5个线程  
        for (int i = 0; i < 5; i++)  
        {  
            threadPool.execute(new Runnable()  
            {  
                @SneakyThrows  
                @Override                public void run()  
                {  
                    //获取“线程本地变量”中当前线程所绑定的值  
                    if (LOCAL_FOO.get() == null)  
                    {  
                        //设置“线程本地变量”中当前线程所绑定的值  
                        LOCAL_FOO.set(new Foo());  
                    }  
                    System.out.println("初始的本地值:" + LOCAL_FOO.get());  
                    //每个线程执行10次  
                    for (int i = 0; i < 10; i++)  
                    {  
                        Foo foo = LOCAL_FOO.get();  
                        foo.setBar(foo.getBar() + 1);  //值增1  
                        Thread.sleep(1);  
                    }  
                    System.out.println("累加10次之后的本地值:" + LOCAL_FOO.get());  
  
                    //删除“线程本地变量”中当前线程所绑定的值  
                    LOCAL_FOO.remove(); //这点对于线程池中的线程尤其重要  
                }  
            });  
        }  
    }  
}

运行结果如下,可见每一个线程的Foo对象都是独立的:

源码分析

在JDK 8版本中,每一个Thread线程内部都有一个Map(ThreadLocalMap),如果给一个Thread创建多个ThreadLocal实例,然后放置本地数据,那么当前线程的ThreadLocalMap中就会有多个“Key-Value对”,其中ThreadLocal实例为Key,本地数据为Value。

  • set

     public void set(T value) {
            //获取当前线程
            Thread t = Thread.currentThread();
            //获取线程的ThreadLocalMap
            ThreadLocalMap map = getMap(t);
            if (map != null)
                map.set(this, value);
            else
                //不存在初始化map
                createMap(t, value);
        }
    
  • get

     public T get() {
            Thread t = Thread.currentThread();
            ThreadLocalMap map = getMap(t);
            if (map != null) {
                // key是ThreadLocal对象
                ThreadLocalMap.Entry e = map.getEntry(this);
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    T result = (T)e.value;
                    return result;
                }
            }
            return setInitialValue();
        }
    

    ThreadLocal提供了一个入口去操作ThreadLocalMap,ThreadLocalMap源码如下:

     static class ThreadLocalMap {
    
            /**
             * The entries in this hash map extend WeakReference, using
             * its main ref field as the key (which is always a
             * ThreadLocal object).  Note that null keys (i.e. entry.get()
             * == null) mean that the key is no longer referenced, so the
             * entry can be expunged from table.  Such entries are referred to
             * as "stale entries" in the code that follows.
             */
            static class Entry extends WeakReference<ThreadLocal<?>> {
                /** The value associated with this ThreadLocal. */
                Object value;
    
                Entry(ThreadLocal<?> k, Object v) {
                    super(k);
                    value = v;
                }
            }
    
            /**
             * The initial capacity -- MUST be a power of two.
             */
            private static final int INITIAL_CAPACITY = 16;
    
            //条目数组,作为hash表使用
            private Entry[] table;
    
            // 条目数量
            private int size = 0;
    
            //扩容因子
            private int threshold; // Default to 0
    
            /**
             * Set the resize threshold to maintain at worst a 2/3 load factor.
             */
            private void setThreshold(int len) {
                threshold = len * 2 / 3;
            }
    
            /**
             * Increment i modulo len.
             */
            private static int nextIndex(int i, int len) {
                return ((i + 1 < len) ? i + 1 : 0);
            }
    
            /**
             * Decrement i modulo len.
             */
            private static int prevIndex(int i, int len) {
                return ((i - 1 >= 0) ? i - 1 : len - 1);
            }
    
            /**
             * Construct a new map initially containing (firstKey, firstValue).
             * ThreadLocalMaps are constructed lazily, so we only create
             * one when we have at least one entry to put in it.
             */
            ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
                table = new Entry[INITIAL_CAPACITY];
                int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
                table[i] = new Entry(firstKey, firstValue);
                size = 1;
                setThreshold(INITIAL_CAPACITY);
            }
    
            /**
             * Construct a new map including all Inheritable ThreadLocals
             * from given parent map. Called only by createInheritedMap.
             *
             * @param parentMap the map associated with parent thread.
             */
            private ThreadLocalMap(ThreadLocalMap parentMap) {
                Entry[] parentTable = parentMap.table;
                int len = parentTable.length;
                setThreshold(len);
                table = new Entry[len];
    
                for (int j = 0; j < len; j++) {
                    Entry e = parentTable[j];
                    if (e != null) {
                        @SuppressWarnings("unchecked")
                        ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                        if (key != null) {
                            Object value = key.childValue(e.value);
                            Entry c = new Entry(key, value);
                            int h = key.threadLocalHashCode & (len - 1);
                            while (table[h] != null)
                                h = nextIndex(h, len);
                            table[h] = c;
                            size++;
                        }
                    }
                }
            }
    
            /**
             * Get the entry associated with key.  This method
             * itself handles only the fast path: a direct hit of existing
             * key. It otherwise relays to getEntryAfterMiss.  This is
             * designed to maximize performance for direct hits, in part
             * by making this method readily inlinable.
             *
             * @param  key the thread local object
             * @return the entry associated with key, or null if no such
             */
            private Entry getEntry(ThreadLocal<?> key) {
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                if (e != null && e.get() == key)
                    return e;
                else
                    return getEntryAfterMiss(key, i, e);
            }
    
            /**
             * Version of getEntry method for use when key is not found in
             * its direct hash slot.
             *
             * @param  key the thread local object
             * @param  i the table index for key's hash code
             * @param  e the entry at table[i]
             * @return the entry associated with key, or null if no such
             */
            private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
                Entry[] tab = table;
                int len = tab.length;
    
                while (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == key)
                        return e;
                    if (k == null)
                        expungeStaleEntry(i);
                    else
                        i = nextIndex(i, len);
                    e = tab[i];
                }
                return null;
            }
    
            /**
             * Set the value associated with key.
             *
             * @param key the thread local object
             * @param value the value to be set
             */
            private void set(ThreadLocal<?> key, Object value) {
    
                // We don't use a fast path as with get() because it is at
                // least as common to use set() to create new entries as
                // it is to replace existing ones, in which case, a fast
                // path would fail more often than not.
    
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
    
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    ThreadLocal<?> k = e.get();
    
                    if (k == key) {
                        e.value = value;
                        return;
                    }
    
                    if (k == null) {
                        replaceStaleEntry(key, value, i);
                        return;
                    }
                }
    
                tab[i] = new Entry(key, value);
                int sz = ++size;
                if (!cleanSomeSlots(i, sz) && sz >= threshold)
                    rehash();
            }
    
            /**
             * Remove the entry for key.
             */
            private void remove(ThreadLocal<?> key) {
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    if (e.get() == key) {
                        e.clear();
                        expungeStaleEntry(i);
                        return;
                    }
                }
            }
    
            /**
             * Replace a stale entry encountered during a set operation
             * with an entry for the specified key.  The value passed in
             * the value parameter is stored in the entry, whether or not
             * an entry already exists for the specified key.
             *
             * As a side effect, this method expunges all stale entries in the
             * "run" containing the stale entry.  (A run is a sequence of entries
             * between two null slots.)
             *
             * @param  key the key
             * @param  value the value to be associated with key
             * @param  staleSlot index of the first stale entry encountered while
             *         searching for key.
             */
            private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                           int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
                Entry e;
    
                // Back up to check for prior stale entry in current run.
                // We clean out whole runs at a time to avoid continual
                // incremental rehashing due to garbage collector freeing
                // up refs in bunches (i.e., whenever the collector runs).
                int slotToExpunge = staleSlot;
                for (int i = prevIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = prevIndex(i, len))
                    if (e.get() == null)
                        slotToExpunge = i;
    
                // Find either the key or trailing null slot of run, whichever
                // occurs first
                for (int i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
    
                    // If we find key, then we need to swap it
                    // with the stale entry to maintain hash table order.
                    // The newly stale slot, or any other stale slot
                    // encountered above it, can then be sent to expungeStaleEntry
                    // to remove or rehash all of the other entries in run.
                    if (k == key) {
                        e.value = value;
    
                        tab[i] = tab[staleSlot];
                        tab[staleSlot] = e;
    
                        // Start expunge at preceding stale entry if it exists
                        if (slotToExpunge == staleSlot)
                            slotToExpunge = i;
                        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                        return;
                    }
    
                    // If we didn't find stale entry on backward scan, the
                    // first stale entry seen while scanning for key is the
                    // first still present in the run.
                    if (k == null && slotToExpunge == staleSlot)
                        slotToExpunge = i;
                }
    
                // If key not found, put new entry in stale slot
                tab[staleSlot].value = null;
                tab[staleSlot] = new Entry(key, value);
    
                // If there are any other stale entries in run, expunge them
                if (slotToExpunge != staleSlot)
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            }
    
            /**
             * Expunge a stale entry by rehashing any possibly colliding entries
             * lying between staleSlot and the next null slot.  This also expunges
             * any other stale entries encountered before the trailing null.  See
             * Knuth, Section 6.4
             *
             * @param staleSlot index of slot known to have null key
             * @return the index of the next null slot after staleSlot
             * (all between staleSlot and this slot will have been checked
             * for expunging).
             */
            private int expungeStaleEntry(int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
    
                // expunge entry at staleSlot
                tab[staleSlot].value = null;
                tab[staleSlot] = null;
                size--;
    
                // Rehash until we encounter null
                Entry e;
                int i;
                for (i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null;
                        tab[i] = null;
                        size--;
                    } else {
                        int h = k.threadLocalHashCode & (len - 1);
                        if (h != i) {
                            tab[i] = null;
    
                            // Unlike Knuth 6.4 Algorithm R, we must scan until
                            // null because multiple entries could have been stale.
                            while (tab[h] != null)
                                h = nextIndex(h, len);
                            tab[h] = e;
                        }
                    }
                }
                return i;
            }
    
            /**
             * Heuristically scan some cells looking for stale entries.
             * This is invoked when either a new element is added, or
             * another stale one has been expunged. It performs a
             * logarithmic number of scans, as a balance between no
             * scanning (fast but retains garbage) and a number of scans
             * proportional to number of elements, that would find all
             * garbage but would cause some insertions to take O(n) time.
             *
             * @param i a position known NOT to hold a stale entry. The
             * scan starts at the element after i.
             *
             * @param n scan control: {@code log2(n)} cells are scanned,
             * unless a stale entry is found, in which case
             * {@code log2(table.length)-1} additional cells are scanned.
             * When called from insertions, this parameter is the number
             * of elements, but when from replaceStaleEntry, it is the
             * table length. (Note: all this could be changed to be either
             * more or less aggressive by weighting n instead of just
             * using straight log n. But this version is simple, fast, and
             * seems to work well.)
             *
             * @return true if any stale entries have been removed.
             */
            private boolean cleanSomeSlots(int i, int n) {
                boolean removed = false;
                Entry[] tab = table;
                int len = tab.length;
                do {
                    i = nextIndex(i, len);
                    Entry e = tab[i];
                    if (e != null && e.get() == null) {
                        n = len;
                        removed = true;
                        i = expungeStaleEntry(i);
                    }
                } while ( (n >>>= 1) != 0);
                return removed;
            }
    
            /**
             * Re-pack and/or re-size the table. First scan the entire
             * table removing stale entries. If this doesn't sufficiently
             * shrink the size of the table, double the table size.
             */
            private void rehash() {
                expungeStaleEntries();
    
                // Use lower threshold for doubling to avoid hysteresis
                if (size >= threshold - threshold / 4)
                    resize();
            }
    
            /**
             * Double the capacity of the table.
             */
            private void resize() {
                Entry[] oldTab = table;
                int oldLen = oldTab.length;
                int newLen = oldLen * 2;
                Entry[] newTab = new Entry[newLen];
                int count = 0;
    
                for (int j = 0; j < oldLen; ++j) {
                    Entry e = oldTab[j];
                    if (e != null) {
                        ThreadLocal<?> k = e.get();
                        if (k == null) {
                            e.value = null; // Help the GC
                        } else {
                            int h = k.threadLocalHashCode & (newLen - 1);
                            while (newTab[h] != null)
                                h = nextIndex(h, newLen);
                            newTab[h] = e;
                            count++;
                        }
                    }
                }
    
                setThreshold(newLen);
                size = count;
                table = newTab;
            }
    
            /**
             * Expunge all stale entries in the table.
             */
            private void expungeStaleEntries() {
                Entry[] tab = table;
                int len = tab.length;
                for (int j = 0; j < len; j++) {
                    Entry e = tab[j];
                    if (e != null && e.get() == null)
                        expungeStaleEntry(j);
                }
            }
        }
    

弱引用机制

Entry用于保存ThreadLocalMap的Key-Value对条目,但是Entry使用了对ThreadLocal实例进行包装之后的弱引用(WeakReference)作为Key,其代码如下:

          static class Entry extends WeakReference<ThreadLocal<?>> {
              /** The value associated with this ThreadLocal. */
              Object value;
  
              Entry(ThreadLocal<?> k, Object v) {
                  super(k);
                  value = v;
              }
          }

弱引用:仅有弱引用(Weak Reference)指向的对象只能生存到下一次垃圾回收之前。换句话说,当GC发生时,无论内存够不够,仅有弱引用所指向的对象都会被回收。而拥有强引用指向的对象则不会被直接回收。

这里为什么不直接使用ThreadLocal实例作为key而是用弱引用进行包装呢?这里以一个简单的案例入手:

public void funcA()  
{  
    //创建一个线程本地变量  
    ThreadLocal local = new ThreadLocal<Integer>();  
    //设置值  
    local.set(100);  
    //获取值  
    local.get();  
    //函数末尾  
}

当在方法内声明local变量时,当该方法执行完后,应当将local变量进行回收,但如果ThreadLocalMap直接将local作为实例key的话将会保持一个强引用关系,这将导致local无法被回收,从而造成内存泄漏。

总结

  • 尽量使用private static final修饰ThreadLocal实例。使用private与final修饰符主要是为了尽可能不让他人修改、变更ThreadLocal变量的引用,使用static修饰符主要是为了确保ThreadLocal实例的全局唯一。
  • ThreadLocal使用完成之后务必调用remove()方法。这是简单、有效地避免ThreadLocal引发内存泄漏问题的方法。