1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "tls_protect.h"
8 #include "sslproto.h"
9 #include "tls_filter.h"
10 
11 namespace nss_test {
12 
FirstSeqno(bool dtls,uint16_t epoc)13 static uint64_t FirstSeqno(bool dtls, uint16_t epoc) {
14   if (dtls) {
15     return static_cast<uint64_t>(epoc) << 48;
16   }
17   return 0;
18 }
19 
TlsCipherSpec(bool dtls,uint16_t epoc)20 TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc)
21     : dtls_(dtls),
22       epoch_(epoc),
23       in_seqno_(FirstSeqno(dtls, epoc)),
24       out_seqno_(FirstSeqno(dtls, epoc)) {}
25 
SetKeys(SSLCipherSuiteInfo * cipherinfo,PK11SymKey * secret)26 bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo,
27                             PK11SymKey* secret) {
28   SSLAeadContext* aead_ctx;
29   SSLProtocolVariant variant =
30       dtls_ ? ssl_variant_datagram : ssl_variant_stream;
31   SECStatus rv =
32       SSL_MakeVariantAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite,
33                           variant, secret, "", 0,  // Use the default labels.
34                           &aead_ctx);
35   if (rv != SECSuccess) {
36     return false;
37   }
38   aead_.reset(aead_ctx);
39 
40   SSLMaskingContext* mask_ctx;
41   const char kHkdfPurposeSn[] = "sn";
42   rv = SSL_CreateVariantMaskingContext(
43       SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, variant, secret,
44       kHkdfPurposeSn, strlen(kHkdfPurposeSn), &mask_ctx);
45   if (rv != SECSuccess) {
46     return false;
47   }
48   mask_.reset(mask_ctx);
49   return true;
50 }
51 
Unprotect(const TlsRecordHeader & header,const DataBuffer & ciphertext,DataBuffer * plaintext,TlsRecordHeader * out_header)52 bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header,
53                               const DataBuffer& ciphertext,
54                               DataBuffer* plaintext,
55                               TlsRecordHeader* out_header) {
56   if (!aead_ || !out_header) {
57     return false;
58   }
59   *out_header = header;
60 
61   // Make space.
62   plaintext->Allocate(ciphertext.len());
63 
64   unsigned int len;
65   uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_;
66   SECStatus rv;
67 
68   if (header.is_dtls13_ciphertext()) {
69     if (!mask_ || !out_header) {
70       return false;
71     }
72     PORT_Assert(ciphertext.len() >= 16);
73     DataBuffer mask(2);
74     rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(),
75                         mask.data(), mask.len());
76     if (rv != SECSuccess) {
77       return false;
78     }
79 
80     if (!out_header->MaskSequenceNumber(mask)) {
81       return false;
82     }
83     seqno = out_header->sequence_number();
84   }
85 
86   auto header_bytes = out_header->header();
87   rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(),
88                        header_bytes.len(), ciphertext.data(), ciphertext.len(),
89                        plaintext->data(), &len, plaintext->len());
90   if (rv != SECSuccess) {
91     return false;
92   }
93 
94   RecordUnprotected(seqno);
95   plaintext->Truncate(static_cast<size_t>(len));
96 
97   return true;
98 }
99 
Protect(const TlsRecordHeader & header,const DataBuffer & plaintext,DataBuffer * ciphertext,TlsRecordHeader * out_header)100 bool TlsCipherSpec::Protect(const TlsRecordHeader& header,
101                             const DataBuffer& plaintext, DataBuffer* ciphertext,
102                             TlsRecordHeader* out_header) {
103   if (!aead_ || !out_header) {
104     return false;
105   }
106 
107   *out_header = header;
108 
109   // Make a padded buffer.
110   ciphertext->Allocate(plaintext.len() +
111                        32);  // Room for any plausible auth tag
112   unsigned int len;
113 
114   DataBuffer header_bytes;
115   (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16);
116   uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_;
117 
118   SECStatus rv =
119       SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(),
120                       header_bytes.len(), plaintext.data(), plaintext.len(),
121                       ciphertext->data(), &len, ciphertext->len());
122   if (rv != SECSuccess) {
123     return false;
124   }
125 
126   if (header.is_dtls13_ciphertext()) {
127     if (!mask_ || !out_header) {
128       return false;
129     }
130     PORT_Assert(ciphertext->len() >= 16);
131     DataBuffer mask(2);
132     rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(),
133                         mask.data(), mask.len());
134     if (rv != SECSuccess) {
135       return false;
136     }
137     if (!out_header->MaskSequenceNumber(mask)) {
138       return false;
139     }
140   }
141 
142   RecordProtected();
143   ciphertext->Truncate(len);
144 
145   return true;
146 }
147 
148 }  // namespace nss_test
149