1 /*
2  * Copyright (c) 2009, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /* @test
25  * @bug 4927640
26  * @summary Tests the SCTP protocol implementation
27  * @author chegar
28  */
29 
30 import java.net.InetSocketAddress;
31 import java.net.SocketAddress;
32 import java.io.IOException;
33 import java.util.Set;
34 import java.util.Iterator;
35 import java.util.concurrent.CountDownLatch;
36 import java.util.concurrent.TimeUnit;
37 import java.nio.ByteBuffer;
38 import com.sun.nio.sctp.AbstractNotificationHandler;
39 import com.sun.nio.sctp.Association;
40 import com.sun.nio.sctp.AssociationChangeNotification;
41 import com.sun.nio.sctp.AssociationChangeNotification.AssocChangeEvent;
42 import com.sun.nio.sctp.HandlerResult;
43 import com.sun.nio.sctp.InvalidStreamException;
44 import com.sun.nio.sctp.MessageInfo;
45 import com.sun.nio.sctp.SctpChannel;
46 import com.sun.nio.sctp.SctpMultiChannel;
47 import com.sun.nio.sctp.ShutdownNotification;
48 import static java.lang.System.out;
49 import static java.lang.System.err;
50 
51 public class Branch {
52     /* Latches used to synchronize between the client and server so that
53      * connections without any IO may not be closed without being accepted */
54     final CountDownLatch clientFinishedLatch = new CountDownLatch(1);
55     final CountDownLatch serverFinishedLatch = new CountDownLatch(1);
56 
test(String[] args)57     void test(String[] args) {
58         SocketAddress address = null;
59         Server server = null;
60 
61         if (!Util.isSCTPSupported()) {
62             out.println("SCTP protocol is not supported");
63             out.println("Test cannot be run");
64             return;
65         }
66 
67         if (args.length == 2) {
68             /* requested to connecct to a specific address */
69             try {
70                 int port = Integer.valueOf(args[1]);
71                 address = new InetSocketAddress(args[0], port);
72             } catch (NumberFormatException nfe) {
73                 err.println(nfe);
74             }
75         } else {
76             /* start server on local machine, default */
77             try {
78                 server = new Server();
79                 server.start();
80                 address = server.address();
81                 debug("Server started and listening on " + address);
82             } catch (IOException ioe) {
83                 ioe.printStackTrace();
84                 return;
85             }
86         }
87 
88         doTest(address);
89     }
90 
doTest(SocketAddress peerAddress)91     void doTest(SocketAddress peerAddress) {
92         SctpMultiChannel channel = null;
93         ByteBuffer buffer = ByteBuffer.allocate(Util.LARGE_BUFFER);
94         MessageInfo info = MessageInfo.createOutgoing(null, 0);
95 
96         try {
97             channel = SctpMultiChannel.open();
98 
99             /* setup an association implicitly by sending a small message */
100             int streamNumber = 0;
101             debug("sending to " + peerAddress + " on stream number: " + streamNumber);
102             info = MessageInfo.createOutgoing(peerAddress, streamNumber);
103             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
104             buffer.flip();
105             int position = buffer.position();
106             int remaining = buffer.remaining();
107 
108             debug("sending small message: " + buffer);
109             int sent = channel.send(buffer, info);
110 
111             check(sent == remaining, "sent should be equal to remaining");
112             check(buffer.position() == (position + sent),
113                     "buffers position should have been incremented by sent");
114 
115             /* Receive the COMM_UP */
116             buffer.clear();
117             BranchNotificationHandler handler = new BranchNotificationHandler();
118             info = channel.receive(buffer, null, handler);
119             check(handler.receivedCommUp(), "COMM_UP no received");
120             Set<Association> associations = channel.associations();
121             check(!associations.isEmpty(),"There should be some associations");
122             Association bassoc = associations.iterator().next();
123 
124             /* TEST 1: branch */
125             SctpChannel bchannel = channel.branch(bassoc);
126 
127             check(!bchannel.getAllLocalAddresses().isEmpty(),
128                                    "branched channel should be bound");
129             check(!bchannel.getRemoteAddresses().isEmpty(),
130                                    "branched channel should be connected");
131             check(channel.associations().isEmpty(),
132                   "there should be no associations since the only one was branched off");
133 
134             buffer.clear();
135             info = bchannel.receive(buffer, null, null);
136             buffer.flip();
137             check(info != null, "info is null");
138             check(info.streamNumber() == streamNumber,
139                     "message not sent on the correct stream");
140             check(info.bytes() == Util.SMALL_MESSAGE.getBytes("ISO-8859-1").
141                   length, "bytes received not equal to message length");
142             check(info.bytes() == buffer.remaining(), "bytes != remaining");
143             check(Util.compare(buffer, Util.SMALL_MESSAGE),
144               "received message not the same as sent message");
145 
146         } catch (IOException ioe) {
147             unexpected(ioe);
148         } finally {
149             clientFinishedLatch.countDown();
150             try { serverFinishedLatch.await(10L, TimeUnit.SECONDS); }
151             catch (InterruptedException ie) { unexpected(ie); }
152             if (channel != null) {
153                 try { channel.close(); }
154                 catch (IOException e) { unexpected (e);}
155             }
156         }
157     }
158 
159     class Server implements Runnable
160     {
161         final InetSocketAddress serverAddr;
162         private SctpMultiChannel serverChannel;
163 
Server()164         public Server() throws IOException {
165             serverChannel = SctpMultiChannel.open().bind(null);
166             java.util.Set<SocketAddress> addrs = serverChannel.getAllLocalAddresses();
167             if (addrs.isEmpty())
168                 debug("addrs should not be empty");
169 
170             serverAddr = (InetSocketAddress) addrs.iterator().next();
171         }
172 
start()173         public void start() {
174             (new Thread(this, "Server-"  + serverAddr.getPort())).start();
175         }
176 
address()177         public InetSocketAddress address() {
178             return serverAddr;
179         }
180 
181         @Override
run()182         public void run() {
183             ByteBuffer buffer = ByteBuffer.allocateDirect(Util.LARGE_BUFFER);
184             try {
185                 MessageInfo info;
186 
187                 /* receive a small message */
188                 do {
189                     info = serverChannel.receive(buffer, null, null);
190                     if (info == null) {
191                         fail("Server: unexpected null from receive");
192                             return;
193                     }
194                 } while (!info.isComplete());
195 
196                 buffer.flip();
197                 check(info != null, "info is null");
198                 check(info.streamNumber() == 0,
199                         "message not sent on the correct stream");
200                 check(info.bytes() == Util.SMALL_MESSAGE.getBytes("ISO-8859-1").
201                       length, "bytes received not equal to message length");
202                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
203                 check(Util.compare(buffer, Util.SMALL_MESSAGE),
204                   "received message not the same as sent message");
205 
206                 check(info != null, "info is null");
207                 Set<Association> assocs = serverChannel.associations();
208                 check(assocs.size() == 1, "there should be only one association");
209 
210                 /* echo the message */
211                 debug("Server: echoing first message");
212                 buffer.flip();
213                 int bytes = serverChannel.send(buffer, info);
214                 debug("Server: sent " + bytes + "bytes");
215 
216                 clientFinishedLatch.await(10L, TimeUnit.SECONDS);
217                 serverFinishedLatch.countDown();
218             } catch (IOException ioe) {
219                 unexpected(ioe);
220             } catch (InterruptedException ie) {
221                 unexpected(ie);
222             } finally {
223                 try { if (serverChannel != null) serverChannel.close(); }
224                 catch (IOException  unused) {}
225             }
226         }
227     }
228 
229     class BranchNotificationHandler extends AbstractNotificationHandler<Object>
230     {
231         boolean receivedCommUp;  // false
232 
receivedCommUp()233         boolean receivedCommUp() {
234             return receivedCommUp;
235         }
236 
237         @Override
handleNotification( AssociationChangeNotification notification, Object attachment)238         public HandlerResult handleNotification(
239                 AssociationChangeNotification notification, Object attachment) {
240             AssocChangeEvent event = notification.event();
241             debug("AssociationChangeNotification");
242             debug("  Association: " + notification.association());
243             debug("  Event: " + event);
244 
245             if (event.equals(AssocChangeEvent.COMM_UP))
246                 receivedCommUp = true;
247 
248             return HandlerResult.RETURN;
249         }
250 
251         /* A ShutdownNotification handler is provided to ensure that no
252          * shutdown notification are being handled since we don't expect
253          * to receive them. This is not part of branch testing, it just
254          * fits here to test another bug. */
255         @Override
handleNotification( ShutdownNotification notification, Object attachment)256         public HandlerResult handleNotification(
257                 ShutdownNotification notification, Object attachment) {
258             debug("ShutdownNotification");
259             debug("  Association: " + notification.association());
260 
261             fail("Shutdown should not be received");
262 
263             return HandlerResult.RETURN;
264         }
265 
266     }
267 
268         //--------------------- Infrastructure ---------------------------
269     boolean debug = true;
270     volatile int passed = 0, failed = 0;
pass()271     void pass() {passed++;}
fail()272     void fail() {failed++; Thread.dumpStack();}
fail(String msg)273     void fail(String msg) {System.err.println(msg); fail();}
unexpected(Throwable t)274     void unexpected(Throwable t) {failed++; t.printStackTrace();}
check(boolean cond)275     void check(boolean cond) {if (cond) pass(); else fail();}
check(boolean cond, String failMessage)276     void check(boolean cond, String failMessage) {if (cond) pass(); else fail(failMessage);}
debug(String message)277     void debug(String message) {if(debug) { System.out.println(message); }  }
main(String[] args)278     public static void main(String[] args) throws Throwable {
279         Class<?> k = new Object(){}.getClass().getEnclosingClass();
280         try {k.getMethod("instanceMain",String[].class)
281                 .invoke( k.newInstance(), (Object) args);}
282         catch (Throwable e) {throw e.getCause();}}
instanceMain(String[] args)283     public void instanceMain(String[] args) throws Throwable {
284         try {test(args);} catch (Throwable t) {unexpected(t);}
285         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
286         if (failed > 0) throw new AssertionError("Some tests failed");}
287 
288 }
289