1 /* -*- Mode: C; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
2 /*
3  * SSL Primitives: Public HKDF and AEAD Functions
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #include "blapit.h"
10 #include "keyhi.h"
11 #include "pk11pub.h"
12 #include "sechash.h"
13 #include "ssl.h"
14 #include "sslexp.h"
15 #include "sslerr.h"
16 #include "sslproto.h"
17 
18 #include "sslimpl.h"
19 #include "tls13con.h"
20 #include "tls13hkdf.h"
21 
22 struct SSLAeadContextStr {
23     /* sigh, the API creates a single context, but then uses either encrypt
24      * and decrypt on that context. We should take an encrypt/decrypt
25      * variable here, but for now create two contexts. */
26     PK11Context *encryptContext;
27     PK11Context *decryptContext;
28     int tagLen;
29     int ivLen;
30     unsigned char iv[MAX_IV_LENGTH];
31 };
32 
33 SECStatus
SSLExp_MakeVariantAead(PRUint16 version,PRUint16 cipherSuite,SSLProtocolVariant variant,PK11SymKey * secret,const char * labelPrefix,unsigned int labelPrefixLen,SSLAeadContext ** ctx)34 SSLExp_MakeVariantAead(PRUint16 version, PRUint16 cipherSuite, SSLProtocolVariant variant,
35                        PK11SymKey *secret, const char *labelPrefix,
36                        unsigned int labelPrefixLen, SSLAeadContext **ctx)
37 {
38     SSLAeadContext *out = NULL;
39     char label[255]; // Maximum length label.
40     static const char *const keySuffix = "key";
41     static const char *const ivSuffix = "iv";
42     CK_MECHANISM_TYPE mech;
43     SECItem nullParams = { siBuffer, NULL, 0 };
44     PK11SymKey *key = NULL;
45 
46     PORT_Assert(strlen(keySuffix) >= strlen(ivSuffix));
47     if (secret == NULL || ctx == NULL ||
48         (labelPrefix == NULL && labelPrefixLen > 0) ||
49         labelPrefixLen + strlen(keySuffix) > sizeof(label)) {
50         PORT_SetError(SEC_ERROR_INVALID_ARGS);
51         goto loser;
52     }
53 
54     SSLHashType hash;
55     const ssl3BulkCipherDef *cipher;
56     SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite,
57                                           &hash, &cipher);
58     if (rv != SECSuccess) {
59         goto loser; /* Code already set. */
60     }
61 
62     out = PORT_ZNew(SSLAeadContext);
63     if (out == NULL) {
64         goto loser;
65     }
66     mech = ssl3_Alg2Mech(cipher->calg);
67     out->ivLen = cipher->iv_size + cipher->explicit_nonce_size;
68     out->tagLen = cipher->tag_size;
69 
70     memcpy(label, labelPrefix, labelPrefixLen);
71     memcpy(label + labelPrefixLen, ivSuffix, strlen(ivSuffix));
72     unsigned int labelLen = labelPrefixLen + strlen(ivSuffix);
73     unsigned int ivLen = cipher->iv_size + cipher->explicit_nonce_size;
74     rv = tls13_HkdfExpandLabelRaw(secret, hash,
75                                   NULL, 0, // Handshake hash.
76                                   label, labelLen, variant,
77                                   out->iv, ivLen);
78     if (rv != SECSuccess) {
79         goto loser;
80     }
81 
82     memcpy(label + labelPrefixLen, keySuffix, strlen(keySuffix));
83     labelLen = labelPrefixLen + strlen(keySuffix);
84     rv = tls13_HkdfExpandLabel(secret, hash,
85                                NULL, 0, // Handshake hash.
86                                label, labelLen, mech, cipher->key_size,
87                                variant, &key);
88     if (rv != SECSuccess) {
89         goto loser;
90     }
91 
92     /* We really need to change the API to Create a context for each
93      * encrypt and decrypt rather than a single call that does both. it's
94      * almost certain that the underlying application tries to use the same
95      * context for both. */
96     out->encryptContext = PK11_CreateContextBySymKey(mech,
97                                                      CKA_NSS_MESSAGE | CKA_ENCRYPT,
98                                                      key, &nullParams);
99     if (out->encryptContext == NULL) {
100         goto loser;
101     }
102 
103     out->decryptContext = PK11_CreateContextBySymKey(mech,
104                                                      CKA_NSS_MESSAGE | CKA_DECRYPT,
105                                                      key, &nullParams);
106     if (out->decryptContext == NULL) {
107         goto loser;
108     }
109 
110     PK11_FreeSymKey(key);
111     *ctx = out;
112     return SECSuccess;
113 
114 loser:
115     PK11_FreeSymKey(key);
116     SSLExp_DestroyAead(out);
117     return SECFailure;
118 }
119 
120 SECStatus
SSLExp_MakeAead(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * secret,const char * labelPrefix,unsigned int labelPrefixLen,SSLAeadContext ** ctx)121 SSLExp_MakeAead(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *secret,
122                 const char *labelPrefix, unsigned int labelPrefixLen, SSLAeadContext **ctx)
123 {
124     return SSLExp_MakeVariantAead(version, cipherSuite, ssl_variant_stream, secret,
125                                   labelPrefix, labelPrefixLen, ctx);
126 }
127 
128 SECStatus
SSLExp_DestroyAead(SSLAeadContext * ctx)129 SSLExp_DestroyAead(SSLAeadContext *ctx)
130 {
131     if (!ctx) {
132         return SECSuccess;
133     }
134     if (ctx->encryptContext) {
135         PK11_DestroyContext(ctx->encryptContext, PR_TRUE);
136     }
137     if (ctx->decryptContext) {
138         PK11_DestroyContext(ctx->decryptContext, PR_TRUE);
139     }
140 
141     PORT_ZFree(ctx, sizeof(*ctx));
142     return SECSuccess;
143 }
144 
145 /* Bug 1529440 exists to refactor this and the other AEAD uses. */
146 static SECStatus
ssl_AeadInner(const SSLAeadContext * ctx,PK11Context * context,PRBool decrypt,PRUint64 counter,const PRUint8 * aad,unsigned int aadLen,const PRUint8 * in,unsigned int inLen,PRUint8 * out,unsigned int * outLen,unsigned int maxOut)147 ssl_AeadInner(const SSLAeadContext *ctx, PK11Context *context,
148               PRBool decrypt, PRUint64 counter,
149               const PRUint8 *aad, unsigned int aadLen,
150               const PRUint8 *in, unsigned int inLen,
151               PRUint8 *out, unsigned int *outLen, unsigned int maxOut)
152 {
153     if (ctx == NULL || (aad == NULL && aadLen > 0) || in == NULL ||
154         out == NULL || outLen == NULL) {
155         PORT_SetError(SEC_ERROR_INVALID_ARGS);
156         return SECFailure;
157     }
158 
159     // Setup the nonce.
160     PRUint8 nonce[sizeof(counter)] = { 0 };
161     sslBuffer nonceBuf = SSL_BUFFER_FIXED(nonce, sizeof(counter));
162     SECStatus rv = sslBuffer_AppendNumber(&nonceBuf, counter, sizeof(counter));
163     if (rv != SECSuccess) {
164         PORT_Assert(0);
165         return SECFailure;
166     }
167     /* at least on encrypt, we should not be using CKG_NO_GENERATE, but
168      * the current experimental API has the application tracking the counter
169      * rather than token. We should look at the QUIC code and see if the
170      * counter can be moved internally where it belongs. That would
171      * also get rid of the  formatting code above and have the API
172      * call tls13_AEAD directly in SSLExp_Aead* */
173     return tls13_AEAD(context, decrypt, CKG_NO_GENERATE, 0, ctx->iv, NULL,
174                       ctx->ivLen, nonce, sizeof(counter), aad, aadLen,
175                       out, outLen, maxOut, ctx->tagLen, in, inLen);
176 }
177 
178 SECStatus
SSLExp_AeadEncrypt(const SSLAeadContext * ctx,PRUint64 counter,const PRUint8 * aad,unsigned int aadLen,const PRUint8 * plaintext,unsigned int plaintextLen,PRUint8 * out,unsigned int * outLen,unsigned int maxOut)179 SSLExp_AeadEncrypt(const SSLAeadContext *ctx, PRUint64 counter,
180                    const PRUint8 *aad, unsigned int aadLen,
181                    const PRUint8 *plaintext, unsigned int plaintextLen,
182                    PRUint8 *out, unsigned int *outLen, unsigned int maxOut)
183 {
184     // false == encrypt
185     return ssl_AeadInner(ctx, ctx->encryptContext, PR_FALSE, counter,
186                          aad, aadLen, plaintext, plaintextLen,
187                          out, outLen, maxOut);
188 }
189 
190 SECStatus
SSLExp_AeadDecrypt(const SSLAeadContext * ctx,PRUint64 counter,const PRUint8 * aad,unsigned int aadLen,const PRUint8 * ciphertext,unsigned int ciphertextLen,PRUint8 * out,unsigned int * outLen,unsigned int maxOut)191 SSLExp_AeadDecrypt(const SSLAeadContext *ctx, PRUint64 counter,
192                    const PRUint8 *aad, unsigned int aadLen,
193                    const PRUint8 *ciphertext, unsigned int ciphertextLen,
194                    PRUint8 *out, unsigned int *outLen, unsigned int maxOut)
195 {
196     // true == decrypt
197     return ssl_AeadInner(ctx, ctx->decryptContext, PR_TRUE, counter,
198                          aad, aadLen, ciphertext, ciphertextLen,
199                          out, outLen, maxOut);
200 }
201 
202 SECStatus
SSLExp_HkdfExtract(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * salt,PK11SymKey * ikm,PK11SymKey ** keyp)203 SSLExp_HkdfExtract(PRUint16 version, PRUint16 cipherSuite,
204                    PK11SymKey *salt, PK11SymKey *ikm, PK11SymKey **keyp)
205 {
206     if (keyp == NULL) {
207         PORT_SetError(SEC_ERROR_INVALID_ARGS);
208         return SECFailure;
209     }
210 
211     SSLHashType hash;
212     SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite,
213                                           &hash, NULL);
214     if (rv != SECSuccess) {
215         return SECFailure; /* Code already set. */
216     }
217     return tls13_HkdfExtract(salt, ikm, hash, keyp);
218 }
219 
220 SECStatus
SSLExp_HkdfExpandLabel(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * prk,const PRUint8 * hsHash,unsigned int hsHashLen,const char * label,unsigned int labelLen,PK11SymKey ** keyp)221 SSLExp_HkdfExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk,
222                        const PRUint8 *hsHash, unsigned int hsHashLen,
223                        const char *label, unsigned int labelLen, PK11SymKey **keyp)
224 {
225     return SSLExp_HkdfVariantExpandLabel(version, cipherSuite, prk, hsHash, hsHashLen,
226                                          label, labelLen, ssl_variant_stream, keyp);
227 }
228 
229 SECStatus
SSLExp_HkdfVariantExpandLabel(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * prk,const PRUint8 * hsHash,unsigned int hsHashLen,const char * label,unsigned int labelLen,SSLProtocolVariant variant,PK11SymKey ** keyp)230 SSLExp_HkdfVariantExpandLabel(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk,
231                               const PRUint8 *hsHash, unsigned int hsHashLen,
232                               const char *label, unsigned int labelLen,
233                               SSLProtocolVariant variant, PK11SymKey **keyp)
234 {
235     if (prk == NULL || keyp == NULL ||
236         label == NULL || labelLen == 0) {
237         PORT_SetError(SEC_ERROR_INVALID_ARGS);
238         return SECFailure;
239     }
240 
241     SSLHashType hash;
242     SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite,
243                                           &hash, NULL);
244     if (rv != SECSuccess) {
245         return SECFailure; /* Code already set. */
246     }
247     return tls13_HkdfExpandLabel(prk, hash, hsHash, hsHashLen, label, labelLen,
248                                  CKM_HKDF_DERIVE,
249                                  tls13_GetHashSizeForHash(hash), variant, keyp);
250 }
251 
252 SECStatus
SSLExp_HkdfExpandLabelWithMech(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * prk,const PRUint8 * hsHash,unsigned int hsHashLen,const char * label,unsigned int labelLen,CK_MECHANISM_TYPE mech,unsigned int keySize,PK11SymKey ** keyp)253 SSLExp_HkdfExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk,
254                                const PRUint8 *hsHash, unsigned int hsHashLen,
255                                const char *label, unsigned int labelLen,
256                                CK_MECHANISM_TYPE mech, unsigned int keySize,
257                                PK11SymKey **keyp)
258 {
259     return SSLExp_HkdfVariantExpandLabelWithMech(version, cipherSuite, prk, hsHash, hsHashLen,
260                                                  label, labelLen, mech, keySize,
261                                                  ssl_variant_stream, keyp);
262 }
263 
264 SECStatus
SSLExp_HkdfVariantExpandLabelWithMech(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * prk,const PRUint8 * hsHash,unsigned int hsHashLen,const char * label,unsigned int labelLen,CK_MECHANISM_TYPE mech,unsigned int keySize,SSLProtocolVariant variant,PK11SymKey ** keyp)265 SSLExp_HkdfVariantExpandLabelWithMech(PRUint16 version, PRUint16 cipherSuite, PK11SymKey *prk,
266                                       const PRUint8 *hsHash, unsigned int hsHashLen,
267                                       const char *label, unsigned int labelLen,
268                                       CK_MECHANISM_TYPE mech, unsigned int keySize,
269                                       SSLProtocolVariant variant, PK11SymKey **keyp)
270 {
271     if (prk == NULL || keyp == NULL ||
272         label == NULL || labelLen == 0 ||
273         mech == CKM_INVALID_MECHANISM || keySize == 0) {
274         PORT_SetError(SEC_ERROR_INVALID_ARGS);
275         return SECFailure;
276     }
277 
278     SSLHashType hash;
279     SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite,
280                                           &hash, NULL);
281     if (rv != SECSuccess) {
282         return SECFailure; /* Code already set. */
283     }
284     return tls13_HkdfExpandLabel(prk, hash, hsHash, hsHashLen, label, labelLen,
285                                  mech, keySize, variant, keyp);
286 }
287 
288 SECStatus
ssl_CreateMaskingContextInner(PRUint16 version,PRUint16 cipherSuite,SSLProtocolVariant variant,PK11SymKey * secret,const char * label,unsigned int labelLen,SSLMaskingContext ** ctx)289 ssl_CreateMaskingContextInner(PRUint16 version, PRUint16 cipherSuite,
290                               SSLProtocolVariant variant,
291                               PK11SymKey *secret,
292                               const char *label,
293                               unsigned int labelLen,
294                               SSLMaskingContext **ctx)
295 {
296     if (!secret || !ctx || (!label && labelLen)) {
297         PORT_SetError(SEC_ERROR_INVALID_ARGS);
298         return SECFailure;
299     }
300 
301     SSLMaskingContext *out = PORT_ZNew(SSLMaskingContext);
302     if (out == NULL) {
303         goto loser;
304     }
305 
306     SSLHashType hash;
307     const ssl3BulkCipherDef *cipher;
308     SECStatus rv = tls13_GetHashAndCipher(version, cipherSuite,
309                                           &hash, &cipher);
310     if (rv != SECSuccess) {
311         PORT_SetError(SEC_ERROR_INVALID_ARGS);
312         goto loser; /* Code already set. */
313     }
314 
315     out->mech = tls13_SequenceNumberEncryptionMechanism(cipher->calg);
316     if (out->mech == CKM_INVALID_MECHANISM) {
317         PORT_SetError(SEC_ERROR_INVALID_ARGS);
318         goto loser;
319     }
320 
321     // Derive the masking key
322     rv = tls13_HkdfExpandLabel(secret, hash,
323                                NULL, 0, // Handshake hash.
324                                label, labelLen,
325                                out->mech,
326                                cipher->key_size, variant,
327                                &out->secret);
328     if (rv != SECSuccess) {
329         goto loser;
330     }
331 
332     out->version = version;
333     out->cipherSuite = cipherSuite;
334 
335     *ctx = out;
336     return SECSuccess;
337 loser:
338     SSLExp_DestroyMaskingContext(out);
339     return SECFailure;
340 }
341 
342 SECStatus
ssl_CreateMaskInner(SSLMaskingContext * ctx,const PRUint8 * sample,unsigned int sampleLen,PRUint8 * outMask,unsigned int maskLen)343 ssl_CreateMaskInner(SSLMaskingContext *ctx, const PRUint8 *sample,
344                     unsigned int sampleLen, PRUint8 *outMask,
345                     unsigned int maskLen)
346 {
347     if (!ctx || !sample || !sampleLen || !outMask || !maskLen) {
348         PORT_SetError(SEC_ERROR_INVALID_ARGS);
349         return SECFailure;
350     }
351 
352     if (ctx->secret == NULL) {
353         PORT_SetError(SEC_ERROR_NO_KEY);
354         return SECFailure;
355     }
356 
357     SECStatus rv = SECFailure;
358     unsigned int outMaskLen = 0;
359     int paramLen = 0;
360 
361     /* Internal output len/buf, for use if the caller allocated and requested
362      * less than one block of output. |oneBlock| should have size equal to the
363      * largest block size supported below. */
364     PRUint8 oneBlock[AES_BLOCK_SIZE];
365     PRUint8 *outMask_ = outMask;
366     unsigned int maskLen_ = maskLen;
367 
368     switch (ctx->mech) {
369         case CKM_AES_ECB:
370             if (sampleLen < AES_BLOCK_SIZE) {
371                 PORT_SetError(SEC_ERROR_INVALID_ARGS);
372                 return SECFailure;
373             }
374             if (maskLen_ < AES_BLOCK_SIZE) {
375                 outMask_ = oneBlock;
376                 maskLen_ = sizeof(oneBlock);
377             }
378             rv = PK11_Encrypt(ctx->secret,
379                               ctx->mech,
380                               NULL,
381                               outMask_, &outMaskLen, maskLen_,
382                               sample, AES_BLOCK_SIZE);
383             if (rv == SECSuccess &&
384                 maskLen < AES_BLOCK_SIZE) {
385                 memcpy(outMask, outMask_, maskLen);
386             }
387             break;
388         case CKM_NSS_CHACHA20_CTR:
389             paramLen = 16;
390         /* fall through */
391         case CKM_CHACHA20:
392             paramLen = (paramLen) ? paramLen : sizeof(CK_CHACHA20_PARAMS);
393             if (sampleLen < paramLen) {
394                 PORT_SetError(SEC_ERROR_INVALID_ARGS);
395                 return SECFailure;
396             }
397 
398             SECItem param;
399             param.type = siBuffer;
400             param.len = paramLen;
401             param.data = (PRUint8 *)sample; // const-cast :(
402             unsigned char zeros[128] = { 0 };
403 
404             if (maskLen > sizeof(zeros)) {
405                 PORT_SetError(SEC_ERROR_OUTPUT_LEN);
406                 return SECFailure;
407             }
408 
409             rv = PK11_Encrypt(ctx->secret,
410                               ctx->mech,
411                               &param,
412                               outMask, &outMaskLen,
413                               maskLen,
414                               zeros, maskLen);
415             break;
416         default:
417             PORT_SetError(SEC_ERROR_INVALID_ARGS);
418             return SECFailure;
419     }
420 
421     if (rv != SECSuccess) {
422         PORT_SetError(SEC_ERROR_PKCS11_FUNCTION_FAILED);
423         return SECFailure;
424     }
425 
426     // Ensure we produced at least as much material as requested.
427     if (outMaskLen < maskLen) {
428         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
429         return SECFailure;
430     }
431 
432     return SECSuccess;
433 }
434 
435 SECStatus
ssl_DestroyMaskingContextInner(SSLMaskingContext * ctx)436 ssl_DestroyMaskingContextInner(SSLMaskingContext *ctx)
437 {
438     if (!ctx) {
439         return SECSuccess;
440     }
441 
442     PK11_FreeSymKey(ctx->secret);
443     PORT_ZFree(ctx, sizeof(*ctx));
444     return SECSuccess;
445 }
446 
447 SECStatus
SSLExp_CreateMask(SSLMaskingContext * ctx,const PRUint8 * sample,unsigned int sampleLen,PRUint8 * outMask,unsigned int maskLen)448 SSLExp_CreateMask(SSLMaskingContext *ctx, const PRUint8 *sample,
449                   unsigned int sampleLen, PRUint8 *outMask,
450                   unsigned int maskLen)
451 {
452     return ssl_CreateMaskInner(ctx, sample, sampleLen, outMask, maskLen);
453 }
454 
455 SECStatus
SSLExp_CreateMaskingContext(PRUint16 version,PRUint16 cipherSuite,PK11SymKey * secret,const char * label,unsigned int labelLen,SSLMaskingContext ** ctx)456 SSLExp_CreateMaskingContext(PRUint16 version, PRUint16 cipherSuite,
457                             PK11SymKey *secret,
458                             const char *label,
459                             unsigned int labelLen,
460                             SSLMaskingContext **ctx)
461 {
462     return ssl_CreateMaskingContextInner(version, cipherSuite, ssl_variant_stream, secret,
463                                          label, labelLen, ctx);
464 }
465 
466 SECStatus
SSLExp_CreateVariantMaskingContext(PRUint16 version,PRUint16 cipherSuite,SSLProtocolVariant variant,PK11SymKey * secret,const char * label,unsigned int labelLen,SSLMaskingContext ** ctx)467 SSLExp_CreateVariantMaskingContext(PRUint16 version, PRUint16 cipherSuite,
468                                    SSLProtocolVariant variant,
469                                    PK11SymKey *secret,
470                                    const char *label,
471                                    unsigned int labelLen,
472                                    SSLMaskingContext **ctx)
473 {
474     return ssl_CreateMaskingContextInner(version, cipherSuite, variant, secret,
475                                          label, labelLen, ctx);
476 }
477 
478 SECStatus
SSLExp_DestroyMaskingContext(SSLMaskingContext * ctx)479 SSLExp_DestroyMaskingContext(SSLMaskingContext *ctx)
480 {
481     return ssl_DestroyMaskingContextInner(ctx);
482 }
483