1 /*
2  * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 import java.util.concurrent.atomic.AtomicInteger;
25 import java.util.function.Supplier;
26 import org.testng.annotations.Test;
27 import static org.testng.Assert.*;
28 
29 /**
30  * @test
31  * @run testng ThreadLocalSupplierTest
32  * @summary tests ThreadLocal.withInitial(<Supplier>).
33  * Adapted from java.lang.Basic functional test of ThreadLocal
34  *
35  * @author Jim Gish <jim.gish@oracle.com>
36  */
37 @Test
38 public class ThreadLocalSupplierTest {
39 
40     static final class IntegerSupplier implements Supplier<Integer> {
41 
42         private final AtomicInteger supply = new AtomicInteger(0);
43 
44         @Override
get()45         public Integer get() {
46             return supply.getAndIncrement();
47         }
48 
numCalls()49         public int numCalls() {
50             return supply.intValue();
51         }
52     }
53 
54     static IntegerSupplier theSupply = new IntegerSupplier();
55 
56     static final class MyThreadLocal extends ThreadLocal<Integer> {
57 
58         private final ThreadLocal<Integer> delegate;
59 
60         public volatile boolean everCalled;
61 
MyThreadLocal(Supplier<Integer> supplier)62         public MyThreadLocal(Supplier<Integer> supplier) {
63             delegate = ThreadLocal.<Integer>withInitial(supplier);
64         }
65 
66         @Override
get()67         public Integer get() {
68             return delegate.get();
69         }
70 
71         @Override
initialValue()72         protected synchronized Integer initialValue() {
73             // this should never be called since we are using the factory instead
74             everCalled = true;
75             return null;
76         }
77     }
78 
79     /**
80      * Our one and only ThreadLocal from which we get thread ids using a
81      * supplier which simply increments a counter on each call of get().
82      */
83     static MyThreadLocal threadLocal = new MyThreadLocal(theSupply);
84 
testMultiThread()85     public void testMultiThread() throws Exception {
86         final int threadCount = 500;
87         final Thread th[] = new Thread[threadCount];
88         final boolean visited[] = new boolean[threadCount];
89 
90         // Create and start the threads
91         for (int i = 0; i < threadCount; i++) {
92             th[i] = new Thread() {
93                 @Override
94                 public void run() {
95                     final int threadId = threadLocal.get();
96                     assertFalse(visited[threadId], "visited[" + threadId + "]=" + visited[threadId]);
97                     visited[threadId] = true;
98                     // check the get() again
99                     final int secondCheckThreadId = threadLocal.get();
100                     assertEquals(secondCheckThreadId, threadId);
101                 }
102             };
103             th[i].start();
104         }
105 
106         // Wait for the threads to finish
107         for (int i = 0; i < threadCount; i++) {
108             th[i].join();
109         }
110 
111         assertEquals(theSupply.numCalls(), threadCount);
112         // make sure the provided initialValue() has not been called
113         assertFalse(threadLocal.everCalled);
114         // Check results
115         for (int i = 0; i < threadCount; i++) {
116             assertTrue(visited[i], "visited[" + i + "]=" + visited[i]);
117         }
118     }
119 
testSimple()120     public void testSimple() {
121         final String expected = "OneWithEverything";
122         final ThreadLocal<String> threadLocal = ThreadLocal.<String>withInitial(() -> expected);
123         assertEquals(expected, threadLocal.get());
124     }
125 }
126