1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7 #include "tls_protect.h"
8 #include "sslproto.h"
9 #include "tls_filter.h"
10
11 namespace nss_test {
12
FirstSeqno(bool dtls,uint16_t epoc)13 static uint64_t FirstSeqno(bool dtls, uint16_t epoc) {
14 if (dtls) {
15 return static_cast<uint64_t>(epoc) << 48;
16 }
17 return 0;
18 }
19
TlsCipherSpec(bool dtls,uint16_t epoc)20 TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc)
21 : dtls_(dtls),
22 epoch_(epoc),
23 in_seqno_(FirstSeqno(dtls, epoc)),
24 out_seqno_(FirstSeqno(dtls, epoc)) {}
25
SetKeys(SSLCipherSuiteInfo * cipherinfo,PK11SymKey * secret)26 bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo,
27 PK11SymKey* secret) {
28 SSLAeadContext* aead_ctx;
29 SSLProtocolVariant variant =
30 dtls_ ? ssl_variant_datagram : ssl_variant_stream;
31 SECStatus rv =
32 SSL_MakeVariantAead(SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite,
33 variant, secret, "", 0, // Use the default labels.
34 &aead_ctx);
35 if (rv != SECSuccess) {
36 return false;
37 }
38 aead_.reset(aead_ctx);
39
40 SSLMaskingContext* mask_ctx;
41 const char kHkdfPurposeSn[] = "sn";
42 rv = SSL_CreateVariantMaskingContext(
43 SSL_LIBRARY_VERSION_TLS_1_3, cipherinfo->cipherSuite, variant, secret,
44 kHkdfPurposeSn, strlen(kHkdfPurposeSn), &mask_ctx);
45 if (rv != SECSuccess) {
46 return false;
47 }
48 mask_.reset(mask_ctx);
49 return true;
50 }
51
Unprotect(const TlsRecordHeader & header,const DataBuffer & ciphertext,DataBuffer * plaintext,TlsRecordHeader * out_header)52 bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header,
53 const DataBuffer& ciphertext,
54 DataBuffer* plaintext,
55 TlsRecordHeader* out_header) {
56 if (!aead_ || !out_header) {
57 return false;
58 }
59 *out_header = header;
60
61 // Make space.
62 plaintext->Allocate(ciphertext.len());
63
64 unsigned int len;
65 uint64_t seqno = dtls_ ? header.sequence_number() : in_seqno_;
66 SECStatus rv;
67
68 if (header.is_dtls13_ciphertext()) {
69 if (!mask_ || !out_header) {
70 return false;
71 }
72 PORT_Assert(ciphertext.len() >= 16);
73 DataBuffer mask(2);
74 rv = SSL_CreateMask(mask_.get(), ciphertext.data(), ciphertext.len(),
75 mask.data(), mask.len());
76 if (rv != SECSuccess) {
77 return false;
78 }
79
80 if (!out_header->MaskSequenceNumber(mask)) {
81 return false;
82 }
83 seqno = out_header->sequence_number();
84 }
85
86 auto header_bytes = out_header->header();
87 rv = SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(),
88 header_bytes.len(), ciphertext.data(), ciphertext.len(),
89 plaintext->data(), &len, plaintext->len());
90 if (rv != SECSuccess) {
91 return false;
92 }
93
94 RecordUnprotected(seqno);
95 plaintext->Truncate(static_cast<size_t>(len));
96
97 return true;
98 }
99
Protect(const TlsRecordHeader & header,const DataBuffer & plaintext,DataBuffer * ciphertext,TlsRecordHeader * out_header)100 bool TlsCipherSpec::Protect(const TlsRecordHeader& header,
101 const DataBuffer& plaintext, DataBuffer* ciphertext,
102 TlsRecordHeader* out_header) {
103 if (!aead_ || !out_header) {
104 return false;
105 }
106
107 *out_header = header;
108
109 // Make a padded buffer.
110 ciphertext->Allocate(plaintext.len() +
111 32); // Room for any plausible auth tag
112 unsigned int len;
113
114 DataBuffer header_bytes;
115 (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16);
116 uint64_t seqno = dtls_ ? header.sequence_number() : out_seqno_;
117
118 SECStatus rv =
119 SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(),
120 header_bytes.len(), plaintext.data(), plaintext.len(),
121 ciphertext->data(), &len, ciphertext->len());
122 if (rv != SECSuccess) {
123 return false;
124 }
125
126 if (header.is_dtls13_ciphertext()) {
127 if (!mask_ || !out_header) {
128 return false;
129 }
130 PORT_Assert(ciphertext->len() >= 16);
131 DataBuffer mask(2);
132 rv = SSL_CreateMask(mask_.get(), ciphertext->data(), ciphertext->len(),
133 mask.data(), mask.len());
134 if (rv != SECSuccess) {
135 return false;
136 }
137 if (!out_header->MaskSequenceNumber(mask)) {
138 return false;
139 }
140 }
141
142 RecordProtected();
143 ciphertext->Truncate(len);
144
145 return true;
146 }
147
148 } // namespace nss_test
149