1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @summary Test SSLEngine.begineHandshake() triggers a KeyUpdate handshake
27  * in TLSv1.3
28  * @run main/othervm TLS13BeginHandshake
29  */
30 
31 import javax.net.ssl.KeyManagerFactory;
32 import javax.net.ssl.SSLContext;
33 import javax.net.ssl.SSLEngine;
34 import javax.net.ssl.SSLEngineResult;
35 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
36 import javax.net.ssl.SSLSession;
37 import javax.net.ssl.TrustManagerFactory;
38 import java.io.File;
39 import java.nio.ByteBuffer;
40 import java.security.KeyStore;
41 import java.security.SecureRandom;
42 
43 public class TLS13BeginHandshake {
44     static String pathToStores =
45             System.getProperty("test.src") + "/../../../../javax/net/ssl/etc/";
46     static String keyStoreFile = "keystore";
47     static String passwd = "passphrase";
48 
49     private SSLEngine serverEngine, clientEngine;
50     SSLEngineResult clientResult, serverResult;
51     private ByteBuffer clientOut, clientIn;
52     private ByteBuffer serverOut, serverIn;
53     private ByteBuffer cTOs,sTOc;
54 
main(String args[])55     public static void main(String args[]) throws Exception{
56         new TLS13BeginHandshake().runDemo();
57     }
58 
runDemo()59     private void runDemo() throws Exception {
60         int done = 0;
61 
62         createSSLEngines();
63         createBuffers();
64 
65         while (!isEngineClosed(clientEngine) || !isEngineClosed(serverEngine)) {
66 
67             System.out.println("================");
68             clientResult = clientEngine.wrap(clientOut, cTOs);
69             System.out.println("client wrap: " + clientResult);
70             runDelegatedTasks(clientResult, clientEngine);
71             serverResult = serverEngine.wrap(serverOut, sTOc);
72             System.out.println("server wrap: " + serverResult);
73             runDelegatedTasks(serverResult, serverEngine);
74 
75             cTOs.flip();
76             sTOc.flip();
77 
78             System.out.println("----");
79             clientResult = clientEngine.unwrap(sTOc, clientIn);
80             System.out.println("client unwrap: " + clientResult);
81             if (clientResult.getStatus() == SSLEngineResult.Status.CLOSED) {
82                 break;
83             }            runDelegatedTasks(clientResult, clientEngine);
84             serverResult = serverEngine.unwrap(cTOs, serverIn);
85             System.out.println("server unwrap: " + serverResult);
86             runDelegatedTasks(serverResult, serverEngine);
87 
88             cTOs.compact();
89             sTOc.compact();
90 
91             //System.err.println("so limit="+serverOut.limit()+" so pos="+serverOut.position());
92             //System.out.println("bf ctos limit="+cTOs.limit()+" pos="+cTOs.position()+" cap="+cTOs.capacity());
93             //System.out.println("bf stoc limit="+sTOc.limit()+" pos="+sTOc.position()+" cap="+sTOc.capacity());
94             if (done < 2  && (clientOut.limit() == serverIn.position()) &&
95                     (serverOut.limit() == clientIn.position())) {
96 
97                 if (done == 0) {
98                     checkTransfer(serverOut, clientIn);
99                     checkTransfer(clientOut, serverIn);
100                     clientEngine.beginHandshake();
101                     done++;
102                     continue;
103                 }
104 
105                 checkTransfer(serverOut, clientIn);
106                 checkTransfer(clientOut, serverIn);
107                 System.out.println("\tClosing...");
108                 clientEngine.closeOutbound();
109                 serverEngine.closeOutbound();
110                 done++;
111                 continue;
112             }
113         }
114     }
115 
isEngineClosed(SSLEngine engine)116     private static boolean isEngineClosed(SSLEngine engine) {
117         if (engine.isInboundDone())
118             System.out.println("inbound closed");
119         if (engine.isOutboundDone())
120             System.out.println("outbound closed");
121         return (engine.isOutboundDone() && engine.isInboundDone());
122     }
123 
checkTransfer(ByteBuffer a, ByteBuffer b)124     private static void checkTransfer(ByteBuffer a, ByteBuffer b)
125             throws Exception {
126         a.flip();
127         b.flip();
128 
129         if (!a.equals(b)) {
130             throw new Exception("Data didn't transfer cleanly");
131         } else {
132             System.out.println("\tData transferred cleanly");
133         }
134 
135         a.compact();
136         b.compact();
137 
138     }
createBuffers()139     private void createBuffers() {
140         SSLSession session = clientEngine.getSession();
141         int appBufferMax = session.getApplicationBufferSize();
142         int netBufferMax = session.getPacketBufferSize();
143 
144         clientIn = ByteBuffer.allocate(appBufferMax + 50);
145         serverIn = ByteBuffer.allocate(appBufferMax + 50);
146 
147         cTOs = ByteBuffer.allocateDirect(netBufferMax);
148         sTOc = ByteBuffer.allocateDirect(netBufferMax);
149 
150         clientOut = ByteBuffer.wrap("client".getBytes());
151         serverOut = ByteBuffer.wrap("server".getBytes());
152     }
153 
createSSLEngines()154     private void createSSLEngines() throws Exception {
155         serverEngine = initContext().createSSLEngine();
156         serverEngine.setUseClientMode(false);
157         serverEngine.setNeedClientAuth(true);
158 
159         clientEngine = initContext().createSSLEngine("client", 80);
160         clientEngine.setUseClientMode(true);
161     }
162 
initContext()163     private SSLContext initContext() throws Exception {
164         SSLContext sc = SSLContext.getInstance("TLSv1.3");
165         KeyStore ks = KeyStore.getInstance(new File(pathToStores + keyStoreFile),
166                 passwd.toCharArray());
167         KeyManagerFactory kmf =
168                 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
169         kmf.init(ks, passwd.toCharArray());
170         TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
171         tmf.init(ks);
172         sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
173         return sc;
174     }
175 
runDelegatedTasks(SSLEngineResult result, SSLEngine engine)176     private static void runDelegatedTasks(SSLEngineResult result,
177             SSLEngine engine) throws Exception {
178 
179         if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
180             Runnable runnable;
181             while ((runnable = engine.getDelegatedTask()) != null) {
182                 runnable.run();
183             }
184             HandshakeStatus hsStatus = engine.getHandshakeStatus();
185             if (hsStatus == HandshakeStatus.NEED_TASK) {
186                 throw new Exception(
187                     "handshake shouldn't need additional tasks");
188             }
189         }
190     }
191 }
192