1 /*
2  * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 6332435
27  * @summary Basic tests for CountDownLatch
28  * @library /test/lib
29  * @author Seetharam Avadhanam, Martin Buchholz
30  */
31 
32 import java.util.concurrent.CountDownLatch;
33 import java.util.concurrent.TimeUnit;
34 import java.util.concurrent.atomic.AtomicInteger;
35 import jdk.test.lib.Utils;
36 
37 public class Basic {
38     static final long LONG_DELAY_MS = Utils.adjustTimeout(10_000);
39 
40     interface AwaiterFactory {
getAwaiter()41         Awaiter getAwaiter();
42     }
43 
44     abstract static class Awaiter extends Thread {
45         private volatile Throwable result = null;
result(Throwable result)46         protected void result(Throwable result) { this.result = result; }
result()47         public Throwable result() { return this.result; }
48     }
49 
toTheStartingGate(CountDownLatch gate)50     private void toTheStartingGate(CountDownLatch gate) {
51         try {
52             gate.await();
53         }
54         catch (Throwable t) { fail(t); }
55     }
56 
awaiter(final CountDownLatch latch, final CountDownLatch gate)57     private Awaiter awaiter(final CountDownLatch latch,
58                             final CountDownLatch gate) {
59         return new Awaiter() { public void run() {
60             System.out.println("without millis: " + latch.toString());
61             gate.countDown();
62 
63             try {
64                 latch.await();
65                 System.out.println("without millis - ComingOut");
66             }
67             catch (Throwable result) { result(result); }}};
68     }
69 
70     private Awaiter awaiter(final CountDownLatch latch,
71                             final CountDownLatch gate,
72                             final long millis) {
73         return new Awaiter() { public void run() {
74             System.out.println("with millis: "+latch.toString());
75             gate.countDown();
76 
77             try {
78                 latch.await(millis, TimeUnit.MILLISECONDS);
79                 System.out.println("with millis - ComingOut");
80             }
81             catch (Throwable result) { result(result); }}};
82     }
83 
84     AwaiterFactory awaiterFactory(CountDownLatch latch, CountDownLatch gate) {
85         return () -> awaiter(latch, gate);
86     }
87 
88     AwaiterFactory timedAwaiterFactory(CountDownLatch latch, CountDownLatch gate) {
89         return () -> awaiter(latch, gate, LONG_DELAY_MS);
90     }
91 
92     //----------------------------------------------------------------
93     // Normal use
94     //----------------------------------------------------------------
95     public static void normalUse() throws Throwable {
96         int count = 0;
97         Basic test = new Basic();
98         CountDownLatch latch = new CountDownLatch(3);
99         Awaiter[] a = new Awaiter[12];
100 
101         for (int i = 0; i < 3; i++) {
102             CountDownLatch gate = new CountDownLatch(4);
103             AwaiterFactory factory1 = test.awaiterFactory(latch, gate);
104             AwaiterFactory factory2 = test.timedAwaiterFactory(latch, gate);
105             a[count] = factory1.getAwaiter(); a[count++].start();
106             a[count] = factory1.getAwaiter(); a[count++].start();
107             a[count] = factory2.getAwaiter(); a[count++].start();
108             a[count] = factory2.getAwaiter(); a[count++].start();
109             test.toTheStartingGate(gate);
110             System.out.println("Main Thread: " + latch.toString());
111             latch.countDown();
112             checkCount(latch, 2-i);
113         }
114         for (int i = 0; i < 12; i++)
115             a[i].join();
116 
117         for (int i = 0; i < 12; i++)
118             checkResult(a[i], null);
119     }
120 
121     //----------------------------------------------------------------
122     // One thread interrupted
123     //----------------------------------------------------------------
124     public static void threadInterrupted() throws Throwable {
125         int count = 0;
126         Basic test = new Basic();
127         CountDownLatch latch = new CountDownLatch(3);
128         Awaiter[] a = new Awaiter[12];
129 
130         for (int i = 0; i < 3; i++) {
131             CountDownLatch gate = new CountDownLatch(4);
132             AwaiterFactory factory1 = test.awaiterFactory(latch, gate);
133             AwaiterFactory factory2 = test.timedAwaiterFactory(latch, gate);
134             a[count] = factory1.getAwaiter(); a[count++].start();
135             a[count] = factory1.getAwaiter(); a[count++].start();
136             a[count] = factory2.getAwaiter(); a[count++].start();
137             a[count] = factory2.getAwaiter(); a[count++].start();
138             a[count-1].interrupt();
139             test.toTheStartingGate(gate);
140             System.out.println("Main Thread: " + latch.toString());
141             latch.countDown();
142             checkCount(latch, 2-i);
143         }
144         for (int i = 0; i < 12; i++)
145             a[i].join();
146 
147         for (int i = 0; i < 12; i++)
148             checkResult(a[i],
149                         (i % 4) == 3 ? InterruptedException.class : null);
150     }
151 
152     //----------------------------------------------------------------
153     // One thread timed out
154     //----------------------------------------------------------------
155     public static void timeOut() throws Throwable {
156         int count =0;
157         Basic test = new Basic();
158         CountDownLatch latch = new CountDownLatch(3);
159         Awaiter[] a = new Awaiter[12];
160 
161         long[] timeout = { 0L, 5L, 10L };
162 
163         for (int i = 0; i < 3; i++) {
164             CountDownLatch gate = new CountDownLatch(4);
165             AwaiterFactory factory1 = test.awaiterFactory(latch, gate);
166             AwaiterFactory factory2 = test.timedAwaiterFactory(latch, gate);
167             a[count] = test.awaiter(latch, gate, timeout[i]); a[count++].start();
168             a[count] = factory1.getAwaiter(); a[count++].start();
169             a[count] = factory2.getAwaiter(); a[count++].start();
170             a[count] = factory2.getAwaiter(); a[count++].start();
171             test.toTheStartingGate(gate);
172             System.out.println("Main Thread: " + latch.toString());
173             latch.countDown();
174             checkCount(latch, 2-i);
175         }
176         for (int i = 0; i < 12; i++)
177             a[i].join();
178 
179         for (int i = 0; i < 12; i++)
180             checkResult(a[i], null);
181     }
182 
183     public static void main(String[] args) throws Throwable {
184         normalUse();
185         threadInterrupted();
186         timeOut();
187         if (failures.get() > 0L)
188             throw new AssertionError(failures.get() + " failures");
189     }
190 
191     private static final AtomicInteger failures = new AtomicInteger(0);
192 
193     private static void fail(String msg) {
194         fail(new AssertionError(msg));
195     }
196 
197     private static void fail(Throwable t) {
198         t.printStackTrace();
199         failures.getAndIncrement();
200     }
201 
202     private static void checkCount(CountDownLatch b, int expected) {
203         if (b.getCount() != expected)
204             fail("Count = " + b.getCount() +
205                  ", expected = " + expected);
206     }
207 
208     private static void checkResult(Awaiter a, Class c) {
209         Throwable t = a.result();
210         if (! ((t == null && c == null) || c.isInstance(t))) {
211             System.out.println("Mismatch: " + t + ", " + c.getName());
212             failures.getAndIncrement();
213         }
214     }
215 }
216