1# This file is part of Scapy
2# Copyright (C) 2017 Maxence Tury
3#               2019 Romain Perez
4# This program is published under a GPLv2 license
5
6"""
7TLS 1.3 key exchange logic.
8"""
9
10import struct
11
12from scapy.config import conf, crypto_validator
13from scapy.error import log_runtime
14from scapy.fields import FieldLenField, IntField, PacketField, \
15    PacketListField, ShortEnumField, ShortField, StrFixedLenField, \
16    StrLenField
17from scapy.packet import Packet, Padding
18from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext
19from scapy.layers.tls.crypto.groups import (
20    _tls_named_curves,
21    _tls_named_ffdh_groups,
22    _tls_named_groups,
23    _tls_named_groups_generate,
24    _tls_named_groups_import,
25    _tls_named_groups_pubbytes,
26)
27import scapy.modules.six as six
28
29if conf.crypto_valid:
30    from cryptography.hazmat.primitives.asymmetric import ec
31
32
33class KeyShareEntry(Packet):
34    """
35    When building from scratch, we create a DH private key, and when
36    dissecting, we create a DH public key. Default group is secp256r1.
37    """
38    __slots__ = ["privkey", "pubkey"]
39    name = "Key Share Entry"
40    fields_desc = [ShortEnumField("group", None, _tls_named_groups),
41                   FieldLenField("kxlen", None, length_of="key_exchange"),
42                   StrLenField("key_exchange", "",
43                               length_from=lambda pkt: pkt.kxlen)]
44
45    def __init__(self, *args, **kargs):
46        self.privkey = None
47        self.pubkey = None
48        super(KeyShareEntry, self).__init__(*args, **kargs)
49
50    def do_build(self):
51        """
52        We need this hack, else 'self' would be replaced by __iter__.next().
53        """
54        tmp = self.explicit
55        self.explicit = True
56        b = super(KeyShareEntry, self).do_build()
57        self.explicit = tmp
58        return b
59
60    @crypto_validator
61    def create_privkey(self):
62        """
63        This is called by post_build() for key creation.
64        """
65        self.privkey = _tls_named_groups_generate(self.group)
66        self.key_exchange = _tls_named_groups_pubbytes(self.privkey)
67
68    def post_build(self, pkt, pay):
69        if self.group is None:
70            self.group = 23     # secp256r1
71
72        if not self.key_exchange:
73            try:
74                self.create_privkey()
75            except ImportError:
76                pass
77
78        if self.kxlen is None:
79            self.kxlen = len(self.key_exchange)
80
81        group = struct.pack("!H", self.group)
82        kxlen = struct.pack("!H", self.kxlen)
83        return group + kxlen + self.key_exchange + pay
84
85    @crypto_validator
86    def register_pubkey(self):
87        self.pubkey = _tls_named_groups_import(
88            self.group,
89            self.key_exchange
90        )
91
92    def post_dissection(self, r):
93        try:
94            self.register_pubkey()
95        except ImportError:
96            pass
97
98    def extract_padding(self, s):
99        return "", s
100
101
102class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown):
103    name = "TLS Extension - Key Share (for ClientHello)"
104    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
105                   ShortField("len", None),
106                   FieldLenField("client_shares_len", None,
107                                 length_of="client_shares"),
108                   PacketListField("client_shares", [], KeyShareEntry,
109                                   length_from=lambda pkt: pkt.client_shares_len)]  # noqa: E501
110
111    def post_build(self, pkt, pay):
112        if not self.tls_session.frozen:
113            privshares = self.tls_session.tls13_client_privshares
114            for kse in self.client_shares:
115                if kse.privkey:
116                    if _tls_named_curves[kse.group] in privshares:
117                        pkt_info = pkt.firstlayer().summary()
118                        log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)  # noqa: E501
119                        break
120                    privshares[_tls_named_groups[kse.group]] = kse.privkey
121        return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay)
122
123    def post_dissection(self, r):
124        if not self.tls_session.frozen:
125            for kse in self.client_shares:
126                if kse.pubkey:
127                    pubshares = self.tls_session.tls13_client_pubshares
128                    if _tls_named_curves[kse.group] in pubshares:
129                        pkt_info = r.firstlayer().summary()
130                        log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)  # noqa: E501
131                        break
132                    pubshares[_tls_named_curves[kse.group]] = kse.pubkey
133        return super(TLS_Ext_KeyShare_CH, self).post_dissection(r)
134
135
136class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown):
137    name = "TLS Extension - Key Share (for HelloRetryRequest)"
138    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
139                   ShortField("len", None),
140                   ShortEnumField("selected_group", None, _tls_named_groups)]
141
142
143class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown):
144    name = "TLS Extension - Key Share (for ServerHello)"
145    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
146                   ShortField("len", None),
147                   PacketField("server_share", None, KeyShareEntry)]
148
149    def post_build(self, pkt, pay):
150        if not self.tls_session.frozen and self.server_share.privkey:
151            # if there is a privkey, we assume the crypto library is ok
152            privshare = self.tls_session.tls13_server_privshare
153            if len(privshare) > 0:
154                pkt_info = pkt.firstlayer().summary()
155                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)  # noqa: E501
156            group_name = _tls_named_groups[self.server_share.group]
157            privshare[group_name] = self.server_share.privkey
158
159            if group_name in self.tls_session.tls13_client_pubshares:
160                privkey = self.server_share.privkey
161                pubkey = self.tls_session.tls13_client_pubshares[group_name]
162                if group_name in six.itervalues(_tls_named_ffdh_groups):
163                    pms = privkey.exchange(pubkey)
164                elif group_name in six.itervalues(_tls_named_curves):
165                    if group_name in ["x25519", "x448"]:
166                        pms = privkey.exchange(pubkey)
167                    else:
168                        pms = privkey.exchange(ec.ECDH(), pubkey)
169                self.tls_session.tls13_dhe_secret = pms
170        return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay)
171
172    def post_dissection(self, r):
173        if not self.tls_session.frozen and self.server_share.pubkey:
174            # if there is a pubkey, we assume the crypto library is ok
175            pubshare = self.tls_session.tls13_server_pubshare
176            if pubshare:
177                pkt_info = r.firstlayer().summary()
178                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)  # noqa: E501
179            group_name = _tls_named_groups[self.server_share.group]
180            pubshare[group_name] = self.server_share.pubkey
181
182            if group_name in self.tls_session.tls13_client_privshares:
183                pubkey = self.server_share.pubkey
184                privkey = self.tls_session.tls13_client_privshares[group_name]
185                if group_name in six.itervalues(_tls_named_ffdh_groups):
186                    pms = privkey.exchange(pubkey)
187                elif group_name in six.itervalues(_tls_named_curves):
188                    if group_name in ["x25519", "x448"]:
189                        pms = privkey.exchange(pubkey)
190                    else:
191                        pms = privkey.exchange(ec.ECDH(), pubkey)
192                self.tls_session.tls13_dhe_secret = pms
193            elif group_name in self.tls_session.tls13_server_privshare:
194                pubkey = self.tls_session.tls13_client_pubshares[group_name]
195                privkey = self.tls_session.tls13_server_privshare[group_name]
196                if group_name in six.itervalues(_tls_named_ffdh_groups):
197                    pms = privkey.exchange(pubkey)
198                elif group_name in six.itervalues(_tls_named_curves):
199                    if group_name in ["x25519", "x448"]:
200                        pms = privkey.exchange(pubkey)
201                    else:
202                        pms = privkey.exchange(ec.ECDH(), pubkey)
203                self.tls_session.tls13_dhe_secret = pms
204        return super(TLS_Ext_KeyShare_SH, self).post_dissection(r)
205
206
207_tls_ext_keyshare_cls = {1: TLS_Ext_KeyShare_CH,
208                         2: TLS_Ext_KeyShare_SH}
209
210_tls_ext_keyshare_hrr_cls = {2: TLS_Ext_KeyShare_HRR}
211
212
213class Ticket(Packet):
214    name = "Recommended Ticket Construction (from RFC 5077)"
215    fields_desc = [StrFixedLenField("key_name", None, 16),
216                   StrFixedLenField("iv", None, 16),
217                   FieldLenField("encstatelen", None, length_of="encstate"),
218                   StrLenField("encstate", "",
219                               length_from=lambda pkt: pkt.encstatelen),
220                   StrFixedLenField("mac", None, 32)]
221
222
223class TicketField(PacketField):
224    __slots__ = ["length_from"]
225
226    def __init__(self, name, default, length_from=None, **kargs):
227        self.length_from = length_from
228        PacketField.__init__(self, name, default, Ticket, **kargs)
229
230    def m2i(self, pkt, m):
231        tmp_len = self.length_from(pkt)
232        tbd, rem = m[:tmp_len], m[tmp_len:]
233        return self.cls(tbd) / Padding(rem)
234
235
236class PSKIdentity(Packet):
237    name = "PSK Identity"
238    fields_desc = [FieldLenField("identity_len", None,
239                                 length_of="identity"),
240                   TicketField("identity", "",
241                               length_from=lambda pkt: pkt.identity_len),
242                   IntField("obfuscated_ticket_age", 0)]
243
244
245class PSKBinderEntry(Packet):
246    name = "PSK Binder Entry"
247    fields_desc = [FieldLenField("binder_len", None, fmt="B",
248                                 length_of="binder"),
249                   StrLenField("binder", "",
250                               length_from=lambda pkt: pkt.binder_len)]
251
252
253class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown):
254    # XXX define post_build and post_dissection methods
255    name = "TLS Extension - Pre Shared Key (for ClientHello)"
256    fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
257                   ShortField("len", None),
258                   FieldLenField("identities_len", None,
259                                 length_of="identities"),
260                   PacketListField("identities", [], PSKIdentity,
261                                   length_from=lambda pkt: pkt.identities_len),
262                   FieldLenField("binders_len", None,
263                                 length_of="binders"),
264                   PacketListField("binders", [], PSKBinderEntry,
265                                   length_from=lambda pkt: pkt.binders_len)]
266
267
268class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown):
269    name = "TLS Extension - Pre Shared Key (for ServerHello)"
270    fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
271                   ShortField("len", None),
272                   ShortField("selected_identity", None)]
273
274
275_tls_ext_presharedkey_cls = {1: TLS_Ext_PreSharedKey_CH,
276                             2: TLS_Ext_PreSharedKey_SH}
277