1 /*
2 * TLS Handshake IO
3 * (C) 2012,2014,2015 Jack Lloyd
4 *
5 * Botan is released under the Simplified BSD License (see license.txt)
6 */
7
8 #include <botan/internal/tls_handshake_io.h>
9 #include <botan/internal/tls_record.h>
10 #include <botan/internal/tls_seq_numbers.h>
11 #include <botan/tls_messages.h>
12 #include <botan/exceptn.h>
13 #include <botan/loadstor.h>
14 #include <chrono>
15
16 namespace Botan {
17
18 namespace TLS {
19
20 namespace {
21
load_be24(const uint8_t q[3])22 inline size_t load_be24(const uint8_t q[3])
23 {
24 return make_uint32(0,
25 q[0],
26 q[1],
27 q[2]);
28 }
29
store_be24(uint8_t out[3],size_t val)30 void store_be24(uint8_t out[3], size_t val)
31 {
32 out[0] = get_byte(1, static_cast<uint32_t>(val));
33 out[1] = get_byte(2, static_cast<uint32_t>(val));
34 out[2] = get_byte(3, static_cast<uint32_t>(val));
35 }
36
steady_clock_ms()37 uint64_t steady_clock_ms()
38 {
39 return std::chrono::duration_cast<std::chrono::milliseconds>(
40 std::chrono::steady_clock::now().time_since_epoch()).count();
41 }
42
43 }
44
initial_record_version() const45 Protocol_Version Stream_Handshake_IO::initial_record_version() const
46 {
47 return Protocol_Version::TLS_V10;
48 }
49
add_record(const uint8_t record[],size_t record_len,Record_Type record_type,uint64_t)50 void Stream_Handshake_IO::add_record(const uint8_t record[],
51 size_t record_len,
52 Record_Type record_type, uint64_t)
53 {
54 if(record_type == HANDSHAKE)
55 {
56 m_queue.insert(m_queue.end(), record, record + record_len);
57 }
58 else if(record_type == CHANGE_CIPHER_SPEC)
59 {
60 if(record_len != 1 || record[0] != 1)
61 throw Decoding_Error("Invalid ChangeCipherSpec");
62
63 // Pretend it's a regular handshake message of zero length
64 const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
65 m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
66 }
67 else
68 throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
69 }
70
71 std::pair<Handshake_Type, std::vector<uint8_t>>
get_next_record(bool)72 Stream_Handshake_IO::get_next_record(bool)
73 {
74 if(m_queue.size() >= 4)
75 {
76 const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
77
78 if(m_queue.size() >= length)
79 {
80 Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
81
82 if(type == HANDSHAKE_NONE)
83 throw Decoding_Error("Invalid handshake message type");
84
85 std::vector<uint8_t> contents(m_queue.begin() + 4,
86 m_queue.begin() + length);
87
88 m_queue.erase(m_queue.begin(), m_queue.begin() + length);
89
90 return std::make_pair(type, contents);
91 }
92 }
93
94 return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
95 }
96
97 std::vector<uint8_t>
format(const std::vector<uint8_t> & msg,Handshake_Type type) const98 Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
99 Handshake_Type type) const
100 {
101 std::vector<uint8_t> send_buf(4 + msg.size());
102
103 const size_t buf_size = msg.size();
104
105 send_buf[0] = static_cast<uint8_t>(type);
106
107 store_be24(&send_buf[1], buf_size);
108
109 if (msg.size() > 0)
110 {
111 copy_mem(&send_buf[4], msg.data(), msg.size());
112 }
113
114 return send_buf;
115 }
116
send_under_epoch(const Handshake_Message &,uint16_t)117 std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
118 {
119 throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
120 }
121
send(const Handshake_Message & msg)122 std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
123 {
124 const std::vector<uint8_t> msg_bits = msg.serialize();
125
126 if(msg.type() == HANDSHAKE_CCS)
127 {
128 m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
129 return std::vector<uint8_t>(); // not included in handshake hashes
130 }
131
132 const std::vector<uint8_t> buf = format(msg_bits, msg.type());
133 m_send_hs(HANDSHAKE, buf);
134 return buf;
135 }
136
initial_record_version() const137 Protocol_Version Datagram_Handshake_IO::initial_record_version() const
138 {
139 return Protocol_Version::DTLS_V10;
140 }
141
retransmit_last_flight()142 void Datagram_Handshake_IO::retransmit_last_flight()
143 {
144 const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145 retransmit_flight(flight_idx);
146 }
147
retransmit_flight(size_t flight_idx)148 void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
149 {
150 const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
151
152 BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
153
154 uint16_t epoch = m_flight_data[flight[0]].epoch;
155
156 for(auto msg_seq : flight)
157 {
158 auto& msg = m_flight_data[msg_seq];
159
160 if(msg.epoch != epoch)
161 {
162 // Epoch gap: insert the CCS
163 std::vector<uint8_t> ccs(1, 1);
164 m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
165 }
166
167 send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
168 epoch = msg.epoch;
169 }
170 }
171
timeout_check()172 bool Datagram_Handshake_IO::timeout_check()
173 {
174 if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
175 {
176 /*
177 If we haven't written anything yet obviously no timeout.
178 Also no timeout possible if we are mid-flight,
179 */
180 return false;
181 }
182
183 const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
184
185 if(ms_since_write < m_next_timeout)
186 return false;
187
188 retransmit_last_flight();
189
190 m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
191 return true;
192 }
193
add_record(const uint8_t record[],size_t record_len,Record_Type record_type,uint64_t record_sequence)194 void Datagram_Handshake_IO::add_record(const uint8_t record[],
195 size_t record_len,
196 Record_Type record_type,
197 uint64_t record_sequence)
198 {
199 const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
200
201 if(record_type == CHANGE_CIPHER_SPEC)
202 {
203 if(record_len != 1 || record[0] != 1)
204 throw Decoding_Error("Invalid ChangeCipherSpec");
205
206 // TODO: check this is otherwise empty
207 m_ccs_epochs.insert(epoch);
208 return;
209 }
210
211 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
212
213 while(record_len)
214 {
215 if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
216 return; // completely bogus? at least degenerate/weird
217
218 const uint8_t msg_type = record[0];
219 const size_t msg_len = load_be24(&record[1]);
220 const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
221 const size_t fragment_offset = load_be24(&record[6]);
222 const size_t fragment_length = load_be24(&record[9]);
223
224 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
225
226 if(record_len < total_size)
227 throw Decoding_Error("Bad lengths in DTLS header");
228
229 if(message_seq >= m_in_message_seq)
230 {
231 m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
232 fragment_length,
233 fragment_offset,
234 epoch,
235 msg_type,
236 msg_len);
237 }
238 else
239 {
240 // TODO: detect retransmitted flight
241 }
242
243 record += total_size;
244 record_len -= total_size;
245 }
246 }
247
248 std::pair<Handshake_Type, std::vector<uint8_t>>
get_next_record(bool expecting_ccs)249 Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
250 {
251 // Expecting a message means the last flight is concluded
252 if(!m_flights.rbegin()->empty())
253 m_flights.push_back(std::vector<uint16_t>());
254
255 if(expecting_ccs)
256 {
257 if(!m_messages.empty())
258 {
259 const uint16_t current_epoch = m_messages.begin()->second.epoch();
260
261 if(m_ccs_epochs.count(current_epoch))
262 return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
263 }
264 return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
265 }
266
267 auto i = m_messages.find(m_in_message_seq);
268
269 if(i == m_messages.end() || !i->second.complete())
270 {
271 return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
272 }
273
274 m_in_message_seq += 1;
275
276 return i->second.message();
277 }
278
add_fragment(const uint8_t fragment[],size_t fragment_length,size_t fragment_offset,uint16_t epoch,uint8_t msg_type,size_t msg_length)279 void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
280 const uint8_t fragment[],
281 size_t fragment_length,
282 size_t fragment_offset,
283 uint16_t epoch,
284 uint8_t msg_type,
285 size_t msg_length)
286 {
287 if(complete())
288 return; // already have entire message, ignore this
289
290 if(m_msg_type == HANDSHAKE_NONE)
291 {
292 m_epoch = epoch;
293 m_msg_type = msg_type;
294 m_msg_length = msg_length;
295 }
296
297 if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
298 throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
299
300 if(fragment_offset > m_msg_length)
301 throw Decoding_Error("Fragment offset past end of message");
302
303 if(fragment_offset + fragment_length > m_msg_length)
304 throw Decoding_Error("Fragment overlaps past end of message");
305
306 if(fragment_offset == 0 && fragment_length == m_msg_length)
307 {
308 m_fragments.clear();
309 m_message.assign(fragment, fragment+fragment_length);
310 }
311 else
312 {
313 /*
314 * FIXME. This is a pretty lame way to do defragmentation, huge
315 * overhead with a tree node per byte.
316 *
317 * Also should confirm that all overlaps have no changes,
318 * otherwise we expose ourselves to the classic fingerprinting
319 * and IDS evasion attacks on IP fragmentation.
320 */
321 for(size_t i = 0; i != fragment_length; ++i)
322 m_fragments[fragment_offset+i] = fragment[i];
323
324 if(m_fragments.size() == m_msg_length)
325 {
326 m_message.resize(m_msg_length);
327 for(size_t i = 0; i != m_msg_length; ++i)
328 m_message[i] = m_fragments[i];
329 m_fragments.clear();
330 }
331 }
332 }
333
complete() const334 bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
335 {
336 return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
337 }
338
339 std::pair<Handshake_Type, std::vector<uint8_t>>
message() const340 Datagram_Handshake_IO::Handshake_Reassembly::message() const
341 {
342 if(!complete())
343 throw Internal_Error("Datagram_Handshake_IO - message not complete");
344
345 return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
346 }
347
348 std::vector<uint8_t>
format_fragment(const uint8_t fragment[],size_t frag_len,uint16_t frag_offset,uint16_t msg_len,Handshake_Type type,uint16_t msg_sequence) const349 Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
350 size_t frag_len,
351 uint16_t frag_offset,
352 uint16_t msg_len,
353 Handshake_Type type,
354 uint16_t msg_sequence) const
355 {
356 std::vector<uint8_t> send_buf(12 + frag_len);
357
358 send_buf[0] = static_cast<uint8_t>(type);
359
360 store_be24(&send_buf[1], msg_len);
361
362 store_be(msg_sequence, &send_buf[4]);
363
364 store_be24(&send_buf[6], frag_offset);
365 store_be24(&send_buf[9], frag_len);
366
367 if (frag_len > 0)
368 {
369 copy_mem(&send_buf[12], fragment, frag_len);
370 }
371
372 return send_buf;
373 }
374
375 std::vector<uint8_t>
format_w_seq(const std::vector<uint8_t> & msg,Handshake_Type type,uint16_t msg_sequence) const376 Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
377 Handshake_Type type,
378 uint16_t msg_sequence) const
379 {
380 return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
381 }
382
383 std::vector<uint8_t>
format(const std::vector<uint8_t> & msg,Handshake_Type type) const384 Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
385 Handshake_Type type) const
386 {
387 return format_w_seq(msg, type, m_in_message_seq - 1);
388 }
389
send(const Handshake_Message & msg)390 std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
391 {
392 return this->send_under_epoch(msg, m_seqs.current_write_epoch());
393 }
394
395 std::vector<uint8_t>
send_under_epoch(const Handshake_Message & msg,uint16_t epoch)396 Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
397 {
398 const std::vector<uint8_t> msg_bits = msg.serialize();
399 const Handshake_Type msg_type = msg.type();
400
401 if(msg_type == HANDSHAKE_CCS)
402 {
403 m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
404 return std::vector<uint8_t>(); // not included in handshake hashes
405 }
406 else if(msg_type == HELLO_VERIFY_REQUEST)
407 {
408 // This message is not included in the handshake hashes
409 send_message(m_out_message_seq, epoch, msg_type, msg_bits);
410 m_out_message_seq += 1;
411 return std::vector<uint8_t>();
412 }
413
414 // Note: not saving CCS, instead we know it was there due to change in epoch
415 m_flights.rbegin()->push_back(m_out_message_seq);
416 m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
417
418 m_out_message_seq += 1;
419 m_last_write = steady_clock_ms();
420 m_next_timeout = m_initial_timeout;
421
422 return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
423 }
424
send_message(uint16_t msg_seq,uint16_t epoch,Handshake_Type msg_type,const std::vector<uint8_t> & msg_bits)425 std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
426 uint16_t epoch,
427 Handshake_Type msg_type,
428 const std::vector<uint8_t>& msg_bits)
429 {
430 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
431
432 const std::vector<uint8_t> no_fragment =
433 format_w_seq(msg_bits, msg_type, msg_seq);
434
435 if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
436 {
437 m_send_hs(epoch, HANDSHAKE, no_fragment);
438 }
439 else
440 {
441 size_t frag_offset = 0;
442
443 /**
444 * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
445 * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
446 * over-estimate here. When CBC ciphers are removed this can be reduced
447 * since AEAD modes have no padding, at most 16 byte mac, and smaller
448 * per-record nonce.
449 */
450 const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
451 const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
452
453 if(m_mtu <= (header_overhead + ciphersuite_overhead))
454 throw Invalid_Argument("DTLS MTU is too small to send headers");
455
456 const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
457
458 while(frag_offset != msg_bits.size())
459 {
460 const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
461
462 const std::vector<uint8_t> frag =
463 format_fragment(&msg_bits[frag_offset],
464 frag_len,
465 static_cast<uint16_t>(frag_offset),
466 static_cast<uint16_t>(msg_bits.size()),
467 msg_type,
468 msg_seq);
469
470 m_send_hs(epoch, HANDSHAKE, frag);
471
472 frag_offset += frag_len;
473 }
474 }
475
476 return no_fragment;
477 }
478
479 }
480 }
481