1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "secerr.h"
8 #include "ssl.h"
9 #include "sslerr.h"
10 #include "sslproto.h"
11 
12 #include "gtest_utils.h"
13 #include "nss_scoped_ptrs.h"
14 #include "tls_connect.h"
15 #include "tls_filter.h"
16 #include "tls_parser.h"
17 
18 namespace nss_test {
19 
20 // This class cuts every unencrypted handshake record into two parts.
21 class RecordFragmenter : public PacketFilter {
22  public:
RecordFragmenter(bool is_dtls13)23   RecordFragmenter(bool is_dtls13)
24       : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {}
25 
26  private:
27   class HandshakeSplitter {
28    public:
HandshakeSplitter(bool is_dtls13,const DataBuffer & input,DataBuffer * output,uint64_t * sequence_number)29     HandshakeSplitter(bool is_dtls13, const DataBuffer& input,
30                       DataBuffer* output, uint64_t* sequence_number)
31         : is_dtls13_(is_dtls13),
32           input_(input),
33           output_(output),
34           cursor_(0),
35           sequence_number_(sequence_number) {}
36 
37    private:
WriteRecord(TlsRecordHeader & record_header,DataBuffer & record_fragment)38     void WriteRecord(TlsRecordHeader& record_header,
39                      DataBuffer& record_fragment) {
40       TlsRecordHeader fragment_header(
41           record_header.variant(), record_header.version(),
42           record_header.content_type(), *sequence_number_);
43       ++*sequence_number_;
44       if (::g_ssl_gtest_verbose) {
45         std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
46                   << std::endl;
47       }
48       cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
49     }
50 
SplitRecord(TlsRecordHeader & record_header,DataBuffer & record)51     bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) {
52       TlsParser parser(record);
53       while (parser.remaining()) {
54         TlsHandshakeFilter::HandshakeHeader handshake_header;
55         DataBuffer handshake_body;
56         bool complete = false;
57         if (!handshake_header.Parse(&parser, record_header, DataBuffer(),
58                                     &handshake_body, &complete)) {
59           ADD_FAILURE() << "couldn't parse handshake header";
60           return false;
61         }
62         if (!complete) {
63           ADD_FAILURE() << "don't want to deal with fragmented messages";
64           return false;
65         }
66 
67         DataBuffer record_fragment;
68         // We can't fragment handshake records that are too small.
69         if (handshake_body.len() < 2) {
70           handshake_header.Write(&record_fragment, 0U, handshake_body);
71           WriteRecord(record_header, record_fragment);
72           continue;
73         }
74 
75         size_t cut = handshake_body.len() / 2;
76         handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
77                                        cut);
78         WriteRecord(record_header, record_fragment);
79 
80         handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
81                                        cut, handshake_body.len() - cut);
82         WriteRecord(record_header, record_fragment);
83       }
84       return true;
85     }
86 
87    public:
Split()88     bool Split() {
89       TlsParser parser(input_);
90       while (parser.remaining()) {
91         TlsRecordHeader header;
92         DataBuffer record;
93         if (!header.Parse(is_dtls13_, 0, &parser, &record)) {
94           ADD_FAILURE() << "bad record header";
95           return false;
96         }
97 
98         if (::g_ssl_gtest_verbose) {
99           std::cerr << "Record: " << header << ' ' << record << std::endl;
100         }
101 
102         // Don't touch packets from a non-zero epoch.  Leave these unmodified.
103         if ((header.sequence_number() >> 48) != 0ULL) {
104           cursor_ = header.Write(output_, cursor_, record);
105           continue;
106         }
107 
108         // Just rewrite the sequence number (CCS only).
109         if (header.content_type() != ssl_ct_handshake) {
110           EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type());
111           WriteRecord(header, record);
112           continue;
113         }
114 
115         if (!SplitRecord(header, record)) {
116           return false;
117         }
118       }
119       return true;
120     }
121 
122    private:
123     bool is_dtls13_;
124     const DataBuffer& input_;
125     DataBuffer* output_;
126     size_t cursor_;
127     uint64_t* sequence_number_;
128   };
129 
130  protected:
Filter(const DataBuffer & input,DataBuffer * output)131   virtual PacketFilter::Action Filter(const DataBuffer& input,
132                                       DataBuffer* output) override {
133     if (!splitting_) {
134       return KEEP;
135     }
136 
137     output->Allocate(input.len());
138     HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_);
139     if (!splitter.Split()) {
140       // If splitting fails, we obviously reached encrypted packets.
141       // Stop splitting from that point onward.
142       splitting_ = false;
143       return KEEP;
144     }
145 
146     return CHANGE;
147   }
148 
149  private:
150   bool is_dtls13_;
151   uint64_t sequence_number_;
152   bool splitting_;
153 };
154 
TEST_P(TlsConnectDatagram,FragmentClientPackets)155 TEST_P(TlsConnectDatagram, FragmentClientPackets) {
156   bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
157   client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
158   Connect();
159   SendReceive();
160 }
161 
TEST_P(TlsConnectDatagram,FragmentServerPackets)162 TEST_P(TlsConnectDatagram, FragmentServerPackets) {
163   bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3;
164   server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13));
165   Connect();
166   SendReceive();
167 }
168 
169 }  // namespace nss_test
170