1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5 * You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7 #include "secerr.h"
8 #include "ssl.h"
9 #include "sslerr.h"
10 #include "sslproto.h"
11
12 #include "gtest_utils.h"
13 #include "nss_scoped_ptrs.h"
14 #include "tls_connect.h"
15 #include "tls_filter.h"
16 #include "tls_parser.h"
17
18 namespace nss_test {
19
20 // This class cuts every unencrypted handshake record into two parts.
21 class RecordFragmenter : public PacketFilter {
22 public:
RecordFragmenter(bool is_dtls13)23 RecordFragmenter(bool is_dtls13)
24 : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {}
25
26 private:
27 class HandshakeSplitter {
28 public:
HandshakeSplitter(bool is_dtls13,const DataBuffer & input,DataBuffer * output,uint64_t * sequence_number)29 HandshakeSplitter(bool is_dtls13, const DataBuffer& input,
30 DataBuffer* output, uint64_t* sequence_number)
31 : is_dtls13_(is_dtls13),
32 input_(input),
33 output_(output),
34 cursor_(0),
35 sequence_number_(sequence_number) {}
36
37 private:
WriteRecord(TlsRecordHeader & record_header,DataBuffer & record_fragment)38 void WriteRecord(TlsRecordHeader& record_header,
39 DataBuffer& record_fragment) {
40 TlsRecordHeader fragment_header(
41 record_header.variant(), record_header.version(),
42 record_header.content_type(), *sequence_number_);
43 ++*sequence_number_;
44 if (::g_ssl_gtest_verbose) {
45 std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
46 << std::endl;
47 }
48 cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
49 }
50
SplitRecord(TlsRecordHeader & record_header,DataBuffer & record)51 bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) {
52 TlsParser parser(record);
53 while (parser.remaining()) {
54 TlsHandshakeFilter::HandshakeHeader handshake_header;
55 DataBuffer handshake_body;
56 bool complete = false;
57 if (!handshake_header.Parse(&parser, record_header, DataBuffer(),
58 &handshake_body, &complete)) {
59 ADD_FAILURE() << "couldn't parse handshake header";
60 return false;
61 }
62 if (!complete) {
63 ADD_FAILURE() << "don't want to deal with fragmented messages";
64 return false;
65 }
66
67 DataBuffer record_fragment;
68 // We can't fragment handshake records that are too small.
69 if (handshake_body.len() < 2) {
70 handshake_header.Write(&record_fragment, 0U, handshake_body);
71 WriteRecord(record_header, record_fragment);
72 continue;
73 }
74
75 size_t cut = handshake_body.len() / 2;
76 handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
77 cut);
78 WriteRecord(record_header, record_fragment);
79
80 handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
81 cut, handshake_body.len() - cut);
82 WriteRecord(record_header, record_fragment);
83 }
84 return true;
85 }
86
87 public:
Split()88 bool Split() {
89 TlsParser parser(input_);
90 while (parser.remaining()) {
91 TlsRecordHeader header;
92 DataBuffer record;
93 if (!header.Parse(is_dtls13_, 0, &parser, &record)) {
94 ADD_FAILURE() << "bad record header";
95 return false;
96 }
97
98 if (::g_ssl_gtest_verbose) {
99 std::cerr << "Record: " << header << ' ' << record << std::endl;
100 }
101
102 // Don't touch packets from a non-zero epoch. Leave these unmodified.
103 if ((header.sequence_number() >> 48) != 0ULL) {
104 cursor_ = header.Write(output_, cursor_, record);
105 continue;
106 }
107
108 // Just rewrite the sequence number (CCS only).
109 if (header.content_type() != ssl_ct_handshake) {
110 EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type());
111 WriteRecord(header, record);
112 continue;
113 }
114
115 if (!SplitRecord(header, record)) {
116 return false;
117 }
118 }
119 return true;
120 }
121
122 private:
123 bool is_dtls13_;
124 const DataBuffer& input_;
125 DataBuffer* output_;
126 size_t cursor_;
127 uint64_t* sequence_number_;
128 };
129
130 protected:
Filter(const DataBuffer & input,DataBuffer * output)131 virtual PacketFilter::Action Filter(const DataBuffer& input,
132 DataBuffer* output) override {
133 if (!splitting_) {
134 return KEEP;
135 }
136
137 output->Allocate(input.len());
138 HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_);
139 if (!splitter.Split()) {
140 // If splitting fails, we obviously reached encrypted packets.
141 // Stop splitting from that point onward.
142 splitting_ = false;
143 return KEEP;
144 }
145
146 return CHANGE;
147 }
148
149 private:
150 bool is_dtls13_;
151 uint64_t sequence_number_;
152 bool splitting_;
153 };
154
TEST_P(TlsConnectDatagram,FragmentClientPackets)155 TEST_P(TlsConnectDatagram, FragmentClientPackets) {
156 bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
157 client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
158 Connect();
159 SendReceive();
160 }
161
TEST_P(TlsConnectDatagram,FragmentServerPackets)162 TEST_P(TlsConnectDatagram, FragmentServerPackets) {
163 bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
164 server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
165 Connect();
166 SendReceive();
167 }
168
169 } // namespace nss_test
170