1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #ifndef tls_parser_h_
8 #define tls_parser_h_
9 
10 #include <cstdint>
11 #include <cstring>
12 #include <memory>
13 #if defined(WIN32) || defined(WIN64)
14 #include <winsock2.h>
15 #else
16 #include <arpa/inet.h>
17 #endif
18 #include "databuffer.h"
19 #include "sslt.h"
20 
21 namespace nss_test {
22 
23 const uint8_t kTlsHandshakeClientHello = 1;
24 const uint8_t kTlsHandshakeServerHello = 2;
25 const uint8_t kTlsHandshakeNewSessionTicket = 4;
26 const uint8_t kTlsHandshakeHelloRetryRequest = 6;
27 const uint8_t kTlsHandshakeEncryptedExtensions = 8;
28 const uint8_t kTlsHandshakeCertificate = 11;
29 const uint8_t kTlsHandshakeServerKeyExchange = 12;
30 const uint8_t kTlsHandshakeCertificateRequest = 13;
31 const uint8_t kTlsHandshakeCertificateVerify = 15;
32 const uint8_t kTlsHandshakeClientKeyExchange = 16;
33 const uint8_t kTlsHandshakeFinished = 20;
34 const uint8_t kTlsHandshakeKeyUpdate = 24;
35 
36 const uint8_t kTlsAlertWarning = 1;
37 const uint8_t kTlsAlertFatal = 2;
38 
39 const uint8_t kTlsAlertCloseNotify = 0;
40 const uint8_t kTlsAlertUnexpectedMessage = 10;
41 const uint8_t kTlsAlertBadRecordMac = 20;
42 const uint8_t kTlsAlertRecordOverflow = 22;
43 const uint8_t kTlsAlertHandshakeFailure = 40;
44 const uint8_t kTlsAlertBadCertificate = 42;
45 const uint8_t kTlsAlertCertificateRevoked = 44;
46 const uint8_t kTlsAlertCertificateExpired = 45;
47 const uint8_t kTlsAlertIllegalParameter = 47;
48 const uint8_t kTlsAlertDecodeError = 50;
49 const uint8_t kTlsAlertDecryptError = 51;
50 const uint8_t kTlsAlertProtocolVersion = 70;
51 const uint8_t kTlsAlertInsufficientSecurity = 71;
52 const uint8_t kTlsAlertInternalError = 80;
53 const uint8_t kTlsAlertInappropriateFallback = 86;
54 const uint8_t kTlsAlertMissingExtension = 109;
55 const uint8_t kTlsAlertUnsupportedExtension = 110;
56 const uint8_t kTlsAlertUnrecognizedName = 112;
57 const uint8_t kTlsAlertCertificateRequired = 116;
58 const uint8_t kTlsAlertNoApplicationProtocol = 120;
59 const uint8_t kTlsAlertEchRequired = 121;
60 
61 const uint8_t kTlsFakeChangeCipherSpec[] = {
62     ssl_ct_change_cipher_spec,  // Type
63     0xfe,
64     0xff,  // Version
65     0x00,
66     0x00,
67     0x00,
68     0x00,
69     0x00,
70     0x00,
71     0x00,
72     0x10,  // Fictitious sequence #
73     0x00,
74     0x01,  // Length
75     0x01   // Value
76 };
77 
78 const uint8_t kCtDtlsCiphertext = 0x20;
79 const uint8_t kCtDtlsCiphertextMask = 0xE0;
80 const uint8_t kCtDtlsCiphertext16bSeqno = 0x08;
81 const uint8_t kCtDtlsCiphertextLengthPresent = 0x04;
82 
83 static const uint8_t kTls13PskKe = 0;
84 static const uint8_t kTls13PskDhKe = 1;
85 static const uint8_t kTls13PskAuth = 0;
86 static const uint8_t kTls13PskSignAuth = 1;
87 
88 inline std::ostream& operator<<(std::ostream& os, SSLProtocolVariant v) {
89   return os << ((v == ssl_variant_stream) ? "TLS" : "DTLS");
90 }
91 
92 inline std::ostream& operator<<(std::ostream& os, SSLContentType v) {
93   switch (v) {
94     case ssl_ct_change_cipher_spec:
95       return os << "CCS";
96     case ssl_ct_alert:
97       return os << "alert";
98     case ssl_ct_handshake:
99       return os << "handshake";
100     case ssl_ct_application_data:
101       return os << "application data";
102     case ssl_ct_ack:
103       return os << "ack";
104   }
105   return os << "UNKNOWN content type " << static_cast<int>(v);
106 }
107 
108 inline std::ostream& operator<<(std::ostream& os, SSLSecretDirection v) {
109   switch (v) {
110     case ssl_secret_read:
111       return os << "read";
112     case ssl_secret_write:
113       return os << "write";
114   }
115   return os << "UNKNOWN secret direction " << static_cast<int>(v);
116 }
117 
IsDtls(uint16_t version)118 inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; }
119 
NormalizeTlsVersion(uint16_t version)120 inline uint16_t NormalizeTlsVersion(uint16_t version) {
121   if (version == 0xfeff) {
122     return 0x0302;  // special: DTLS 1.0 == TLS 1.1
123   }
124   if (IsDtls(version)) {
125     return (version ^ 0xffff) + 0x0201;
126   }
127   return version;
128 }
129 
TlsVersionToDtlsVersion(uint16_t version)130 inline uint16_t TlsVersionToDtlsVersion(uint16_t version) {
131   if (version == 0x0302) {
132     return 0xfeff;
133   }
134   if (version == 0x0304) {
135     return version;
136   }
137   return 0xffff - version + 0x0201;
138 }
139 
WriteVariable(DataBuffer * target,size_t index,const DataBuffer & buf,size_t len_size)140 inline size_t WriteVariable(DataBuffer* target, size_t index,
141                             const DataBuffer& buf, size_t len_size) {
142   index = target->Write(index, static_cast<uint32_t>(buf.len()), len_size);
143   return target->Write(index, buf.data(), buf.len());
144 }
145 
146 class TlsParser {
147  public:
TlsParser(const uint8_t * data,size_t len)148   TlsParser(const uint8_t* data, size_t len) : buffer_(data, len), offset_(0) {}
TlsParser(const DataBuffer & buf)149   explicit TlsParser(const DataBuffer& buf) : buffer_(buf), offset_(0) {}
150 
151   bool Read(uint8_t* val);
152   // Read an integral type of specified width.
153   bool Read(uint32_t* val, size_t size);
154   // Reads len bytes into dest buffer, overwriting it.
155   bool Read(DataBuffer* dest, size_t len);
156   bool ReadFromMark(DataBuffer* val, size_t len, size_t mark);
157   // Reads bytes into dest buffer, overwriting it.  The number of bytes is
158   // determined by reading from len_size bytes from the stream first.
159   bool ReadVariable(DataBuffer* dest, size_t len_size);
160 
161   bool Skip(size_t len);
162   bool SkipVariable(size_t len_size);
163 
consumed()164   size_t consumed() const { return offset_; }
remaining()165   size_t remaining() const { return buffer_.len() - offset_; }
166 
167  private:
consume(size_t len)168   void consume(size_t len) { offset_ += len; }
ptr()169   const uint8_t* ptr() const { return buffer_.data() + offset_; }
170 
171   DataBuffer buffer_;
172   size_t offset_;
173 };
174 
175 }  // namespace nss_test
176 
177 #endif
178