1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 // SunJSSE does not support dynamic system properties, no way to re-use
25 // system properties in samevm/agentvm mode.
26 
27 /*
28  * @test
29  * @bug 8211806
30  * @summary TLS 1.3 handshake server name indication is missing on a session resume
31  * @run main/othervm ResumeTLS13withSNI
32  */
33 
34 import javax.net.ssl.*;
35 import javax.net.ssl.SSLEngineResult.*;
36 import java.io.*;
37 import java.security.*;
38 import java.nio.*;
39 import java.util.List;
40 
41 public class ResumeTLS13withSNI {
42 
43     /*
44      * Enables logging of the SSLEngine operations.
45      */
46     private static final boolean logging = false;
47 
48     /*
49      * Enables the JSSE system debugging system property:
50      *
51      *     -Djavax.net.debug=ssl:handshake
52      *
53      * This gives a lot of low-level information about operations underway,
54      * including specific handshake messages, and might be best examined
55      * after gaining some familiarity with this application.
56      */
57     private static final boolean debug = true;
58 
59     private static final ByteBuffer clientOut =
60             ByteBuffer.wrap("Hi Server, I'm Client".getBytes());
61     private static final ByteBuffer serverOut =
62             ByteBuffer.wrap("Hello Client, I'm Server".getBytes());
63 
64     /*
65      * The following is to set up the keystores.
66      */
67     private static final String pathToStores = "../etc";
68     private static final String keyStoreFile = "keystore";
69     private static final String trustStoreFile = "truststore";
70     private static final char[] passphrase = "passphrase".toCharArray();
71 
72     private static final String keyFilename =
73             System.getProperty("test.src", ".") + "/" + pathToStores +
74                 "/" + keyStoreFile;
75     private static final String trustFilename =
76             System.getProperty("test.src", ".") + "/" + pathToStores +
77                 "/" + trustStoreFile;
78 
79     private static final String HOST_NAME = "arf.yak.foo";
80     private static final SNIHostName SNI_NAME = new SNIHostName(HOST_NAME);
81     private static final SNIMatcher SNI_MATCHER =
82             SNIHostName.createSNIMatcher("arf\\.yak\\.foo");
83 
84     /*
85      * Main entry point for this test.
86      */
main(String args[])87     public static void main(String args[]) throws Exception {
88         if (debug) {
89             System.setProperty("javax.net.debug", "ssl:handshake");
90         }
91 
92         KeyManagerFactory kmf = makeKeyManagerFactory(keyFilename,
93                 passphrase);
94         TrustManagerFactory tmf = makeTrustManagerFactory(trustFilename,
95                 passphrase);
96 
97         SSLContext sslCtx = SSLContext.getInstance("TLS");
98         sslCtx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
99 
100         // Make client and server engines, then customize as needed
101         SSLEngine clientEngine = makeEngine(sslCtx, kmf, tmf, true);
102         SSLParameters cliSSLParams = clientEngine.getSSLParameters();
103         cliSSLParams.setServerNames(List.of(SNI_NAME));
104         clientEngine.setSSLParameters(cliSSLParams);
105         clientEngine.setEnabledProtocols(new String[] { "TLSv1.3" });
106 
107         SSLEngine serverEngine = makeEngine(sslCtx, kmf, tmf, false);
108         SSLParameters servSSLParams = serverEngine.getSSLParameters();
109         servSSLParams.setSNIMatchers(List.of(SNI_MATCHER));
110         serverEngine.setSSLParameters(servSSLParams);
111 
112         initialHandshake(clientEngine, serverEngine);
113 
114         // Create a new client-side engine which can initiate TLS session
115         // resumption
116         SSLEngine newCliEngine = makeEngine(sslCtx, kmf, tmf, true);
117         newCliEngine.setEnabledProtocols(new String[] { "TLSv1.3" });
118         ByteBuffer resCliHello = getResumptionClientHello(newCliEngine);
119 
120         dumpBuffer("Resumed ClientHello Data", resCliHello);
121 
122         // Parse the client hello message and make sure it is a resumption
123         // hello and has SNI in it.
124         checkResumedClientHelloSNI(resCliHello);
125     }
126 
127     /*
128      * Run the test.
129      *
130      * Sit in a tight loop, both engines calling wrap/unwrap regardless
131      * of whether data is available or not.  We do this until both engines
132      * report back they are closed.
133      *
134      * The main loop handles all of the I/O phases of the SSLEngine's
135      * lifetime:
136      *
137      *     initial handshaking
138      *     application data transfer
139      *     engine closing
140      *
141      * One could easily separate these phases into separate
142      * sections of code.
143      */
initialHandshake(SSLEngine clientEngine, SSLEngine serverEngine)144     private static void initialHandshake(SSLEngine clientEngine,
145             SSLEngine serverEngine) throws Exception {
146         boolean dataDone = false;
147 
148         // Create all the buffers
149         SSLSession session = clientEngine.getSession();
150         int appBufferMax = session.getApplicationBufferSize();
151         int netBufferMax = session.getPacketBufferSize();
152         ByteBuffer clientIn = ByteBuffer.allocate(appBufferMax + 50);
153         ByteBuffer serverIn = ByteBuffer.allocate(appBufferMax + 50);
154         ByteBuffer cTOs = ByteBuffer.allocateDirect(netBufferMax);
155         ByteBuffer sTOc = ByteBuffer.allocateDirect(netBufferMax);
156 
157         // results from client's last operation
158         SSLEngineResult clientResult;
159 
160         // results from server's last operation
161         SSLEngineResult serverResult;
162 
163         /*
164          * Examining the SSLEngineResults could be much more involved,
165          * and may alter the overall flow of the application.
166          *
167          * For example, if we received a BUFFER_OVERFLOW when trying
168          * to write to the output pipe, we could reallocate a larger
169          * pipe, but instead we wait for the peer to drain it.
170          */
171         Exception clientException = null;
172         Exception serverException = null;
173 
174         while (!dataDone) {
175             log("================");
176 
177             try {
178                 clientResult = clientEngine.wrap(clientOut, cTOs);
179                 log("client wrap: ", clientResult);
180             } catch (Exception e) {
181                 clientException = e;
182                 System.err.println("Client wrap() threw: " + e.getMessage());
183             }
184             logEngineStatus(clientEngine);
185             runDelegatedTasks(clientEngine);
186 
187             log("----");
188 
189             try {
190                 serverResult = serverEngine.wrap(serverOut, sTOc);
191                 log("server wrap: ", serverResult);
192             } catch (Exception e) {
193                 serverException = e;
194                 System.err.println("Server wrap() threw: " + e.getMessage());
195             }
196             logEngineStatus(serverEngine);
197             runDelegatedTasks(serverEngine);
198 
199             cTOs.flip();
200             sTOc.flip();
201 
202             log("--------");
203 
204             try {
205                 clientResult = clientEngine.unwrap(sTOc, clientIn);
206                 log("client unwrap: ", clientResult);
207             } catch (Exception e) {
208                 clientException = e;
209                 System.err.println("Client unwrap() threw: " + e.getMessage());
210             }
211             logEngineStatus(clientEngine);
212             runDelegatedTasks(clientEngine);
213 
214             log("----");
215 
216             try {
217                 serverResult = serverEngine.unwrap(cTOs, serverIn);
218                 log("server unwrap: ", serverResult);
219             } catch (Exception e) {
220                 serverException = e;
221                 System.err.println("Server unwrap() threw: " + e.getMessage());
222             }
223             logEngineStatus(serverEngine);
224             runDelegatedTasks(serverEngine);
225 
226             cTOs.compact();
227             sTOc.compact();
228 
229             /*
230              * After we've transfered all application data between the client
231              * and server, we close the clientEngine's outbound stream.
232              * This generates a close_notify handshake message, which the
233              * server engine receives and responds by closing itself.
234              */
235             if (!dataDone && (clientOut.limit() == serverIn.position()) &&
236                     (serverOut.limit() == clientIn.position())) {
237 
238                 /*
239                  * A sanity check to ensure we got what was sent.
240                  */
241                 checkTransfer(serverOut, clientIn);
242                 checkTransfer(clientOut, serverIn);
243 
244                 dataDone = true;
245             }
246         }
247     }
248 
249     /**
250      * The goal of this function is to start a simple TLS session resumption
251      * and get the client hello message data back so it can be inspected.
252      *
253      * @param clientEngine
254      *
255      * @return a ByteBuffer consisting of the ClientHello TLS record.
256      *
257      * @throws Exception if any processing goes wrong.
258      */
getResumptionClientHello(SSLEngine clientEngine)259     private static ByteBuffer getResumptionClientHello(SSLEngine clientEngine)
260             throws Exception {
261         // Create all the buffers
262         SSLSession session = clientEngine.getSession();
263         int appBufferMax = session.getApplicationBufferSize();
264         int netBufferMax = session.getPacketBufferSize();
265         ByteBuffer cTOs = ByteBuffer.allocateDirect(netBufferMax);
266         Exception clientException = null;
267 
268         // results from client's last operation
269         SSLEngineResult clientResult;
270 
271         // results from server's last operation
272         SSLEngineResult serverResult;
273 
274         log("================");
275 
276         // Start by having the client create a new ClientHello.  It should
277         // contain PSK info that allows it to attempt session resumption.
278         try {
279             clientResult = clientEngine.wrap(clientOut, cTOs);
280             log("client wrap: ", clientResult);
281         } catch (Exception e) {
282             clientException = e;
283             System.err.println("Client wrap() threw: " + e.getMessage());
284         }
285         logEngineStatus(clientEngine);
286         runDelegatedTasks(clientEngine);
287 
288         log("----");
289 
290         cTOs.flip();
291         return cTOs;
292     }
293 
294     /**
295      * This method walks a ClientHello TLS record, looking for a matching
296      * server_name hostname value from the original handshake and a PSK
297      * extension, which indicates (in the context of this test) that this
298      * is a resumed handshake.
299      *
300      * @param resCliHello a ByteBuffer consisting of a complete TLS handshake
301      *      record that is a ClientHello message.  The position of the buffer
302      *      must be at the beginning of the TLS record header.
303      *
304      * @throws Exception if any of the consistency checks for the TLS record,
305      *      or handshake message fails.  It will also throw an exception if
306      *      either the server_name extension doesn't have a matching hostname
307      *      field or the pre_shared_key extension is not present.
308      */
checkResumedClientHelloSNI(ByteBuffer resCliHello)309     private static void checkResumedClientHelloSNI(ByteBuffer resCliHello)
310             throws Exception {
311         boolean foundMatchingSNI = false;
312         boolean foundPSK = false;
313 
314         // Advance past the following fields:
315         // TLS Record header (5 bytes)
316         resCliHello.position(resCliHello.position() + 5);
317 
318         // Get the next byte and make sure it is a Client Hello
319         byte hsMsgType = resCliHello.get();
320         if (hsMsgType != 0x01) {
321             throw new Exception("Message is not a ClientHello, MsgType = " +
322                     hsMsgType);
323         }
324 
325         // Skip past the length (3 bytes)
326         resCliHello.position(resCliHello.position() + 3);
327 
328         // Protocol version should be TLSv1.2 (0x03, 0x03)
329         short chProto = resCliHello.getShort();
330         if (chProto != 0x0303) {
331             throw new Exception(
332                     "Client Hello protocol version is not TLSv1.2: Got " +
333                             String.format("0x%04X", chProto));
334         }
335 
336         // Skip 32-bytes of random data
337         resCliHello.position(resCliHello.position() + 32);
338 
339         // Get the legacy session length and skip that many bytes
340         int sessIdLen = Byte.toUnsignedInt(resCliHello.get());
341         resCliHello.position(resCliHello.position() + sessIdLen);
342 
343         // Skip over all the cipher suites
344         int csLen = Short.toUnsignedInt(resCliHello.getShort());
345         resCliHello.position(resCliHello.position() + csLen);
346 
347         // Skip compression methods
348         int compLen = Byte.toUnsignedInt(resCliHello.get());
349         resCliHello.position(resCliHello.position() + compLen);
350 
351         // Parse the extensions.  Get length first, then walk the extensions
352         // List and look for the presence of the PSK extension and server_name.
353         // For server_name, make sure it is the same as what was provided
354         // in the original handshake.
355         System.err.println("ClientHello Extensions Check");
356         int extListLen = Short.toUnsignedInt(resCliHello.getShort());
357         while (extListLen > 0) {
358             // Get the Extension type and length
359             int extType = Short.toUnsignedInt(resCliHello.getShort());
360             int extLen = Short.toUnsignedInt(resCliHello.getShort());
361             switch (extType) {
362                 case 0:                 // server_name
363                     System.err.println("* Found server_name Extension");
364                     int snListLen = Short.toUnsignedInt(resCliHello.getShort());
365                     while (snListLen > 0) {
366                         int nameType = Byte.toUnsignedInt(resCliHello.get());
367                         if (nameType == 0) {            // host_name
368                             int hostNameLen =
369                                     Short.toUnsignedInt(resCliHello.getShort());
370                             byte[] hostNameData = new byte[hostNameLen];
371                             resCliHello.get(hostNameData);
372                             String hostNameStr = new String(hostNameData);
373                             System.err.println("\tHostname: " + hostNameStr);
374                             if (hostNameStr.equals(HOST_NAME)) {
375                                 foundMatchingSNI = true;
376                             }
377                             snListLen -= 3 + hostNameLen;   // type, len, data
378                         } else {                        // something else
379                             // We don't support anything else and cannot
380                             // know how to advance.  Throw an exception
381                             throw new Exception("Unknown server name type: " +
382                                     nameType);
383                         }
384                     }
385                     break;
386                 case 41:                // pre_shared_key
387                     // We're not going to bother checking the value.  The
388                     // presence of the extension in the context of this test
389                     // is good enough to tell us this is a resumed ClientHello.
390                     foundPSK = true;
391                     System.err.println("* Found pre_shared_key Extension");
392                     resCliHello.position(resCliHello.position() + extLen);
393                     break;
394                 default:
395                     System.err.format("* Found extension %d (%d bytes)\n",
396                             extType, extLen);
397                     resCliHello.position(resCliHello.position() + extLen);
398                     break;
399             }
400             extListLen -= extLen + 4;   // Ext type(2), length(2), data(var.)
401         }
402 
403         // At the end of all the extension processing, either we've found
404         // both extensions and the server_name matches our expected value
405         // or we throw an exception.
406         if (!foundMatchingSNI) {
407             throw new Exception("Could not find a matching server_name");
408         } else if (!foundPSK) {
409             throw new Exception("Missing PSK extension, not a resumption?");
410         }
411     }
412 
413     /**
414      * Create a TrustManagerFactory from a given keystore.
415      *
416      * @param tsPath the path to the trust store file.
417      * @param pass the password for the trust store.
418      *
419      * @return a new TrustManagerFactory built from the trust store provided.
420      *
421      * @throws GeneralSecurityException if any processing errors occur
422      *      with the Keystore instantiation or TrustManagerFactory creation.
423      * @throws IOException if any loading error with the trust store occurs.
424      */
makeTrustManagerFactory(String tsPath, char[] pass)425     private static TrustManagerFactory makeTrustManagerFactory(String tsPath,
426             char[] pass) throws GeneralSecurityException, IOException {
427         TrustManagerFactory tmf;
428         KeyStore ts = KeyStore.getInstance("JKS");
429 
430         try (FileInputStream fsIn = new FileInputStream(tsPath)) {
431             ts.load(fsIn, pass);
432             tmf = TrustManagerFactory.getInstance("SunX509");
433             tmf.init(ts);
434         }
435         return tmf;
436     }
437 
438     /**
439      * Create a KeyManagerFactory from a given keystore.
440      *
441      * @param ksPath the path to the keystore file.
442      * @param pass the password for the keystore.
443      *
444      * @return a new TrustManagerFactory built from the keystore provided.
445      *
446      * @throws GeneralSecurityException if any processing errors occur
447      *      with the Keystore instantiation or KeyManagerFactory creation.
448      * @throws IOException if any loading error with the keystore occurs
449      */
makeKeyManagerFactory(String ksPath, char[] pass)450     private static KeyManagerFactory makeKeyManagerFactory(String ksPath,
451             char[] pass) throws GeneralSecurityException, IOException {
452         KeyManagerFactory kmf;
453         KeyStore ks = KeyStore.getInstance("JKS");
454 
455         try (FileInputStream fsIn = new FileInputStream(ksPath)) {
456             ks.load(fsIn, pass);
457             kmf = KeyManagerFactory.getInstance("SunX509");
458             kmf.init(ks, pass);
459         }
460         return kmf;
461     }
462 
463     /**
464      * Create an SSLEngine instance from a given protocol specifier,
465      * KeyManagerFactory and TrustManagerFactory.
466      *
467      * @param ctx the SSLContext used to create the SSLEngine
468      * @param kmf an initialized KeyManagerFactory.  May be null.
469      * @param tmf an initialized TrustManagerFactory.  May be null.
470      * @param isClient true if it intended to create a client engine, false
471      *      for a server engine.
472      *
473      * @return an SSLEngine instance configured as a server and with client
474      *      authentication disabled.
475      *
476      * @throws GeneralSecurityException if any errors occur during the
477      *      creation of the SSLEngine.
478      */
makeEngine(SSLContext ctx, KeyManagerFactory kmf, TrustManagerFactory tmf, boolean isClient)479     private static SSLEngine makeEngine(SSLContext ctx,
480             KeyManagerFactory kmf, TrustManagerFactory tmf, boolean isClient)
481             throws GeneralSecurityException {
482         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
483         SSLEngine ssle = ctx.createSSLEngine("localhost", 8443);
484         ssle.setUseClientMode(isClient);
485         ssle.setNeedClientAuth(false);
486         return ssle;
487     }
488 
logEngineStatus(SSLEngine engine)489     private static void logEngineStatus(SSLEngine engine) {
490         log("\tCurrent HS State  " + engine.getHandshakeStatus().toString());
491         log("\tisInboundDone():  " + engine.isInboundDone());
492         log("\tisOutboundDone(): " + engine.isOutboundDone());
493     }
494 
495     /*
496      * If the result indicates that we have outstanding tasks to do,
497      * go ahead and run them in this thread.
498      */
runDelegatedTasks(SSLEngine engine)499     private static void runDelegatedTasks(SSLEngine engine) throws Exception {
500 
501         if (engine.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
502             Runnable runnable;
503             while ((runnable = engine.getDelegatedTask()) != null) {
504                 log("    running delegated task...");
505                 runnable.run();
506             }
507             HandshakeStatus hsStatus = engine.getHandshakeStatus();
508             if (hsStatus == HandshakeStatus.NEED_TASK) {
509                 throw new Exception(
510                     "handshake shouldn't need additional tasks");
511             }
512             logEngineStatus(engine);
513         }
514     }
515 
isEngineClosed(SSLEngine engine)516     private static boolean isEngineClosed(SSLEngine engine) {
517         return (engine.isOutboundDone() && engine.isInboundDone());
518     }
519 
520     /*
521      * Simple check to make sure everything came across as expected.
522      */
checkTransfer(ByteBuffer a, ByteBuffer b)523     private static void checkTransfer(ByteBuffer a, ByteBuffer b)
524             throws Exception {
525         a.flip();
526         b.flip();
527 
528         if (!a.equals(b)) {
529             throw new Exception("Data didn't transfer cleanly");
530         } else {
531             log("\tData transferred cleanly");
532         }
533 
534         a.position(a.limit());
535         b.position(b.limit());
536         a.limit(a.capacity());
537         b.limit(b.capacity());
538     }
539 
540     /*
541      * Logging code
542      */
543     private static boolean resultOnce = true;
544 
log(String str, SSLEngineResult result)545     private static void log(String str, SSLEngineResult result) {
546         if (!logging) {
547             return;
548         }
549         if (resultOnce) {
550             resultOnce = false;
551             System.err.println("The format of the SSLEngineResult is: \n" +
552                     "\t\"getStatus() / getHandshakeStatus()\" +\n" +
553                     "\t\"bytesConsumed() / bytesProduced()\"\n");
554         }
555         HandshakeStatus hsStatus = result.getHandshakeStatus();
556         log(str +
557                 result.getStatus() + "/" + hsStatus + ", " +
558                 result.bytesConsumed() + "/" + result.bytesProduced() +
559                 " bytes");
560         if (hsStatus == HandshakeStatus.FINISHED) {
561             log("\t...ready for application data");
562         }
563     }
564 
log(String str)565     private static void log(String str) {
566         if (logging) {
567             System.err.println(str);
568         }
569     }
570 
dumpBuffer(String header, ByteBuffer data)571     private static void dumpBuffer(String header, ByteBuffer data) {
572         data.mark();
573         System.err.format("========== %s ==========\n", header);
574         int i = 0;
575         while (data.remaining() > 0) {
576             if (i != 0 && i % 16 == 0) {
577                 System.err.print("\n");
578             }
579             System.err.format("%02X ", data.get());
580             i++;
581         }
582         System.err.println();
583         data.reset();
584     }
585 
586 }
587