1from __future__ import absolute_import
2from __future__ import print_function
3
4import struct
5from enum import IntEnum
6
7from tls_parser.byte_utils import int_to_bytes
8from tls_parser.tls_version import TlsVersionEnum
9
10from tls_parser.exceptions import NotEnoughData, UnknownTypeByte
11from tls_parser.record_protocol import TlsSubprotocolMessage, TlsRecord, TlsRecordHeader, TlsRecordTypeByte
12from typing import Tuple, List
13
14
15class TlsHandshakeTypeByte(IntEnum):
16   HELLO_REQUEST = 0x00
17   CLIENT_HELLO = 0x01
18   SERVER_HELLO = 0x02
19   CERTIFICATE = 0x0b
20   SERVER_KEY_EXCHANGE = 0x0c
21   CERTIFICATE_REQUEST = 0x0d
22   SERVER_DONE = 0x0e
23   CERTIFICATE_VERIFY = 0x0f
24   CLIENT_KEY_EXCHANGE = 0x10
25   FINISHED = 0x14
26
27
28class TlsHandshakeMessage(TlsSubprotocolMessage):
29    """The payload of a handshake record.
30    """
31
32    def __init__(self, handshake_type, handshake_data):
33        # type: (TlsHandshakeTypeByte, bytes) -> None
34        self.handshake_type = handshake_type
35        self.handshake_data = handshake_data
36
37    @classmethod
38    def from_bytes(cls, raw_bytes):
39        # type: (bytes) -> Tuple[TlsHandshakeMessage, int]
40        if len(raw_bytes) < 4:
41            raise NotEnoughData()
42
43        handshake_type = TlsHandshakeTypeByte(struct.unpack('B', raw_bytes[0:1])[0])
44        message_length = struct.unpack('!I', b'\x00' + raw_bytes[1:4])[0]
45        message = raw_bytes[4:message_length+4]
46        if len(message) < message_length:
47            raise NotEnoughData()
48
49        return TlsHandshakeMessage(handshake_type, message), 4 + message_length
50
51    def to_bytes(self):
52        # type: () -> bytes
53        bytes = b''
54        # TLS Handshake type - 1 byte
55        bytes += struct.pack('B', self.handshake_type.value)
56        # TLS Handshake length - 3 bytes
57        bytes += struct.pack('!I', len(self.handshake_data))[1:4]  # We only keep the first 3 out of 4 bytes
58        # TLS Handshake message
59        bytes += self.handshake_data
60        return bytes
61
62
63class TlsHandshakeRecord(TlsRecord):
64
65    def __init__(self, record_header, handshake_messages):
66        # type: (TlsRecordHeader, List[TlsHandshakeMessage]) -> None
67        super(TlsHandshakeRecord, self).__init__(record_header, handshake_messages)
68
69    @classmethod
70    def from_bytes(cls, raw_bytes):
71        # type: (bytes) -> Tuple[TlsHandshakeRecord, int]
72        header, len_consumed_for_header = TlsRecordHeader.from_bytes(raw_bytes)
73        remaining_bytes = raw_bytes[len_consumed_for_header::]
74
75        if header.type != TlsRecordTypeByte.HANDSHAKE:
76            raise UnknownTypeByte()
77
78        # Try to parse the handshake record - there may be multiple messages packed in the record
79        messages = []
80        total_len_consumed_for_messages = 0
81        while total_len_consumed_for_messages != header.length:
82            message, len_consumed_for_message = TlsHandshakeMessage.from_bytes(remaining_bytes)
83            messages.append(message)
84            total_len_consumed_for_messages += len_consumed_for_message
85            remaining_bytes = remaining_bytes[len_consumed_for_message::]
86
87        parsed_record = TlsHandshakeRecord(header, messages)
88        return parsed_record, len_consumed_for_header + total_len_consumed_for_messages
89
90
91class TlsRsaClientKeyExchangeRecord(TlsHandshakeRecord):
92
93    @classmethod
94    def from_parameters(cls, tls_version, public_exponent, public_modulus, pre_master_secret_with_padding):
95        # type: (TlsVersionEnum, int, int, int) -> TlsHandshakeRecord
96        cke_bytes = b''
97
98        # Encrypt the pre_master_secret
99        encrypted_pms = pow(pre_master_secret_with_padding, public_exponent, public_modulus)
100        # Add it to the message but pad it so that its length is the same as the length of the modulus
101        modulus_length = len(int_to_bytes(public_modulus))
102        encrypted_pms_bytes = int_to_bytes(encrypted_pms, expected_length=modulus_length)
103
104        # Per RFC 5246: the RSA-encrypted PreMasterSecret in a ClientKeyExchange is preceded by two length bytes
105        # These bytes are redundant in the case of RSA because the EncryptedPreMasterSecret is the only data in the
106        # ClientKeyExchange
107        msg_size = struct.pack('!I', len(encrypted_pms_bytes))[2:4]  # Length is two bytes
108        cke_bytes += msg_size
109        cke_bytes += encrypted_pms_bytes
110        msg = TlsHandshakeMessage(TlsHandshakeTypeByte.CLIENT_KEY_EXCHANGE, cke_bytes)
111
112        # Build the header
113        header = TlsRecordHeader(TlsRecordTypeByte.HANDSHAKE, tls_version, len(msg.to_bytes()))
114        return TlsRsaClientKeyExchangeRecord(header, [msg])
115