1# This file is part of Scapy 2# Copyright (C) 2017 Maxence Tury 3# 2019 Romain Perez 4# This program is published under a GPLv2 license 5 6""" 7TLS 1.3 key exchange logic. 8""" 9 10import struct 11 12from scapy.config import conf, crypto_validator 13from scapy.error import log_runtime 14from scapy.fields import FieldLenField, IntField, PacketField, \ 15 PacketListField, ShortEnumField, ShortField, StrFixedLenField, \ 16 StrLenField 17from scapy.packet import Packet, Padding 18from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext 19from scapy.layers.tls.crypto.groups import ( 20 _tls_named_curves, 21 _tls_named_ffdh_groups, 22 _tls_named_groups, 23 _tls_named_groups_generate, 24 _tls_named_groups_import, 25 _tls_named_groups_pubbytes, 26) 27import scapy.modules.six as six 28 29if conf.crypto_valid: 30 from cryptography.hazmat.primitives.asymmetric import ec 31 32 33class KeyShareEntry(Packet): 34 """ 35 When building from scratch, we create a DH private key, and when 36 dissecting, we create a DH public key. Default group is secp256r1. 37 """ 38 __slots__ = ["privkey", "pubkey"] 39 name = "Key Share Entry" 40 fields_desc = [ShortEnumField("group", None, _tls_named_groups), 41 FieldLenField("kxlen", None, length_of="key_exchange"), 42 StrLenField("key_exchange", "", 43 length_from=lambda pkt: pkt.kxlen)] 44 45 def __init__(self, *args, **kargs): 46 self.privkey = None 47 self.pubkey = None 48 super(KeyShareEntry, self).__init__(*args, **kargs) 49 50 def do_build(self): 51 """ 52 We need this hack, else 'self' would be replaced by __iter__.next(). 53 """ 54 tmp = self.explicit 55 self.explicit = True 56 b = super(KeyShareEntry, self).do_build() 57 self.explicit = tmp 58 return b 59 60 @crypto_validator 61 def create_privkey(self): 62 """ 63 This is called by post_build() for key creation. 64 """ 65 self.privkey = _tls_named_groups_generate(self.group) 66 self.key_exchange = _tls_named_groups_pubbytes(self.privkey) 67 68 def post_build(self, pkt, pay): 69 if self.group is None: 70 self.group = 23 # secp256r1 71 72 if not self.key_exchange: 73 try: 74 self.create_privkey() 75 except ImportError: 76 pass 77 78 if self.kxlen is None: 79 self.kxlen = len(self.key_exchange) 80 81 group = struct.pack("!H", self.group) 82 kxlen = struct.pack("!H", self.kxlen) 83 return group + kxlen + self.key_exchange + pay 84 85 @crypto_validator 86 def register_pubkey(self): 87 self.pubkey = _tls_named_groups_import( 88 self.group, 89 self.key_exchange 90 ) 91 92 def post_dissection(self, r): 93 try: 94 self.register_pubkey() 95 except ImportError: 96 pass 97 98 def extract_padding(self, s): 99 return "", s 100 101 102class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown): 103 name = "TLS Extension - Key Share (for ClientHello)" 104 fields_desc = [ShortEnumField("type", 0x33, _tls_ext), 105 ShortField("len", None), 106 FieldLenField("client_shares_len", None, 107 length_of="client_shares"), 108 PacketListField("client_shares", [], KeyShareEntry, 109 length_from=lambda pkt: pkt.client_shares_len)] # noqa: E501 110 111 def post_build(self, pkt, pay): 112 if not self.tls_session.frozen: 113 privshares = self.tls_session.tls13_client_privshares 114 for kse in self.client_shares: 115 if kse.privkey: 116 if _tls_named_curves[kse.group] in privshares: 117 pkt_info = pkt.firstlayer().summary() 118 log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info) # noqa: E501 119 break 120 privshares[_tls_named_groups[kse.group]] = kse.privkey 121 return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay) 122 123 def post_dissection(self, r): 124 if not self.tls_session.frozen: 125 for kse in self.client_shares: 126 if kse.pubkey: 127 pubshares = self.tls_session.tls13_client_pubshares 128 if _tls_named_curves[kse.group] in pubshares: 129 pkt_info = r.firstlayer().summary() 130 log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info) # noqa: E501 131 break 132 pubshares[_tls_named_curves[kse.group]] = kse.pubkey 133 return super(TLS_Ext_KeyShare_CH, self).post_dissection(r) 134 135 136class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown): 137 name = "TLS Extension - Key Share (for HelloRetryRequest)" 138 fields_desc = [ShortEnumField("type", 0x33, _tls_ext), 139 ShortField("len", None), 140 ShortEnumField("selected_group", None, _tls_named_groups)] 141 142 143class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown): 144 name = "TLS Extension - Key Share (for ServerHello)" 145 fields_desc = [ShortEnumField("type", 0x33, _tls_ext), 146 ShortField("len", None), 147 PacketField("server_share", None, KeyShareEntry)] 148 149 def post_build(self, pkt, pay): 150 if not self.tls_session.frozen and self.server_share.privkey: 151 # if there is a privkey, we assume the crypto library is ok 152 privshare = self.tls_session.tls13_server_privshare 153 if len(privshare) > 0: 154 pkt_info = pkt.firstlayer().summary() 155 log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info) # noqa: E501 156 group_name = _tls_named_groups[self.server_share.group] 157 privshare[group_name] = self.server_share.privkey 158 159 if group_name in self.tls_session.tls13_client_pubshares: 160 privkey = self.server_share.privkey 161 pubkey = self.tls_session.tls13_client_pubshares[group_name] 162 if group_name in six.itervalues(_tls_named_ffdh_groups): 163 pms = privkey.exchange(pubkey) 164 elif group_name in six.itervalues(_tls_named_curves): 165 if group_name in ["x25519", "x448"]: 166 pms = privkey.exchange(pubkey) 167 else: 168 pms = privkey.exchange(ec.ECDH(), pubkey) 169 self.tls_session.tls13_dhe_secret = pms 170 return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay) 171 172 def post_dissection(self, r): 173 if not self.tls_session.frozen and self.server_share.pubkey: 174 # if there is a pubkey, we assume the crypto library is ok 175 pubshare = self.tls_session.tls13_server_pubshare 176 if pubshare: 177 pkt_info = r.firstlayer().summary() 178 log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info) # noqa: E501 179 group_name = _tls_named_groups[self.server_share.group] 180 pubshare[group_name] = self.server_share.pubkey 181 182 if group_name in self.tls_session.tls13_client_privshares: 183 pubkey = self.server_share.pubkey 184 privkey = self.tls_session.tls13_client_privshares[group_name] 185 if group_name in six.itervalues(_tls_named_ffdh_groups): 186 pms = privkey.exchange(pubkey) 187 elif group_name in six.itervalues(_tls_named_curves): 188 if group_name in ["x25519", "x448"]: 189 pms = privkey.exchange(pubkey) 190 else: 191 pms = privkey.exchange(ec.ECDH(), pubkey) 192 self.tls_session.tls13_dhe_secret = pms 193 elif group_name in self.tls_session.tls13_server_privshare: 194 pubkey = self.tls_session.tls13_client_pubshares[group_name] 195 privkey = self.tls_session.tls13_server_privshare[group_name] 196 if group_name in six.itervalues(_tls_named_ffdh_groups): 197 pms = privkey.exchange(pubkey) 198 elif group_name in six.itervalues(_tls_named_curves): 199 if group_name in ["x25519", "x448"]: 200 pms = privkey.exchange(pubkey) 201 else: 202 pms = privkey.exchange(ec.ECDH(), pubkey) 203 self.tls_session.tls13_dhe_secret = pms 204 return super(TLS_Ext_KeyShare_SH, self).post_dissection(r) 205 206 207_tls_ext_keyshare_cls = {1: TLS_Ext_KeyShare_CH, 208 2: TLS_Ext_KeyShare_SH} 209 210_tls_ext_keyshare_hrr_cls = {2: TLS_Ext_KeyShare_HRR} 211 212 213class Ticket(Packet): 214 name = "Recommended Ticket Construction (from RFC 5077)" 215 fields_desc = [StrFixedLenField("key_name", None, 16), 216 StrFixedLenField("iv", None, 16), 217 FieldLenField("encstatelen", None, length_of="encstate"), 218 StrLenField("encstate", "", 219 length_from=lambda pkt: pkt.encstatelen), 220 StrFixedLenField("mac", None, 32)] 221 222 223class TicketField(PacketField): 224 __slots__ = ["length_from"] 225 226 def __init__(self, name, default, length_from=None, **kargs): 227 self.length_from = length_from 228 PacketField.__init__(self, name, default, Ticket, **kargs) 229 230 def m2i(self, pkt, m): 231 tmp_len = self.length_from(pkt) 232 tbd, rem = m[:tmp_len], m[tmp_len:] 233 return self.cls(tbd) / Padding(rem) 234 235 236class PSKIdentity(Packet): 237 name = "PSK Identity" 238 fields_desc = [FieldLenField("identity_len", None, 239 length_of="identity"), 240 TicketField("identity", "", 241 length_from=lambda pkt: pkt.identity_len), 242 IntField("obfuscated_ticket_age", 0)] 243 244 245class PSKBinderEntry(Packet): 246 name = "PSK Binder Entry" 247 fields_desc = [FieldLenField("binder_len", None, fmt="B", 248 length_of="binder"), 249 StrLenField("binder", "", 250 length_from=lambda pkt: pkt.binder_len)] 251 252 253class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown): 254 # XXX define post_build and post_dissection methods 255 name = "TLS Extension - Pre Shared Key (for ClientHello)" 256 fields_desc = [ShortEnumField("type", 0x29, _tls_ext), 257 ShortField("len", None), 258 FieldLenField("identities_len", None, 259 length_of="identities"), 260 PacketListField("identities", [], PSKIdentity, 261 length_from=lambda pkt: pkt.identities_len), 262 FieldLenField("binders_len", None, 263 length_of="binders"), 264 PacketListField("binders", [], PSKBinderEntry, 265 length_from=lambda pkt: pkt.binders_len)] 266 267 268class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown): 269 name = "TLS Extension - Pre Shared Key (for ServerHello)" 270 fields_desc = [ShortEnumField("type", 0x29, _tls_ext), 271 ShortField("len", None), 272 ShortField("selected_identity", None)] 273 274 275_tls_ext_presharedkey_cls = {1: TLS_Ext_PreSharedKey_CH, 276 2: TLS_Ext_PreSharedKey_SH} 277