1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 /* This file contains functions for frobbing the internals of libssl */
8 #include "libssl_internals.h"
9 
10 #include "nss.h"
11 #include "pk11pub.h"
12 #include "seccomon.h"
13 #include "selfencrypt.h"
14 
SSLInt_IncrementClientHandshakeVersion(PRFileDesc * fd)15 SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) {
16   sslSocket *ss = ssl_FindSocket(fd);
17   if (!ss) {
18     return SECFailure;
19   }
20 
21   ++ss->clientHelloVersion;
22 
23   return SECSuccess;
24 }
25 
26 /* Use this function to update the ClientRandom of a client's handshake state
27  * after replacing its ClientHello message. We for example need to do this
28  * when replacing an SSLv3 ClientHello with its SSLv2 equivalent. */
SSLInt_UpdateSSLv2ClientRandom(PRFileDesc * fd,uint8_t * rnd,size_t rnd_len,uint8_t * msg,size_t msg_len)29 SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
30                                          size_t rnd_len, uint8_t *msg,
31                                          size_t msg_len) {
32   sslSocket *ss = ssl_FindSocket(fd);
33   if (!ss) {
34     return SECFailure;
35   }
36 
37   ssl3_RestartHandshakeHashes(ss);
38 
39   // Ensure we don't overrun hs.client_random.
40   rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
41 
42   // Zero the client_random.
43   PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
44 
45   // Copy over the challenge bytes.
46   size_t offset = SSL3_RANDOM_LENGTH - rnd_len;
47   PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len);
48 
49   // Rehash the SSLv2 client hello message.
50   return ssl3_UpdateHandshakeHashes(ss, msg, msg_len);
51 }
52 
SSLInt_ExtensionNegotiated(PRFileDesc * fd,PRUint16 ext)53 PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) {
54   sslSocket *ss = ssl_FindSocket(fd);
55   return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext));
56 }
57 
SSLInt_ClearSelfEncryptKey()58 void SSLInt_ClearSelfEncryptKey() { ssl_ResetSelfEncryptKeys(); }
59 
60 sslSelfEncryptKeys *ssl_GetSelfEncryptKeysInt();
61 
SSLInt_SetSelfEncryptMacKey(PK11SymKey * key)62 void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key) {
63   sslSelfEncryptKeys *keys = ssl_GetSelfEncryptKeysInt();
64 
65   PK11_FreeSymKey(keys->macKey);
66   keys->macKey = key;
67 }
68 
SSLInt_SetMTU(PRFileDesc * fd,PRUint16 mtu)69 SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
70   sslSocket *ss = ssl_FindSocket(fd);
71   if (!ss) {
72     return SECFailure;
73   }
74   ss->ssl3.mtu = mtu;
75   ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */
76   return SECSuccess;
77 }
78 
SSLInt_CountCipherSpecs(PRFileDesc * fd)79 PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) {
80   PRCList *cur_p;
81   PRInt32 ct = 0;
82 
83   sslSocket *ss = ssl_FindSocket(fd);
84   if (!ss) {
85     return -1;
86   }
87 
88   for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
89        cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
90     ++ct;
91   }
92   return ct;
93 }
94 
SSLInt_PrintCipherSpecs(const char * label,PRFileDesc * fd)95 void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) {
96   PRCList *cur_p;
97 
98   sslSocket *ss = ssl_FindSocket(fd);
99   if (!ss) {
100     return;
101   }
102 
103   fprintf(stderr, "Cipher specs for %s\n", label);
104   for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
105        cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
106     ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p;
107     fprintf(stderr, "  %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec),
108             spec->epoch, spec->phase, spec->refCt);
109   }
110 }
111 
112 /* Force a timer expiry by backdating when all active timers were started. We
113  * could set the remaining time to 0 but then backoff would not work properly if
114  * we decide to test it. */
SSLInt_ShiftDtlsTimers(PRFileDesc * fd,PRIntervalTime shift)115 SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) {
116   size_t i;
117   sslSocket *ss = ssl_FindSocket(fd);
118   if (!ss) {
119     return SECFailure;
120   }
121 
122   for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
123     if (ss->ssl3.hs.timers[i].cb) {
124       ss->ssl3.hs.timers[i].started -= shift;
125     }
126   }
127   return SECSuccess;
128 }
129 
130 #define CHECK_SECRET(secret)                  \
131   if (ss->ssl3.hs.secret) {                   \
132     fprintf(stderr, "%s != NULL\n", #secret); \
133     return PR_FALSE;                          \
134   }
135 
SSLInt_CheckSecretsDestroyed(PRFileDesc * fd)136 PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) {
137   sslSocket *ss = ssl_FindSocket(fd);
138   if (!ss) {
139     return PR_FALSE;
140   }
141 
142   CHECK_SECRET(currentSecret);
143   CHECK_SECRET(dheSecret);
144   CHECK_SECRET(clientEarlyTrafficSecret);
145   CHECK_SECRET(clientHsTrafficSecret);
146   CHECK_SECRET(serverHsTrafficSecret);
147 
148   return PR_TRUE;
149 }
150 
sslint_DamageTrafficSecret(PRFileDesc * fd,size_t offset)151 PRBool sslint_DamageTrafficSecret(PRFileDesc *fd, size_t offset) {
152   unsigned char data[32] = {0};
153   PK11SymKey **keyPtr;
154   PK11SlotInfo *slot = PK11_GetInternalSlot();
155   SECItem key_item = {siBuffer, data, sizeof(data)};
156   sslSocket *ss = ssl_FindSocket(fd);
157   if (!ss) {
158     return PR_FALSE;
159   }
160   if (!slot) {
161     return PR_FALSE;
162   }
163   keyPtr = (PK11SymKey **)((char *)&ss->ssl3.hs + offset);
164   if (!*keyPtr) {
165     return PR_FALSE;
166   }
167   PK11_FreeSymKey(*keyPtr);
168   *keyPtr = PK11_ImportSymKey(slot, CKM_NSS_HKDF_SHA256, PK11_OriginUnwrap,
169                               CKA_DERIVE, &key_item, NULL);
170   PK11_FreeSlot(slot);
171   if (!*keyPtr) {
172     return PR_FALSE;
173   }
174 
175   return PR_TRUE;
176 }
177 
SSLInt_DamageClientHsTrafficSecret(PRFileDesc * fd)178 PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd) {
179   return sslint_DamageTrafficSecret(
180       fd, offsetof(SSL3HandshakeState, clientHsTrafficSecret));
181 }
182 
SSLInt_DamageServerHsTrafficSecret(PRFileDesc * fd)183 PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd) {
184   return sslint_DamageTrafficSecret(
185       fd, offsetof(SSL3HandshakeState, serverHsTrafficSecret));
186 }
187 
SSLInt_DamageEarlyTrafficSecret(PRFileDesc * fd)188 PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd) {
189   return sslint_DamageTrafficSecret(
190       fd, offsetof(SSL3HandshakeState, clientEarlyTrafficSecret));
191 }
192 
SSLInt_Set0RttAlpn(PRFileDesc * fd,PRUint8 * data,unsigned int len)193 SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) {
194   sslSocket *ss = ssl_FindSocket(fd);
195   if (!ss) {
196     return SECFailure;
197   }
198 
199   ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE;
200   if (ss->xtnData.nextProto.data) {
201     SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE);
202   }
203   if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) {
204     return SECFailure;
205   }
206   PORT_Memcpy(ss->xtnData.nextProto.data, data, len);
207 
208   return SECSuccess;
209 }
210 
SSLInt_HasCertWithAuthType(PRFileDesc * fd,SSLAuthType authType)211 PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) {
212   sslSocket *ss = ssl_FindSocket(fd);
213   if (!ss) {
214     return PR_FALSE;
215   }
216 
217   return (PRBool)(!!ssl_FindServerCert(ss, authType, NULL));
218 }
219 
SSLInt_SendAlert(PRFileDesc * fd,uint8_t level,uint8_t type)220 PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
221   sslSocket *ss = ssl_FindSocket(fd);
222   if (!ss) {
223     return PR_FALSE;
224   }
225 
226   SECStatus rv = SSL3_SendAlert(ss, level, type);
227   if (rv != SECSuccess) return PR_FALSE;
228 
229   return PR_TRUE;
230 }
231 
SSLInt_AdvanceReadSeqNum(PRFileDesc * fd,PRUint64 to)232 SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
233   sslSocket *ss;
234   ssl3CipherSpec *spec;
235 
236   ss = ssl_FindSocket(fd);
237   if (!ss) {
238     return SECFailure;
239   }
240   if (to >= RECORD_SEQ_MAX) {
241     PORT_SetError(SEC_ERROR_INVALID_ARGS);
242     return SECFailure;
243   }
244   ssl_GetSpecWriteLock(ss);
245   spec = ss->ssl3.crSpec;
246   spec->seqNum = to;
247 
248   /* For DTLS, we need to fix the record sequence number.  For this, we can just
249    * scrub the entire structure on the assumption that the new sequence number
250    * is far enough past the last received sequence number. */
251   if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
252     PORT_SetError(SEC_ERROR_INVALID_ARGS);
253     return SECFailure;
254   }
255   dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum);
256 
257   ssl_ReleaseSpecWriteLock(ss);
258   return SECSuccess;
259 }
260 
SSLInt_AdvanceWriteSeqNum(PRFileDesc * fd,PRUint64 to)261 SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
262   sslSocket *ss;
263 
264   ss = ssl_FindSocket(fd);
265   if (!ss) {
266     return SECFailure;
267   }
268   if (to >= RECORD_SEQ_MAX) {
269     PORT_SetError(SEC_ERROR_INVALID_ARGS);
270     return SECFailure;
271   }
272   ssl_GetSpecWriteLock(ss);
273   ss->ssl3.cwSpec->seqNum = to;
274   ssl_ReleaseSpecWriteLock(ss);
275   return SECSuccess;
276 }
277 
SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc * fd,PRInt32 extra)278 SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
279   sslSocket *ss;
280   sslSequenceNumber to;
281 
282   ss = ssl_FindSocket(fd);
283   if (!ss) {
284     return SECFailure;
285   }
286   ssl_GetSpecReadLock(ss);
287   to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra;
288   ssl_ReleaseSpecReadLock(ss);
289   return SSLInt_AdvanceWriteSeqNum(fd, to);
290 }
291 
SSLInt_GetKEAType(SSLNamedGroup group)292 SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
293   const sslNamedGroupDef *groupDef = ssl_LookupNamedGroup(group);
294   if (!groupDef) return ssl_kea_null;
295 
296   return groupDef->keaType;
297 }
298 
SSLInt_SetCipherSpecChangeFunc(PRFileDesc * fd,sslCipherSpecChangedFunc func,void * arg)299 SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
300                                          sslCipherSpecChangedFunc func,
301                                          void *arg) {
302   sslSocket *ss;
303 
304   ss = ssl_FindSocket(fd);
305   if (!ss) {
306     return SECFailure;
307   }
308 
309   ss->ssl3.changedCipherSpecFunc = func;
310   ss->ssl3.changedCipherSpecArg = arg;
311 
312   return SECSuccess;
313 }
314 
SSLInt_CipherSpecToKey(const ssl3CipherSpec * spec)315 PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) {
316   return spec->keyMaterial.key;
317 }
318 
SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec * spec)319 SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) {
320   return spec->cipherDef->calg;
321 }
322 
SSLInt_CipherSpecToIv(const ssl3CipherSpec * spec)323 const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) {
324   return spec->keyMaterial.iv;
325 }
326 
SSLInt_CipherSpecToEpoch(const ssl3CipherSpec * spec)327 PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) {
328   return spec->epoch;
329 }
330 
SSLInt_SetTicketLifetime(uint32_t lifetime)331 void SSLInt_SetTicketLifetime(uint32_t lifetime) {
332   ssl_ticket_lifetime = lifetime;
333 }
334 
SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc * fd,uint32_t size)335 SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
336   sslSocket *ss;
337 
338   ss = ssl_FindSocket(fd);
339   if (!ss) {
340     return SECFailure;
341   }
342 
343   /* This only works when resuming. */
344   if (!ss->statelessResume) {
345     PORT_SetError(SEC_INTERNAL_ONLY);
346     return SECFailure;
347   }
348 
349   /* Modifying both specs allows this to be used on either peer. */
350   ssl_GetSpecWriteLock(ss);
351   ss->ssl3.crSpec->earlyDataRemaining = size;
352   ss->ssl3.cwSpec->earlyDataRemaining = size;
353   ssl_ReleaseSpecWriteLock(ss);
354 
355   return SECSuccess;
356 }
357 
SSLInt_RolloverAntiReplay(void)358 void SSLInt_RolloverAntiReplay(void) {
359   tls13_AntiReplayRollover(ssl_TimeUsec());
360 }
361 
SSLInt_GetEpochs(PRFileDesc * fd,PRUint16 * readEpoch,PRUint16 * writeEpoch)362 SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch,
363                            PRUint16 *writeEpoch) {
364   sslSocket *ss = ssl_FindSocket(fd);
365   if (!ss || !readEpoch || !writeEpoch) {
366     return SECFailure;
367   }
368 
369   ssl_GetSpecReadLock(ss);
370   *readEpoch = ss->ssl3.crSpec->epoch;
371   *writeEpoch = ss->ssl3.cwSpec->epoch;
372   ssl_ReleaseSpecReadLock(ss);
373   return SECSuccess;
374 }
375