1 /*
2 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import javax.net.ServerSocketFactory;
25 import javax.net.ssl.SSLServerSocketFactory;
26 import java.io.Closeable;
27 import java.io.IOException;
28 import java.io.InputStream;
29 import java.io.OutputStream;
30 import java.io.UncheckedIOException;
31 import java.net.InetAddress;
32 import java.net.InetSocketAddress;
33 import java.net.Socket;
34 import java.net.ServerSocket;
35 import java.net.SocketAddress;
36 import java.net.SocketOption;
37 import java.net.StandardSocketOptions;
38 import java.net.URI;
39 import java.nio.ByteBuffer;
40 import java.nio.CharBuffer;
41 import java.nio.channels.ClosedByInterruptException;
42 import java.nio.channels.ServerSocketChannel;
43 import java.nio.channels.SocketChannel;
44 import java.nio.charset.CharacterCodingException;
45 import java.security.MessageDigest;
main(String[] args)46 import java.security.NoSuchAlgorithmException;
try(DummyWebSocketServer server = new DummyWebSocketServer())47 import java.util.ArrayList;
48 import java.util.Arrays;
49 import java.util.Base64;
50 import java.util.HashMap;
51 import java.util.Iterator;
52 import java.util.LinkedList;
53 import java.util.List;
54 import java.util.Map;
55 import java.util.concurrent.CountDownLatch;
56 import java.util.concurrent.atomic.AtomicBoolean;
57 import java.util.function.BiFunction;
58 import java.util.regex.Pattern;
59 import java.util.stream.Collectors;
60
61 import static java.lang.String.format;
62 import static java.lang.System.err;
63 import static java.nio.charset.StandardCharsets.ISO_8859_1;
64 import static java.nio.charset.StandardCharsets.UTF_8;
65 import static java.util.Arrays.asList;
66 import static java.util.Objects.requireNonNull;
67
68 /**
69 * Dummy WebSocket Server, which supports TLS.
70 * By default the dummy webserver uses a plain TCP connection,
71 * but it can use a TLS connection if secure() is called before
72 * open(). It will use the default SSL context.
73 *
74 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e.
75 * no proxying, cookies, etc.) Supports sequential connections, one at a time,
76 * i.e. in order for a client to connect to the server the previous client must
77 * disconnect first.
78 *
79 * Expected client request:
80 *
81 * GET /chat HTTP/1.1
82 * Host: server.example.com
83 * Upgrade: websocket
84 * Connection: Upgrade
85 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
86 * Origin: http://example.com
87 * Sec-WebSocket-Protocol: chat, superchat
88 * Sec-WebSocket-Version: 13
89 *
90 * This server response:
91 *
92 * HTTP/1.1 101 Switching Protocols
93 * Upgrade: websocket
94 * Connection: Upgrade
95 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
96 * Sec-WebSocket-Protocol: chat
97 */
98 public class DummySecureWebSocketServer implements Closeable {
99
100 /**
101 * Emulates some of the SocketChannel APIs over a Socket
102 * instance.
103 */
104 public static class WebSocketChannel implements AutoCloseable {
105 interface Reader {
106 int read(ByteBuffer buf) throws IOException;
107 }
108 interface Writer {
109 void write(ByteBuffer buf) throws IOException;
110 }
111 interface Config {
112 <T> void setOption(SocketOption<T> option, T value) throws IOException;
113 }
114 interface Closer {
115 void close() throws IOException;
116 }
117 final AutoCloseable channel;
118 final Reader reader;
119 final Writer writer;
120 final Config config;
121 final Closer closer;
122 WebSocketChannel(AutoCloseable channel, Reader reader, Writer writer, Config config, Closer closer) {
123 this.channel = channel;
124 this.reader = reader;
125 this.writer = writer;
126 this.config = config;
127 this.closer = closer;
128 }
129 public void close() throws IOException {
130 closer.close();
131 }
132 public String toString() {
133 return channel.toString();
134 }
135 public int read(ByteBuffer bb) throws IOException {
136 return reader.read(bb);
137 }
138 public void write(ByteBuffer bb) throws IOException {
139 writer.write(bb);
140 }
141 public <T> void setOption(SocketOption<T> option, T value) throws IOException {
142 config.setOption(option, value);
143 }
144 public static WebSocketChannel of(Socket s) {
145 Reader reader = (bb) -> DummySecureWebSocketServer.read(s.getInputStream(), bb);
146 Writer writer = (bb) -> DummySecureWebSocketServer.write(s.getOutputStream(), bb);
147 return new WebSocketChannel(s, reader, writer, s::setOption, s::close);
148 }
149 }
150
151 /**
152 * Emulates some of the ServerSocketChannel APIs over a ServerSocket
153 * instance.
154 */
155 public static class WebServerSocketChannel implements AutoCloseable {
156 interface Accepter {
157 WebSocketChannel accept() throws IOException;
158 }
159 interface Binder {
160 void bind(SocketAddress address) throws IOException;
161 }
162 interface Config {
163 <T> void setOption(SocketOption<T> option, T value) throws IOException;
164 }
165 interface Closer {
166 void close() throws IOException;
167 }
168 interface Addressable {
169 SocketAddress getLocalAddress() throws IOException;
170 }
171 final AutoCloseable server;
172 final Accepter accepter;
173 final Binder binder;
174 final Addressable address;
175 final Config config;
176 final Closer closer;
177 WebServerSocketChannel(AutoCloseable server,
178 Accepter accepter,
179 Binder binder,
180 Addressable address,
181 Config config,
182 Closer closer) {
183 this.server = server;
184 this.accepter = accepter;
185 this.binder = binder;
186 this.address = address;
187 this.config = config;
188 this.closer = closer;
189 }
190 public void close() throws IOException {
191 closer.close();
192 }
193 public String toString() {
194 return server.toString();
195 }
196 public WebSocketChannel accept() throws IOException {
197 return accepter.accept();
198 }
199 public void bind(SocketAddress address) throws IOException {
200 binder.bind(address);
201 }
202 public <T> void setOption(SocketOption<T> option, T value) throws IOException {
203 config.setOption(option, value);
204 }
205 public SocketAddress getLocalAddress() throws IOException {
206 return address.getLocalAddress();
207 }
208 public static WebServerSocketChannel of(ServerSocket ss) {
209 Accepter a = () -> WebSocketChannel.of(ss.accept());
210 return new WebServerSocketChannel(ss, a, ss::bind, ss::getLocalSocketAddress, ss::setOption, ss::close);
211 }
212 }
213
214 // Creates a secure WebServerSocketChannel
215 static WebServerSocketChannel openWSS() throws IOException {
216 return WebServerSocketChannel.of(SSLServerSocketFactory.getDefault().createServerSocket());
217 }
218
219 // Creates a plain WebServerSocketChannel
220 static WebServerSocketChannel openWS() throws IOException {
221 return WebServerSocketChannel.of(ServerSocketFactory.getDefault().createServerSocket());
222 }
223
224
225 static int read(InputStream str, ByteBuffer buffer) throws IOException {
226 int len = Math.min(buffer.remaining(), 1024);
227 if (len <= 0) return 0;
228 byte[] bytes = new byte[len];
229 int res = 0;
230 if (buffer.hasRemaining()) {
231 len = Math.min(len, buffer.remaining());
232 int n = str.read(bytes, 0, len);
233 if (n > 0) {
234 buffer.put(bytes, 0, n);
235 res += n;
236 } else if (res > 0) {
237 return res;
238 } else {
239 return n;
240 }
241 }
242 return res;
243 }
244
245 static void write(OutputStream str, ByteBuffer buffer) throws IOException {
246 int len = Math.min(buffer.remaining(), 1024);
247 if (len <= 0) return;
248 byte[] bytes = new byte[len];
249 int res = 0;
250 int pos = buffer.position();
251 while (buffer.hasRemaining()) {
252 len = Math.min(len, buffer.remaining());
253 buffer.get(bytes, 0, len);
254 str.write(bytes, 0, len);
255 }
256 }
257
258 private final AtomicBoolean started = new AtomicBoolean();
259 private final Thread thread;
260 private volatile WebServerSocketChannel ss;
261 private volatile InetSocketAddress address;
262 private volatile boolean secure;
263 private ByteBuffer read = ByteBuffer.allocate(16384);
264 private final CountDownLatch readReady = new CountDownLatch(1);
265 private volatile boolean done;
266
267 private static class Credentials {
268 private final String name;
269 private final String password;
270 private Credentials(String name, String password) {
271 this.name = name;
272 this.password = password;
273 }
274 public String name() { return name; }
275 public String password() { return password; }
276 }
277
278 public DummySecureWebSocketServer() {
279 this(defaultMapping(), null, null);
280 }
281
282 public DummySecureWebSocketServer(String username, String password) {
283 this(defaultMapping(), username, password);
284 }
285
286 public DummySecureWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
287 String username,
288 String password) {
289 requireNonNull(mapping);
290 Credentials credentials = username != null ?
291 new Credentials(username, password) : null;
292
293 thread = new Thread(() -> {
294 try {
295 while (!Thread.currentThread().isInterrupted() && !done) {
296 err.println("Accepting next connection at: " + ss);
297 WebSocketChannel channel = ss.accept();
298 err.println("Accepted: " + channel);
299 try {
300 channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
301 while (!done) {
302 StringBuilder request = new StringBuilder();
303 if (!readRequest(channel, request)) {
304 throw new IOException("Bad request:[" + request + "]");
305 }
306 List<String> strings = asList(request.toString().split("\r\n"));
307 List<String> response = mapping.apply(strings, credentials);
308 writeResponse(channel, response);
309
310 if (response.get(0).startsWith("HTTP/1.1 401")) {
311 err.println("Sent 401 Authentication response " + channel);
312 continue;
313 } else {
314 serve(channel);
315 break;
316 }
317 }
318 } catch (IOException e) {
319 if (!done) {
320 err.println("Error in connection: " + channel + ", " + e);
321 }
322 } finally {
323 err.println("Closed: " + channel);
324 close(channel);
325 readReady.countDown();
326 }
327 }
328 } catch (ClosedByInterruptException ignored) {
329 } catch (Throwable e) {
330 if (!done) {
331 e.printStackTrace(err);
332 }
333 } finally {
334 done = true;
335 close(ss);
336 err.println("Stopped at: " + getURI());
337 }
338 });
339 thread.setName("DummySecureWebSocketServer");
340 thread.setDaemon(false);
341 }
342
343 // must be called before open()
344 public DummySecureWebSocketServer secure() {
345 secure = true;
346 return this;
347 }
348
349 protected void read(WebSocketChannel ch) throws IOException {
350 // Read until the thread is interrupted or an error occurred
351 // or the input is shutdown
352 ByteBuffer b = ByteBuffer.allocate(65536);
353 while (ch.read(b) != -1) {
354 b.flip();
355 if (read.remaining() < b.remaining()) {
356 int required = read.capacity() - read.remaining() + b.remaining();
357 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1);
358 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required);
359 newBuffer.put(read.flip());
360 read = newBuffer;
361 }
362 read.put(b);
363 b.clear();
364 }
365 }
366
367 protected void write(WebSocketChannel ch) throws IOException { }
368
369 protected final void serve(WebSocketChannel channel)
370 throws InterruptedException
371 {
372 Thread reader = new Thread(() -> {
373 try {
374 read(channel);
375 } catch (IOException ignored) { }
376 });
377 Thread writer = new Thread(() -> {
378 try {
379 write(channel);
380 } catch (IOException ignored) { }
381 });
382 reader.start();
383 writer.start();
384 try {
385 while (!done) {
386 try {
387 reader.join(500);
388 } catch (InterruptedException x) {
389 if (done) {
390 close(channel);
391 break;
392 }
393 }
394 }
395 } finally {
396 reader.interrupt();
397 try {
398 while (!done) {
399 try {
400 writer.join(500);
401 } catch (InterruptedException x) {
402 if (done) break;
403 }
404 }
405 } finally {
406 writer.interrupt();
407 }
408 }
409 }
410
411 public ByteBuffer read() throws InterruptedException {
412 readReady.await();
413 return read.duplicate().asReadOnlyBuffer().flip();
414 }
415
416 public void open() throws IOException {
417 err.println("Starting");
418 if (!started.compareAndSet(false, true)) {
419 throw new IllegalStateException("Already started");
420 }
421 ss = secure ? openWSS() : openWS();
422 try {
423 ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
424 address = (InetSocketAddress) ss.getLocalAddress();
425 thread.start();
426 } catch (IOException e) {
427 done = true;
428 close(ss);
429 throw e;
430 }
431 err.println("Started at: " + getURI());
432 }
433
434 @Override
435 public void close() {
436 err.println("Stopping: " + getURI());
437 done = true;
438 thread.interrupt();
439 close(ss);
440 }
441
442 URI getURI() {
443 if (!started.get()) {
444 throw new IllegalStateException("Not yet started");
445 }
446 if (!secure) {
447 return URI.create("ws://localhost:" + address.getPort());
448 } else {
449 return URI.create("wss://localhost:" + address.getPort());
450 }
451 }
452
453 private boolean readRequest(WebSocketChannel channel, StringBuilder request)
454 throws IOException
455 {
456 ByteBuffer buffer = ByteBuffer.allocate(512);
457 while (channel.read(buffer) != -1) {
458 // read the complete HTTP request headers, there should be no body
459 CharBuffer decoded;
460 buffer.flip();
461 try {
462 decoded = ISO_8859_1.newDecoder().decode(buffer);
463 } catch (CharacterCodingException e) {
464 throw new UncheckedIOException(e);
465 }
466 request.append(decoded);
467 if (Pattern.compile("\r\n\r\n").matcher(request).find())
468 return true;
469 buffer.clear();
470 }
471 return false;
472 }
473
474 private void writeResponse(WebSocketChannel channel, List<String> response)
475 throws IOException
476 {
477 String s = response.stream().collect(Collectors.joining("\r\n"))
478 + "\r\n\r\n";
479 ByteBuffer encoded;
480 try {
481 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
482 } catch (CharacterCodingException e) {
483 throw new UncheckedIOException(e);
484 }
485 while (encoded.hasRemaining()) {
486 channel.write(encoded);
487 }
488 }
489
490 private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
491 return (request, credentials) -> {
492 List<String> response = new LinkedList<>();
493 Iterator<String> iterator = request.iterator();
494 if (!iterator.hasNext()) {
495 throw new IllegalStateException("The request is empty");
496 }
497 String statusLine = iterator.next();
498 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
499 throw new IllegalStateException
500 ("Unexpected status line: " + request.get(0));
501 }
502 response.add("HTTP/1.1 101 Switching Protocols");
503 Map<String, List<String>> requestHeaders = new HashMap<>();
504 while (iterator.hasNext()) {
505 String header = iterator.next();
506 String[] split = header.split(": ");
507 if (split.length != 2) {
508 throw new IllegalStateException
509 ("Unexpected header: " + header
510 + ", split=" + Arrays.toString(split));
511 }
512 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
513
514 }
515 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
516 throw new IllegalStateException("Subprotocols are not expected");
517 }
518 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
519 throw new IllegalStateException("Extensions are not expected");
520 }
521 expectHeader(requestHeaders, "Connection", "Upgrade");
522 response.add("Connection: Upgrade");
523 expectHeader(requestHeaders, "Upgrade", "websocket");
524 response.add("Upgrade: websocket");
525 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
526 List<String> key = requestHeaders.get("Sec-WebSocket-Key");
527 if (key == null || key.isEmpty()) {
528 throw new IllegalStateException("Sec-WebSocket-Key is missing");
529 }
530 if (key.size() != 1) {
531 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
532 }
533 MessageDigest sha1 = null;
534 try {
535 sha1 = MessageDigest.getInstance("SHA-1");
536 } catch (NoSuchAlgorithmException e) {
537 throw new InternalError(e);
538 }
539 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
540 sha1.update(x.getBytes(ISO_8859_1));
541 String v = Base64.getEncoder().encodeToString(sha1.digest());
542 response.add("Sec-WebSocket-Accept: " + v);
543
544 // check authorization credentials, if required by the server
545 if (credentials != null && !authorized(credentials, requestHeaders)) {
546 response.clear();
547 response.add("HTTP/1.1 401 Unauthorized");
548 response.add("Content-Length: 0");
549 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
550 }
551
552 return response;
553 };
554 }
555
556 // Checks credentials in the request against those allowable by the server.
557 private static boolean authorized(Credentials credentials,
558 Map<String,List<String>> requestHeaders) {
559 List<String> authorization = requestHeaders.get("Authorization");
560 if (authorization == null)
561 return false;
562
563 if (authorization.size() != 1) {
564 throw new IllegalStateException("Authorization unexpected count:" + authorization);
565 }
566 String header = authorization.get(0);
567 if (!header.startsWith("Basic "))
568 throw new IllegalStateException("Authorization not Basic: " + header);
569
570 header = header.substring("Basic ".length());
571 String values = new String(Base64.getDecoder().decode(header), UTF_8);
572 int sep = values.indexOf(':');
573 if (sep < 1) {
574 throw new IllegalStateException("Authorization not colon: " + values);
575 }
576 String name = values.substring(0, sep);
577 String password = values.substring(sep + 1);
578
579 if (name.equals(credentials.name()) && password.equals(credentials.password()))
580 return true;
581
582 return false;
583 }
584
585 protected static String expectHeader(Map<String, List<String>> headers,
586 String name,
587 String value) {
588 List<String> v = headers.get(name);
589 if (v == null) {
590 throw new IllegalStateException(
591 format("Expected '%s' header, not present in %s",
592 name, headers));
593 }
594 if (!v.contains(value)) {
595 throw new IllegalStateException(
596 format("Expected '%s: %s', actual: '%s: %s'",
597 name, value, name, v)
598 );
599 }
600 return value;
601 }
602
603 private static void close(AutoCloseable... acs) {
604 for (AutoCloseable ac : acs) {
605 try {
606 ac.close();
607 } catch (Exception ignored) { }
608 }
609 }
610 }
611