1 /*
2  * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 import java.io.Closeable;
25 import java.io.IOException;
26 import java.io.UncheckedIOException;
27 import java.net.InetAddress;
28 import java.net.InetSocketAddress;
29 import java.net.StandardSocketOptions;
30 import java.net.URI;
31 import java.nio.ByteBuffer;
32 import java.nio.CharBuffer;
33 import java.nio.channels.ClosedByInterruptException;
34 import java.nio.channels.ServerSocketChannel;
35 import java.nio.channels.SocketChannel;
36 import java.nio.charset.CharacterCodingException;
37 import java.security.MessageDigest;
38 import java.security.NoSuchAlgorithmException;
39 import java.util.ArrayList;
40 import java.util.Arrays;
41 import java.util.Base64;
42 import java.util.HashMap;
43 import java.util.Iterator;
44 import java.util.LinkedList;
45 import java.util.List;
46 import java.util.Map;
47 import java.util.concurrent.CountDownLatch;
48 import java.util.concurrent.atomic.AtomicBoolean;
49 import java.util.function.BiFunction;
50 import java.util.regex.Pattern;
51 import java.util.stream.Collectors;
52 
53 import static java.lang.String.format;
54 import static java.lang.System.err;
55 import static java.nio.charset.StandardCharsets.ISO_8859_1;
56 import static java.nio.charset.StandardCharsets.UTF_8;
57 import static java.util.Arrays.asList;
58 import static java.util.Objects.requireNonNull;
59 
60 /**
61  * Dummy WebSocket Server.
62  *
63  * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e.
64  * no proxying, cookies, etc.) Supports sequential connections, one at a time,
65  * i.e. in order for a client to connect to the server the previous client must
66  * disconnect first.
67  *
68  * Expected client request:
69  *
70  *     GET /chat HTTP/1.1
71  *     Host: server.example.com
72  *     Upgrade: websocket
73  *     Connection: Upgrade
74  *     Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
75  *     Origin: http://example.com
76  *     Sec-WebSocket-Protocol: chat, superchat
77  *     Sec-WebSocket-Version: 13
78  *
79  * This server response:
80  *
81  *     HTTP/1.1 101 Switching Protocols
82  *     Upgrade: websocket
83  *     Connection: Upgrade
84  *     Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
85  *     Sec-WebSocket-Protocol: chat
86  */
87 public class DummyWebSocketServer implements Closeable {
88 
89     private final AtomicBoolean started = new AtomicBoolean();
90     private final Thread thread;
91     private volatile ServerSocketChannel ssc;
92     private volatile InetSocketAddress address;
93     private ByteBuffer read = ByteBuffer.allocate(16384);
94     private final CountDownLatch readReady = new CountDownLatch(1);
95 
96     private static class Credentials {
97         private final String name;
98         private final String password;
Credentials(String name, String password)99         private Credentials(String name, String password) {
100             this.name = name;
101             this.password = password;
102         }
name()103         public String name() { return name; }
password()104         public String password() { return password; }
105     }
106 
DummyWebSocketServer()107     public DummyWebSocketServer() {
108         this(defaultMapping(), null, null);
109     }
110 
DummyWebSocketServer(String username, String password)111     public DummyWebSocketServer(String username, String password) {
112         this(defaultMapping(), username, password);
113     }
114 
DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, String username, String password)115     public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
116                                 String username,
117                                 String password) {
118         requireNonNull(mapping);
119         Credentials credentials = username != null ?
120                 new Credentials(username, password) : null;
121 
122         thread = new Thread(() -> {
123             try {
124                 while (!Thread.currentThread().isInterrupted()) {
125                     err.println("Accepting next connection at: " + ssc);
126                     SocketChannel channel = ssc.accept();
127                     err.println("Accepted: " + channel);
128                     try {
129                         channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
130                         channel.configureBlocking(true);
131                         while (true) {
132                             StringBuilder request = new StringBuilder();
133                             if (!readRequest(channel, request)) {
134                                 throw new IOException("Bad request:[" + request + "]");
135                             }
136                             List<String> strings = asList(request.toString().split("\r\n"));
137                             List<String> response = mapping.apply(strings, credentials);
138                             writeResponse(channel, response);
139 
140                             if (response.get(0).startsWith("HTTP/1.1 401")) {
141                                 err.println("Sent 401 Authentication response " + channel);
142                                 continue;
143                             } else {
144                                 serve(channel);
145                                 break;
146                             }
147                         }
148                     } catch (IOException e) {
149                         err.println("Error in connection: " + channel + ", " + e);
150                     } finally {
151                         err.println("Closed: " + channel);
152                         close(channel);
153                         readReady.countDown();
154                     }
155                 }
156             } catch (ClosedByInterruptException ignored) {
157             } catch (Exception e) {
158                 e.printStackTrace(err);
159             } finally {
160                 close(ssc);
161                 err.println("Stopped at: " + getURI());
162             }
163         });
164         thread.setName("DummyWebSocketServer");
165         thread.setDaemon(false);
166     }
167 
read(SocketChannel ch)168     protected void read(SocketChannel ch) throws IOException {
169         // Read until the thread is interrupted or an error occurred
170         // or the input is shutdown
171         ByteBuffer b = ByteBuffer.allocate(65536);
172         while (ch.read(b) != -1) {
173             b.flip();
174             if (read.remaining() < b.remaining()) {
175                 int required = read.capacity() - read.remaining() + b.remaining();
176                 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1);
177                 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required);
178                 newBuffer.put(read.flip());
179                 read = newBuffer;
180             }
181             read.put(b);
182             b.clear();
183         }
184     }
185 
write(SocketChannel ch)186     protected void write(SocketChannel ch) throws IOException { }
187 
serve(SocketChannel channel)188     protected final void serve(SocketChannel channel)
189             throws InterruptedException
190     {
191         Thread reader = new Thread(() -> {
192             try {
193                 read(channel);
194             } catch (IOException ignored) { }
195         });
196         Thread writer = new Thread(() -> {
197             try {
198                 write(channel);
199             } catch (IOException ignored) { }
200         });
201         reader.start();
202         writer.start();
203         try {
204             reader.join();
205         } finally {
206             reader.interrupt();
207             try {
208                 writer.join();
209             } finally {
210                 writer.interrupt();
211             }
212         }
213     }
214 
read()215     public ByteBuffer read() throws InterruptedException {
216         readReady.await();
217         return read.duplicate().asReadOnlyBuffer().flip();
218     }
219 
open()220     public void open() throws IOException {
221         err.println("Starting");
222         if (!started.compareAndSet(false, true)) {
223             throw new IllegalStateException("Already started");
224         }
225         ssc = ServerSocketChannel.open();
226         try {
227             ssc.configureBlocking(true);
228             ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
229             address = (InetSocketAddress) ssc.getLocalAddress();
230             thread.start();
231         } catch (IOException e) {
232             close(ssc);
233             throw e;
234         }
235         err.println("Started at: " + getURI());
236     }
237 
238     @Override
close()239     public void close() {
240         err.println("Stopping: " + getURI());
241         thread.interrupt();
242         close(ssc);
243     }
244 
getURI()245     URI getURI() {
246         if (!started.get()) {
247             throw new IllegalStateException("Not yet started");
248         }
249         return URI.create("ws://localhost:" + address.getPort());
250     }
251 
readRequest(SocketChannel channel, StringBuilder request)252     private boolean readRequest(SocketChannel channel, StringBuilder request)
253             throws IOException
254     {
255         ByteBuffer buffer = ByteBuffer.allocate(512);
256         while (channel.read(buffer) != -1) {
257             // read the complete HTTP request headers, there should be no body
258             CharBuffer decoded;
259             buffer.flip();
260             try {
261                 decoded = ISO_8859_1.newDecoder().decode(buffer);
262             } catch (CharacterCodingException e) {
263                 throw new UncheckedIOException(e);
264             }
265             request.append(decoded);
266             if (Pattern.compile("\r\n\r\n").matcher(request).find())
267                 return true;
268             buffer.clear();
269         }
270         return false;
271     }
272 
writeResponse(SocketChannel channel, List<String> response)273     private void writeResponse(SocketChannel channel, List<String> response)
274             throws IOException
275     {
276         String s = response.stream().collect(Collectors.joining("\r\n"))
277                 + "\r\n\r\n";
278         ByteBuffer encoded;
279         try {
280             encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
281         } catch (CharacterCodingException e) {
282             throw new UncheckedIOException(e);
283         }
284         while (encoded.hasRemaining()) {
285             channel.write(encoded);
286         }
287     }
288 
defaultMapping()289     private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
290         return (request, credentials) -> {
291             List<String> response = new LinkedList<>();
292             Iterator<String> iterator = request.iterator();
293             if (!iterator.hasNext()) {
294                 throw new IllegalStateException("The request is empty");
295             }
296             String statusLine = iterator.next();
297             if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
298                 throw new IllegalStateException
299                         ("Unexpected status line: " + request.get(0));
300             }
301             response.add("HTTP/1.1 101 Switching Protocols");
302             Map<String, List<String>> requestHeaders = new HashMap<>();
303             while (iterator.hasNext()) {
304                 String header = iterator.next();
305                 String[] split = header.split(": ");
306                 if (split.length != 2) {
307                     throw new IllegalStateException
308                             ("Unexpected header: " + header
309                                      + ", split=" + Arrays.toString(split));
310                 }
311                 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
312 
313             }
314             if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
315                 throw new IllegalStateException("Subprotocols are not expected");
316             }
317             if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
318                 throw new IllegalStateException("Extensions are not expected");
319             }
320             expectHeader(requestHeaders, "Connection", "Upgrade");
321             response.add("Connection: Upgrade");
322             expectHeader(requestHeaders, "Upgrade", "websocket");
323             response.add("Upgrade: websocket");
324             expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
325             List<String> key = requestHeaders.get("Sec-WebSocket-Key");
326             if (key == null || key.isEmpty()) {
327                 throw new IllegalStateException("Sec-WebSocket-Key is missing");
328             }
329             if (key.size() != 1) {
330                 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
331             }
332             MessageDigest sha1 = null;
333             try {
334                 sha1 = MessageDigest.getInstance("SHA-1");
335             } catch (NoSuchAlgorithmException e) {
336                 throw new InternalError(e);
337             }
338             String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
339             sha1.update(x.getBytes(ISO_8859_1));
340             String v = Base64.getEncoder().encodeToString(sha1.digest());
341             response.add("Sec-WebSocket-Accept: " + v);
342 
343             // check authorization credentials, if required by the server
344             if (credentials != null && !authorized(credentials, requestHeaders)) {
345                 response.clear();
346                 response.add("HTTP/1.1 401 Unauthorized");
347                 response.add("Content-Length: 0");
348                 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
349             }
350 
351             return response;
352         };
353     }
354 
355     // Checks credentials in the request against those allowable by the server.
authorized(Credentials credentials, Map<String,List<String>> requestHeaders)356     private static boolean authorized(Credentials credentials,
357                                       Map<String,List<String>> requestHeaders) {
358         List<String> authorization = requestHeaders.get("Authorization");
359         if (authorization == null)
360             return false;
361 
362         if (authorization.size() != 1) {
363             throw new IllegalStateException("Authorization unexpected count:" + authorization);
364         }
365         String header = authorization.get(0);
366         if (!header.startsWith("Basic "))
367             throw new IllegalStateException("Authorization not Basic: " + header);
368 
369         header = header.substring("Basic ".length());
370         String values = new String(Base64.getDecoder().decode(header), UTF_8);
371         int sep = values.indexOf(':');
372         if (sep < 1) {
373             throw new IllegalStateException("Authorization not colon: " +  values);
374         }
375         String name = values.substring(0, sep);
376         String password = values.substring(sep + 1);
377 
378         if (name.equals(credentials.name()) && password.equals(credentials.password()))
379             return true;
380 
381         return false;
382     }
383 
expectHeader(Map<String, List<String>> headers, String name, String value)384     protected static String expectHeader(Map<String, List<String>> headers,
385                                          String name,
386                                          String value) {
387         List<String> v = headers.get(name);
388         if (v == null) {
389             throw new IllegalStateException(
390                     format("Expected '%s' header, not present in %s",
391                            name, headers));
392         }
393         if (!v.contains(value)) {
394             throw new IllegalStateException(
395                     format("Expected '%s: %s', actual: '%s: %s'",
396                            name, value, name, v)
397             );
398         }
399         return value;
400     }
401 
close(AutoCloseable... acs)402     private static void close(AutoCloseable... acs) {
403         for (AutoCloseable ac : acs) {
404             try {
405                 ac.close();
406             } catch (Exception ignored) { }
407         }
408     }
409 }
410