1 /*
2  * Copyright (c) 2009, 2010, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /* @test
25  * @bug 4927640
26  * @summary Tests the SCTP protocol implementation
27  * @author chegar
28  */
29 
30 import java.net.InetSocketAddress;
31 import java.net.SocketAddress;
32 import java.io.IOException;
33 import java.util.Set;
34 import java.util.Iterator;
35 import java.util.concurrent.CountDownLatch;
36 import java.util.concurrent.TimeUnit;
37 import java.nio.ByteBuffer;
38 import com.sun.nio.sctp.Association;
39 import com.sun.nio.sctp.InvalidStreamException;
40 import com.sun.nio.sctp.MessageInfo;
41 import com.sun.nio.sctp.SctpMultiChannel;
42 import static java.lang.System.out;
43 import static java.lang.System.err;
44 
45 public class Send {
46     /* Latches used to synchronize between the client and server so that
47      * connections without any IO may not be closed without being accepted */
48     final CountDownLatch clientFinishedLatch = new CountDownLatch(1);
49     final CountDownLatch serverFinishedLatch = new CountDownLatch(1);
50 
test(String[] args)51     void test(String[] args) {
52         SocketAddress address = null;
53         Server server = null;
54 
55         if (!Util.isSCTPSupported()) {
56             out.println("SCTP protocol is not supported");
57             out.println("Test cannot be run");
58             return;
59         }
60 
61         if (args.length == 2) {
62             /* requested to connecct to a specific address */
63             try {
64                 int port = Integer.valueOf(args[1]);
65                 address = new InetSocketAddress(args[0], port);
66             } catch (NumberFormatException nfe) {
67                 err.println(nfe);
68             }
69         } else {
70             /* start server on local machine, default */
71             try {
72                 server = new Server();
73                 server.start();
74                 address = server.address();
75                 debug("Server started and listening on " + address);
76             } catch (IOException ioe) {
77                 ioe.printStackTrace();
78                 return;
79             }
80         }
81 
82         doTest(address);
83     }
84 
doTest(SocketAddress peerAddress)85     void doTest(SocketAddress peerAddress) {
86         SctpMultiChannel channel = null;
87         ByteBuffer buffer = ByteBuffer.allocate(Util.LARGE_BUFFER);
88         MessageInfo info = MessageInfo.createOutgoing(null, 0);
89 
90         try {
91             channel = SctpMultiChannel.open();
92 
93             /* TEST 1: send small message */
94             int streamNumber = 0;
95             debug("sending to " + peerAddress + " on stream number: " + streamNumber);
96             info = MessageInfo.createOutgoing(peerAddress, streamNumber);
97             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
98             buffer.flip();
99             int position = buffer.position();
100             int remaining = buffer.remaining();
101 
102             debug("sending small message: " + buffer);
103             int sent = channel.send(buffer, info);
104 
105             check(sent == remaining, "sent should be equal to remaining");
106             check(buffer.position() == (position + sent),
107                     "buffers position should have been incremented by sent");
108 
109             /* TEST 2: receive the echoed message */
110             buffer.clear();
111             info = channel.receive(buffer, null, null);
112             buffer.flip();
113             check(info != null, "info is null");
114             check(info.streamNumber() == streamNumber,
115                     "message not sent on the correct stream");
116             check(info.bytes() == Util.SMALL_MESSAGE.getBytes("ISO-8859-1").
117                   length, "bytes received not equal to message length");
118             check(info.bytes() == buffer.remaining(), "bytes != remaining");
119             check(Util.compare(buffer, Util.SMALL_MESSAGE),
120               "received message not the same as sent message");
121 
122 
123             /* TEST 3: send large message */
124             Set<Association> assocs = channel.associations();
125             check(assocs.size() == 1, "there should be only one association");
126             Iterator<Association> it = assocs.iterator();
127             check(it.hasNext());
128             Association assoc = it.next();
129             streamNumber = assoc.maxOutboundStreams() - 1;
130 
131             debug("sending on stream number: " + streamNumber);
132             info = MessageInfo.createOutgoing(assoc, null, streamNumber);
133             buffer.clear();
134             buffer.put(Util.LARGE_MESSAGE.getBytes("ISO-8859-1"));
135             buffer.flip();
136             position = buffer.position();
137             remaining = buffer.remaining();
138 
139             debug("sending large message: " + buffer);
140             sent = channel.send(buffer, info);
141 
142             check(sent == remaining, "sent should be equal to remaining");
143             check(buffer.position() == (position + sent),
144                     "buffers position should have been incremented by sent");
145 
146             /* TEST 4: receive the echoed message */
147             buffer.clear();
148             info = channel.receive(buffer, null, null);
149             buffer.flip();
150             check(info != null, "info is null");
151             check(info.streamNumber() == streamNumber,
152                     "message not sent on the correct stream");
153             check(info.bytes() == Util.LARGE_MESSAGE.getBytes("ISO-8859-1").
154                   length, "bytes received not equal to message length");
155             check(info.bytes() == buffer.remaining(), "bytes != remaining");
156             check(Util.compare(buffer, Util.LARGE_MESSAGE),
157               "received message not the same as sent message");
158 
159 
160             /* TEST 5: InvalidStreamExcepton */
161             streamNumber = assoc.maxOutboundStreams() + 1;
162             info = MessageInfo.createOutgoing(assoc, null, streamNumber);
163             buffer.clear();
164             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
165             buffer.flip();
166             position = buffer.position();
167             remaining = buffer.remaining();
168 
169             debug("sending on stream number: " + streamNumber);
170             debug("sending small message: " + buffer);
171             try {
172                 sent = channel.send(buffer, info);
173                 fail("should have thrown InvalidStreamExcepton");
174             } catch (InvalidStreamException ise){
175                 pass();
176             } catch (IOException ioe) {
177                 unexpected(ioe);
178             }
179             check(buffer.remaining() == remaining,
180                     "remaining should not be changed");
181             check(buffer.position() == position,
182                     "buffers position should not be changed");
183 
184 
185             /* TEST 5: getRemoteAddresses(Association) */
186             channel.getRemoteAddresses(assoc);
187 
188             /* TEST 6: Send from heap buffer to force implementation to
189              * substitute with a native buffer, then check that its position
190              * is updated correctly */
191             info = MessageInfo.createOutgoing(assoc, null, 0);
192             buffer.clear();
193             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
194             buffer.flip();
195             final int offset = 1;
196             buffer.position(offset);
197             remaining = buffer.remaining();
198 
199             try {
200                 sent = channel.send(buffer, info);
201 
202                 check(sent == remaining, "sent should be equal to remaining");
203                 check(buffer.position() == (offset + sent),
204                         "buffers position should have been incremented by sent");
205             } catch (IllegalArgumentException iae) {
206                 fail(iae + ", Error updating buffers position");
207             }
208 
209         } catch (IOException ioe) {
210             unexpected(ioe);
211         } finally {
212             clientFinishedLatch.countDown();
213             try { serverFinishedLatch.await(10L, TimeUnit.SECONDS); }
214             catch (InterruptedException ie) { unexpected(ie); }
215             if (channel != null) {
216                 try { channel.close(); }
217                 catch (IOException e) { unexpected (e);}
218             }
219         }
220     }
221 
222     class Server implements Runnable
223     {
224         final InetSocketAddress serverAddr;
225         private SctpMultiChannel serverChannel;
226 
Server()227         public Server() throws IOException {
228             serverChannel = SctpMultiChannel.open().bind(null);
229             java.util.Set<SocketAddress> addrs = serverChannel.getAllLocalAddresses();
230             if (addrs.isEmpty())
231                 debug("addrs should not be empty");
232 
233             serverAddr = (InetSocketAddress) addrs.iterator().next();
234         }
235 
start()236         public void start() {
237             (new Thread(this, "Server-"  + serverAddr.getPort())).start();
238         }
239 
address()240         public InetSocketAddress address() {
241             return serverAddr;
242         }
243 
244         @Override
run()245         public void run() {
246             ByteBuffer buffer = ByteBuffer.allocateDirect(Util.LARGE_BUFFER);
247             try {
248                 MessageInfo info;
249 
250                 /* receive a small message */
251                 do {
252                     info = serverChannel.receive(buffer, null, null);
253                     if (info == null) {
254                         fail("Server: unexpected null from receive");
255                             return;
256                     }
257                 } while (!info.isComplete());
258 
259                 buffer.flip();
260                 check(info != null, "info is null");
261                 check(info.streamNumber() == 0,
262                         "message not sent on the correct stream");
263                 check(info.bytes() == Util.SMALL_MESSAGE.getBytes("ISO-8859-1").
264                       length, "bytes received not equal to message length");
265                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
266                 check(Util.compare(buffer, Util.SMALL_MESSAGE),
267                   "received message not the same as sent message");
268 
269                 check(info != null, "info is null");
270                 Set<Association> assocs = serverChannel.associations();
271                 check(assocs.size() == 1, "there should be only one association");
272                 Iterator<Association> it = assocs.iterator();
273                 check(it.hasNext());
274                 Association assoc = it.next();
275 
276                 /* echo the message */
277                 debug("Server: echoing first message");
278                 buffer.flip();
279                 int bytes = serverChannel.send(buffer, info);
280                 debug("Server: sent " + bytes + "bytes");
281 
282                 /* receive a large message */
283                 buffer.clear();
284                 do {
285                     info = serverChannel.receive(buffer, null, null);
286                     if (info == null) {
287                         fail("Server: unexpected null from receive");
288                             return;
289                     }
290                 } while (!info.isComplete());
291 
292                 buffer.flip();
293 
294                 check(info.streamNumber() == assoc.maxInboundStreams() - 1,
295                         "message not sent on the correct stream");
296                 check(info.bytes() == Util.LARGE_MESSAGE.getBytes("ISO-8859-1").
297                       length, "bytes received not equal to message length");
298                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
299                 check(Util.compare(buffer, Util.LARGE_MESSAGE),
300                   "received message not the same as sent message");
301 
302                 /* echo the message */
303                 debug("Server: echoing second message");
304                 buffer.flip();
305                 bytes = serverChannel.send(buffer, info);
306                 debug("Server: sent " + bytes + "bytes");
307 
308                 /* TEST 6 */
309                 ByteBuffer expected = ByteBuffer.allocate(Util.SMALL_BUFFER);
310                 expected.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
311                 expected.flip();
312                 final int offset = 1;
313                 expected.position(offset);
314                 buffer.clear();
315                 do {
316                     info = serverChannel.receive(buffer, null, null);
317                     if (info == null) {
318                         fail("Server: unexpected null from receive");
319                         return;
320                     }
321                 } while (!info.isComplete());
322 
323                 buffer.flip();
324                 check(info != null, "info is null");
325                 check(info.streamNumber() == 0, "message not sent on the correct stream");
326                 check(info.bytes() == expected.remaining(),
327                     "bytes received not equal to message length");
328                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
329                 check(expected.equals(buffer),
330                     "received message not the same as sent message");
331 
332                 clientFinishedLatch.await(10L, TimeUnit.SECONDS);
333                 serverFinishedLatch.countDown();
334             } catch (IOException ioe) {
335                 unexpected(ioe);
336             } catch (InterruptedException ie) {
337                 unexpected(ie);
338             } finally {
339                 try { if (serverChannel != null) serverChannel.close(); }
340                 catch (IOException  unused) {}
341             }
342         }
343     }
344 
345         //--------------------- Infrastructure ---------------------------
346     boolean debug = true;
347     volatile int passed = 0, failed = 0;
pass()348     void pass() {passed++;}
fail()349     void fail() {failed++; Thread.dumpStack();}
fail(String msg)350     void fail(String msg) {System.err.println(msg); fail();}
unexpected(Throwable t)351     void unexpected(Throwable t) {failed++; t.printStackTrace();}
check(boolean cond)352     void check(boolean cond) {if (cond) pass(); else fail();}
check(boolean cond, String failMessage)353     void check(boolean cond, String failMessage) {if (cond) pass(); else fail(failMessage);}
debug(String message)354     void debug(String message) {if(debug) { System.out.println(message); }  }
main(String[] args)355     public static void main(String[] args) throws Throwable {
356         Class<?> k = new Object(){}.getClass().getEnclosingClass();
357         try {k.getMethod("instanceMain",String[].class)
358                 .invoke( k.newInstance(), (Object) args);}
359         catch (Throwable e) {throw e.getCause();}}
instanceMain(String[] args)360     public void instanceMain(String[] args) throws Throwable {
361         try {test(args);} catch (Throwable t) {unexpected(t);}
362         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
363         if (failed > 0) throw new AssertionError("Some tests failed");}
364 
365 }
366