1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 * You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #include "secerr.h" 8 #include "ssl.h" 9 #include "sslerr.h" 10 #include "sslproto.h" 11 12 #include "gtest_utils.h" 13 #include "nss_scoped_ptrs.h" 14 #include "tls_connect.h" 15 #include "tls_filter.h" 16 #include "tls_parser.h" 17 18 namespace nss_test { 19 20 // This class cuts every unencrypted handshake record into two parts. 21 class RecordFragmenter : public PacketFilter { 22 public: 23 RecordFragmenter(bool is_dtls13) 24 : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {} 25 26 private: 27 class HandshakeSplitter { 28 public: 29 HandshakeSplitter(bool is_dtls13, const DataBuffer& input, 30 DataBuffer* output, uint64_t* sequence_number) 31 : is_dtls13_(is_dtls13), 32 input_(input), 33 output_(output), 34 cursor_(0), 35 sequence_number_(sequence_number) {} 36 37 private: 38 void WriteRecord(TlsRecordHeader& record_header, 39 DataBuffer& record_fragment) { 40 TlsRecordHeader fragment_header( 41 record_header.variant(), record_header.version(), 42 record_header.content_type(), *sequence_number_); 43 ++*sequence_number_; 44 if (::g_ssl_gtest_verbose) { 45 std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment 46 << std::endl; 47 } 48 cursor_ = fragment_header.Write(output_, cursor_, record_fragment); 49 } 50 51 bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) { 52 TlsParser parser(record); 53 while (parser.remaining()) { 54 TlsHandshakeFilter::HandshakeHeader handshake_header; 55 DataBuffer handshake_body; 56 bool complete = false; 57 if (!handshake_header.Parse(&parser, record_header, DataBuffer(), 58 &handshake_body, &complete)) { 59 ADD_FAILURE() << "couldn't parse handshake header"; 60 return false; 61 } 62 if (!complete) { 63 ADD_FAILURE() << "don't want to deal with fragmented messages"; 64 return false; 65 } 66 67 DataBuffer record_fragment; 68 // We can't fragment handshake records that are too small. 69 if (handshake_body.len() < 2) { 70 handshake_header.Write(&record_fragment, 0U, handshake_body); 71 WriteRecord(record_header, record_fragment); 72 continue; 73 } 74 75 size_t cut = handshake_body.len() / 2; 76 handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U, 77 cut); 78 WriteRecord(record_header, record_fragment); 79 80 handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 81 cut, handshake_body.len() - cut); 82 WriteRecord(record_header, record_fragment); 83 } 84 return true; 85 } 86 87 public: 88 bool Split() { 89 TlsParser parser(input_); 90 while (parser.remaining()) { 91 TlsRecordHeader header; 92 DataBuffer record; 93 if (!header.Parse(is_dtls13_, 0, &parser, &record)) { 94 ADD_FAILURE() << "bad record header"; 95 return false; 96 } 97 98 if (::g_ssl_gtest_verbose) { 99 std::cerr << "Record: " << header << ' ' << record << std::endl; 100 } 101 102 // Don't touch packets from a non-zero epoch. Leave these unmodified. 103 if ((header.sequence_number() >> 48) != 0ULL) { 104 cursor_ = header.Write(output_, cursor_, record); 105 continue; 106 } 107 108 // Just rewrite the sequence number (CCS only). 109 if (header.content_type() != ssl_ct_handshake) { 110 EXPECT_EQ(ssl_ct_change_cipher_spec, header.content_type()); 111 WriteRecord(header, record); 112 continue; 113 } 114 115 if (!SplitRecord(header, record)) { 116 return false; 117 } 118 } 119 return true; 120 } 121 122 private: 123 bool is_dtls13_; 124 const DataBuffer& input_; 125 DataBuffer* output_; 126 size_t cursor_; 127 uint64_t* sequence_number_; 128 }; 129 130 protected: 131 virtual PacketFilter::Action Filter(const DataBuffer& input, 132 DataBuffer* output) override { 133 if (!splitting_) { 134 return KEEP; 135 } 136 137 output->Allocate(input.len()); 138 HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_); 139 if (!splitter.Split()) { 140 // If splitting fails, we obviously reached encrypted packets. 141 // Stop splitting from that point onward. 142 splitting_ = false; 143 return KEEP; 144 } 145 146 return CHANGE; 147 } 148 149 private: 150 bool is_dtls13_; 151 uint64_t sequence_number_; 152 bool splitting_; 153 }; 154 155 TEST_P(TlsConnectDatagram, FragmentClientPackets) { 156 bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; 157 client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); 158 Connect(); 159 SendReceive(); 160 } 161 162 TEST_P(TlsConnectDatagram, FragmentServerPackets) { 163 bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; 164 server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); 165 Connect(); 166 SendReceive(); 167 } 168 169 } // namespace nss_test 170