1 /*
2  * Copyright (c) 2016, 2021, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 import java.io.Closeable;
25 import java.io.IOException;
26 import java.io.UncheckedIOException;
27 import java.net.InetAddress;
28 import java.net.InetSocketAddress;
29 import java.net.StandardSocketOptions;
30 import java.net.URI;
31 import java.nio.ByteBuffer;
32 import java.nio.CharBuffer;
33 import java.nio.channels.ClosedByInterruptException;
34 import java.nio.channels.ServerSocketChannel;
35 import java.nio.channels.SocketChannel;
36 import java.nio.charset.CharacterCodingException;
37 import java.security.MessageDigest;
38 import java.security.NoSuchAlgorithmException;
39 import java.util.ArrayList;
40 import java.util.Arrays;
41 import java.util.Base64;
42 import java.util.HashMap;
43 import java.util.Iterator;
44 import java.util.LinkedList;
45 import java.util.List;
46 import java.util.Map;
47 import java.util.concurrent.CountDownLatch;
48 import java.util.concurrent.atomic.AtomicBoolean;
49 import java.util.function.BiFunction;
50 import java.util.regex.Pattern;
51 import java.util.stream.Collectors;
52 
53 import static java.lang.String.format;
54 import static java.lang.System.err;
55 import static java.nio.charset.StandardCharsets.ISO_8859_1;
56 import static java.nio.charset.StandardCharsets.UTF_8;
57 import static java.util.Arrays.asList;
58 import static java.util.Objects.requireNonNull;
59 
60 /**
61  * Dummy WebSocket Server.
62  *
63  * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e.
64  * no proxying, cookies, etc.) Supports sequential connections, one at a time,
65  * i.e. in order for a client to connect to the server the previous client must
66  * disconnect first.
67  *
68  * Expected client request:
69  *
70  *     GET /chat HTTP/1.1
71  *     Host: server.example.com
72  *     Upgrade: websocket
73  *     Connection: Upgrade
74  *     Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
75  *     Origin: http://example.com
76  *     Sec-WebSocket-Protocol: chat, superchat
77  *     Sec-WebSocket-Version: 13
78  *
79  * This server response:
80  *
81  *     HTTP/1.1 101 Switching Protocols
82  *     Upgrade: websocket
83  *     Connection: Upgrade
84  *     Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
85  *     Sec-WebSocket-Protocol: chat
86  */
87 public class DummyWebSocketServer implements Closeable {
88 
89     private final AtomicBoolean started = new AtomicBoolean();
90     private final Thread thread;
91     private volatile ServerSocketChannel ssc;
92     private volatile InetSocketAddress address;
93     private ByteBuffer read = ByteBuffer.allocate(16384);
94     private final CountDownLatch readReady = new CountDownLatch(1);
95     private volatile int receiveBufferSize;
96 
97     private static class Credentials {
98         private final String name;
99         private final String password;
Credentials(String name, String password)100         private Credentials(String name, String password) {
101             this.name = name;
102             this.password = password;
103         }
name()104         public String name() { return name; }
password()105         public String password() { return password; }
106     }
107 
DummyWebSocketServer()108     public DummyWebSocketServer() {
109         this(defaultMapping(), null, null);
110     }
111 
DummyWebSocketServer(String username, String password)112     public DummyWebSocketServer(String username, String password) {
113         this(defaultMapping(), username, password);
114     }
115 
DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, String username, String password)116     public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
117                                 String username,
118                                 String password) {
119         requireNonNull(mapping);
120         Credentials credentials = username != null ?
121                 new Credentials(username, password) : null;
122 
123         thread = new Thread(() -> {
124             try {
125                 while (!Thread.currentThread().isInterrupted()) {
126                     err.println("Accepting next connection at: " + ssc);
127                     SocketChannel channel = ssc.accept();
128                     err.println("Accepted: " + channel);
129                     try {
130                         channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
131                         channel.configureBlocking(true);
132                         while (true) {
133                             StringBuilder request = new StringBuilder();
134                             if (!readRequest(channel, request)) {
135                                 throw new IOException("Bad request:[" + request + "]");
136                             }
137                             List<String> strings = asList(request.toString().split("\r\n"));
138                             List<String> response = mapping.apply(strings, credentials);
139                             writeResponse(channel, response);
140 
141                             if (response.get(0).startsWith("HTTP/1.1 401")) {
142                                 err.println("Sent 401 Authentication response " + channel);
143                                 continue;
144                             } else {
145                                 serve(channel);
146                                 break;
147                             }
148                         }
149                     } catch (IOException e) {
150                         err.println("Error in connection: " + channel + ", " + e);
151                     } finally {
152                         err.println("Closed: " + channel);
153                         closeChannel(channel);
154                         readReady.countDown();
155                     }
156                 }
157             } catch (ClosedByInterruptException ignored) {
158             } catch (Exception e) {
159                 e.printStackTrace(err);
160             } finally {
161                 close(ssc);
162                 err.println("Stopped at: " + getURI());
163             }
164         });
165         thread.setName("DummyWebSocketServer");
166         thread.setDaemon(false);
167     }
168 
read(SocketChannel ch)169     protected void read(SocketChannel ch) throws IOException {
170         // Read until the thread is interrupted or an error occurred
171         // or the input is shutdown
172         ByteBuffer b = ByteBuffer.allocate(65536);
173         while (ch.read(b) != -1) {
174             b.flip();
175             if (read.remaining() < b.remaining()) {
176                 int required = read.capacity() - read.remaining() + b.remaining();
177                 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1);
178                 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required);
179                 newBuffer.put(read.flip());
180                 read = newBuffer;
181             }
182             read.put(b);
183             b.clear();
184         }
185     }
186 
closeChannel(SocketChannel channel)187     protected void closeChannel(SocketChannel channel) {
188         close(channel);
189     }
190 
write(SocketChannel ch)191     protected void write(SocketChannel ch) throws IOException { }
192 
serve(SocketChannel channel)193     protected final void serve(SocketChannel channel)
194             throws InterruptedException
195     {
196         Thread reader = new Thread(() -> {
197             try {
198                 read(channel);
199             } catch (IOException ignored) { }
200         });
201         Thread writer = new Thread(() -> {
202             try {
203                 write(channel);
204             } catch (IOException ignored) { }
205         });
206         reader.start();
207         writer.start();
208         try {
209             reader.join();
210         } finally {
211             reader.interrupt();
212             try {
213                 writer.join();
214             } finally {
215                 writer.interrupt();
216             }
217         }
218     }
219 
read()220     public ByteBuffer read() throws InterruptedException {
221         readReady.await();
222         return read.duplicate().asReadOnlyBuffer().flip();
223     }
224 
setReceiveBufferSize(int bufsize)225     public void setReceiveBufferSize(int bufsize) {
226         assert ssc == null : "Must configure before calling open()";
227         this.receiveBufferSize = bufsize;
228     }
229 
open()230     public void open() throws IOException {
231         err.println("Starting");
232         if (!started.compareAndSet(false, true)) {
233             throw new IllegalStateException("Already started");
234         }
235         ssc = ServerSocketChannel.open();
236         try {
237             ssc.configureBlocking(true);
238             var bufsize = receiveBufferSize;
239             if (bufsize > 0) {
240                 err.printf("Configuring receive buffer size to %d%n", bufsize);
241                 try {
242                     ssc.setOption(StandardSocketOptions.SO_RCVBUF, bufsize);
243                 } catch (IOException x) {
244                     err.printf("Failed to configure receive buffer size to %d%n", bufsize);
245                 }
246             }
247             ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
248             address = (InetSocketAddress) ssc.getLocalAddress();
249             thread.start();
250         } catch (IOException e) {
251             close(ssc);
252             throw e;
253         }
254         err.println("Started at: " + getURI());
255     }
256 
257     @Override
close()258     public void close() {
259         err.println("Stopping: " + getURI());
260         thread.interrupt();
261         close(ssc);
262     }
263 
getURI()264     URI getURI() {
265         if (!started.get()) {
266             throw new IllegalStateException("Not yet started");
267         }
268         return URI.create("ws://localhost:" + address.getPort());
269     }
270 
readRequest(SocketChannel channel, StringBuilder request)271     private boolean readRequest(SocketChannel channel, StringBuilder request)
272             throws IOException
273     {
274         ByteBuffer buffer = ByteBuffer.allocate(512);
275         while (channel.read(buffer) != -1) {
276             // read the complete HTTP request headers, there should be no body
277             CharBuffer decoded;
278             buffer.flip();
279             try {
280                 decoded = ISO_8859_1.newDecoder().decode(buffer);
281             } catch (CharacterCodingException e) {
282                 throw new UncheckedIOException(e);
283             }
284             request.append(decoded);
285             if (Pattern.compile("\r\n\r\n").matcher(request).find())
286                 return true;
287             buffer.clear();
288         }
289         return false;
290     }
291 
writeResponse(SocketChannel channel, List<String> response)292     private void writeResponse(SocketChannel channel, List<String> response)
293             throws IOException
294     {
295         String s = response.stream().collect(Collectors.joining("\r\n"))
296                 + "\r\n\r\n";
297         ByteBuffer encoded;
298         try {
299             encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
300         } catch (CharacterCodingException e) {
301             throw new UncheckedIOException(e);
302         }
303         while (encoded.hasRemaining()) {
304             channel.write(encoded);
305         }
306     }
307 
defaultMapping()308     private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
309         return (request, credentials) -> {
310             List<String> response = new LinkedList<>();
311             Iterator<String> iterator = request.iterator();
312             if (!iterator.hasNext()) {
313                 throw new IllegalStateException("The request is empty");
314             }
315             String statusLine = iterator.next();
316             if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
317                 throw new IllegalStateException
318                         ("Unexpected status line: " + request.get(0));
319             }
320             response.add("HTTP/1.1 101 Switching Protocols");
321             Map<String, List<String>> requestHeaders = new HashMap<>();
322             while (iterator.hasNext()) {
323                 String header = iterator.next();
324                 String[] split = header.split(": ");
325                 if (split.length != 2) {
326                     throw new IllegalStateException
327                             ("Unexpected header: " + header
328                                      + ", split=" + Arrays.toString(split));
329                 }
330                 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
331 
332             }
333             if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
334                 throw new IllegalStateException("Subprotocols are not expected");
335             }
336             if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
337                 throw new IllegalStateException("Extensions are not expected");
338             }
339             expectHeader(requestHeaders, "Connection", "Upgrade");
340             response.add("Connection: Upgrade");
341             expectHeader(requestHeaders, "Upgrade", "websocket");
342             response.add("Upgrade: websocket");
343             expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
344             List<String> key = requestHeaders.get("Sec-WebSocket-Key");
345             if (key == null || key.isEmpty()) {
346                 throw new IllegalStateException("Sec-WebSocket-Key is missing");
347             }
348             if (key.size() != 1) {
349                 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
350             }
351             MessageDigest sha1 = null;
352             try {
353                 sha1 = MessageDigest.getInstance("SHA-1");
354             } catch (NoSuchAlgorithmException e) {
355                 throw new InternalError(e);
356             }
357             String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
358             sha1.update(x.getBytes(ISO_8859_1));
359             String v = Base64.getEncoder().encodeToString(sha1.digest());
360             response.add("Sec-WebSocket-Accept: " + v);
361 
362             // check authorization credentials, if required by the server
363             if (credentials != null && !authorized(credentials, requestHeaders)) {
364                 response.clear();
365                 response.add("HTTP/1.1 401 Unauthorized");
366                 response.add("Content-Length: 0");
367                 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
368             }
369 
370             return response;
371         };
372     }
373 
374     // Checks credentials in the request against those allowable by the server.
authorized(Credentials credentials, Map<String,List<String>> requestHeaders)375     private static boolean authorized(Credentials credentials,
376                                       Map<String,List<String>> requestHeaders) {
377         List<String> authorization = requestHeaders.get("Authorization");
378         if (authorization == null)
379             return false;
380 
381         if (authorization.size() != 1) {
382             throw new IllegalStateException("Authorization unexpected count:" + authorization);
383         }
384         String header = authorization.get(0);
385         if (!header.startsWith("Basic "))
386             throw new IllegalStateException("Authorization not Basic: " + header);
387 
388         header = header.substring("Basic ".length());
389         String values = new String(Base64.getDecoder().decode(header), UTF_8);
390         int sep = values.indexOf(':');
391         if (sep < 1) {
392             throw new IllegalStateException("Authorization not colon: " +  values);
393         }
394         String name = values.substring(0, sep);
395         String password = values.substring(sep + 1);
396 
397         if (name.equals(credentials.name()) && password.equals(credentials.password()))
398             return true;
399 
400         return false;
401     }
402 
expectHeader(Map<String, List<String>> headers, String name, String value)403     protected static String expectHeader(Map<String, List<String>> headers,
404                                          String name,
405                                          String value) {
406         List<String> v = headers.get(name);
407         if (v == null) {
408             throw new IllegalStateException(
409                     format("Expected '%s' header, not present in %s",
410                            name, headers));
411         }
412         if (!v.contains(value)) {
413             throw new IllegalStateException(
414                     format("Expected '%s: %s', actual: '%s: %s'",
415                            name, value, name, v)
416             );
417         }
418         return value;
419     }
420 
close(AutoCloseable... acs)421     private static void close(AutoCloseable... acs) {
422         for (AutoCloseable ac : acs) {
423             try {
424                 ac.close();
425             } catch (Exception ignored) { }
426         }
427     }
428 }
429