1 /*
2  * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 6332435
27  * @summary Basic tests for CountDownLatch
28  * @author Seetharam Avadhanam, Martin Buchholz
29  */
30 
31 import java.util.concurrent.CountDownLatch;
32 import java.util.concurrent.TimeUnit;
33 import java.util.concurrent.atomic.AtomicInteger;
34 
35 interface AwaiterFactory {
getAwaiter()36     Awaiter getAwaiter();
37 }
38 
39 abstract class Awaiter extends Thread {
40     private volatile Throwable result = null;
result(Throwable result)41     protected void result(Throwable result) { this.result = result; }
result()42     public Throwable result() { return this.result; }
43 }
44 
45 public class Basic {
46 
toTheStartingGate(CountDownLatch gate)47     private void toTheStartingGate(CountDownLatch gate) {
48         try {
49             gate.await();
50         }
51         catch (Throwable t) { fail(t); }
52     }
53 
awaiter(final CountDownLatch latch, final CountDownLatch gate)54     private Awaiter awaiter(final CountDownLatch latch,
55                             final CountDownLatch gate) {
56         return new Awaiter() { public void run() {
57             System.out.println("without millis: " + latch.toString());
58             gate.countDown();
59 
60             try {
61                 latch.await();
62                 System.out.println("without millis - ComingOut");
63             }
64             catch (Throwable result) { result(result); }}};
65     }
66 
67     private Awaiter awaiter(final CountDownLatch latch,
68                             final CountDownLatch gate,
69                             final long millis) {
70         return new Awaiter() { public void run() {
71             System.out.println("with millis: "+latch.toString());
72             gate.countDown();
73 
74             try {
75                 latch.await(millis, TimeUnit.MILLISECONDS);
76                 System.out.println("with millis - ComingOut");
77             }
78             catch (Throwable result) { result(result); }}};
79     }
80 
81     private AwaiterFactory awaiterFactories(final CountDownLatch latch,
82                                             final CountDownLatch gate,
83                                             final int i) {
84         if (i == 1)
85             return new AwaiterFactory() { public Awaiter getAwaiter() {
86                 return awaiter(latch, gate); }};
87 
88         return new AwaiterFactory() { public Awaiter getAwaiter() {
89             return awaiter(latch, gate, 10000); }};
90     }
91 
92     //----------------------------------------------------------------
93     // Normal use
94     //----------------------------------------------------------------
95     public static void normalUse() throws Throwable {
96         int count = 0;
97         Basic test = new Basic();
98         CountDownLatch latch = new CountDownLatch(3);
99         Awaiter[] a = new Awaiter[12];
100 
101         for (int i = 0; i < 3; i++) {
102             CountDownLatch gate = new CountDownLatch(4);
103             AwaiterFactory factory1 = test.awaiterFactories(latch, gate, 1);
104             AwaiterFactory factory2 = test.awaiterFactories(latch, gate, 0);
105             a[count] = factory1.getAwaiter(); a[count++].start();
106             a[count] = factory1.getAwaiter(); a[count++].start();
107             a[count] = factory2.getAwaiter(); a[count++].start();
108             a[count] = factory2.getAwaiter(); a[count++].start();
109             test.toTheStartingGate(gate);
110             System.out.println("Main Thread: " + latch.toString());
111             latch.countDown();
112             checkCount(latch, 2-i);
113         }
114         for (int i = 0; i < 12; i++)
115             a[i].join();
116 
117         for (int i = 0; i < 12; i++)
118             checkResult(a[i], null);
119     }
120 
121     //----------------------------------------------------------------
122     // One thread interrupted
123     //----------------------------------------------------------------
124     public static void threadInterrupted() throws Throwable {
125         int count = 0;
126         Basic test = new Basic();
127         CountDownLatch latch = new CountDownLatch(3);
128         Awaiter[] a = new Awaiter[12];
129 
130         for (int i = 0; i < 3; i++) {
131             CountDownLatch gate = new CountDownLatch(4);
132             AwaiterFactory factory1 = test.awaiterFactories(latch, gate, 1);
133             AwaiterFactory factory2 = test.awaiterFactories(latch, gate, 0);
134             a[count] = factory1.getAwaiter(); a[count++].start();
135             a[count] = factory1.getAwaiter(); a[count++].start();
136             a[count] = factory2.getAwaiter(); a[count++].start();
137             a[count] = factory2.getAwaiter(); a[count++].start();
138             a[count-1].interrupt();
139             test.toTheStartingGate(gate);
140             System.out.println("Main Thread: " + latch.toString());
141             latch.countDown();
142             checkCount(latch, 2-i);
143         }
144         for (int i = 0; i < 12; i++)
145             a[i].join();
146 
147         for (int i = 0; i < 12; i++)
148             checkResult(a[i],
149                         (i % 4) == 3 ? InterruptedException.class : null);
150     }
151 
152     //----------------------------------------------------------------
153     // One thread timed out
154     //----------------------------------------------------------------
155     public static void timeOut() throws Throwable {
156         int count =0;
157         Basic test = new Basic();
158         CountDownLatch latch = new CountDownLatch(3);
159         Awaiter[] a = new Awaiter[12];
160 
161         long[] timeout = { 0L, 5L, 10L };
162 
163         for (int i = 0; i < 3; i++) {
164             CountDownLatch gate = new CountDownLatch(4);
165             AwaiterFactory factory1 = test.awaiterFactories(latch, gate, 1);
166             AwaiterFactory factory2 = test.awaiterFactories(latch, gate, 0);
167             a[count] = test.awaiter(latch, gate, timeout[i]); a[count++].start();
168             a[count] = factory1.getAwaiter(); a[count++].start();
169             a[count] = factory2.getAwaiter(); a[count++].start();
170             a[count] = factory2.getAwaiter(); a[count++].start();
171             test.toTheStartingGate(gate);
172             System.out.println("Main Thread: " + latch.toString());
173             latch.countDown();
174             checkCount(latch, 2-i);
175         }
176         for (int i = 0; i < 12; i++)
177             a[i].join();
178 
179         for (int i = 0; i < 12; i++)
180             checkResult(a[i], null);
181     }
182 
183     public static void main(String[] args) throws Throwable {
184         normalUse();
185         threadInterrupted();
186         timeOut();
187         if (failures.get() > 0L)
188             throw new AssertionError(failures.get() + " failures");
189     }
190 
191     private static final AtomicInteger failures = new AtomicInteger(0);
192 
193     private static void fail(String msg) {
194         fail(new AssertionError(msg));
195     }
196 
197     private static void fail(Throwable t) {
198         t.printStackTrace();
199         failures.getAndIncrement();
200     }
201 
202     private static void checkCount(CountDownLatch b, int expected) {
203         if (b.getCount() != expected)
204             fail("Count = " + b.getCount() +
205                  ", expected = " + expected);
206     }
207 
208     private static void checkResult(Awaiter a, Class c) {
209         Throwable t = a.result();
210         if (! ((t == null && c == null) || c.isInstance(t))) {
211             System.out.println("Mismatch: " + t + ", " + c.getName());
212             failures.getAndIncrement();
213         }
214     }
215 }
216