1 /* 2 * Copyright (c) 2016, 2021, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.Closeable; 25 import java.io.IOException; 26 import java.io.UncheckedIOException; 27 import java.net.InetAddress; 28 import java.net.InetSocketAddress; 29 import java.net.StandardSocketOptions; 30 import java.net.URI; 31 import java.nio.ByteBuffer; 32 import java.nio.CharBuffer; 33 import java.nio.channels.ClosedByInterruptException; 34 import java.nio.channels.ServerSocketChannel; 35 import java.nio.channels.SocketChannel; 36 import java.nio.charset.CharacterCodingException; 37 import java.security.MessageDigest; 38 import java.security.NoSuchAlgorithmException; 39 import java.util.ArrayList; 40 import java.util.Arrays; 41 import java.util.Base64; 42 import java.util.HashMap; 43 import java.util.Iterator; 44 import java.util.LinkedList; 45 import java.util.List; 46 import java.util.Map; 47 import java.util.concurrent.CountDownLatch; 48 import java.util.concurrent.atomic.AtomicBoolean; 49 import java.util.function.BiFunction; 50 import java.util.regex.Pattern; 51 import java.util.stream.Collectors; 52 53 import static java.lang.String.format; 54 import static java.lang.System.err; 55 import static java.nio.charset.StandardCharsets.ISO_8859_1; 56 import static java.nio.charset.StandardCharsets.UTF_8; 57 import static java.util.Arrays.asList; 58 import static java.util.Objects.requireNonNull; 59 60 /** 61 * Dummy WebSocket Server. 62 * 63 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e. 64 * no proxying, cookies, etc.) Supports sequential connections, one at a time, 65 * i.e. in order for a client to connect to the server the previous client must 66 * disconnect first. 67 * 68 * Expected client request: 69 * 70 * GET /chat HTTP/1.1 71 * Host: server.example.com 72 * Upgrade: websocket 73 * Connection: Upgrade 74 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== 75 * Origin: http://example.com 76 * Sec-WebSocket-Protocol: chat, superchat 77 * Sec-WebSocket-Version: 13 78 * 79 * This server response: 80 * 81 * HTTP/1.1 101 Switching Protocols 82 * Upgrade: websocket 83 * Connection: Upgrade 84 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 85 * Sec-WebSocket-Protocol: chat 86 */ 87 public class DummyWebSocketServer implements Closeable { 88 89 private final AtomicBoolean started = new AtomicBoolean(); 90 private final Thread thread; 91 private volatile ServerSocketChannel ssc; 92 private volatile InetSocketAddress address; 93 private ByteBuffer read = ByteBuffer.allocate(16384); 94 private final CountDownLatch readReady = new CountDownLatch(1); 95 private volatile int receiveBufferSize; 96 97 private static class Credentials { 98 private final String name; 99 private final String password; Credentials(String name, String password)100 private Credentials(String name, String password) { 101 this.name = name; 102 this.password = password; 103 } name()104 public String name() { return name; } password()105 public String password() { return password; } 106 } 107 DummyWebSocketServer()108 public DummyWebSocketServer() { 109 this(defaultMapping(), null, null); 110 } 111 DummyWebSocketServer(String username, String password)112 public DummyWebSocketServer(String username, String password) { 113 this(defaultMapping(), username, password); 114 } 115 DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, String username, String password)116 public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, 117 String username, 118 String password) { 119 requireNonNull(mapping); 120 Credentials credentials = username != null ? 121 new Credentials(username, password) : null; 122 123 thread = new Thread(() -> { 124 try { 125 while (!Thread.currentThread().isInterrupted()) { 126 err.println("Accepting next connection at: " + ssc); 127 SocketChannel channel = ssc.accept(); 128 err.println("Accepted: " + channel); 129 try { 130 channel.setOption(StandardSocketOptions.TCP_NODELAY, true); 131 channel.configureBlocking(true); 132 while (true) { 133 StringBuilder request = new StringBuilder(); 134 if (!readRequest(channel, request)) { 135 throw new IOException("Bad request:[" + request + "]"); 136 } 137 List<String> strings = asList(request.toString().split("\r\n")); 138 List<String> response = mapping.apply(strings, credentials); 139 writeResponse(channel, response); 140 141 if (response.get(0).startsWith("HTTP/1.1 401")) { 142 err.println("Sent 401 Authentication response " + channel); 143 continue; 144 } else { 145 serve(channel); 146 break; 147 } 148 } 149 } catch (IOException e) { 150 err.println("Error in connection: " + channel + ", " + e); 151 } finally { 152 err.println("Closed: " + channel); 153 closeChannel(channel); 154 readReady.countDown(); 155 } 156 } 157 } catch (ClosedByInterruptException ignored) { 158 } catch (Exception e) { 159 e.printStackTrace(err); 160 } finally { 161 close(ssc); 162 err.println("Stopped at: " + getURI()); 163 } 164 }); 165 thread.setName("DummyWebSocketServer"); 166 thread.setDaemon(false); 167 } 168 read(SocketChannel ch)169 protected void read(SocketChannel ch) throws IOException { 170 // Read until the thread is interrupted or an error occurred 171 // or the input is shutdown 172 ByteBuffer b = ByteBuffer.allocate(65536); 173 while (ch.read(b) != -1) { 174 b.flip(); 175 if (read.remaining() < b.remaining()) { 176 int required = read.capacity() - read.remaining() + b.remaining(); 177 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1); 178 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required); 179 newBuffer.put(read.flip()); 180 read = newBuffer; 181 } 182 read.put(b); 183 b.clear(); 184 } 185 } 186 closeChannel(SocketChannel channel)187 protected void closeChannel(SocketChannel channel) { 188 close(channel); 189 } 190 write(SocketChannel ch)191 protected void write(SocketChannel ch) throws IOException { } 192 serve(SocketChannel channel)193 protected final void serve(SocketChannel channel) 194 throws InterruptedException 195 { 196 Thread reader = new Thread(() -> { 197 try { 198 read(channel); 199 } catch (IOException ignored) { } 200 }); 201 Thread writer = new Thread(() -> { 202 try { 203 write(channel); 204 } catch (IOException ignored) { } 205 }); 206 reader.start(); 207 writer.start(); 208 try { 209 reader.join(); 210 } finally { 211 reader.interrupt(); 212 try { 213 writer.join(); 214 } finally { 215 writer.interrupt(); 216 } 217 } 218 } 219 read()220 public ByteBuffer read() throws InterruptedException { 221 readReady.await(); 222 return read.duplicate().asReadOnlyBuffer().flip(); 223 } 224 setReceiveBufferSize(int bufsize)225 public void setReceiveBufferSize(int bufsize) { 226 assert ssc == null : "Must configure before calling open()"; 227 this.receiveBufferSize = bufsize; 228 } 229 open()230 public void open() throws IOException { 231 err.println("Starting"); 232 if (!started.compareAndSet(false, true)) { 233 throw new IllegalStateException("Already started"); 234 } 235 ssc = ServerSocketChannel.open(); 236 try { 237 ssc.configureBlocking(true); 238 var bufsize = receiveBufferSize; 239 if (bufsize > 0) { 240 err.printf("Configuring receive buffer size to %d%n", bufsize); 241 try { 242 ssc.setOption(StandardSocketOptions.SO_RCVBUF, bufsize); 243 } catch (IOException x) { 244 err.printf("Failed to configure receive buffer size to %d%n", bufsize); 245 } 246 } 247 ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); 248 address = (InetSocketAddress) ssc.getLocalAddress(); 249 thread.start(); 250 } catch (IOException e) { 251 close(ssc); 252 throw e; 253 } 254 err.println("Started at: " + getURI()); 255 } 256 257 @Override close()258 public void close() { 259 err.println("Stopping: " + getURI()); 260 thread.interrupt(); 261 close(ssc); 262 } 263 getURI()264 URI getURI() { 265 if (!started.get()) { 266 throw new IllegalStateException("Not yet started"); 267 } 268 return URI.create("ws://localhost:" + address.getPort()); 269 } 270 readRequest(SocketChannel channel, StringBuilder request)271 private boolean readRequest(SocketChannel channel, StringBuilder request) 272 throws IOException 273 { 274 ByteBuffer buffer = ByteBuffer.allocate(512); 275 while (channel.read(buffer) != -1) { 276 // read the complete HTTP request headers, there should be no body 277 CharBuffer decoded; 278 buffer.flip(); 279 try { 280 decoded = ISO_8859_1.newDecoder().decode(buffer); 281 } catch (CharacterCodingException e) { 282 throw new UncheckedIOException(e); 283 } 284 request.append(decoded); 285 if (Pattern.compile("\r\n\r\n").matcher(request).find()) 286 return true; 287 buffer.clear(); 288 } 289 return false; 290 } 291 writeResponse(SocketChannel channel, List<String> response)292 private void writeResponse(SocketChannel channel, List<String> response) 293 throws IOException 294 { 295 String s = response.stream().collect(Collectors.joining("\r\n")) 296 + "\r\n\r\n"; 297 ByteBuffer encoded; 298 try { 299 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); 300 } catch (CharacterCodingException e) { 301 throw new UncheckedIOException(e); 302 } 303 while (encoded.hasRemaining()) { 304 channel.write(encoded); 305 } 306 } 307 defaultMapping()308 private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() { 309 return (request, credentials) -> { 310 List<String> response = new LinkedList<>(); 311 Iterator<String> iterator = request.iterator(); 312 if (!iterator.hasNext()) { 313 throw new IllegalStateException("The request is empty"); 314 } 315 String statusLine = iterator.next(); 316 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) { 317 throw new IllegalStateException 318 ("Unexpected status line: " + request.get(0)); 319 } 320 response.add("HTTP/1.1 101 Switching Protocols"); 321 Map<String, List<String>> requestHeaders = new HashMap<>(); 322 while (iterator.hasNext()) { 323 String header = iterator.next(); 324 String[] split = header.split(": "); 325 if (split.length != 2) { 326 throw new IllegalStateException 327 ("Unexpected header: " + header 328 + ", split=" + Arrays.toString(split)); 329 } 330 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]); 331 332 } 333 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { 334 throw new IllegalStateException("Subprotocols are not expected"); 335 } 336 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { 337 throw new IllegalStateException("Extensions are not expected"); 338 } 339 expectHeader(requestHeaders, "Connection", "Upgrade"); 340 response.add("Connection: Upgrade"); 341 expectHeader(requestHeaders, "Upgrade", "websocket"); 342 response.add("Upgrade: websocket"); 343 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); 344 List<String> key = requestHeaders.get("Sec-WebSocket-Key"); 345 if (key == null || key.isEmpty()) { 346 throw new IllegalStateException("Sec-WebSocket-Key is missing"); 347 } 348 if (key.size() != 1) { 349 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key); 350 } 351 MessageDigest sha1 = null; 352 try { 353 sha1 = MessageDigest.getInstance("SHA-1"); 354 } catch (NoSuchAlgorithmException e) { 355 throw new InternalError(e); 356 } 357 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 358 sha1.update(x.getBytes(ISO_8859_1)); 359 String v = Base64.getEncoder().encodeToString(sha1.digest()); 360 response.add("Sec-WebSocket-Accept: " + v); 361 362 // check authorization credentials, if required by the server 363 if (credentials != null && !authorized(credentials, requestHeaders)) { 364 response.clear(); 365 response.add("HTTP/1.1 401 Unauthorized"); 366 response.add("Content-Length: 0"); 367 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\""); 368 } 369 370 return response; 371 }; 372 } 373 374 // Checks credentials in the request against those allowable by the server. authorized(Credentials credentials, Map<String,List<String>> requestHeaders)375 private static boolean authorized(Credentials credentials, 376 Map<String,List<String>> requestHeaders) { 377 List<String> authorization = requestHeaders.get("Authorization"); 378 if (authorization == null) 379 return false; 380 381 if (authorization.size() != 1) { 382 throw new IllegalStateException("Authorization unexpected count:" + authorization); 383 } 384 String header = authorization.get(0); 385 if (!header.startsWith("Basic ")) 386 throw new IllegalStateException("Authorization not Basic: " + header); 387 388 header = header.substring("Basic ".length()); 389 String values = new String(Base64.getDecoder().decode(header), UTF_8); 390 int sep = values.indexOf(':'); 391 if (sep < 1) { 392 throw new IllegalStateException("Authorization not colon: " + values); 393 } 394 String name = values.substring(0, sep); 395 String password = values.substring(sep + 1); 396 397 if (name.equals(credentials.name()) && password.equals(credentials.password())) 398 return true; 399 400 return false; 401 } 402 expectHeader(Map<String, List<String>> headers, String name, String value)403 protected static String expectHeader(Map<String, List<String>> headers, 404 String name, 405 String value) { 406 List<String> v = headers.get(name); 407 if (v == null) { 408 throw new IllegalStateException( 409 format("Expected '%s' header, not present in %s", 410 name, headers)); 411 } 412 if (!v.contains(value)) { 413 throw new IllegalStateException( 414 format("Expected '%s: %s', actual: '%s: %s'", 415 name, value, name, v) 416 ); 417 } 418 return value; 419 } 420 close(AutoCloseable... acs)421 private static void close(AutoCloseable... acs) { 422 for (AutoCloseable ac : acs) { 423 try { 424 ac.close(); 425 } catch (Exception ignored) { } 426 } 427 } 428 } 429