1 /*
2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
3  *
4  * This code is free software; you can redistribute it and/or modify it
5  * under the terms of the GNU General Public License version 2 only, as
6  * published by the Free Software Foundation.
7  *
8  * This code is distributed in the hope that it will be useful, but WITHOUT
9  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
11  * version 2 for more details (a copy is included in the LICENSE file that
12  * accompanied this code).
13  *
14  * You should have received a copy of the GNU General Public License version
15  * 2 along with this work; if not, write to the Free Software Foundation,
16  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
17  *
18  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
19  * or visit www.oracle.com if you need additional information or have any
20  * questions.
21  */
22 
23 /*
24  * This file is available under and governed by the GNU General Public
25  * License version 2 only, as published by the Free Software Foundation.
26  * However, the following notice accompanied the original version of this
27  * file:
28  *
29  * Written by Martin Buchholz with assistance from members of JCP
30  * JSR-166 Expert Group and released to the public domain, as
31  * explained at http://creativecommons.org/publicdomain/zero/1.0/
32  */
33 
34 /*
35  * @test
36  * @modules java.base/java.util.concurrent:open
37  * @run testng WhiteBox
38  * @summary White box tests of implementation details
39  */
40 
41 import static org.testng.Assert.*;
42 import org.testng.annotations.DataProvider;
43 import org.testng.annotations.Test;
44 
45 import java.io.ByteArrayInputStream;
46 import java.io.ByteArrayOutputStream;
47 import java.io.ObjectInputStream;
48 import java.io.ObjectOutputStream;
49 import java.lang.invoke.MethodHandles;
50 import java.lang.invoke.VarHandle;
51 import java.util.ArrayList;
52 import java.util.Iterator;
53 import java.util.List;
54 import java.util.concurrent.LinkedTransferQueue;
55 import java.util.concurrent.ThreadLocalRandom;
56 import java.util.concurrent.TimeUnit;
57 import static java.util.stream.Collectors.toList;
58 import java.util.function.Consumer;
59 
60 @Test
61 public class WhiteBox {
62     final ThreadLocalRandom rnd = ThreadLocalRandom.current();
63     final VarHandle HEAD, TAIL, ITEM, NEXT;
64     final int SWEEP_THRESHOLD;
65 
WhiteBox()66     public WhiteBox() throws ReflectiveOperationException {
67         Class<?> qClass = LinkedTransferQueue.class;
68         Class<?> nodeClass = Class.forName(qClass.getName() + "$Node");
69         MethodHandles.Lookup lookup
70             = MethodHandles.privateLookupIn(qClass, MethodHandles.lookup());
71         HEAD = lookup.findVarHandle(qClass, "head", nodeClass);
72         TAIL = lookup.findVarHandle(qClass, "tail", nodeClass);
73         NEXT = lookup.findVarHandle(nodeClass, "next", nodeClass);
74         ITEM = lookup.findVarHandle(nodeClass, "item", Object.class);
75         SWEEP_THRESHOLD = (int)
76             lookup.findStaticVarHandle(qClass, "SWEEP_THRESHOLD", int.class)
77             .get();
78     }
79 
head(LinkedTransferQueue q)80     Object head(LinkedTransferQueue q) { return HEAD.getVolatile(q); }
tail(LinkedTransferQueue q)81     Object tail(LinkedTransferQueue q) { return TAIL.getVolatile(q); }
item(Object node)82     Object item(Object node)           { return ITEM.getVolatile(node); }
next(Object node)83     Object next(Object node)           { return NEXT.getVolatile(node); }
84 
nodeCount(LinkedTransferQueue q)85     int nodeCount(LinkedTransferQueue q) {
86         int i = 0;
87         for (Object p = head(q); p != null; ) {
88             i++;
89             if (p == (p = next(p))) p = head(q);
90         }
91         return i;
92     }
93 
tailCount(LinkedTransferQueue q)94     int tailCount(LinkedTransferQueue q) {
95         int i = 0;
96         for (Object p = tail(q); p != null; ) {
97             i++;
98             if (p == (p = next(p))) p = head(q);
99         }
100         return i;
101     }
102 
findNode(LinkedTransferQueue q, Object e)103     Object findNode(LinkedTransferQueue q, Object e) {
104         for (Object p = head(q); p != null; ) {
105             if (item(p) != null && e.equals(item(p)))
106                 return p;
107             if (p == (p = next(p))) p = head(q);
108         }
109         throw new AssertionError("not found");
110     }
111 
iteratorAt(LinkedTransferQueue q, Object e)112     Iterator iteratorAt(LinkedTransferQueue q, Object e) {
113         for (Iterator it = q.iterator(); it.hasNext(); )
114             if (it.next().equals(e))
115                 return it;
116         throw new AssertionError("not found");
117     }
118 
assertIsSelfLinked(Object node)119     void assertIsSelfLinked(Object node) {
120         assertSame(next(node), node);
121         assertNull(item(node));
122     }
assertIsNotSelfLinked(Object node)123     void assertIsNotSelfLinked(Object node) {
124         assertNotSame(node, next(node));
125     }
126 
127     @Test
addRemove()128     public void addRemove() {
129         LinkedTransferQueue q = new LinkedTransferQueue();
130         assertInvariants(q);
131         assertNull(next(head(q)));
132         assertNull(item(head(q)));
133         q.add(1);
134         assertEquals(nodeCount(q), 2);
135         assertInvariants(q);
136         q.remove(1);
137         assertEquals(nodeCount(q), 1);
138         assertInvariants(q);
139     }
140 
141     /**
142      * Traversal actions that visit every node and do nothing, but
143      * have side effect of squeezing out dead nodes.
144      */
145     @DataProvider
traversalActions()146     public Object[][] traversalActions() {
147         return List.<Consumer<LinkedTransferQueue>>of(
148             q -> q.forEach(e -> {}),
149             q -> assertFalse(q.contains(new Object())),
150             q -> assertFalse(q.remove(new Object())),
151             q -> q.spliterator().forEachRemaining(e -> {}),
152             q -> q.stream().collect(toList()),
153             q -> assertFalse(q.removeIf(e -> false)),
154             q -> assertFalse(q.removeAll(List.of())))
155             .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
156     }
157 
158     @Test(dataProvider = "traversalActions")
traversalOperationsCollapseLeadingNodes( Consumer<LinkedTransferQueue> traversalAction)159     public void traversalOperationsCollapseLeadingNodes(
160         Consumer<LinkedTransferQueue> traversalAction) {
161         LinkedTransferQueue q = new LinkedTransferQueue();
162         Object oldHead;
163         int n = 1 + rnd.nextInt(5);
164         for (int i = 0; i < n; i++) q.add(i);
165         assertEquals(nodeCount(q), n + 1);
166         oldHead = head(q);
167         traversalAction.accept(q);
168         assertInvariants(q);
169         assertEquals(nodeCount(q), n);
170         assertIsSelfLinked(oldHead);
171     }
172 
173     @Test(dataProvider = "traversalActions")
traversalOperationsCollapseInteriorNodes( Consumer<LinkedTransferQueue> traversalAction)174     public void traversalOperationsCollapseInteriorNodes(
175         Consumer<LinkedTransferQueue> traversalAction) {
176         LinkedTransferQueue q = new LinkedTransferQueue();
177         int n = 6;
178         for (int i = 0; i < n; i++) q.add(i);
179 
180         // We must be quite devious to reliably create an interior dead node
181         Object p0 = findNode(q, 0);
182         Object p1 = findNode(q, 1);
183         Object p2 = findNode(q, 2);
184         Object p3 = findNode(q, 3);
185         Object p4 = findNode(q, 4);
186         Object p5 = findNode(q, 5);
187 
188         Iterator it1 = iteratorAt(q, 1);
189         Iterator it2 = iteratorAt(q, 2);
190 
191         it2.remove(); // causes it2's ancestor to advance to 1
192         assertSame(next(p1), p3);
193         assertSame(next(p2), p3);
194         assertNull(item(p2));
195         it1.remove(); // removes it2's ancestor
196         assertSame(next(p0), p3);
197         assertSame(next(p1), p3);
198         assertSame(next(p2), p3);
199         assertNull(item(p1));
200         assertEquals(it2.next(), 3);
201         it2.remove(); // it2's ancestor can't unlink
202 
203         assertSame(next(p0), p3); // p3 is now interior dead node
204         assertSame(next(p1), p4); // it2 uselessly CASed p1.next
205         assertSame(next(p2), p3);
206         assertSame(next(p3), p4);
207         assertInvariants(q);
208 
209         int c = nodeCount(q);
210         traversalAction.accept(q);
211         assertEquals(nodeCount(q), c - 1);
212 
213         assertSame(next(p0), p4);
214         assertSame(next(p1), p4);
215         assertSame(next(p2), p3);
216         assertSame(next(p3), p4);
217         assertInvariants(q);
218 
219         // trailing nodes are not unlinked
220         Iterator it5 = iteratorAt(q, 5); it5.remove();
221         traversalAction.accept(q);
222         assertSame(next(p4), p5);
223         assertNull(next(p5));
224         assertEquals(nodeCount(q), c - 1);
225     }
226 
227     /**
228      * Checks that traversal operations collapse a random pattern of
229      * dead nodes as could normally only occur with a race.
230      */
231     @Test(dataProvider = "traversalActions")
traversalOperationsCollapseRandomNodes( Consumer<LinkedTransferQueue> traversalAction)232     public void traversalOperationsCollapseRandomNodes(
233         Consumer<LinkedTransferQueue> traversalAction) {
234         LinkedTransferQueue q = new LinkedTransferQueue();
235         int n = rnd.nextInt(6);
236         for (int i = 0; i < n; i++) q.add(i);
237         ArrayList nulledOut = new ArrayList();
238         for (Object p = head(q); p != null; p = next(p))
239             if (rnd.nextBoolean()) {
240                 nulledOut.add(item(p));
241                 ITEM.setVolatile(p, null);
242             }
243         traversalAction.accept(q);
244         int c = nodeCount(q);
245         assertEquals(q.size(), c - (q.contains(n - 1) ? 0 : 1));
246         for (int i = 0; i < n; i++)
247             assertTrue(nulledOut.contains(i) ^ q.contains(i));
248     }
249 
250     /**
251      * Traversal actions that remove every element, and are also
252      * expected to squeeze out dead nodes.
253      */
254     @DataProvider
bulkRemovalActions()255     public Object[][] bulkRemovalActions() {
256         return List.<Consumer<LinkedTransferQueue>>of(
257             q -> q.clear(),
258             q -> assertTrue(q.removeIf(e -> true)),
259             q -> assertTrue(q.retainAll(List.of())))
260             .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
261     }
262 
263     @Test(dataProvider = "bulkRemovalActions")
bulkRemovalOperationsCollapseNodes( Consumer<LinkedTransferQueue> bulkRemovalAction)264     public void bulkRemovalOperationsCollapseNodes(
265         Consumer<LinkedTransferQueue> bulkRemovalAction) {
266         LinkedTransferQueue q = new LinkedTransferQueue();
267         int n = 1 + rnd.nextInt(5);
268         for (int i = 0; i < n; i++) q.add(i);
269         bulkRemovalAction.accept(q);
270         assertEquals(nodeCount(q), 1);
271         assertInvariants(q);
272     }
273 
274     /**
275      * Actions that remove the first element, and are expected to
276      * leave at most one slack dead node at head.
277      */
278     @DataProvider
pollActions()279     public Object[][] pollActions() {
280         return List.<Consumer<LinkedTransferQueue>>of(
281             q -> assertNotNull(q.poll()),
282             q -> { try { assertNotNull(q.poll(1L, TimeUnit.DAYS)); }
283                 catch (Throwable x) { throw new AssertionError(x); }},
284             q -> { try { assertNotNull(q.take()); }
285                 catch (Throwable x) { throw new AssertionError(x); }},
286             q -> assertNotNull(q.remove()))
287             .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
288     }
289 
290     @Test(dataProvider = "pollActions")
291     public void pollActionsOneNodeSlack(
292         Consumer<LinkedTransferQueue> pollAction) {
293         LinkedTransferQueue q = new LinkedTransferQueue();
294         int n = 1 + rnd.nextInt(5);
295         for (int i = 0; i < n; i++) q.add(i);
296         assertEquals(nodeCount(q), n + 1);
297         for (int i = 0; i < n; i++) {
298             int c = nodeCount(q);
299             boolean slack = item(head(q)) == null;
300             if (slack) assertNotNull(item(next(head(q))));
301             pollAction.accept(q);
302             assertEquals(nodeCount(q), q.isEmpty() ? 1 : c - (slack ? 2 : 0));
303         }
304         assertInvariants(q);
305     }
306 
307     /**
308      * Actions that append an element, and are expected to
309      * leave at most one slack node at tail.
310      */
311     @DataProvider
312     public Object[][] addActions() {
313         return List.<Consumer<LinkedTransferQueue>>of(
314             q -> q.add(1),
315             q -> q.offer(1))
316             .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
317     }
318 
319     @Test(dataProvider = "addActions")
320     public void addActionsOneNodeSlack(
321         Consumer<LinkedTransferQueue> addAction) {
322         LinkedTransferQueue q = new LinkedTransferQueue();
323         int n = 1 + rnd.nextInt(9);
324         for (int i = 0; i < n; i++) {
325             boolean slack = next(tail(q)) != null;
326             addAction.accept(q);
327             if (slack)
328                 assertNull(next(tail(q)));
329             else {
330                 assertNotNull(next(tail(q)));
331                 assertNull(next(next(tail(q))));
332             }
333             assertInvariants(q);
334         }
335     }
336 
337     byte[] serialBytes(Object o) {
338         try {
339             ByteArrayOutputStream bos = new ByteArrayOutputStream();
340             ObjectOutputStream oos = new ObjectOutputStream(bos);
341             oos.writeObject(o);
342             oos.flush();
343             oos.close();
344             return bos.toByteArray();
345         } catch (Exception fail) {
346             throw new AssertionError(fail);
347         }
348     }
349 
350     @SuppressWarnings("unchecked")
351     <T> T serialClone(T o) {
352         try {
353             ObjectInputStream ois = new ObjectInputStream
354                 (new ByteArrayInputStream(serialBytes(o)));
355             T clone = (T) ois.readObject();
356             assertNotSame(o, clone);
357             assertSame(o.getClass(), clone.getClass());
358             return clone;
359         } catch (Exception fail) {
360             throw new AssertionError(fail);
361         }
362     }
363 
364     @Test
365     public void testSerialization() {
366         LinkedTransferQueue q = serialClone(new LinkedTransferQueue());
367         assertInvariants(q);
368     }
369 
370     public void cancelledNodeSweeping() throws Throwable {
371         assertEquals(SWEEP_THRESHOLD & (SWEEP_THRESHOLD - 1), 0);
372         LinkedTransferQueue q = new LinkedTransferQueue();
373         Thread blockHead = null;
374         if (rnd.nextBoolean()) {
375             blockHead = new Thread(
376                 () -> { try { q.take(); } catch (InterruptedException ok) {}});
377             blockHead.start();
378             while (nodeCount(q) != 2) { Thread.yield(); }
379             assertTrue(q.hasWaitingConsumer());
380             assertEquals(q.getWaitingConsumerCount(), 1);
381         }
382         int initialNodeCount = nodeCount(q);
383 
384         // Some dead nodes do in fact accumulate ...
385         if (blockHead != null)
386             while (nodeCount(q) < initialNodeCount + SWEEP_THRESHOLD / 2)
387                 q.poll(1L, TimeUnit.MICROSECONDS);
388 
389         // ... but no more than SWEEP_THRESHOLD nodes accumulate
390         for (int i = rnd.nextInt(SWEEP_THRESHOLD * 10); i-->0; )
391             q.poll(1L, TimeUnit.MICROSECONDS);
392         assertTrue(nodeCount(q) <= initialNodeCount + SWEEP_THRESHOLD);
393 
394         if (blockHead != null) {
395             blockHead.interrupt();
396             blockHead.join();
397         }
398     }
399 
400     /** Checks conditions which should always be true. */
401     void assertInvariants(LinkedTransferQueue q) {
402         assertNotNull(head(q));
403         assertNotNull(tail(q));
404         // head is never self-linked (but tail may!)
405         for (Object h; next(h = head(q)) == h; )
406             assertNotSame(h, head(q)); // must be update race
407     }
408 }
409