1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 8206929 8212885
27  * @summary ensure that client only resumes a session if certain properties
28  *    of the session are compatible with the new connection
29  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 ResumeChecksClient BASIC
30  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 ResumeChecksClient BASIC
31  * @run main/othervm ResumeChecksClient BASIC
32  * @run main/othervm ResumeChecksClient VERSION_2_TO_3
33  * @run main/othervm ResumeChecksClient VERSION_3_TO_2
34  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 ResumeChecksClient CIPHER_SUITE
35  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 ResumeChecksClient SIGNATURE_SCHEME
36  *
37  */
38 
39 import javax.net.*;
40 import javax.net.ssl.*;
41 import java.io.*;
42 import java.security.*;
43 import java.net.*;
44 import java.util.*;
45 
46 public class ResumeChecksClient {
47 
48     static String pathToStores = "../../../../javax/net/ssl/etc";
49     static String keyStoreFile = "keystore";
50     static String trustStoreFile = "truststore";
51     static String passwd = "passphrase";
52 
53     enum TestMode {
54         BASIC,
55         VERSION_2_TO_3,
56         VERSION_3_TO_2,
57         CIPHER_SUITE,
58         SIGNATURE_SCHEME
59     }
60 
main(String[] args)61     public static void main(String[] args) throws Exception {
62 
63         TestMode mode = TestMode.valueOf(args[0]);
64 
65         String keyFilename =
66             System.getProperty("test.src", "./") + "/" + pathToStores +
67                 "/" + keyStoreFile;
68         String trustFilename =
69             System.getProperty("test.src", "./") + "/" + pathToStores +
70                 "/" + trustStoreFile;
71 
72         System.setProperty("javax.net.ssl.keyStore", keyFilename);
73         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
74         System.setProperty("javax.net.ssl.trustStore", trustFilename);
75         System.setProperty("javax.net.ssl.trustStorePassword", passwd);
76 
77         Server server = startServer();
78         server.signal();
79         SSLContext sslContext = SSLContext.getDefault();
80         while (!server.started) {
81             Thread.yield();
82         }
83         SSLSession firstSession = connect(sslContext, server.port, mode, false);
84 
85         server.signal();
86         long secondStartTime = System.currentTimeMillis();
87         Thread.sleep(10);
88         SSLSession secondSession = connect(sslContext, server.port, mode, true);
89 
90         server.go = false;
91         server.signal();
92 
93         switch (mode) {
94         case BASIC:
95             // fail if session is not resumed
96             checkResumedSession(firstSession, secondSession);
97             break;
98         case VERSION_2_TO_3:
99         case VERSION_3_TO_2:
100         case CIPHER_SUITE:
101         case SIGNATURE_SCHEME:
102             // fail if a new session is not created
103             if (secondSession.getCreationTime() <= secondStartTime) {
104                 throw new RuntimeException("Existing session was used");
105             }
106             break;
107         default:
108             throw new RuntimeException("unknown mode: " + mode);
109         }
110     }
111 
112     private static class NoSig implements AlgorithmConstraints {
113 
114         private final String alg;
115 
NoSig(String alg)116         NoSig(String alg) {
117             this.alg = alg;
118         }
119 
120 
test(String a)121         private boolean test(String a) {
122             return !a.toLowerCase().contains(alg.toLowerCase());
123         }
124 
125         @Override
permits(Set<CryptoPrimitive> primitives, Key key)126         public boolean permits(Set<CryptoPrimitive> primitives, Key key) {
127             return true;
128         }
129         @Override
permits(Set<CryptoPrimitive> primitives, String algorithm, AlgorithmParameters parameters)130         public boolean permits(Set<CryptoPrimitive> primitives,
131             String algorithm, AlgorithmParameters parameters) {
132 
133             return test(algorithm);
134         }
135         @Override
permits(Set<CryptoPrimitive> primitives, String algorithm, Key key, AlgorithmParameters parameters)136         public boolean permits(Set<CryptoPrimitive> primitives,
137             String algorithm, Key key, AlgorithmParameters parameters) {
138 
139             return test(algorithm);
140         }
141     }
142 
connect(SSLContext sslContext, int port, TestMode mode, boolean second)143     private static SSLSession connect(SSLContext sslContext, int port,
144         TestMode mode, boolean second) {
145 
146         try {
147             SSLSocket sock = (SSLSocket)
148                 sslContext.getSocketFactory().createSocket();
149             SSLParameters params = sock.getSSLParameters();
150 
151             switch (mode) {
152             case BASIC:
153                 // do nothing to ensure resumption works
154                 break;
155             case VERSION_2_TO_3:
156                 if (second) {
157                     params.setProtocols(new String[] {"TLSv1.3"});
158                 } else {
159                     params.setProtocols(new String[] {"TLSv1.2"});
160                 }
161                 break;
162             case VERSION_3_TO_2:
163                 if (second) {
164                     params.setProtocols(new String[] {"TLSv1.2"});
165                 } else {
166                     params.setProtocols(new String[] {"TLSv1.3"});
167                 }
168                 break;
169             case CIPHER_SUITE:
170                 if (second) {
171                     params.setCipherSuites(
172                         new String[] {"TLS_AES_256_GCM_SHA384"});
173                 } else {
174                     params.setCipherSuites(
175                         new String[] {"TLS_AES_128_GCM_SHA256"});
176                 }
177                 break;
178             case SIGNATURE_SCHEME:
179                 AlgorithmConstraints constraints =
180                     params.getAlgorithmConstraints();
181                 if (second) {
182                     params.setAlgorithmConstraints(new NoSig("ecdsa"));
183                 } else {
184                     params.setAlgorithmConstraints(new NoSig("rsa"));
185                 }
186                 break;
187             default:
188                 throw new RuntimeException("unknown mode: " + mode);
189             }
190             sock.setSSLParameters(params);
191             sock.connect(new InetSocketAddress("localhost", port));
192             PrintWriter out = new PrintWriter(
193                 new OutputStreamWriter(sock.getOutputStream()));
194             out.println("message");
195             out.flush();
196             BufferedReader reader = new BufferedReader(
197                 new InputStreamReader(sock.getInputStream()));
198             String inMsg = reader.readLine();
199             System.out.println("Client received: " + inMsg);
200             SSLSession result = sock.getSession();
201             sock.close();
202             return result;
203         } catch (Exception ex) {
204             // unexpected exception
205             throw new RuntimeException(ex);
206         }
207     }
208 
checkResumedSession(SSLSession initSession, SSLSession resSession)209     private static void checkResumedSession(SSLSession initSession,
210             SSLSession resSession) throws Exception {
211         StringBuilder diffLog = new StringBuilder();
212 
213         // Initial and resumed SSLSessions should have the same creation
214         // times so they get invalidated together.
215         long initCt = initSession.getCreationTime();
216         long resumeCt = resSession.getCreationTime();
217         if (initCt != resumeCt) {
218             diffLog.append("Session creation time is different. Initial: ").
219                     append(initCt).append(", Resumed: ").append(resumeCt).
220                     append("\n");
221         }
222 
223         // Ensure that peer and local certificate lists are preserved
224         if (!Arrays.equals(initSession.getLocalCertificates(),
225                 resSession.getLocalCertificates())) {
226             diffLog.append("Local certificate mismatch between initial " +
227                     "and resumed sessions\n");
228         }
229 
230         if (!Arrays.equals(initSession.getPeerCertificates(),
231                 resSession.getPeerCertificates())) {
232             diffLog.append("Peer certificate mismatch between initial " +
233                     "and resumed sessions\n");
234         }
235 
236         // Buffer sizes should also be the same
237         if (initSession.getApplicationBufferSize() !=
238                 resSession.getApplicationBufferSize()) {
239             diffLog.append(String.format(
240                     "App Buffer sizes differ: Init: %d, Res: %d\n",
241                     initSession.getApplicationBufferSize(),
242                     resSession.getApplicationBufferSize()));
243         }
244 
245         if (initSession.getPacketBufferSize() !=
246                 resSession.getPacketBufferSize()) {
247             diffLog.append(String.format(
248                     "Packet Buffer sizes differ: Init: %d, Res: %d\n",
249                     initSession.getPacketBufferSize(),
250                     resSession.getPacketBufferSize()));
251         }
252 
253         // Cipher suite should match
254         if (!initSession.getCipherSuite().equals(
255                 resSession.getCipherSuite())) {
256             diffLog.append(String.format(
257                     "CipherSuite does not match - Init: %s, Res: %s\n",
258                     initSession.getCipherSuite(), resSession.getCipherSuite()));
259         }
260 
261         // Peer host/port should match
262         if (!initSession.getPeerHost().equals(resSession.getPeerHost()) ||
263                 initSession.getPeerPort() != resSession.getPeerPort()) {
264             diffLog.append(String.format(
265                     "Host/Port mismatch - Init: %s/%d, Res: %s/%d\n",
266                     initSession.getPeerHost(), initSession.getPeerPort(),
267                     resSession.getPeerHost(), resSession.getPeerPort()));
268         }
269 
270         // Check protocol
271         if (!initSession.getProtocol().equals(resSession.getProtocol())) {
272             diffLog.append(String.format(
273                     "Protocol mismatch - Init: %s, Res: %s\n",
274                     initSession.getProtocol(), resSession.getProtocol()));
275         }
276 
277         // If the StringBuilder has any data in it then one of the checks
278         // above failed and we should throw an exception.
279         if (diffLog.length() > 0) {
280             throw new RuntimeException(diffLog.toString());
281         }
282     }
283 
startServer()284     private static Server startServer() {
285         Server server = new Server();
286         new Thread(server).start();
287         return server;
288     }
289 
290     private static class Server implements Runnable {
291 
292         public volatile boolean go = true;
293         private boolean signal = false;
294         public volatile int port = 0;
295         public volatile boolean started = false;
296 
waitForSignal()297         private synchronized void waitForSignal() {
298             while (!signal) {
299                 try {
300                     wait();
301                 } catch (InterruptedException ex) {
302                     // do nothing
303                 }
304             }
305             signal = false;
306         }
signal()307         public synchronized void signal() {
308             signal = true;
309             notify();
310         }
311 
312         @Override
run()313         public void run() {
314             try {
315 
316                 SSLContext sc = SSLContext.getDefault();
317                 ServerSocketFactory fac = sc.getServerSocketFactory();
318                 SSLServerSocket ssock = (SSLServerSocket)
319                     fac.createServerSocket(0);
320                 this.port = ssock.getLocalPort();
321 
322                 waitForSignal();
323                 started = true;
324                 while (go) {
325                     try {
326                         System.out.println("Waiting for connection");
327                         Socket sock = ssock.accept();
328                         BufferedReader reader = new BufferedReader(
329                             new InputStreamReader(sock.getInputStream()));
330                         String line = reader.readLine();
331                         System.out.println("server read: " + line);
332                         PrintWriter out = new PrintWriter(
333                             new OutputStreamWriter(sock.getOutputStream()));
334                         out.println(line);
335                         out.flush();
336                         waitForSignal();
337                     } catch (Exception ex) {
338                         ex.printStackTrace();
339                     }
340                 }
341             } catch (Exception ex) {
342                 throw new RuntimeException(ex);
343             }
344         }
345     }
346 }
347