1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
2 /* This Source Code Form is subject to the terms of the Mozilla Public
3  * License, v. 2.0. If a copy of the MPL was not distributed with this
4  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
5 
6 /*
7  * DTLS Protocol
8  */
9 
10 #include "ssl.h"
11 #include "sslimpl.h"
12 #include "sslproto.h"
13 #include "dtls13con.h"
14 
15 #ifndef PR_ARRAY_SIZE
16 #define PR_ARRAY_SIZE(a) (sizeof(a) / sizeof((a)[0]))
17 #endif
18 
19 static SECStatus dtls_StartRetransmitTimer(sslSocket *ss);
20 static void dtls_RetransmitTimerExpiredCb(sslSocket *ss);
21 static SECStatus dtls_SendSavedWriteData(sslSocket *ss);
22 static void dtls_FinishedTimerCb(sslSocket *ss);
23 static void dtls_CancelAllTimers(sslSocket *ss);
24 
25 /* -28 adjusts for the IP/UDP header */
26 static const PRUint16 COMMON_MTU_VALUES[] = {
27     1500 - 28, /* Ethernet MTU */
28     1280 - 28, /* IPv6 minimum MTU */
29     576 - 28,  /* Common assumption */
30     256 - 28   /* We're in serious trouble now */
31 };
32 
33 #define DTLS_COOKIE_BYTES 32
34 /* Maximum DTLS expansion = header + IV + max CBC padding +
35  * maximum MAC. */
36 #define DTLS_MAX_EXPANSION (DTLS_RECORD_HEADER_LENGTH + 16 + 16 + 32)
37 
38 /* List copied from ssl3con.c:cipherSuites */
39 static const ssl3CipherSuite nonDTLSSuites[] = {
40     TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
41     TLS_ECDHE_RSA_WITH_RC4_128_SHA,
42     TLS_DHE_DSS_WITH_RC4_128_SHA,
43     TLS_ECDH_RSA_WITH_RC4_128_SHA,
44     TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
45     TLS_RSA_WITH_RC4_128_MD5,
46     TLS_RSA_WITH_RC4_128_SHA,
47     0 /* End of list marker */
48 };
49 
50 /* Map back and forth between TLS and DTLS versions in wire format.
51  * Mapping table is:
52  *
53  * TLS             DTLS
54  * 1.1 (0302)      1.0 (feff)
55  * 1.2 (0303)      1.2 (fefd)
56  * 1.3 (0304)      1.3 (0304)
57  */
58 SSL3ProtocolVersion
dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)59 dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)
60 {
61     if (tlsv == SSL_LIBRARY_VERSION_TLS_1_1) {
62         return SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
63     }
64     if (tlsv == SSL_LIBRARY_VERSION_TLS_1_2) {
65         return SSL_LIBRARY_VERSION_DTLS_1_2_WIRE;
66     }
67     if (tlsv == SSL_LIBRARY_VERSION_TLS_1_3) {
68         return SSL_LIBRARY_VERSION_DTLS_1_3_WIRE;
69     }
70 
71     /* Anything else is an error, so return
72      * the invalid version 0xffff. */
73     return 0xffff;
74 }
75 
76 /* Map known DTLS versions to known TLS versions.
77  * - Invalid versions (< 1.0) return a version of 0
78  * - Versions > known return a version one higher than we know of
79  * to accomodate a theoretically newer version */
80 SSL3ProtocolVersion
dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)81 dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)
82 {
83     if (MSB(dtlsv) == 0xff) {
84         return 0;
85     }
86 
87     if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_0_WIRE) {
88         return SSL_LIBRARY_VERSION_TLS_1_1;
89     }
90     /* Handle the skipped version of DTLS 1.1 by returning
91      * an error. */
92     if (dtlsv == ((~0x0101) & 0xffff)) {
93         return 0;
94     }
95     if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_2_WIRE) {
96         return SSL_LIBRARY_VERSION_TLS_1_2;
97     }
98     if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_3_WIRE) {
99         return SSL_LIBRARY_VERSION_TLS_1_3;
100     }
101 
102     /* Return a fictional higher version than we know of */
103     return SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1;
104 }
105 
106 /* On this socket, Disable non-DTLS cipher suites in the argument's list */
107 SECStatus
ssl3_DisableNonDTLSSuites(sslSocket * ss)108 ssl3_DisableNonDTLSSuites(sslSocket *ss)
109 {
110     const ssl3CipherSuite *suite;
111 
112     for (suite = nonDTLSSuites; *suite; ++suite) {
113         PORT_CheckSuccess(ssl3_CipherPrefSet(ss, *suite, PR_FALSE));
114     }
115     return SECSuccess;
116 }
117 
118 /* Allocate a DTLSQueuedMessage.
119  *
120  * Called from dtls_QueueMessage()
121  */
122 static DTLSQueuedMessage *
dtls_AllocQueuedMessage(ssl3CipherSpec * cwSpec,SSLContentType ct,const unsigned char * data,PRUint32 len)123 dtls_AllocQueuedMessage(ssl3CipherSpec *cwSpec, SSLContentType ct,
124                         const unsigned char *data, PRUint32 len)
125 {
126     DTLSQueuedMessage *msg;
127 
128     msg = PORT_ZNew(DTLSQueuedMessage);
129     if (!msg)
130         return NULL;
131 
132     msg->data = PORT_Alloc(len);
133     if (!msg->data) {
134         PORT_Free(msg);
135         return NULL;
136     }
137     PORT_Memcpy(msg->data, data, len);
138 
139     msg->len = len;
140     msg->cwSpec = cwSpec;
141     msg->type = ct;
142     /* Safe if we are < 1.3, since the refct is
143      * already very high. */
144     ssl_CipherSpecAddRef(cwSpec);
145 
146     return msg;
147 }
148 
149 /*
150  * Free a handshake message
151  *
152  * Called from dtls_FreeHandshakeMessages()
153  */
154 void
dtls_FreeHandshakeMessage(DTLSQueuedMessage * msg)155 dtls_FreeHandshakeMessage(DTLSQueuedMessage *msg)
156 {
157     if (!msg)
158         return;
159 
160     /* Safe if we are < 1.3, since the refct is
161      * already very high. */
162     ssl_CipherSpecRelease(msg->cwSpec);
163     PORT_ZFree(msg->data, msg->len);
164     PORT_Free(msg);
165 }
166 
167 /*
168  * Free a list of handshake messages
169  *
170  * Called from:
171  *              dtls_HandleHandshake()
172  *              ssl3_DestroySSL3Info()
173  */
174 void
dtls_FreeHandshakeMessages(PRCList * list)175 dtls_FreeHandshakeMessages(PRCList *list)
176 {
177     PRCList *cur_p;
178 
179     while (!PR_CLIST_IS_EMPTY(list)) {
180         cur_p = PR_LIST_TAIL(list);
181         PR_REMOVE_LINK(cur_p);
182         dtls_FreeHandshakeMessage((DTLSQueuedMessage *)cur_p);
183     }
184 }
185 
186 /* Called by dtls_HandleHandshake() and dtls_MaybeRetransmitHandshake() if a
187  * handshake message retransmission is detected. */
188 static SECStatus
dtls_RetransmitDetected(sslSocket * ss)189 dtls_RetransmitDetected(sslSocket *ss)
190 {
191     dtlsTimer *timer = ss->ssl3.hs.rtTimer;
192     SECStatus rv = SECSuccess;
193 
194     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
195     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
196 
197     if (timer->cb == dtls_RetransmitTimerExpiredCb) {
198         /* Check to see if we retransmitted recently. If so,
199          * suppress the triggered retransmit. This avoids
200          * retransmit wars after packet loss.
201          * This is not in RFC 5346 but it should be.
202          */
203         if ((PR_IntervalNow() - timer->started) >
204             (timer->timeout / 4)) {
205             SSL_TRC(30,
206                     ("%d: SSL3[%d]: Shortcutting retransmit timer",
207                      SSL_GETPID(), ss->fd));
208 
209             /* Cancel the timer and call the CB,
210              * which re-arms the timer */
211             dtls_CancelTimer(ss, ss->ssl3.hs.rtTimer);
212             dtls_RetransmitTimerExpiredCb(ss);
213         } else {
214             SSL_TRC(30,
215                     ("%d: SSL3[%d]: Ignoring retransmission: "
216                      "last retransmission %dms ago, suppressed for %dms",
217                      SSL_GETPID(), ss->fd,
218                      PR_IntervalNow() - timer->started,
219                      timer->timeout / 4));
220         }
221 
222     } else if (timer->cb == dtls_FinishedTimerCb) {
223         SSL_TRC(30, ("%d: SSL3[%d]: Retransmit detected in holddown",
224                      SSL_GETPID(), ss->fd));
225         /* Retransmit the messages and re-arm the timer
226          * Note that we are not backing off the timer here.
227          * The spec isn't clear and my reasoning is that this
228          * may be a re-ordered packet rather than slowness,
229          * so let's be aggressive. */
230         dtls_CancelTimer(ss, ss->ssl3.hs.rtTimer);
231         rv = dtls_TransmitMessageFlight(ss);
232         if (rv == SECSuccess) {
233             rv = dtls_StartHolddownTimer(ss);
234         }
235 
236     } else {
237         PORT_Assert(timer->cb == NULL);
238         /* ... and ignore it. */
239     }
240     return rv;
241 }
242 
243 static SECStatus
dtls_HandleHandshakeMessage(sslSocket * ss,PRUint8 * data,PRBool last)244 dtls_HandleHandshakeMessage(sslSocket *ss, PRUint8 *data, PRBool last)
245 {
246     ss->ssl3.hs.recvdHighWater = -1;
247 
248     return ssl3_HandleHandshakeMessage(ss, data, ss->ssl3.hs.msg_len,
249                                        last);
250 }
251 
252 /* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record.
253  * origBuf is the decrypted ssl record content and is expected to contain
254  * complete handshake records
255  * Caller must hold the handshake and RecvBuf locks.
256  *
257  * Note that this code uses msg_len for two purposes:
258  *
259  * (1) To pass the length to ssl3_HandleHandshakeMessage()
260  * (2) To carry the length of a message currently being reassembled
261  *
262  * However, unlike ssl3_HandleHandshake(), it is not used to carry
263  * the state of reassembly (i.e., whether one is in progress). That
264  * is carried in recvdHighWater and recvdFragments.
265  */
266 #define OFFSET_BYTE(o) (o / 8)
267 #define OFFSET_MASK(o) (1 << (o % 8))
268 
269 SECStatus
dtls_HandleHandshake(sslSocket * ss,DTLSEpoch epoch,sslSequenceNumber seqNum,sslBuffer * origBuf)270 dtls_HandleHandshake(sslSocket *ss, DTLSEpoch epoch, sslSequenceNumber seqNum,
271                      sslBuffer *origBuf)
272 {
273     sslBuffer buf = *origBuf;
274     SECStatus rv = SECSuccess;
275     PRBool discarded = PR_FALSE;
276 
277     ss->ssl3.hs.endOfFlight = PR_FALSE;
278 
279     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
280     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
281 
282     while (buf.len > 0) {
283         PRUint8 type;
284         PRUint32 message_length;
285         PRUint16 message_seq;
286         PRUint32 fragment_offset;
287         PRUint32 fragment_length;
288         PRUint32 offset;
289 
290         if (buf.len < 12) {
291             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
292             rv = SECFailure;
293             goto loser;
294         }
295 
296         /* Parse the header */
297         type = buf.buf[0];
298         message_length = (buf.buf[1] << 16) | (buf.buf[2] << 8) | buf.buf[3];
299         message_seq = (buf.buf[4] << 8) | buf.buf[5];
300         fragment_offset = (buf.buf[6] << 16) | (buf.buf[7] << 8) | buf.buf[8];
301         fragment_length = (buf.buf[9] << 16) | (buf.buf[10] << 8) | buf.buf[11];
302 
303 #define MAX_HANDSHAKE_MSG_LEN 0x1ffff /* 128k - 1 */
304         if (message_length > MAX_HANDSHAKE_MSG_LEN) {
305             (void)ssl3_DecodeError(ss);
306             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
307             rv = SECFailure;
308             goto loser;
309         }
310 #undef MAX_HANDSHAKE_MSG_LEN
311 
312         buf.buf += 12;
313         buf.len -= 12;
314 
315         /* This fragment must be complete */
316         if (buf.len < fragment_length) {
317             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
318             rv = SECFailure;
319             goto loser;
320         }
321 
322         /* Sanity check the packet contents */
323         if ((fragment_length + fragment_offset) > message_length) {
324             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
325             rv = SECFailure;
326             goto loser;
327         }
328 
329         /* If we're a server and we receive what appears to be a retried
330          * ClientHello, and we are expecting a ClientHello, move the receive
331          * sequence number forward.  This allows for a retried ClientHello if we
332          * send a stateless HelloRetryRequest. */
333         if (message_seq > ss->ssl3.hs.recvMessageSeq &&
334             message_seq == 1 &&
335             fragment_offset == 0 &&
336             ss->ssl3.hs.ws == wait_client_hello &&
337             (SSLHandshakeType)type == ssl_hs_client_hello) {
338             SSL_TRC(5, ("%d: DTLS[%d]: Received apparent 2nd ClientHello",
339                         SSL_GETPID(), ss->fd));
340             ss->ssl3.hs.recvMessageSeq = 1;
341             ss->ssl3.hs.helloRetry = PR_TRUE;
342         }
343 
344         /* There are three ways we could not be ready for this packet.
345          *
346          * 1. It's a partial next message.
347          * 2. It's a partial or complete message beyond the next
348          * 3. It's a message we've already seen
349          *
350          * If it's the complete next message we accept it right away.
351          * This is the common case for short messages
352          */
353         if ((message_seq == ss->ssl3.hs.recvMessageSeq) &&
354             (fragment_offset == 0) &&
355             (fragment_length == message_length)) {
356             /* Complete next message. Process immediately */
357             ss->ssl3.hs.msg_type = (SSLHandshakeType)type;
358             ss->ssl3.hs.msg_len = message_length;
359 
360             rv = dtls_HandleHandshakeMessage(ss, buf.buf,
361                                              buf.len == fragment_length);
362             if (rv == SECFailure) {
363                 goto loser;
364             }
365         } else {
366             if (message_seq < ss->ssl3.hs.recvMessageSeq) {
367                 /* Case 3: we do an immediate retransmit if we're
368                  * in a waiting state. */
369                 rv = dtls_RetransmitDetected(ss);
370                 goto loser;
371             } else if (message_seq > ss->ssl3.hs.recvMessageSeq) {
372                 /* Case 2
373                  *
374                  * Ignore this message. This means we don't handle out of
375                  * order complete messages that well, but we're still
376                  * compliant and this probably does not happen often
377                  *
378                  * XXX OK for now. Maybe do something smarter at some point?
379                  */
380                 SSL_TRC(10, ("%d: SSL3[%d]: dtls_HandleHandshake, discarding handshake message",
381                              SSL_GETPID(), ss->fd));
382                 discarded = PR_TRUE;
383             } else {
384                 PRInt32 end = fragment_offset + fragment_length;
385 
386                 /* Case 1
387                  *
388                  * Buffer the fragment for reassembly
389                  */
390                 /* Make room for the message */
391                 if (ss->ssl3.hs.recvdHighWater == -1) {
392                     PRUint32 map_length = OFFSET_BYTE(message_length) + 1;
393 
394                     rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, message_length);
395                     if (rv != SECSuccess)
396                         goto loser;
397                     /* Make room for the fragment map */
398                     rv = sslBuffer_Grow(&ss->ssl3.hs.recvdFragments,
399                                         map_length);
400                     if (rv != SECSuccess)
401                         goto loser;
402 
403                     /* Reset the reassembly map */
404                     ss->ssl3.hs.recvdHighWater = 0;
405                     PORT_Memset(ss->ssl3.hs.recvdFragments.buf, 0,
406                                 ss->ssl3.hs.recvdFragments.space);
407                     ss->ssl3.hs.msg_type = (SSLHandshakeType)type;
408                     ss->ssl3.hs.msg_len = message_length;
409                 }
410 
411                 /* If we have a message length mismatch, abandon the reassembly
412                  * in progress and hope that the next retransmit will give us
413                  * something sane
414                  */
415                 if (message_length != ss->ssl3.hs.msg_len) {
416                     ss->ssl3.hs.recvdHighWater = -1;
417                     PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
418                     rv = SECFailure;
419                     goto loser;
420                 }
421 
422                 /* Now copy this fragment into the buffer. */
423                 if (end > ss->ssl3.hs.recvdHighWater) {
424                     PORT_Memcpy(ss->ssl3.hs.msg_body.buf + fragment_offset,
425                                 buf.buf, fragment_length);
426                 }
427 
428                 /* This logic is a bit tricky. We have two values for
429                  * reassembly state:
430                  *
431                  * - recvdHighWater contains the highest contiguous number of
432                  *   bytes received
433                  * - recvdFragments contains a bitmask of packets received
434                  *   above recvdHighWater
435                  *
436                  * This avoids having to fill in the bitmask in the common
437                  * case of adjacent fragments received in sequence
438                  */
439                 if (fragment_offset <= (unsigned int)ss->ssl3.hs.recvdHighWater) {
440                     /* Either this is the adjacent fragment or an overlapping
441                      * fragment */
442                     if (end > ss->ssl3.hs.recvdHighWater) {
443                         ss->ssl3.hs.recvdHighWater = end;
444                     }
445                 } else {
446                     for (offset = fragment_offset; offset < end; offset++) {
447                         ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] |=
448                             OFFSET_MASK(offset);
449                     }
450                 }
451 
452                 /* Now figure out the new high water mark if appropriate */
453                 for (offset = ss->ssl3.hs.recvdHighWater;
454                      offset < ss->ssl3.hs.msg_len; offset++) {
455                     /* Note that this loop is not efficient, since it counts
456                      * bit by bit. If we have a lot of out-of-order packets,
457                      * we should optimize this */
458                     if (ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] &
459                         OFFSET_MASK(offset)) {
460                         ss->ssl3.hs.recvdHighWater++;
461                     } else {
462                         break;
463                     }
464                 }
465 
466                 /* If we have all the bytes, then we are good to go */
467                 if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) {
468                     rv = dtls_HandleHandshakeMessage(ss, ss->ssl3.hs.msg_body.buf,
469                                                      buf.len == fragment_length);
470 
471                     if (rv == SECFailure) {
472                         goto loser;
473                     }
474                 }
475             }
476         }
477 
478         buf.buf += fragment_length;
479         buf.len -= fragment_length;
480     }
481 
482     // This should never happen, but belt and suspenders.
483     if (rv != SECSuccess) {
484         PORT_Assert(0);
485         goto loser;
486     }
487 
488     /* If we processed all the fragments in this message, then mark it as remembered.
489      * TODO(ekr@rtfm.com): Store out of order messages for DTLS 1.3 so ACKs work
490      * better. Bug 1392620.*/
491     if (!discarded && tls13_MaybeTls13(ss)) {
492         rv = dtls13_RememberFragment(ss, &ss->ssl3.hs.dtlsRcvdHandshake,
493                                      0, 0, 0, epoch, seqNum);
494     }
495     if (rv != SECSuccess) {
496         goto loser;
497     }
498 
499     rv = dtls13_SetupAcks(ss);
500 
501 loser:
502     origBuf->len = 0; /* So ssl3_GatherAppDataRecord will keep looping. */
503     return rv;
504 }
505 
506 /* Enqueue a message (either handshake or CCS)
507  *
508  * Called from:
509  *              dtls_StageHandshakeMessage()
510  *              ssl3_SendChangeCipherSpecs()
511  */
512 SECStatus
dtls_QueueMessage(sslSocket * ss,SSLContentType ct,const PRUint8 * pIn,PRInt32 nIn)513 dtls_QueueMessage(sslSocket *ss, SSLContentType ct,
514                   const PRUint8 *pIn, PRInt32 nIn)
515 {
516     SECStatus rv = SECSuccess;
517     DTLSQueuedMessage *msg = NULL;
518     ssl3CipherSpec *spec;
519 
520     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
521     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
522 
523     spec = ss->ssl3.cwSpec;
524     msg = dtls_AllocQueuedMessage(spec, ct, pIn, nIn);
525 
526     if (!msg) {
527         PORT_SetError(SEC_ERROR_NO_MEMORY);
528         rv = SECFailure;
529     } else {
530         PR_APPEND_LINK(&msg->link, &ss->ssl3.hs.lastMessageFlight);
531     }
532 
533     return rv;
534 }
535 
536 /* Add DTLS handshake message to the pending queue
537  * Empty the sendBuf buffer.
538  * Always set sendBuf.len to 0, even when returning SECFailure.
539  *
540  * Called from:
541  *              ssl3_AppendHandshakeHeader()
542  *              dtls_FlushHandshake()
543  */
544 SECStatus
dtls_StageHandshakeMessage(sslSocket * ss)545 dtls_StageHandshakeMessage(sslSocket *ss)
546 {
547     SECStatus rv = SECSuccess;
548 
549     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
550     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
551 
552     /* This function is sometimes called when no data is actually to
553      * be staged, so just return SECSuccess. */
554     if (!ss->sec.ci.sendBuf.buf || !ss->sec.ci.sendBuf.len)
555         return rv;
556 
557     rv = dtls_QueueMessage(ss, ssl_ct_handshake,
558                            ss->sec.ci.sendBuf.buf, ss->sec.ci.sendBuf.len);
559 
560     /* Whether we succeeded or failed, toss the old handshake data. */
561     ss->sec.ci.sendBuf.len = 0;
562     return rv;
563 }
564 
565 /* Enqueue the handshake message in sendBuf (if any) and then
566  * transmit the resulting flight of handshake messages.
567  *
568  * Called from:
569  *              ssl3_FlushHandshake()
570  */
571 SECStatus
dtls_FlushHandshakeMessages(sslSocket * ss,PRInt32 flags)572 dtls_FlushHandshakeMessages(sslSocket *ss, PRInt32 flags)
573 {
574     SECStatus rv = SECSuccess;
575 
576     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
577     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
578 
579     rv = dtls_StageHandshakeMessage(ss);
580     if (rv != SECSuccess)
581         return rv;
582 
583     if (!(flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) {
584         rv = dtls_TransmitMessageFlight(ss);
585         if (rv != SECSuccess) {
586             return rv;
587         }
588 
589         if (!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)) {
590             rv = dtls_StartRetransmitTimer(ss);
591         } else {
592             PORT_Assert(ss->version < SSL_LIBRARY_VERSION_TLS_1_3);
593         }
594     }
595 
596     return rv;
597 }
598 
599 /* The callback for when the retransmit timer expires
600  *
601  * Called from:
602  *              dtls_CheckTimer()
603  *              dtls_HandleHandshake()
604  */
605 static void
dtls_RetransmitTimerExpiredCb(sslSocket * ss)606 dtls_RetransmitTimerExpiredCb(sslSocket *ss)
607 {
608     SECStatus rv;
609     dtlsTimer *timer = ss->ssl3.hs.rtTimer;
610     ss->ssl3.hs.rtRetries++;
611 
612     if (!(ss->ssl3.hs.rtRetries % 3)) {
613         /* If one of the messages was potentially greater than > MTU,
614          * then downgrade. Do this every time we have retransmitted a
615          * message twice, per RFC 6347 Sec. 4.1.1 */
616         dtls_SetMTU(ss, ss->ssl3.hs.maxMessageSent - 1);
617     }
618 
619     rv = dtls_TransmitMessageFlight(ss);
620     if (rv == SECSuccess) {
621         /* Re-arm the timer */
622         timer->timeout *= 2;
623         if (timer->timeout > DTLS_RETRANSMIT_MAX_MS) {
624             timer->timeout = DTLS_RETRANSMIT_MAX_MS;
625         }
626 
627         timer->started = PR_IntervalNow();
628         timer->cb = dtls_RetransmitTimerExpiredCb;
629 
630         SSL_TRC(30,
631                 ("%d: SSL3[%d]: Retransmit #%d, next in %d",
632                  SSL_GETPID(), ss->fd,
633                  ss->ssl3.hs.rtRetries, timer->timeout));
634     }
635     /* else: OK for now. In future maybe signal the stack that we couldn't
636      * transmit. For now, let the read handle any real network errors */
637 }
638 
639 #define DTLS_HS_HDR_LEN 12
640 #define DTLS_MIN_FRAGMENT (DTLS_HS_HDR_LEN + 1 + DTLS_MAX_EXPANSION)
641 
642 /* Encrypt and encode a handshake message fragment.  Flush the data out to the
643  * network if there is insufficient space for any fragment. */
644 static SECStatus
dtls_SendFragment(sslSocket * ss,DTLSQueuedMessage * msg,PRUint8 * data,unsigned int len)645 dtls_SendFragment(sslSocket *ss, DTLSQueuedMessage *msg, PRUint8 *data,
646                   unsigned int len)
647 {
648     PRInt32 sent;
649     SECStatus rv;
650 
651     PRINT_BUF(40, (ss, "dtls_SendFragment", data, len));
652     sent = ssl3_SendRecord(ss, msg->cwSpec, msg->type, data, len,
653                            ssl_SEND_FLAG_FORCE_INTO_BUFFER);
654     if (sent != len) {
655         if (sent != -1) {
656             PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
657         }
658         return SECFailure;
659     }
660 
661     /* If another fragment won't fit, flush. */
662     if (ss->ssl3.mtu < ss->pendingBuf.len + DTLS_MIN_FRAGMENT) {
663         SSL_TRC(20, ("%d: DTLS[%d]: dtls_SendFragment: flush",
664                      SSL_GETPID(), ss->fd));
665         rv = dtls_SendSavedWriteData(ss);
666         if (rv != SECSuccess) {
667             return SECFailure;
668         }
669     }
670     return SECSuccess;
671 }
672 
673 /* Fragment a handshake message into multiple records and send them. */
674 static SECStatus
dtls_FragmentHandshake(sslSocket * ss,DTLSQueuedMessage * msg)675 dtls_FragmentHandshake(sslSocket *ss, DTLSQueuedMessage *msg)
676 {
677     PRBool fragmentWritten = PR_FALSE;
678     PRUint16 msgSeq;
679     PRUint8 *fragment;
680     PRUint32 fragmentOffset = 0;
681     PRUint32 fragmentLen;
682     const PRUint8 *content = msg->data + DTLS_HS_HDR_LEN;
683     PRUint32 contentLen = msg->len - DTLS_HS_HDR_LEN;
684     SECStatus rv;
685 
686     /* The headers consume 12 bytes so the smallest possible message (i.e., an
687      * empty one) is 12 bytes. */
688     PORT_Assert(msg->len >= DTLS_HS_HDR_LEN);
689 
690     /* DTLS only supports fragmenting handshaking messages. */
691     PORT_Assert(msg->type == ssl_ct_handshake);
692 
693     msgSeq = (msg->data[4] << 8) | msg->data[5];
694 
695     /* do {} while() so that empty messages are sent at least once. */
696     do {
697         PRUint8 buf[DTLS_MAX_MTU]; /* >= than largest plausible MTU */
698         PRBool hasUnackedRange;
699         PRUint32 end;
700 
701         hasUnackedRange = dtls_NextUnackedRange(ss, msgSeq,
702                                                 fragmentOffset, contentLen,
703                                                 &fragmentOffset, &end);
704         if (!hasUnackedRange) {
705             SSL_TRC(20, ("%d: SSL3[%d]: FragmentHandshake %d: all acknowledged",
706                          SSL_GETPID(), ss->fd, msgSeq));
707             break;
708         }
709 
710         SSL_TRC(20, ("%d: SSL3[%d]: FragmentHandshake %d: unacked=%u-%u",
711                      SSL_GETPID(), ss->fd, msgSeq, fragmentOffset, end));
712 
713         /* Cut down to the data we have available. */
714         PORT_Assert(fragmentOffset <= contentLen);
715         PORT_Assert(fragmentOffset <= end);
716         PORT_Assert(end <= contentLen);
717         fragmentLen = PR_MIN(end, contentLen) - fragmentOffset;
718 
719         /* Limit further by the record size limit.  Account for the header. */
720         fragmentLen = PR_MIN(fragmentLen,
721                              msg->cwSpec->recordSizeLimit - DTLS_HS_HDR_LEN);
722 
723         /* Reduce to the space remaining in the MTU. */
724         fragmentLen = PR_MIN(fragmentLen,
725                              ss->ssl3.mtu -           /* MTU estimate. */
726                                  ss->pendingBuf.len - /* Less any unsent records. */
727                                  DTLS_MAX_EXPANSION - /* Allow for expansion. */
728                                  DTLS_HS_HDR_LEN);    /* And the handshake header. */
729         PORT_Assert(fragmentLen > 0 || fragmentOffset == 0);
730 
731         /* Make totally sure that we will fit in the buffer. This should be
732          * impossible; DTLS_MAX_MTU should always be more than ss->ssl3.mtu. */
733         if (fragmentLen >= (DTLS_MAX_MTU - DTLS_HS_HDR_LEN)) {
734             PORT_Assert(0);
735             PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
736             return SECFailure;
737         }
738 
739         if (fragmentLen == contentLen) {
740             fragment = msg->data;
741         } else {
742             sslBuffer tmp = SSL_BUFFER_FIXED(buf, sizeof(buf));
743 
744             /* Construct an appropriate-sized fragment */
745             /* Type, length, sequence */
746             rv = sslBuffer_Append(&tmp, msg->data, 6);
747             if (rv != SECSuccess) {
748                 return SECFailure;
749             }
750             /* Offset. */
751             rv = sslBuffer_AppendNumber(&tmp, fragmentOffset, 3);
752             if (rv != SECSuccess) {
753                 return SECFailure;
754             }
755             /* Length. */
756             rv = sslBuffer_AppendNumber(&tmp, fragmentLen, 3);
757             if (rv != SECSuccess) {
758                 return SECFailure;
759             }
760             /* Data. */
761             rv = sslBuffer_Append(&tmp, content + fragmentOffset, fragmentLen);
762             if (rv != SECSuccess) {
763                 return SECFailure;
764             }
765 
766             fragment = SSL_BUFFER_BASE(&tmp);
767         }
768 
769         /* Record that we are sending first, because encrypting
770          * increments the sequence number. */
771         rv = dtls13_RememberFragment(ss, &ss->ssl3.hs.dtlsSentHandshake,
772                                      msgSeq, fragmentOffset, fragmentLen,
773                                      msg->cwSpec->epoch,
774                                      msg->cwSpec->nextSeqNum);
775         if (rv != SECSuccess) {
776             return SECFailure;
777         }
778 
779         rv = dtls_SendFragment(ss, msg, fragment,
780                                fragmentLen + DTLS_HS_HDR_LEN);
781         if (rv != SECSuccess) {
782             return SECFailure;
783         }
784 
785         fragmentWritten = PR_TRUE;
786         fragmentOffset += fragmentLen;
787     } while (fragmentOffset < contentLen);
788 
789     if (!fragmentWritten) {
790         /* Nothing was written if we got here, so the whole message must have
791          * been acknowledged.  Discard it. */
792         SSL_TRC(10, ("%d: SSL3[%d]: FragmentHandshake %d: removed",
793                      SSL_GETPID(), ss->fd, msgSeq));
794         PR_REMOVE_LINK(&msg->link);
795         dtls_FreeHandshakeMessage(msg);
796     }
797 
798     return SECSuccess;
799 }
800 
801 /* Transmit a flight of handshake messages, stuffing them
802  * into as few records as seems reasonable.
803  *
804  * TODO: Space separate UDP packets out a little.
805  *
806  * Called from:
807  *             dtls_FlushHandshake()
808  *             dtls_RetransmitTimerExpiredCb()
809  */
810 SECStatus
dtls_TransmitMessageFlight(sslSocket * ss)811 dtls_TransmitMessageFlight(sslSocket *ss)
812 {
813     SECStatus rv = SECSuccess;
814     PRCList *msg_p;
815 
816     SSL_TRC(10, ("%d: SSL3[%d]: dtls_TransmitMessageFlight",
817                  SSL_GETPID(), ss->fd));
818 
819     ssl_GetXmitBufLock(ss);
820     ssl_GetSpecReadLock(ss);
821 
822     /* DTLS does not buffer its handshake messages in ss->pendingBuf, but rather
823      * in the lastMessageFlight structure. This is just a sanity check that some
824      * programming error hasn't inadvertantly stuffed something in
825      * ss->pendingBuf.  This function uses ss->pendingBuf temporarily and it
826      * needs to be empty to start.
827      */
828     PORT_Assert(!ss->pendingBuf.len);
829 
830     for (msg_p = PR_LIST_HEAD(&ss->ssl3.hs.lastMessageFlight);
831          msg_p != &ss->ssl3.hs.lastMessageFlight;) {
832         DTLSQueuedMessage *msg = (DTLSQueuedMessage *)msg_p;
833 
834         /* Move the pointer forward so that the functions below are free to
835          * remove messages from the list. */
836         msg_p = PR_NEXT_LINK(msg_p);
837 
838         /* Note: This function fragments messages so that each record is close
839          * to full.  This produces fewer records, but it means that messages can
840          * be quite fragmented.  Adding an extra flush here would push new
841          * messages into new records and reduce fragmentation. */
842 
843         if (msg->type == ssl_ct_handshake) {
844             rv = dtls_FragmentHandshake(ss, msg);
845         } else {
846             PORT_Assert(!tls13_MaybeTls13(ss));
847             rv = dtls_SendFragment(ss, msg, msg->data, msg->len);
848         }
849         if (rv != SECSuccess) {
850             break;
851         }
852     }
853 
854     /* Finally, flush any data that wasn't flushed already. */
855     if (rv == SECSuccess) {
856         rv = dtls_SendSavedWriteData(ss);
857     }
858 
859     /* Give up the locks */
860     ssl_ReleaseSpecReadLock(ss);
861     ssl_ReleaseXmitBufLock(ss);
862 
863     return rv;
864 }
865 
866 /* Flush the data in the pendingBuf and update the max message sent
867  * so we can adjust the MTU estimate if we need to.
868  * Wrapper for ssl_SendSavedWriteData.
869  *
870  * Called from dtls_TransmitMessageFlight()
871  */
872 static SECStatus
dtls_SendSavedWriteData(sslSocket * ss)873 dtls_SendSavedWriteData(sslSocket *ss)
874 {
875     PRInt32 sent;
876 
877     sent = ssl_SendSavedWriteData(ss);
878     if (sent < 0)
879         return SECFailure;
880 
881     /* We should always have complete writes b/c datagram sockets
882      * don't really block */
883     if (ss->pendingBuf.len > 0) {
884         ssl_MapLowLevelError(SSL_ERROR_SOCKET_WRITE_FAILURE);
885         return SECFailure;
886     }
887 
888     /* Update the largest message sent so we can adjust the MTU
889      * estimate if necessary */
890     if (sent > ss->ssl3.hs.maxMessageSent)
891         ss->ssl3.hs.maxMessageSent = sent;
892 
893     return SECSuccess;
894 }
895 
896 void
dtls_InitTimers(sslSocket * ss)897 dtls_InitTimers(sslSocket *ss)
898 {
899     unsigned int i;
900     dtlsTimer **timers[PR_ARRAY_SIZE(ss->ssl3.hs.timers)] = {
901         &ss->ssl3.hs.rtTimer,
902         &ss->ssl3.hs.ackTimer,
903         &ss->ssl3.hs.hdTimer
904     };
905     static const char *timerLabels[] = {
906         "retransmit", "ack", "holddown"
907     };
908 
909     PORT_Assert(PR_ARRAY_SIZE(timers) == PR_ARRAY_SIZE(timerLabels));
910     for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
911         *timers[i] = &ss->ssl3.hs.timers[i];
912         ss->ssl3.hs.timers[i].label = timerLabels[i];
913     }
914 }
915 
916 SECStatus
dtls_StartTimer(sslSocket * ss,dtlsTimer * timer,PRUint32 time,DTLSTimerCb cb)917 dtls_StartTimer(sslSocket *ss, dtlsTimer *timer, PRUint32 time, DTLSTimerCb cb)
918 {
919     PORT_Assert(timer->cb == NULL);
920 
921     SSL_TRC(10, ("%d: SSL3[%d]: %s dtls_StartTimer %s timeout=%d",
922                  SSL_GETPID(), ss->fd, SSL_ROLE(ss), timer->label, time));
923 
924     timer->started = PR_IntervalNow();
925     timer->timeout = time;
926     timer->cb = cb;
927     return SECSuccess;
928 }
929 
930 SECStatus
dtls_RestartTimer(sslSocket * ss,dtlsTimer * timer)931 dtls_RestartTimer(sslSocket *ss, dtlsTimer *timer)
932 {
933     timer->started = PR_IntervalNow();
934     return SECSuccess;
935 }
936 
937 PRBool
dtls_TimerActive(sslSocket * ss,dtlsTimer * timer)938 dtls_TimerActive(sslSocket *ss, dtlsTimer *timer)
939 {
940     return timer->cb != NULL;
941 }
942 /* Start a timer for retransmission. */
943 static SECStatus
dtls_StartRetransmitTimer(sslSocket * ss)944 dtls_StartRetransmitTimer(sslSocket *ss)
945 {
946     ss->ssl3.hs.rtRetries = 0;
947     return dtls_StartTimer(ss, ss->ssl3.hs.rtTimer,
948                            DTLS_RETRANSMIT_INITIAL_MS,
949                            dtls_RetransmitTimerExpiredCb);
950 }
951 
952 /* Start a timer for holding an old cipher spec. */
953 SECStatus
dtls_StartHolddownTimer(sslSocket * ss)954 dtls_StartHolddownTimer(sslSocket *ss)
955 {
956     ss->ssl3.hs.rtRetries = 0;
957     return dtls_StartTimer(ss, ss->ssl3.hs.rtTimer,
958                            DTLS_RETRANSMIT_FINISHED_MS,
959                            dtls_FinishedTimerCb);
960 }
961 
962 /* Cancel a pending timer
963  *
964  * Called from:
965  *              dtls_HandleHandshake()
966  *              dtls_CheckTimer()
967  */
968 void
dtls_CancelTimer(sslSocket * ss,dtlsTimer * timer)969 dtls_CancelTimer(sslSocket *ss, dtlsTimer *timer)
970 {
971     SSL_TRC(30, ("%d: SSL3[%d]: %s dtls_CancelTimer %s",
972                  SSL_GETPID(), ss->fd, SSL_ROLE(ss),
973                  timer->label));
974 
975     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
976 
977     timer->cb = NULL;
978 }
979 
980 static void
dtls_CancelAllTimers(sslSocket * ss)981 dtls_CancelAllTimers(sslSocket *ss)
982 {
983     unsigned int i;
984 
985     for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
986         dtls_CancelTimer(ss, &ss->ssl3.hs.timers[i]);
987     }
988 }
989 
990 /* Check the pending timer and fire the callback if it expired
991  *
992  * Called from ssl3_GatherCompleteHandshake()
993  */
994 void
dtls_CheckTimer(sslSocket * ss)995 dtls_CheckTimer(sslSocket *ss)
996 {
997     unsigned int i;
998     SSL_TRC(30, ("%d: SSL3[%d]: dtls_CheckTimer (%s)",
999                  SSL_GETPID(), ss->fd, ss->sec.isServer ? "server" : "client"));
1000 
1001     ssl_GetSSL3HandshakeLock(ss);
1002 
1003     for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
1004         dtlsTimer *timer = &ss->ssl3.hs.timers[i];
1005         if (!timer->cb) {
1006             continue;
1007         }
1008 
1009         if ((PR_IntervalNow() - timer->started) >=
1010             PR_MillisecondsToInterval(timer->timeout)) {
1011             /* Timer has expired */
1012             DTLSTimerCb cb = timer->cb;
1013 
1014             SSL_TRC(10, ("%d: SSL3[%d]: %s firing timer %s",
1015                          SSL_GETPID(), ss->fd, SSL_ROLE(ss),
1016                          timer->label));
1017 
1018             /* Cancel the timer so that we can call the CB safely */
1019             dtls_CancelTimer(ss, timer);
1020 
1021             /* Now call the CB */
1022             cb(ss);
1023         }
1024     }
1025     ssl_ReleaseSSL3HandshakeLock(ss);
1026 }
1027 
1028 /* The callback to fire when the holddown timer for the Finished
1029  * message expires and we can delete it
1030  *
1031  * Called from dtls_CheckTimer()
1032  */
1033 static void
dtls_FinishedTimerCb(sslSocket * ss)1034 dtls_FinishedTimerCb(sslSocket *ss)
1035 {
1036     dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
1037 }
1038 
1039 /* Cancel the Finished hold-down timer and destroy the
1040  * pending cipher spec. Note that this means that
1041  * successive rehandshakes will fail if the Finished is
1042  * lost.
1043  *
1044  * XXX OK for now. Figure out how to handle the combination
1045  * of Finished lost and rehandshake
1046  */
1047 void
dtls_RehandshakeCleanup(sslSocket * ss)1048 dtls_RehandshakeCleanup(sslSocket *ss)
1049 {
1050     /* Skip this if we are handling a second ClientHello. */
1051     if (ss->ssl3.hs.helloRetry) {
1052         return;
1053     }
1054     PORT_Assert((ss->version < SSL_LIBRARY_VERSION_TLS_1_3));
1055     dtls_CancelAllTimers(ss);
1056     dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
1057     ss->ssl3.hs.sendMessageSeq = 0;
1058     ss->ssl3.hs.recvMessageSeq = 0;
1059 }
1060 
1061 /* Set the MTU to the next step less than or equal to the
1062  * advertised value. Also used to downgrade the MTU by
1063  * doing dtls_SetMTU(ss, biggest packet set).
1064  *
1065  * Passing 0 means set this to the largest MTU known
1066  * (effectively resetting the PMTU backoff value).
1067  *
1068  * Called by:
1069  *            ssl3_InitState()
1070  *            dtls_RetransmitTimerExpiredCb()
1071  */
1072 void
dtls_SetMTU(sslSocket * ss,PRUint16 advertised)1073 dtls_SetMTU(sslSocket *ss, PRUint16 advertised)
1074 {
1075     int i;
1076 
1077     if (advertised == 0) {
1078         ss->ssl3.mtu = COMMON_MTU_VALUES[0];
1079         SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
1080         return;
1081     }
1082 
1083     for (i = 0; i < PR_ARRAY_SIZE(COMMON_MTU_VALUES); i++) {
1084         if (COMMON_MTU_VALUES[i] <= advertised) {
1085             ss->ssl3.mtu = COMMON_MTU_VALUES[i];
1086             SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
1087             return;
1088         }
1089     }
1090 
1091     /* Fallback */
1092     ss->ssl3.mtu = COMMON_MTU_VALUES[PR_ARRAY_SIZE(COMMON_MTU_VALUES) - 1];
1093     SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
1094 }
1095 
1096 /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a
1097  * DTLS hello_verify_request
1098  * Caller must hold Handshake and RecvBuf locks.
1099  */
1100 SECStatus
dtls_HandleHelloVerifyRequest(sslSocket * ss,PRUint8 * b,PRUint32 length)1101 dtls_HandleHelloVerifyRequest(sslSocket *ss, PRUint8 *b, PRUint32 length)
1102 {
1103     int errCode = SSL_ERROR_RX_MALFORMED_HELLO_VERIFY_REQUEST;
1104     SECStatus rv;
1105     SSL3ProtocolVersion temp;
1106     SSL3AlertDescription desc = illegal_parameter;
1107 
1108     SSL_TRC(3, ("%d: SSL3[%d]: handle hello_verify_request handshake",
1109                 SSL_GETPID(), ss->fd));
1110     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
1111     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
1112 
1113     if (ss->ssl3.hs.ws != wait_server_hello) {
1114         errCode = SSL_ERROR_RX_UNEXPECTED_HELLO_VERIFY_REQUEST;
1115         desc = unexpected_message;
1116         goto alert_loser;
1117     }
1118 
1119     dtls_ReceivedFirstMessageInFlight(ss);
1120 
1121     /* The version.
1122      *
1123      * RFC 4347 required that you verify that the server versions
1124      * match (Section 4.2.1) in the HelloVerifyRequest and the
1125      * ServerHello.
1126      *
1127      * RFC 6347 suggests (SHOULD) that servers always use 1.0 in
1128      * HelloVerifyRequest and allows the versions not to match,
1129      * especially when 1.2 is being negotiated.
1130      *
1131      * Therefore we do not do anything to enforce a match, just
1132      * read and check that this value is sane.
1133      */
1134     rv = ssl_ClientReadVersion(ss, &b, &length, &temp);
1135     if (rv != SECSuccess) {
1136         goto loser; /* alert has been sent */
1137     }
1138 
1139     /* Read the cookie.
1140      * IMPORTANT: The value of ss->ssl3.hs.cookie is only valid while the
1141      * HelloVerifyRequest message remains valid. */
1142     rv = ssl3_ConsumeHandshakeVariable(ss, &ss->ssl3.hs.cookie, 1, &b, &length);
1143     if (rv != SECSuccess) {
1144         goto loser; /* alert has been sent */
1145     }
1146     if (ss->ssl3.hs.cookie.len > DTLS_COOKIE_BYTES) {
1147         desc = decode_error;
1148         goto alert_loser; /* malformed. */
1149     }
1150 
1151     ssl_GetXmitBufLock(ss); /*******************************/
1152 
1153     /* Now re-send the client hello */
1154     rv = ssl3_SendClientHello(ss, client_hello_retransmit);
1155 
1156     ssl_ReleaseXmitBufLock(ss); /*******************************/
1157 
1158     if (rv == SECSuccess)
1159         return rv;
1160 
1161 alert_loser:
1162     (void)SSL3_SendAlert(ss, alert_fatal, desc);
1163 
1164 loser:
1165     ssl_MapLowLevelError(errCode);
1166     return SECFailure;
1167 }
1168 
1169 /* Initialize the DTLS anti-replay window
1170  *
1171  * Called from:
1172  *              ssl3_SetupPendingCipherSpec()
1173  *              ssl3_InitCipherSpec()
1174  */
1175 void
dtls_InitRecvdRecords(DTLSRecvdRecords * records)1176 dtls_InitRecvdRecords(DTLSRecvdRecords *records)
1177 {
1178     PORT_Memset(records->data, 0, sizeof(records->data));
1179     records->left = 0;
1180     records->right = DTLS_RECVD_RECORDS_WINDOW - 1;
1181 }
1182 
1183 /*
1184  * Has this DTLS record been received? Return values are:
1185  * -1 -- out of range to the left
1186  *  0 -- not received yet
1187  *  1 -- replay
1188  *
1189  *  Called from: ssl3_HandleRecord()
1190  */
1191 int
dtls_RecordGetRecvd(const DTLSRecvdRecords * records,sslSequenceNumber seq)1192 dtls_RecordGetRecvd(const DTLSRecvdRecords *records, sslSequenceNumber seq)
1193 {
1194     PRUint64 offset;
1195 
1196     /* Out of range to the left */
1197     if (seq < records->left) {
1198         return -1;
1199     }
1200 
1201     /* Out of range to the right; since we advance the window on
1202      * receipt, that means that this packet has not been received
1203      * yet */
1204     if (seq > records->right)
1205         return 0;
1206 
1207     offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1208 
1209     return !!(records->data[offset / 8] & (1 << (offset % 8)));
1210 }
1211 
1212 /* Update the DTLS anti-replay window
1213  *
1214  * Called from ssl3_HandleRecord()
1215  */
1216 void
dtls_RecordSetRecvd(DTLSRecvdRecords * records,sslSequenceNumber seq)1217 dtls_RecordSetRecvd(DTLSRecvdRecords *records, sslSequenceNumber seq)
1218 {
1219     PRUint64 offset;
1220 
1221     if (seq < records->left)
1222         return;
1223 
1224     if (seq > records->right) {
1225         sslSequenceNumber new_left;
1226         sslSequenceNumber new_right;
1227         sslSequenceNumber right;
1228 
1229         /* Slide to the right; this is the tricky part
1230          *
1231          * 1. new_top is set to have room for seq, on the
1232          *    next byte boundary by setting the right 8
1233          *    bits of seq
1234          * 2. new_left is set to compensate.
1235          * 3. Zero all bits between top and new_top. Since
1236          *    this is a ring, this zeroes everything as-yet
1237          *    unseen. Because we always operate on byte
1238          *    boundaries, we can zero one byte at a time
1239          */
1240         new_right = seq | 0x07;
1241         new_left = (new_right - DTLS_RECVD_RECORDS_WINDOW) + 1;
1242 
1243         if (new_right > records->right + DTLS_RECVD_RECORDS_WINDOW) {
1244             PORT_Memset(records->data, 0, sizeof(records->data));
1245         } else {
1246             for (right = records->right + 8; right <= new_right; right += 8) {
1247                 offset = right % DTLS_RECVD_RECORDS_WINDOW;
1248                 records->data[offset / 8] = 0;
1249             }
1250         }
1251 
1252         records->right = new_right;
1253         records->left = new_left;
1254     }
1255 
1256     offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1257 
1258     records->data[offset / 8] |= (1 << (offset % 8));
1259 }
1260 
1261 SECStatus
DTLS_GetHandshakeTimeout(PRFileDesc * socket,PRIntervalTime * timeout)1262 DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout)
1263 {
1264     sslSocket *ss = NULL;
1265     PRBool found = PR_FALSE;
1266     PRIntervalTime now = PR_IntervalNow();
1267     PRIntervalTime to;
1268     unsigned int i;
1269 
1270     *timeout = PR_INTERVAL_NO_TIMEOUT;
1271 
1272     ss = ssl_FindSocket(socket);
1273 
1274     if (!ss) {
1275         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1276         return SECFailure;
1277     }
1278 
1279     if (!IS_DTLS(ss)) {
1280         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1281         return SECFailure;
1282     }
1283 
1284     for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
1285         PRIntervalTime elapsed;
1286         PRIntervalTime desired;
1287         dtlsTimer *timer = &ss->ssl3.hs.timers[i];
1288 
1289         if (!timer->cb) {
1290             continue;
1291         }
1292         found = PR_TRUE;
1293 
1294         elapsed = now - timer->started;
1295         desired = PR_MillisecondsToInterval(timer->timeout);
1296         if (elapsed > desired) {
1297             /* Timer expired */
1298             *timeout = PR_INTERVAL_NO_WAIT;
1299             return SECSuccess;
1300         } else {
1301             to = desired - elapsed;
1302         }
1303 
1304         if (*timeout > to) {
1305             *timeout = to;
1306         }
1307     }
1308 
1309     if (!found) {
1310         PORT_SetError(SSL_ERROR_NO_TIMERS_FOUND);
1311         return SECFailure;
1312     }
1313 
1314     return SECSuccess;
1315 }
1316 
1317 PRBool
dtls_IsLongHeader(SSL3ProtocolVersion version,PRUint8 firstOctet)1318 dtls_IsLongHeader(SSL3ProtocolVersion version, PRUint8 firstOctet)
1319 {
1320 #ifndef UNSAFE_FUZZER_MODE
1321     return version < SSL_LIBRARY_VERSION_TLS_1_3 ||
1322            firstOctet == ssl_ct_handshake ||
1323            firstOctet == ssl_ct_ack ||
1324            firstOctet == ssl_ct_alert;
1325 #else
1326     return PR_TRUE;
1327 #endif
1328 }
1329 
1330 PRBool
dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version,PRUint8 firstOctet)1331 dtls_IsDtls13Ciphertext(SSL3ProtocolVersion version, PRUint8 firstOctet)
1332 {
1333     // Allow no version in case we haven't negotiated one yet.
1334     return (version == 0 || version >= SSL_LIBRARY_VERSION_TLS_1_3) &&
1335            (firstOctet & 0xe0) == 0x20;
1336 }
1337 
1338 DTLSEpoch
dtls_ReadEpoch(const ssl3CipherSpec * crSpec,const PRUint8 * hdr)1339 dtls_ReadEpoch(const ssl3CipherSpec *crSpec, const PRUint8 *hdr)
1340 {
1341     DTLSEpoch epoch;
1342     DTLSEpoch maxEpoch;
1343     DTLSEpoch partial;
1344 
1345     if (dtls_IsLongHeader(crSpec->version, hdr[0])) {
1346         return ((DTLSEpoch)hdr[3] << 8) | hdr[4];
1347     }
1348 
1349     /* A lot of how we recover the epoch here will depend on how we plan to
1350      * manage KeyUpdate.  In the case that we decide to install a new read spec
1351      * as a KeyUpdate is handled, crSpec will always be the highest epoch we can
1352      * possibly receive.  That makes this easier to manage.
1353      */
1354     if (dtls_IsDtls13Ciphertext(crSpec->version, hdr[0])) {
1355         /* TODO(ekr@rtfm.com: do something with the two-bit epoch. */
1356         /* Use crSpec->epoch, or crSpec->epoch - 1 if the last bit differs. */
1357         return crSpec->epoch - ((hdr[0] ^ crSpec->epoch) & 0x3);
1358     }
1359 
1360     /* dtls_GatherData should ensure that this works. */
1361     PORT_Assert(hdr[0] == ssl_ct_application_data);
1362 
1363     /* This uses the same method as is used to recover the sequence number in
1364      * dtls_ReadSequenceNumber, except that the maximum value is set to the
1365      * current epoch. */
1366     partial = hdr[1] >> 6;
1367     maxEpoch = PR_MAX(crSpec->epoch, 3);
1368     epoch = (maxEpoch & 0xfffc) | partial;
1369     if (partial > (maxEpoch & 0x03)) {
1370         epoch -= 4;
1371     }
1372     return epoch;
1373 }
1374 
1375 static sslSequenceNumber
dtls_ReadSequenceNumber(const ssl3CipherSpec * spec,const PRUint8 * hdr)1376 dtls_ReadSequenceNumber(const ssl3CipherSpec *spec, const PRUint8 *hdr)
1377 {
1378     sslSequenceNumber cap;
1379     sslSequenceNumber partial;
1380     sslSequenceNumber seqNum;
1381     sslSequenceNumber mask;
1382 
1383     if (dtls_IsLongHeader(spec->version, hdr[0])) {
1384         static const unsigned int seqNumOffset = 5; /* type, version, epoch */
1385         static const unsigned int seqNumLength = 6;
1386         sslReader r = SSL_READER(hdr + seqNumOffset, seqNumLength);
1387         (void)sslRead_ReadNumber(&r, seqNumLength, &seqNum);
1388         return seqNum;
1389     }
1390 
1391     /* Only the least significant bits of the sequence number is available here.
1392      * This recovers the value based on the next expected sequence number.
1393      *
1394      * This works by determining the maximum possible sequence number, which is
1395      * half the range of possible values above the expected next value (the
1396      * expected next value is in |spec->seqNum|).  Then, the last part of the
1397      * sequence number is replaced.  If that causes the value to exceed the
1398      * maximum, subtract an entire range.
1399      */
1400     if (hdr[0] & 0x08) {
1401         cap = spec->nextSeqNum + (1ULL << 15);
1402         partial = (((sslSequenceNumber)hdr[1]) << 8) |
1403                   (sslSequenceNumber)hdr[2];
1404         mask = (1ULL << 16) - 1;
1405     } else {
1406         cap = spec->nextSeqNum + (1ULL << 7);
1407         partial = (sslSequenceNumber)hdr[1];
1408         mask = (1ULL << 8) - 1;
1409     }
1410     seqNum = (cap & ~mask) | partial;
1411     /* The second check prevents the value from underflowing if we get a large
1412      * gap at the start of a connection, where this subtraction would cause the
1413      * sequence number to wrap to near UINT64_MAX. */
1414     if ((partial > (cap & mask)) && (seqNum > mask)) {
1415         seqNum -= mask + 1;
1416     }
1417     return seqNum;
1418 }
1419 
1420 /*
1421  * DTLS relevance checks:
1422  * Note that this code currently ignores all out-of-epoch packets,
1423  * which means we lose some in the case of rehandshake +
1424  * loss/reordering. Since DTLS is explicitly unreliable, this
1425  * seems like a good tradeoff for implementation effort and is
1426  * consistent with the guidance of RFC 6347 Sections 4.1 and 4.2.4.1.
1427  *
1428  * If the packet is not relevant, this function returns PR_FALSE.  If the packet
1429  * is relevant, this function returns PR_TRUE and sets |*seqNumOut| to the
1430  * packet sequence number (removing the epoch).
1431  */
1432 PRBool
dtls_IsRelevant(sslSocket * ss,const ssl3CipherSpec * spec,const SSL3Ciphertext * cText,sslSequenceNumber * seqNumOut)1433 dtls_IsRelevant(sslSocket *ss, const ssl3CipherSpec *spec,
1434                 const SSL3Ciphertext *cText,
1435                 sslSequenceNumber *seqNumOut)
1436 {
1437     sslSequenceNumber seqNum = dtls_ReadSequenceNumber(spec, cText->hdr);
1438     if (dtls_RecordGetRecvd(&spec->recvdRecords, seqNum) != 0) {
1439         SSL_TRC(10, ("%d: SSL3[%d]: dtls_IsRelevant, rejecting "
1440                      "potentially replayed packet",
1441                      SSL_GETPID(), ss->fd));
1442         return PR_FALSE;
1443     }
1444 
1445     *seqNumOut = seqNum;
1446     return PR_TRUE;
1447 }
1448 
1449 void
dtls_ReceivedFirstMessageInFlight(sslSocket * ss)1450 dtls_ReceivedFirstMessageInFlight(sslSocket *ss)
1451 {
1452     if (!IS_DTLS(ss))
1453         return;
1454 
1455     /* At this point we are advancing our state machine, so we can free our last
1456      * flight of messages. */
1457     if (ss->ssl3.hs.ws != idle_handshake ||
1458         ss->version >= SSL_LIBRARY_VERSION_TLS_1_3) {
1459         /* We need to keep our last flight around in DTLS 1.2 and below,
1460          * so we can retransmit it in response to other people's
1461          * retransmits. */
1462         dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
1463 
1464         /* Reset the timer to the initial value if the retry counter
1465          * is 0, per RFC 6347, Sec. 4.2.4.1 */
1466         dtls_CancelTimer(ss, ss->ssl3.hs.rtTimer);
1467         if (ss->ssl3.hs.rtRetries == 0) {
1468             ss->ssl3.hs.rtTimer->timeout = DTLS_RETRANSMIT_INITIAL_MS;
1469         }
1470     }
1471 
1472     /* Empty the ACK queue (TLS 1.3 only). */
1473     ssl_ClearPRCList(&ss->ssl3.hs.dtlsRcvdHandshake, NULL);
1474 }
1475