1 // Copyright 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 package com.google.security.cryptauth.lib.securegcm;
16 
17 import com.google.protobuf.InvalidProtocolBufferException;
18 import com.google.security.cryptauth.lib.securegcm.DeviceToDeviceMessagesProto.DeviceToDeviceMessage;
19 import com.google.security.cryptauth.lib.securegcm.DeviceToDeviceMessagesProto.InitiatorHello;
20 import com.google.security.cryptauth.lib.securegcm.DeviceToDeviceMessagesProto.ResponderHello;
21 import com.google.security.cryptauth.lib.securegcm.TransportCryptoOps.Payload;
22 import com.google.security.cryptauth.lib.securegcm.TransportCryptoOps.PayloadType;
23 import com.google.security.cryptauth.lib.securemessage.PublicKeyProtoUtil;
24 import java.security.InvalidKeyException;
25 import java.security.KeyPair;
26 import java.security.NoSuchAlgorithmException;
27 import java.security.PublicKey;
28 import java.security.SignatureException;
29 import java.security.spec.InvalidKeySpecException;
30 import javax.crypto.SecretKey;
31 
32 /**
33  * Implements an unauthenticated EC Diffie Hellman Key Exchange Handshake
34  * <p>
35  * Initiator sends an InitiatorHello, which is a protobuf that contains a public key. Responder
36  * sends a responder hello, which a signed and encrypted message containing a payload, and a public
37  * key in the unencrypted header (payload is encrypted with the derived DH key).
38  * <p>
39  * Example Usage:
40  * <pre>
41  *    // initiator:
42  *    D2DHandshakeContext initiatorHandshakeContext =
43  *        D2DDiffieHellmanKeyExchangeHandshake.forInitiator();
44  *    byte[] initiatorHello = initiatorHandshakeContext.getNextHandshakeMessage();
45  *    // (send initiatorHello to responder)
46  *
47  *    // responder:
48  *    D2DHandshakeContext responderHandshakeContext =
49  *        D2DDiffieHellmanKeyExchangeHandshake.forResponder();
50  *    responderHandshakeContext.parseHandshakeMessage(initiatorHello);
51  *    byte[] responderHelloAndPayload = responderHandshakeContext.getNextHandshakeMessage(
52  *        toBytes(RESPONDER_HELLO_MESSAGE));
53  *    D2DConnectionContext responderCtx = responderHandshakeContext.toConnectionContext();
54  *    // (send responderHelloAndPayload to initiator)
55  *
56  *    // initiator
57  *    byte[] messageFromPayload =
58  *        initiatorHandshakeContext.parseHandshakeMessage(responderHelloAndPayload);
59  *    if (messageFromPayload.length > 0) {
60  *      handle(messageFromPayload);
61  *    }
62  *
63  *    D2DConnectionContext initiatorCtx = initiatorHandshakeContext.toConnectionContext();
64  * </pre>
65  */
66 public class D2DDiffieHellmanKeyExchangeHandshake implements D2DHandshakeContext {
67   private KeyPair ourKeyPair;
68   private PublicKey theirPublicKey;
69   private SecretKey initiatorEncodeKey;
70   private SecretKey responderEncodeKey;
71   private State handshakeState;
72   private boolean isInitiator;
73   private int protocolVersionToUse;
74 
75   private enum State {
76     // Initiator state
77     INITIATOR_START,
78     INITIATOR_WAITING_FOR_RESPONDER_HELLO,
79 
80     // Responder state
81     RESPONDER_START,
82     RESPONDER_AFTER_INITIATOR_HELLO,
83 
84     // Common completion state
85     HANDSHAKE_FINISHED,
86     HANDSHAKE_ALREADY_USED
87   }
88 
D2DDiffieHellmanKeyExchangeHandshake(State state)89   private D2DDiffieHellmanKeyExchangeHandshake(State state) {
90     ourKeyPair = PublicKeyProtoUtil.generateEcP256KeyPair();
91     theirPublicKey = null;
92     initiatorEncodeKey = null;
93     responderEncodeKey = null;
94     handshakeState = state;
95     isInitiator = state == State.INITIATOR_START;
96     protocolVersionToUse = D2DConnectionContextV1.PROTOCOL_VERSION;
97   }
98 
99   /**
100    * Creates a new Diffie Hellman handshake context for the handshake initiator
101    */
forInitiator()102   public static D2DDiffieHellmanKeyExchangeHandshake forInitiator() {
103     return new D2DDiffieHellmanKeyExchangeHandshake(State.INITIATOR_START);
104   }
105 
106   /**
107    * Creates a new Diffie Hellman handshake context for the handshake responder
108    */
forResponder()109   public static D2DDiffieHellmanKeyExchangeHandshake forResponder() {
110     return new D2DDiffieHellmanKeyExchangeHandshake(State.RESPONDER_START);
111   }
112 
113   @Override
isHandshakeComplete()114   public boolean isHandshakeComplete() {
115     return handshakeState == State.HANDSHAKE_FINISHED
116         || handshakeState == State.HANDSHAKE_ALREADY_USED;
117   }
118 
119   @Override
getNextHandshakeMessage()120   public byte[] getNextHandshakeMessage() throws HandshakeException {
121     switch(handshakeState) {
122       case INITIATOR_START:
123         handshakeState = State.INITIATOR_WAITING_FOR_RESPONDER_HELLO;
124         return InitiatorHello.newBuilder()
125             .setPublicDhKey(PublicKeyProtoUtil.encodePublicKey(ourKeyPair.getPublic()))
126             .setProtocolVersion(protocolVersionToUse)
127             .build()
128             .toByteArray();
129 
130       case RESPONDER_AFTER_INITIATOR_HELLO:
131         byte[] responderHello = makeResponderHelloWithPayload(new byte[0]);
132         handshakeState = State.HANDSHAKE_FINISHED;
133         return responderHello;
134 
135       default:
136         throw new HandshakeException("Cannot get next message in state: " + handshakeState);
137     }
138   }
139 
140   @Override
canSendPayloadInHandshakeMessage()141   public boolean canSendPayloadInHandshakeMessage() {
142     return handshakeState == State.RESPONDER_AFTER_INITIATOR_HELLO;
143   }
144 
145   @Override
getNextHandshakeMessage(byte[] payload)146   public byte[] getNextHandshakeMessage(byte[] payload) throws HandshakeException {
147     if (handshakeState != State.RESPONDER_AFTER_INITIATOR_HELLO) {
148       throw new HandshakeException(
149           "Cannot get next message with payload in state: " + handshakeState);
150     }
151 
152     byte[] responderHello = makeResponderHelloWithPayload(payload);
153     handshakeState = State.HANDSHAKE_FINISHED;
154 
155     return responderHello;
156   }
157 
makeResponderHelloWithPayload(byte[] payload)158   private byte[] makeResponderHelloWithPayload(byte[] payload) throws HandshakeException {
159     if (payload == null) {
160       throw new HandshakeException("Not expecting null payload");
161     }
162 
163     try {
164       SecretKey masterKey =
165           EnrollmentCryptoOps.doKeyAgreement(ourKeyPair.getPrivate(), theirPublicKey);
166 
167       // V0 uses the same key for encoding and decoding, but V1 uses separate keys.
168       switch (protocolVersionToUse) {
169         case D2DConnectionContextV0.PROTOCOL_VERSION:
170           initiatorEncodeKey = masterKey;
171           responderEncodeKey = masterKey;
172           break;
173         case D2DConnectionContextV1.PROTOCOL_VERSION:
174           initiatorEncodeKey = D2DCryptoOps.deriveNewKeyForPurpose(masterKey,
175               D2DCryptoOps.INITIATOR_PURPOSE);
176           responderEncodeKey = D2DCryptoOps.deriveNewKeyForPurpose(masterKey,
177               D2DCryptoOps.RESPONDER_PURPOSE);
178           break;
179         default:
180           throw new IllegalStateException("Unexpected protocol version: " + protocolVersionToUse);
181       }
182 
183       DeviceToDeviceMessage deviceToDeviceMessage =
184           D2DConnectionContext.createDeviceToDeviceMessage(payload, 1 /* sequence number */);
185 
186       return D2DCryptoOps.signcryptMessageAndResponderHello(
187           new Payload(PayloadType.DEVICE_TO_DEVICE_RESPONDER_HELLO_PAYLOAD,
188               deviceToDeviceMessage.toByteArray()),
189           responderEncodeKey,
190           ourKeyPair.getPublic(),
191           protocolVersionToUse);
192     } catch (InvalidKeyException|NoSuchAlgorithmException e) {
193       throw new HandshakeException(e);
194     }
195   }
196 
197   @Override
parseHandshakeMessage(byte[] handshakeMessage)198   public byte[] parseHandshakeMessage(byte[] handshakeMessage) throws HandshakeException {
199     if (handshakeMessage == null || handshakeMessage.length == 0) {
200       throw new HandshakeException("Handshake message too short");
201     }
202 
203     switch(handshakeState) {
204       case INITIATOR_WAITING_FOR_RESPONDER_HELLO:
205           byte[] payload = parseResponderHello(handshakeMessage);
206           handshakeState = State.HANDSHAKE_FINISHED;
207           return payload;
208 
209       case RESPONDER_START:
210           parseInitiatorHello(handshakeMessage);
211           handshakeState = State.RESPONDER_AFTER_INITIATOR_HELLO;
212           return new byte[0];
213 
214       default:
215         throw new HandshakeException("Cannot parse message in state: " + handshakeState);
216     }
217   }
218 
parseResponderHello(byte[] responderHello)219   private byte[] parseResponderHello(byte[] responderHello) throws HandshakeException {
220      try {
221         ResponderHello responderHelloProto =
222             D2DCryptoOps.parseAndValidateResponderHello(responderHello);
223 
224         // Downgrade to protocol version 0 if needed for backwards compatibility.
225         int protocolVersion = responderHelloProto.getProtocolVersion();
226         if (protocolVersion == D2DConnectionContextV0.PROTOCOL_VERSION) {
227           protocolVersionToUse = D2DConnectionContextV0.PROTOCOL_VERSION;
228         }
229 
230         SecretKey masterKey = D2DCryptoOps.deriveSharedKeyFromGenericPublicKey(
231             ourKeyPair.getPrivate(), responderHelloProto.getPublicDhKey());
232 
233         // V0 uses the same key for encoding and decoding, but V1 uses separate keys.
234         if (protocolVersionToUse == D2DConnectionContextV0.PROTOCOL_VERSION) {
235           initiatorEncodeKey = masterKey;
236           responderEncodeKey = masterKey;
237         } else {
238           initiatorEncodeKey = D2DCryptoOps.deriveNewKeyForPurpose(masterKey,
239               D2DCryptoOps.INITIATOR_PURPOSE);
240           responderEncodeKey = D2DCryptoOps.deriveNewKeyForPurpose(masterKey,
241               D2DCryptoOps.RESPONDER_PURPOSE);
242         }
243 
244         DeviceToDeviceMessage message =
245             D2DCryptoOps.decryptResponderHelloMessage(responderEncodeKey, responderHello);
246 
247         if (message.getSequenceNumber() != 1) {
248           throw new HandshakeException("Incorrect sequence number in responder hello");
249         }
250 
251         return message.getMessage().toByteArray();
252       } catch (SignatureException | InvalidProtocolBufferException
253                | NoSuchAlgorithmException | InvalidKeyException e) {
254         throw new HandshakeException(e);
255       }
256   }
257 
parseInitiatorHello(byte[] initiatorHello)258   private void parseInitiatorHello(byte[] initiatorHello) throws HandshakeException {
259     try {
260         InitiatorHello initiatorHelloProto = InitiatorHello.parseFrom(initiatorHello);
261 
262         if (!initiatorHelloProto.hasPublicDhKey()) {
263           throw new HandshakeException("Missing public key in initiator hello");
264         }
265 
266         theirPublicKey = PublicKeyProtoUtil.parsePublicKey(initiatorHelloProto.getPublicDhKey());
267 
268         // Downgrade to protocol version 0 if needed for backwards compatibility.
269         int protocolVersion = initiatorHelloProto.getProtocolVersion();
270         if (protocolVersion == D2DConnectionContextV0.PROTOCOL_VERSION) {
271           protocolVersionToUse = D2DConnectionContextV0.PROTOCOL_VERSION;
272         }
273       } catch (InvalidKeySpecException | InvalidProtocolBufferException e) {
274         throw new HandshakeException(e);
275       }
276   }
277 
278   @Override
toConnectionContext()279   public D2DConnectionContext toConnectionContext() throws HandshakeException {
280     if (handshakeState == State.HANDSHAKE_ALREADY_USED) {
281       throw new HandshakeException("Cannot reuse handshake context; is has already been used");
282     }
283 
284     if (!isHandshakeComplete()) {
285       throw new HandshakeException("Handshake is not complete; cannot create connection context");
286     }
287 
288     handshakeState = State.HANDSHAKE_ALREADY_USED;
289 
290     if (protocolVersionToUse == D2DConnectionContextV0.PROTOCOL_VERSION) {
291       // Both sides start with an initial sequence number of 1 because the last message of the
292       // handshake had an optional payload with sequence number 1.  D2DConnectionContext remembers
293       // the last sequence number used by each side.
294       // Note: initiatorEncodeKey == responderEncodeKey
295       return new D2DConnectionContextV0(initiatorEncodeKey, 1 /** initialSequenceNumber */);
296     } else {
297       SecretKey encodeKey = isInitiator ? initiatorEncodeKey : responderEncodeKey;
298       SecretKey decodeKey = isInitiator ? responderEncodeKey : initiatorEncodeKey;
299       // Only the responder sends a DeviceToDeviceMessage during the handshake, so it has an initial
300       // sequence number of 1.  The initiator will therefore have an initial sequence number of 0.
301       int initialEncodeSequenceNumber = isInitiator ? 0 : 1;
302       int initialDecodeSequenceNumber = isInitiator ? 1 : 0;
303       return new D2DConnectionContextV1(
304           encodeKey, decodeKey, initialEncodeSequenceNumber, initialDecodeSequenceNumber);
305     }
306   }
307 }
308