1 /*
2  * Copyright (c) 2018, 2019 Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 8206929 8212885
27  * @summary ensure that client only resumes a session if certain properties
28  *    of the session are compatible with the new connection
29  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=false ResumeChecksClient BASIC
30  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=false ResumeChecksClient BASIC
31  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
32  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
33  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
34  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
35  * @run main/othervm -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
36  * @run main/othervm -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient VERSION_2_TO_3
37  * @run main/othervm -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient VERSION_3_TO_2
38  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient CIPHER_SUITE
39  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient SIGNATURE_SCHEME
40  *
41  */
42 
43 import javax.net.*;
44 import javax.net.ssl.*;
45 import java.io.*;
46 import java.security.*;
47 import java.net.*;
48 import java.util.*;
49 
50 public class ResumeChecksClient {
51 
52     static String pathToStores = "../../../../javax/net/ssl/etc";
53     static String keyStoreFile = "keystore";
54     static String trustStoreFile = "truststore";
55     static String passwd = "passphrase";
56 
57     enum TestMode {
58         BASIC,
59         VERSION_2_TO_3,
60         VERSION_3_TO_2,
61         CIPHER_SUITE,
62         SIGNATURE_SCHEME
63     }
64 
main(String[] args)65     public static void main(String[] args) throws Exception {
66 
67         TestMode mode = TestMode.valueOf(args[0]);
68 
69         String keyFilename =
70             System.getProperty("test.src", "./") + "/" + pathToStores +
71                 "/" + keyStoreFile;
72         String trustFilename =
73             System.getProperty("test.src", "./") + "/" + pathToStores +
74                 "/" + trustStoreFile;
75 
76         System.setProperty("javax.net.ssl.keyStore", keyFilename);
77         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
78         System.setProperty("javax.net.ssl.trustStore", trustFilename);
79         System.setProperty("javax.net.ssl.trustStorePassword", passwd);
80 
81         Server server = startServer();
82         server.signal();
83         SSLContext sslContext = SSLContext.getDefault();
84         while (!server.started) {
85             Thread.yield();
86         }
87         SSLSession firstSession = connect(sslContext, server.port, mode, false);
88 
89         server.signal();
90         long secondStartTime = System.currentTimeMillis();
91         Thread.sleep(10);
92         SSLSession secondSession = connect(sslContext, server.port, mode, true);
93 
94         server.go = false;
95         server.signal();
96 
97         switch (mode) {
98         case BASIC:
99             // fail if session is not resumed
100             checkResumedSession(firstSession, secondSession);
101             break;
102         case VERSION_2_TO_3:
103         case VERSION_3_TO_2:
104         case CIPHER_SUITE:
105         case SIGNATURE_SCHEME:
106             // fail if a new session is not created
107             if (secondSession.getCreationTime() <= secondStartTime) {
108                 throw new RuntimeException("Existing session was used");
109             }
110             break;
111         default:
112             throw new RuntimeException("unknown mode: " + mode);
113         }
114     }
115 
116     private static class NoSig implements AlgorithmConstraints {
117 
118         private final String alg;
119 
NoSig(String alg)120         NoSig(String alg) {
121             this.alg = alg;
122         }
123 
124 
test(String a)125         private boolean test(String a) {
126             return !a.toLowerCase().contains(alg.toLowerCase());
127         }
128 
129         @Override
permits(Set<CryptoPrimitive> primitives, Key key)130         public boolean permits(Set<CryptoPrimitive> primitives, Key key) {
131             return true;
132         }
133         @Override
permits(Set<CryptoPrimitive> primitives, String algorithm, AlgorithmParameters parameters)134         public boolean permits(Set<CryptoPrimitive> primitives,
135             String algorithm, AlgorithmParameters parameters) {
136 
137             return test(algorithm);
138         }
139         @Override
permits(Set<CryptoPrimitive> primitives, String algorithm, Key key, AlgorithmParameters parameters)140         public boolean permits(Set<CryptoPrimitive> primitives,
141             String algorithm, Key key, AlgorithmParameters parameters) {
142 
143             return test(algorithm);
144         }
145     }
146 
connect(SSLContext sslContext, int port, TestMode mode, boolean second)147     private static SSLSession connect(SSLContext sslContext, int port,
148         TestMode mode, boolean second) {
149 
150         try {
151             SSLSocket sock = (SSLSocket)
152                 sslContext.getSocketFactory().createSocket();
153             SSLParameters params = sock.getSSLParameters();
154 
155             switch (mode) {
156             case BASIC:
157                 // do nothing to ensure resumption works
158                 break;
159             case VERSION_2_TO_3:
160                 if (second) {
161                     params.setProtocols(new String[] {"TLSv1.3"});
162                 } else {
163                     params.setProtocols(new String[] {"TLSv1.2"});
164                 }
165                 break;
166             case VERSION_3_TO_2:
167                 if (second) {
168                     params.setProtocols(new String[] {"TLSv1.2"});
169                 } else {
170                     params.setProtocols(new String[] {"TLSv1.3"});
171                 }
172                 break;
173             case CIPHER_SUITE:
174                 if (second) {
175                     params.setCipherSuites(
176                         new String[] {"TLS_AES_256_GCM_SHA384"});
177                 } else {
178                     params.setCipherSuites(
179                         new String[] {"TLS_AES_128_GCM_SHA256"});
180                 }
181                 break;
182             case SIGNATURE_SCHEME:
183                 AlgorithmConstraints constraints =
184                     params.getAlgorithmConstraints();
185                 if (second) {
186                     params.setAlgorithmConstraints(new NoSig("ecdsa"));
187                 } else {
188                     params.setAlgorithmConstraints(new NoSig("rsa"));
189                 }
190                 break;
191             default:
192                 throw new RuntimeException("unknown mode: " + mode);
193             }
194             sock.setSSLParameters(params);
195             sock.connect(new InetSocketAddress("localhost", port));
196             PrintWriter out = new PrintWriter(
197                 new OutputStreamWriter(sock.getOutputStream()));
198             out.println("message");
199             out.flush();
200             BufferedReader reader = new BufferedReader(
201                 new InputStreamReader(sock.getInputStream()));
202             String inMsg = reader.readLine();
203             System.out.println("Client received: " + inMsg);
204             SSLSession result = sock.getSession();
205             sock.close();
206             return result;
207         } catch (Exception ex) {
208             // unexpected exception
209             throw new RuntimeException(ex);
210         }
211     }
212 
checkResumedSession(SSLSession initSession, SSLSession resSession)213     private static void checkResumedSession(SSLSession initSession,
214             SSLSession resSession) throws Exception {
215         StringBuilder diffLog = new StringBuilder();
216 
217         // Initial and resumed SSLSessions should have the same creation
218         // times so they get invalidated together.
219         long initCt = initSession.getCreationTime();
220         long resumeCt = resSession.getCreationTime();
221         if (initCt != resumeCt) {
222             diffLog.append("Session creation time is different. Initial: ").
223                     append(initCt).append(", Resumed: ").append(resumeCt).
224                     append("\n");
225         }
226 
227         // Ensure that peer and local certificate lists are preserved
228         if (!Arrays.equals(initSession.getLocalCertificates(),
229                 resSession.getLocalCertificates())) {
230             diffLog.append("Local certificate mismatch between initial " +
231                     "and resumed sessions\n");
232         }
233 
234         if (!Arrays.equals(initSession.getPeerCertificates(),
235                 resSession.getPeerCertificates())) {
236             diffLog.append("Peer certificate mismatch between initial " +
237                     "and resumed sessions\n");
238         }
239 
240         // Buffer sizes should also be the same
241         if (initSession.getApplicationBufferSize() !=
242                 resSession.getApplicationBufferSize()) {
243             diffLog.append(String.format(
244                     "App Buffer sizes differ: Init: %d, Res: %d\n",
245                     initSession.getApplicationBufferSize(),
246                     resSession.getApplicationBufferSize()));
247         }
248 
249         if (initSession.getPacketBufferSize() !=
250                 resSession.getPacketBufferSize()) {
251             diffLog.append(String.format(
252                     "Packet Buffer sizes differ: Init: %d, Res: %d\n",
253                     initSession.getPacketBufferSize(),
254                     resSession.getPacketBufferSize()));
255         }
256 
257         // Cipher suite should match
258         if (!initSession.getCipherSuite().equals(
259                 resSession.getCipherSuite())) {
260             diffLog.append(String.format(
261                     "CipherSuite does not match - Init: %s, Res: %s\n",
262                     initSession.getCipherSuite(), resSession.getCipherSuite()));
263         }
264 
265         // Peer host/port should match
266         if (!initSession.getPeerHost().equals(resSession.getPeerHost()) ||
267                 initSession.getPeerPort() != resSession.getPeerPort()) {
268             diffLog.append(String.format(
269                     "Host/Port mismatch - Init: %s/%d, Res: %s/%d\n",
270                     initSession.getPeerHost(), initSession.getPeerPort(),
271                     resSession.getPeerHost(), resSession.getPeerPort()));
272         }
273 
274         // Check protocol
275         if (!initSession.getProtocol().equals(resSession.getProtocol())) {
276             diffLog.append(String.format(
277                     "Protocol mismatch - Init: %s, Res: %s\n",
278                     initSession.getProtocol(), resSession.getProtocol()));
279         }
280 
281         // If the StringBuilder has any data in it then one of the checks
282         // above failed and we should throw an exception.
283         if (diffLog.length() > 0) {
284             throw new RuntimeException(diffLog.toString());
285         }
286     }
287 
startServer()288     private static Server startServer() {
289         Server server = new Server();
290         new Thread(server).start();
291         return server;
292     }
293 
294     private static class Server implements Runnable {
295 
296         public volatile boolean go = true;
297         private boolean signal = false;
298         public volatile int port = 0;
299         public volatile boolean started = false;
300 
waitForSignal()301         private synchronized void waitForSignal() {
302             while (!signal) {
303                 try {
304                     wait();
305                 } catch (InterruptedException ex) {
306                     // do nothing
307                 }
308             }
309             signal = false;
310         }
signal()311         public synchronized void signal() {
312             signal = true;
313             notify();
314         }
315 
316         @Override
run()317         public void run() {
318             try {
319 
320                 SSLContext sc = SSLContext.getDefault();
321                 ServerSocketFactory fac = sc.getServerSocketFactory();
322                 SSLServerSocket ssock = (SSLServerSocket)
323                     fac.createServerSocket(0);
324                 this.port = ssock.getLocalPort();
325 
326                 waitForSignal();
327                 started = true;
328                 while (go) {
329                     try {
330                         System.out.println("Waiting for connection");
331                         Socket sock = ssock.accept();
332                         BufferedReader reader = new BufferedReader(
333                             new InputStreamReader(sock.getInputStream()));
334                         String line = reader.readLine();
335                         System.out.println("server read: " + line);
336                         PrintWriter out = new PrintWriter(
337                             new OutputStreamWriter(sock.getOutputStream()));
338                         out.println(line);
339                         out.flush();
340                         waitForSignal();
341                     } catch (Exception ex) {
342                         ex.printStackTrace();
343                     }
344                 }
345             } catch (Exception ex) {
346                 throw new RuntimeException(ex);
347             }
348         }
349     }
350 }
351