1 /*
2 * TLS Handshake IO
3 * (C) 2012,2014,2015 Jack Lloyd
4 *
5 * Botan is released under the Simplified BSD License (see license.txt)
6 */
7 
8 #include <botan/internal/tls_handshake_io.h>
9 #include <botan/internal/tls_record.h>
10 #include <botan/internal/tls_seq_numbers.h>
11 #include <botan/tls_messages.h>
12 #include <botan/exceptn.h>
13 #include <botan/loadstor.h>
14 #include <chrono>
15 
16 namespace Botan {
17 
18 namespace TLS {
19 
20 namespace {
21 
load_be24(const uint8_t q[3])22 inline size_t load_be24(const uint8_t q[3])
23    {
24    return make_uint32(0,
25                       q[0],
26                       q[1],
27                       q[2]);
28    }
29 
store_be24(uint8_t out[3],size_t val)30 void store_be24(uint8_t out[3], size_t val)
31    {
32    out[0] = get_byte(1, static_cast<uint32_t>(val));
33    out[1] = get_byte(2, static_cast<uint32_t>(val));
34    out[2] = get_byte(3, static_cast<uint32_t>(val));
35    }
36 
steady_clock_ms()37 uint64_t steady_clock_ms()
38    {
39    return std::chrono::duration_cast<std::chrono::milliseconds>(
40       std::chrono::steady_clock::now().time_since_epoch()).count();
41    }
42 
43 }
44 
initial_record_version() const45 Protocol_Version Stream_Handshake_IO::initial_record_version() const
46    {
47    return Protocol_Version::TLS_V10;
48    }
49 
add_record(const uint8_t record[],size_t record_len,Record_Type record_type,uint64_t)50 void Stream_Handshake_IO::add_record(const uint8_t record[],
51                                      size_t record_len,
52                                      Record_Type record_type, uint64_t)
53    {
54    if(record_type == HANDSHAKE)
55       {
56       m_queue.insert(m_queue.end(), record, record + record_len);
57       }
58    else if(record_type == CHANGE_CIPHER_SPEC)
59       {
60       if(record_len != 1 || record[0] != 1)
61          throw Decoding_Error("Invalid ChangeCipherSpec");
62 
63       // Pretend it's a regular handshake message of zero length
64       const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
65       m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
66       }
67    else
68       throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
69    }
70 
71 std::pair<Handshake_Type, std::vector<uint8_t>>
get_next_record(bool)72 Stream_Handshake_IO::get_next_record(bool)
73    {
74    if(m_queue.size() >= 4)
75       {
76       const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
77 
78       if(m_queue.size() >= length)
79          {
80          Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
81 
82          if(type == HANDSHAKE_NONE)
83             throw Decoding_Error("Invalid handshake message type");
84 
85          std::vector<uint8_t> contents(m_queue.begin() + 4,
86                                        m_queue.begin() + length);
87 
88          m_queue.erase(m_queue.begin(), m_queue.begin() + length);
89 
90          return std::make_pair(type, contents);
91          }
92       }
93 
94    return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
95    }
96 
97 std::vector<uint8_t>
format(const std::vector<uint8_t> & msg,Handshake_Type type) const98 Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
99                             Handshake_Type type) const
100    {
101    std::vector<uint8_t> send_buf(4 + msg.size());
102 
103    const size_t buf_size = msg.size();
104 
105    send_buf[0] = static_cast<uint8_t>(type);
106 
107    store_be24(&send_buf[1], buf_size);
108 
109    if (msg.size() > 0)
110       {
111       copy_mem(&send_buf[4], msg.data(), msg.size());
112       }
113 
114    return send_buf;
115    }
116 
send_under_epoch(const Handshake_Message &,uint16_t)117 std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
118    {
119    throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
120    }
121 
send(const Handshake_Message & msg)122 std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
123    {
124    const std::vector<uint8_t> msg_bits = msg.serialize();
125 
126    if(msg.type() == HANDSHAKE_CCS)
127       {
128       m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
129       return std::vector<uint8_t>(); // not included in handshake hashes
130       }
131 
132    const std::vector<uint8_t> buf = format(msg_bits, msg.type());
133    m_send_hs(HANDSHAKE, buf);
134    return buf;
135    }
136 
initial_record_version() const137 Protocol_Version Datagram_Handshake_IO::initial_record_version() const
138    {
139    return Protocol_Version::DTLS_V10;
140    }
141 
retransmit_last_flight()142 void Datagram_Handshake_IO::retransmit_last_flight()
143    {
144    const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145    retransmit_flight(flight_idx);
146    }
147 
retransmit_flight(size_t flight_idx)148 void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
149    {
150    const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
151 
152    BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
153 
154    uint16_t epoch = m_flight_data[flight[0]].epoch;
155 
156    for(auto msg_seq : flight)
157       {
158       auto& msg = m_flight_data[msg_seq];
159 
160       if(msg.epoch != epoch)
161          {
162          // Epoch gap: insert the CCS
163          std::vector<uint8_t> ccs(1, 1);
164          m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
165          }
166 
167       send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
168       epoch = msg.epoch;
169       }
170    }
171 
timeout_check()172 bool Datagram_Handshake_IO::timeout_check()
173    {
174    if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
175       {
176       /*
177       If we haven't written anything yet obviously no timeout.
178       Also no timeout possible if we are mid-flight,
179       */
180       return false;
181       }
182 
183    const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
184 
185    if(ms_since_write < m_next_timeout)
186       return false;
187 
188    retransmit_last_flight();
189 
190    m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
191    return true;
192    }
193 
add_record(const uint8_t record[],size_t record_len,Record_Type record_type,uint64_t record_sequence)194 void Datagram_Handshake_IO::add_record(const uint8_t record[],
195                                        size_t record_len,
196                                        Record_Type record_type,
197                                        uint64_t record_sequence)
198    {
199    const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
200 
201    if(record_type == CHANGE_CIPHER_SPEC)
202       {
203       if(record_len != 1 || record[0] != 1)
204          throw Decoding_Error("Invalid ChangeCipherSpec");
205 
206       // TODO: check this is otherwise empty
207       m_ccs_epochs.insert(epoch);
208       return;
209       }
210 
211    const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
212 
213    while(record_len)
214       {
215       if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
216          return; // completely bogus? at least degenerate/weird
217 
218       const uint8_t msg_type = record[0];
219       const size_t msg_len = load_be24(&record[1]);
220       const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
221       const size_t fragment_offset = load_be24(&record[6]);
222       const size_t fragment_length = load_be24(&record[9]);
223 
224       const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
225 
226       if(record_len < total_size)
227          throw Decoding_Error("Bad lengths in DTLS header");
228 
229       if(message_seq >= m_in_message_seq)
230          {
231          m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
232                                               fragment_length,
233                                               fragment_offset,
234                                               epoch,
235                                               msg_type,
236                                               msg_len);
237          }
238       else
239          {
240          // TODO: detect retransmitted flight
241          }
242 
243       record += total_size;
244       record_len -= total_size;
245       }
246    }
247 
248 std::pair<Handshake_Type, std::vector<uint8_t>>
get_next_record(bool expecting_ccs)249 Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
250    {
251    // Expecting a message means the last flight is concluded
252    if(!m_flights.rbegin()->empty())
253       m_flights.push_back(std::vector<uint16_t>());
254 
255    if(expecting_ccs)
256       {
257       if(!m_messages.empty())
258          {
259          const uint16_t current_epoch = m_messages.begin()->second.epoch();
260 
261          if(m_ccs_epochs.count(current_epoch))
262             return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
263          }
264       return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
265       }
266 
267    auto i = m_messages.find(m_in_message_seq);
268 
269    if(i == m_messages.end() || !i->second.complete())
270       {
271       return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
272       }
273 
274    m_in_message_seq += 1;
275 
276    return i->second.message();
277    }
278 
add_fragment(const uint8_t fragment[],size_t fragment_length,size_t fragment_offset,uint16_t epoch,uint8_t msg_type,size_t msg_length)279 void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
280    const uint8_t fragment[],
281    size_t fragment_length,
282    size_t fragment_offset,
283    uint16_t epoch,
284    uint8_t msg_type,
285    size_t msg_length)
286    {
287    if(complete())
288       return; // already have entire message, ignore this
289 
290    if(m_msg_type == HANDSHAKE_NONE)
291       {
292       m_epoch = epoch;
293       m_msg_type = msg_type;
294       m_msg_length = msg_length;
295       }
296 
297    if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
298       throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
299 
300    if(fragment_offset > m_msg_length)
301       throw Decoding_Error("Fragment offset past end of message");
302 
303    if(fragment_offset + fragment_length > m_msg_length)
304       throw Decoding_Error("Fragment overlaps past end of message");
305 
306    if(fragment_offset == 0 && fragment_length == m_msg_length)
307       {
308       m_fragments.clear();
309       m_message.assign(fragment, fragment+fragment_length);
310       }
311    else
312       {
313       /*
314       * FIXME. This is a pretty lame way to do defragmentation, huge
315       * overhead with a tree node per byte.
316       *
317       * Also should confirm that all overlaps have no changes,
318       * otherwise we expose ourselves to the classic fingerprinting
319       * and IDS evasion attacks on IP fragmentation.
320       */
321       for(size_t i = 0; i != fragment_length; ++i)
322          m_fragments[fragment_offset+i] = fragment[i];
323 
324       if(m_fragments.size() == m_msg_length)
325          {
326          m_message.resize(m_msg_length);
327          for(size_t i = 0; i != m_msg_length; ++i)
328             m_message[i] = m_fragments[i];
329          m_fragments.clear();
330          }
331       }
332    }
333 
complete() const334 bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
335    {
336    return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
337    }
338 
339 std::pair<Handshake_Type, std::vector<uint8_t>>
message() const340 Datagram_Handshake_IO::Handshake_Reassembly::message() const
341    {
342    if(!complete())
343       throw Internal_Error("Datagram_Handshake_IO - message not complete");
344 
345    return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
346    }
347 
348 std::vector<uint8_t>
format_fragment(const uint8_t fragment[],size_t frag_len,uint16_t frag_offset,uint16_t msg_len,Handshake_Type type,uint16_t msg_sequence) const349 Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
350                                        size_t frag_len,
351                                        uint16_t frag_offset,
352                                        uint16_t msg_len,
353                                        Handshake_Type type,
354                                        uint16_t msg_sequence) const
355    {
356    std::vector<uint8_t> send_buf(12 + frag_len);
357 
358    send_buf[0] = static_cast<uint8_t>(type);
359 
360    store_be24(&send_buf[1], msg_len);
361 
362    store_be(msg_sequence, &send_buf[4]);
363 
364    store_be24(&send_buf[6], frag_offset);
365    store_be24(&send_buf[9], frag_len);
366 
367    if (frag_len > 0)
368       {
369       copy_mem(&send_buf[12], fragment, frag_len);
370       }
371 
372    return send_buf;
373    }
374 
375 std::vector<uint8_t>
format_w_seq(const std::vector<uint8_t> & msg,Handshake_Type type,uint16_t msg_sequence) const376 Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
377                                     Handshake_Type type,
378                                     uint16_t msg_sequence) const
379    {
380    return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
381    }
382 
383 std::vector<uint8_t>
format(const std::vector<uint8_t> & msg,Handshake_Type type) const384 Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
385                               Handshake_Type type) const
386    {
387    return format_w_seq(msg, type, m_in_message_seq - 1);
388    }
389 
send(const Handshake_Message & msg)390 std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
391    {
392    return this->send_under_epoch(msg, m_seqs.current_write_epoch());
393    }
394 
395 std::vector<uint8_t>
send_under_epoch(const Handshake_Message & msg,uint16_t epoch)396 Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
397    {
398    const std::vector<uint8_t> msg_bits = msg.serialize();
399    const Handshake_Type msg_type = msg.type();
400 
401    if(msg_type == HANDSHAKE_CCS)
402       {
403       m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
404       return std::vector<uint8_t>(); // not included in handshake hashes
405       }
406    else if(msg_type == HELLO_VERIFY_REQUEST)
407       {
408       // This message is not included in the handshake hashes
409       send_message(m_out_message_seq, epoch, msg_type, msg_bits);
410       m_out_message_seq += 1;
411       return std::vector<uint8_t>();
412       }
413 
414    // Note: not saving CCS, instead we know it was there due to change in epoch
415    m_flights.rbegin()->push_back(m_out_message_seq);
416    m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
417 
418    m_out_message_seq += 1;
419    m_last_write = steady_clock_ms();
420    m_next_timeout = m_initial_timeout;
421 
422    return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
423    }
424 
send_message(uint16_t msg_seq,uint16_t epoch,Handshake_Type msg_type,const std::vector<uint8_t> & msg_bits)425 std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
426                                                       uint16_t epoch,
427                                                       Handshake_Type msg_type,
428                                                       const std::vector<uint8_t>& msg_bits)
429    {
430    const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
431 
432    const std::vector<uint8_t> no_fragment =
433       format_w_seq(msg_bits, msg_type, msg_seq);
434 
435    if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
436       {
437       m_send_hs(epoch, HANDSHAKE, no_fragment);
438       }
439    else
440       {
441       size_t frag_offset = 0;
442 
443       /**
444       * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
445       * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
446       * over-estimate here. When CBC ciphers are removed this can be reduced
447       * since AEAD modes have no padding, at most 16 byte mac, and smaller
448       * per-record nonce.
449       */
450       const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
451       const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
452 
453       if(m_mtu <= (header_overhead + ciphersuite_overhead))
454          throw Invalid_Argument("DTLS MTU is too small to send headers");
455 
456       const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
457 
458       while(frag_offset != msg_bits.size())
459          {
460          const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
461 
462          const std::vector<uint8_t> frag =
463             format_fragment(&msg_bits[frag_offset],
464                             frag_len,
465                             static_cast<uint16_t>(frag_offset),
466                             static_cast<uint16_t>(msg_bits.size()),
467                             msg_type,
468                             msg_seq);
469 
470          m_send_hs(epoch, HANDSHAKE, frag);
471 
472          frag_offset += frag_len;
473          }
474       }
475 
476    return no_fragment;
477    }
478 
479 }
480 }
481