1# This file is part of Scapy
2# Copyright (C) 2017 Maxence Tury
3# This program is published under a GPLv2 license
4
5"""
6TLS handshake extensions.
7"""
8
9from __future__ import print_function
10
11import os
12import struct
13
14from scapy.fields import ByteEnumField, ByteField, EnumField, FieldLenField, \
15    FieldListField, IntField, PacketField, PacketListField, ShortEnumField, \
16    ShortField, StrFixedLenField, StrLenField, XStrLenField
17from scapy.packet import Packet, Raw, Padding
18from scapy.layers.x509 import X509_Extensions
19from scapy.layers.tls.basefields import _tls_version
20from scapy.layers.tls.keyexchange import (SigAndHashAlgsLenField,
21                                          SigAndHashAlgsField, _tls_hash_sig)
22from scapy.layers.tls.session import _GenericTLSSessionInheritance
23from scapy.layers.tls.crypto.groups import _tls_named_groups
24from scapy.layers.tls.crypto.suites import _tls_cipher_suites
25from scapy.themes import AnsiColorTheme
26from scapy.compat import raw
27from scapy.config import conf
28
29
30# Because ServerHello and HelloRetryRequest have the same
31# msg_type, the only way to distinguish these message is by
32# checking the random_bytes. If the random_bytes are equal to
33# SHA256('HelloRetryRequest') then we know this is a
34# HelloRetryRequest and the TLS_Ext_KeyShare must be parsed as
35# TLS_Ext_KeyShare_HRR and not as TLS_Ext_KeyShare_SH
36
37# from cryptography.hazmat.backends import default_backend
38# from cryptography.hazmat.primitives import hashes
39# digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
40# digest.update(b"HelloRetryRequest")
41# _tls_hello_retry_magic = digest.finalize()
42
43_tls_hello_retry_magic = (
44    b'\xcf!\xadt\xe5\x9aa\x11\xbe\x1d\x8c\x02\x1ee\xb8\x91\xc2\xa2\x11'
45    b'\x16z\xbb\x8c^\x07\x9e\t\xe2\xc8\xa83\x9c'
46)
47
48
49_tls_ext = {0: "server_name",             # RFC 4366
50            1: "max_fragment_length",     # RFC 4366
51            2: "client_certificate_url",  # RFC 4366
52            3: "trusted_ca_keys",         # RFC 4366
53            4: "truncated_hmac",          # RFC 4366
54            5: "status_request",          # RFC 4366
55            6: "user_mapping",            # RFC 4681
56            7: "client_authz",            # RFC 5878
57            8: "server_authz",            # RFC 5878
58            9: "cert_type",               # RFC 6091
59            # 10: "elliptic_curves",         # RFC 4492
60            10: "supported_groups",
61            11: "ec_point_formats",        # RFC 4492
62            13: "signature_algorithms",    # RFC 5246
63            0x0f: "heartbeat",             # RFC 6520
64            0x10: "alpn",                  # RFC 7301
65            0x12: "signed_certificate_timestamp",  # RFC 6962
66            0x13: "client_certificate_type",  # RFC 7250
67            0x14: "server_certificate_type",  # RFC 7250
68            0x15: "padding",               # RFC 7685
69            0x16: "encrypt_then_mac",      # RFC 7366
70            0x17: "extended_master_secret",  # RFC 7627
71            0x1c: "record_size_limit",     # RFC 8449
72            0x23: "session_ticket",        # RFC 5077
73            0x29: "pre_shared_key",
74            0x2a: "early_data_indication",
75            0x2b: "supported_versions",
76            0x2c: "cookie",
77            0x2d: "psk_key_exchange_modes",
78            0x2f: "certificate_authorities",
79            0x30: "oid_filters",
80            0x31: "post_handshake_auth",
81            0x32: "signature_algorithms_cert",
82            0x33: "key_share",
83            0x3374: "next_protocol_negotiation",
84            # RFC-draft-agl-tls-nextprotoneg-03
85            0xff01: "renegotiation_info",   # RFC 5746
86            0xffce: "encrypted_server_name"
87            }
88
89
90class TLS_Ext_Unknown(_GenericTLSSessionInheritance):
91    """
92    We put this here rather than in extensions.py in order to avoid
93    circular imports...
94    """
95    name = "TLS Extension - Scapy Unknown"
96    fields_desc = [ShortEnumField("type", None, _tls_ext),
97                   FieldLenField("len", None, fmt="!H", length_of="val"),
98                   StrLenField("val", "",
99                               length_from=lambda pkt: pkt.len)]
100
101    def post_build(self, p, pay):
102        if self.len is None:
103            tmp_len = len(p) - 4
104            p = p[:2] + struct.pack("!H", tmp_len) + p[4:]
105        return p + pay
106
107
108###############################################################################
109#   ClientHello/ServerHello extensions                                        #
110###############################################################################
111
112# We provide these extensions mostly for packet manipulation purposes.
113# For now, most of them are not considered by our automaton.
114
115class TLS_Ext_PrettyPacketList(TLS_Ext_Unknown):
116    """
117    Dummy extension used for server_name/ALPN/NPN for a lighter representation:
118    the final field is showed as a 1-line list rather than as lots of packets.
119    XXX Define a new condition for packet lists in Packet._show_or_dump?
120    """
121
122    def _show_or_dump(self, dump=False, indent=3,
123                      lvl="", label_lvl="", first_call=True):
124        """ Reproduced from packet.py """
125        ct = AnsiColorTheme() if dump else conf.color_theme
126        s = "%s%s %s %s \n" % (label_lvl, ct.punct("###["),
127                               ct.layer_name(self.name), ct.punct("]###"))
128        for f in self.fields_desc[:-1]:
129            ncol = ct.field_name
130            vcol = ct.field_value
131            fvalue = self.getfieldval(f.name)
132            begn = "%s  %-10s%s " % (label_lvl + lvl, ncol(f.name),
133                                     ct.punct("="),)
134            reprval = f.i2repr(self, fvalue)
135            if isinstance(reprval, str):
136                reprval = reprval.replace("\n", "\n" + " " * (len(label_lvl) +
137                                                              len(lvl) +
138                                                              len(f.name) +
139                                                              4))
140            s += "%s%s\n" % (begn, vcol(reprval))
141        f = self.fields_desc[-1]
142        ncol = ct.field_name
143        vcol = ct.field_value
144        fvalue = self.getfieldval(f.name)
145        begn = "%s  %-10s%s " % (label_lvl + lvl, ncol(f.name), ct.punct("="),)
146        reprval = f.i2repr(self, fvalue)
147        if isinstance(reprval, str):
148            reprval = reprval.replace("\n", "\n" + " " * (len(label_lvl) +
149                                                          len(lvl) +
150                                                          len(f.name) +
151                                                          4))
152        s += "%s%s\n" % (begn, vcol(reprval))
153        if self.payload:
154            s += self.payload._show_or_dump(dump=dump, indent=indent,
155                                            lvl=lvl + (" " * indent * self.show_indent),  # noqa: E501
156                                            label_lvl=label_lvl, first_call=False)  # noqa: E501
157
158        if first_call and not dump:
159            print(s)
160        else:
161            return s
162
163
164_tls_server_name_types = {0: "host_name"}
165
166
167class ServerName(Packet):
168    name = "HostName"
169    fields_desc = [ByteEnumField("nametype", 0, _tls_server_name_types),
170                   FieldLenField("namelen", None, length_of="servername"),
171                   StrLenField("servername", "",
172                               length_from=lambda pkt: pkt.namelen)]
173
174    def guess_payload_class(self, p):
175        return Padding
176
177
178class ServerListField(PacketListField):
179    def i2repr(self, pkt, x):
180        res = [p.servername for p in x]
181        return "[%s]" % b", ".join(res)
182
183
184class ServerLenField(FieldLenField):
185    """
186    There is no length when there are no servernames (as in a ServerHello).
187    """
188
189    def addfield(self, pkt, s, val):
190        if not val:
191            if not pkt.servernames:
192                return s
193        return super(ServerLenField, self).addfield(pkt, s, val)
194
195
196class TLS_Ext_ServerName(TLS_Ext_PrettyPacketList):                 # RFC 4366
197    name = "TLS Extension - Server Name"
198    fields_desc = [ShortEnumField("type", 0, _tls_ext),
199                   FieldLenField("len", None, length_of="servernames",
200                                 adjust=lambda pkt, x: x + 2),
201                   ServerLenField("servernameslen", None,
202                                  length_of="servernames"),
203                   ServerListField("servernames", [], ServerName,
204                                   length_from=lambda pkt: pkt.servernameslen)]
205
206
207class TLS_Ext_EncryptedServerName(TLS_Ext_PrettyPacketList):
208    name = "TLS Extension - Encrypted Server Name"
209    fields_desc = [ShortEnumField("type", 0xffce, _tls_ext),
210                   ShortField("len", None),
211                   EnumField("cipher", None, _tls_cipher_suites),
212                   ShortEnumField("key_exchange_group", None,
213                                  _tls_named_groups),
214                   FieldLenField("key_exchange_len", None,
215                                 length_of="key_exchange", fmt="H"),
216                   XStrLenField("key_exchange", "",
217                                length_from=lambda pkt: pkt.key_exchange_len),
218                   FieldLenField("record_digest_len",
219                                 None, length_of="record_digest"),
220                   XStrLenField("record_digest", "",
221                                length_from=lambda pkt: pkt.record_digest_len),
222                   FieldLenField("encrypted_sni_len", None,
223                                 length_of="encrypted_sni", fmt="H"),
224                   XStrLenField("encrypted_sni", "",
225                                length_from=lambda pkt: pkt.encrypted_sni_len)]
226
227
228class TLS_Ext_MaxFragLen(TLS_Ext_Unknown):                          # RFC 4366
229    name = "TLS Extension - Max Fragment Length"
230    fields_desc = [ShortEnumField("type", 1, _tls_ext),
231                   ShortField("len", None),
232                   ByteEnumField("maxfraglen", 4, {1: "2^9",
233                                                   2: "2^10",
234                                                   3: "2^11",
235                                                   4: "2^12"})]
236
237
238class TLS_Ext_ClientCertURL(TLS_Ext_Unknown):                       # RFC 4366
239    name = "TLS Extension - Client Certificate URL"
240    fields_desc = [ShortEnumField("type", 2, _tls_ext),
241                   ShortField("len", None)]
242
243
244_tls_trusted_authority_types = {0: "pre_agreed",
245                                1: "key_sha1_hash",
246                                2: "x509_name",
247                                3: "cert_sha1_hash"}
248
249
250class TAPreAgreed(Packet):
251    name = "Trusted authority - pre_agreed"
252    fields_desc = [ByteEnumField("idtype", 0, _tls_trusted_authority_types)]
253
254    def guess_payload_class(self, p):
255        return Padding
256
257
258class TAKeySHA1Hash(Packet):
259    name = "Trusted authority - key_sha1_hash"
260    fields_desc = [ByteEnumField("idtype", 1, _tls_trusted_authority_types),
261                   StrFixedLenField("id", None, 20)]
262
263    def guess_payload_class(self, p):
264        return Padding
265
266
267class TAX509Name(Packet):
268    """
269    XXX Section 3.4 of RFC 4366. Implement a more specific DNField
270    rather than current StrLenField.
271    """
272    name = "Trusted authority - x509_name"
273    fields_desc = [ByteEnumField("idtype", 2, _tls_trusted_authority_types),
274                   FieldLenField("dnlen", None, length_of="dn"),
275                   StrLenField("dn", "", length_from=lambda pkt: pkt.dnlen)]
276
277    def guess_payload_class(self, p):
278        return Padding
279
280
281class TACertSHA1Hash(Packet):
282    name = "Trusted authority - cert_sha1_hash"
283    fields_desc = [ByteEnumField("idtype", 3, _tls_trusted_authority_types),
284                   StrFixedLenField("id", None, 20)]
285
286    def guess_payload_class(self, p):
287        return Padding
288
289
290_tls_trusted_authority_cls = {0: TAPreAgreed,
291                              1: TAKeySHA1Hash,
292                              2: TAX509Name,
293                              3: TACertSHA1Hash}
294
295
296class _TAListField(PacketListField):
297    """
298    Specific version that selects the right Trusted Authority (previous TA*)
299    class to be used for dissection based on idtype.
300    """
301
302    def m2i(self, pkt, m):
303        idtype = ord(m[0])
304        cls = self.cls
305        if idtype in _tls_trusted_authority_cls:
306            cls = _tls_trusted_authority_cls[idtype]
307        return cls(m)
308
309
310class TLS_Ext_TrustedCAInd(TLS_Ext_Unknown):                        # RFC 4366
311    name = "TLS Extension - Trusted CA Indication"
312    fields_desc = [ShortEnumField("type", 3, _tls_ext),
313                   ShortField("len", None),
314                   FieldLenField("talen", None, length_of="ta"),
315                   _TAListField("ta", [], Raw,
316                                length_from=lambda pkt: pkt.talen)]
317
318
319class TLS_Ext_TruncatedHMAC(TLS_Ext_Unknown):                       # RFC 4366
320    name = "TLS Extension - Truncated HMAC"
321    fields_desc = [ShortEnumField("type", 4, _tls_ext),
322                   ShortField("len", None)]
323
324
325class ResponderID(Packet):
326    name = "Responder ID structure"
327    fields_desc = [FieldLenField("respidlen", None, length_of="respid"),
328                   StrLenField("respid", "",
329                               length_from=lambda pkt: pkt.respidlen)]
330
331    def guess_payload_class(self, p):
332        return Padding
333
334
335class OCSPStatusRequest(Packet):
336    """
337    This is the structure defined in RFC 6066, not in RFC 6960!
338    """
339    name = "OCSPStatusRequest structure"
340    fields_desc = [FieldLenField("respidlen", None, length_of="respid"),
341                   PacketListField("respid", [], ResponderID,
342                                   length_from=lambda pkt: pkt.respidlen),
343                   FieldLenField("reqextlen", None, length_of="reqext"),
344                   PacketField("reqext", "", X509_Extensions)]
345
346    def guess_payload_class(self, p):
347        return Padding
348
349
350_cert_status_type = {1: "ocsp"}
351_cert_status_req_cls = {1: OCSPStatusRequest}
352
353
354class _StatusReqField(PacketListField):
355    def m2i(self, pkt, m):
356        idtype = pkt.stype
357        cls = self.cls
358        if idtype in _cert_status_req_cls:
359            cls = _cert_status_req_cls[idtype]
360        return cls(m)
361
362
363class TLS_Ext_CSR(TLS_Ext_Unknown):                                 # RFC 4366
364    name = "TLS Extension - Certificate Status Request"
365    fields_desc = [ShortEnumField("type", 5, _tls_ext),
366                   ShortField("len", None),
367                   ByteEnumField("stype", None, _cert_status_type),
368                   _StatusReqField("req", [], Raw,
369                                   length_from=lambda pkt: pkt.len - 1)]
370
371
372class TLS_Ext_UserMapping(TLS_Ext_Unknown):                         # RFC 4681
373    name = "TLS Extension - User Mapping"
374    fields_desc = [ShortEnumField("type", 6, _tls_ext),
375                   ShortField("len", None),
376                   FieldLenField("umlen", None, fmt="B", length_of="um"),
377                   FieldListField("um", [],
378                                  ByteField("umtype", 0),
379                                  length_from=lambda pkt: pkt.umlen)]
380
381
382class TLS_Ext_ClientAuthz(TLS_Ext_Unknown):                         # RFC 5878
383    """ XXX Unsupported """
384    name = "TLS Extension - Client Authz"
385    fields_desc = [ShortEnumField("type", 7, _tls_ext),
386                   ShortField("len", None),
387                   ]
388
389
390class TLS_Ext_ServerAuthz(TLS_Ext_Unknown):                         # RFC 5878
391    """ XXX Unsupported """
392    name = "TLS Extension - Server Authz"
393    fields_desc = [ShortEnumField("type", 8, _tls_ext),
394                   ShortField("len", None),
395                   ]
396
397
398_tls_cert_types = {0: "X.509", 1: "OpenPGP"}
399
400
401class TLS_Ext_ClientCertType(TLS_Ext_Unknown):                      # RFC 5081
402    name = "TLS Extension - Certificate Type (client version)"
403    fields_desc = [ShortEnumField("type", 9, _tls_ext),
404                   ShortField("len", None),
405                   FieldLenField("ctypeslen", None, length_of="ctypes"),
406                   FieldListField("ctypes", [0, 1],
407                                  ByteEnumField("certtypes", None,
408                                                _tls_cert_types),
409                                  length_from=lambda pkt: pkt.ctypeslen)]
410
411
412class TLS_Ext_ServerCertType(TLS_Ext_Unknown):                      # RFC 5081
413    name = "TLS Extension - Certificate Type (server version)"
414    fields_desc = [ShortEnumField("type", 9, _tls_ext),
415                   ShortField("len", None),
416                   ByteEnumField("ctype", None, _tls_cert_types)]
417
418
419def _TLS_Ext_CertTypeDispatcher(m, *args, **kargs):
420    """
421    We need to select the correct one on dissection. We use the length for
422    that, as 1 for client version would emply an empty list.
423    """
424    tmp_len = struct.unpack("!H", m[2:4])[0]
425    if tmp_len == 1:
426        cls = TLS_Ext_ServerCertType
427    else:
428        cls = TLS_Ext_ClientCertType
429    return cls(m, *args, **kargs)
430
431
432class TLS_Ext_SupportedGroups(TLS_Ext_Unknown):
433    """
434    This extension was known as 'Supported Elliptic Curves' before TLS 1.3
435    merged both group selection mechanisms for ECDH and FFDH.
436    """
437    name = "TLS Extension - Supported Groups"
438    fields_desc = [ShortEnumField("type", 10, _tls_ext),
439                   ShortField("len", None),
440                   FieldLenField("groupslen", None, length_of="groups"),
441                   FieldListField("groups", [],
442                                  ShortEnumField("ng", None,
443                                                 _tls_named_groups),
444                                  length_from=lambda pkt: pkt.groupslen)]
445
446
447class TLS_Ext_SupportedEllipticCurves(TLS_Ext_SupportedGroups):     # RFC 4492
448    pass
449
450
451_tls_ecpoint_format = {0: "uncompressed",
452                       1: "ansiX962_compressed_prime",
453                       2: "ansiX962_compressed_char2"}
454
455
456class TLS_Ext_SupportedPointFormat(TLS_Ext_Unknown):                # RFC 4492
457    name = "TLS Extension - Supported Point Format"
458    fields_desc = [ShortEnumField("type", 11, _tls_ext),
459                   ShortField("len", None),
460                   FieldLenField("ecpllen", None, fmt="B", length_of="ecpl"),
461                   FieldListField("ecpl", [0],
462                                  ByteEnumField("nc", None,
463                                                _tls_ecpoint_format),
464                                  length_from=lambda pkt: pkt.ecpllen)]
465
466
467class TLS_Ext_SignatureAlgorithms(TLS_Ext_Unknown):                 # RFC 5246
468    name = "TLS Extension - Signature Algorithms"
469    fields_desc = [ShortEnumField("type", 13, _tls_ext),
470                   ShortField("len", None),
471                   SigAndHashAlgsLenField("sig_algs_len", None,
472                                          length_of="sig_algs"),
473                   SigAndHashAlgsField("sig_algs", [],
474                                       EnumField("hash_sig", None,
475                                                 _tls_hash_sig),
476                                       length_from=lambda pkt: pkt.sig_algs_len)]  # noqa: E501
477
478
479class TLS_Ext_Heartbeat(TLS_Ext_Unknown):                           # RFC 6520
480    name = "TLS Extension - Heartbeat"
481    fields_desc = [ShortEnumField("type", 0x0f, _tls_ext),
482                   ShortField("len", None),
483                   ByteEnumField("heartbeat_mode", 2,
484                                 {1: "peer_allowed_to_send",
485                                  2: "peer_not_allowed_to_send"})]
486
487
488class ProtocolName(Packet):
489    name = "Protocol Name"
490    fields_desc = [FieldLenField("len", None, fmt='B', length_of="protocol"),
491                   StrLenField("protocol", "",
492                               length_from=lambda pkt: pkt.len)]
493
494    def guess_payload_class(self, p):
495        return Padding
496
497
498class ProtocolListField(PacketListField):
499    def i2repr(self, pkt, x):
500        res = [p.protocol for p in x]
501        return "[%s]" % b", ".join(res)
502
503
504class TLS_Ext_ALPN(TLS_Ext_PrettyPacketList):                       # RFC 7301
505    name = "TLS Extension - Application Layer Protocol Negotiation"
506    fields_desc = [ShortEnumField("type", 0x10, _tls_ext),
507                   ShortField("len", None),
508                   FieldLenField("protocolslen", None, length_of="protocols"),
509                   ProtocolListField("protocols", [], ProtocolName,
510                                     length_from=lambda pkt:pkt.protocolslen)]
511
512
513class TLS_Ext_Padding(TLS_Ext_Unknown):                             # RFC 7685
514    name = "TLS Extension - Padding"
515    fields_desc = [ShortEnumField("type", 0x15, _tls_ext),
516                   FieldLenField("len", None, length_of="padding"),
517                   StrLenField("padding", "",
518                               length_from=lambda pkt: pkt.len)]
519
520
521class TLS_Ext_EncryptThenMAC(TLS_Ext_Unknown):                      # RFC 7366
522    name = "TLS Extension - Encrypt-then-MAC"
523    fields_desc = [ShortEnumField("type", 0x16, _tls_ext),
524                   ShortField("len", None)]
525
526
527class TLS_Ext_ExtendedMasterSecret(TLS_Ext_Unknown):                # RFC 7627
528    name = "TLS Extension - Extended Master Secret"
529    fields_desc = [ShortEnumField("type", 0x17, _tls_ext),
530                   ShortField("len", None)]
531
532
533class TLS_Ext_SessionTicket(TLS_Ext_Unknown):                       # RFC 5077
534    """
535    RFC 5077 updates RFC 4507 according to most implementations, which do not
536    use another (useless) 'ticketlen' field after the global 'len' field.
537    """
538    name = "TLS Extension - Session Ticket"
539    fields_desc = [ShortEnumField("type", 0x23, _tls_ext),
540                   FieldLenField("len", None, length_of="ticket"),
541                   StrLenField("ticket", "",
542                               length_from=lambda pkt: pkt.len)]
543
544
545class TLS_Ext_KeyShare(TLS_Ext_Unknown):
546    name = "TLS Extension - Key Share (dummy class)"
547    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
548                   ShortField("len", None)]
549
550
551class TLS_Ext_PreSharedKey(TLS_Ext_Unknown):
552    name = "TLS Extension - Pre Shared Key (dummy class)"
553    fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
554                   ShortField("len", None)]
555
556
557class TLS_Ext_EarlyDataIndication(TLS_Ext_Unknown):
558    name = "TLS Extension - Early Data"
559    fields_desc = [ShortEnumField("type", 0x2a, _tls_ext),
560                   ShortField("len", None)]
561
562
563class TLS_Ext_EarlyDataIndicationTicket(TLS_Ext_Unknown):
564    name = "TLS Extension - Ticket Early Data Info"
565    fields_desc = [ShortEnumField("type", 0x2a, _tls_ext),
566                   ShortField("len", None),
567                   IntField("max_early_data_size", 0)]
568
569
570_tls_ext_early_data_cls = {1: TLS_Ext_EarlyDataIndication,
571                           4: TLS_Ext_EarlyDataIndicationTicket,
572                           8: TLS_Ext_EarlyDataIndication}
573
574
575class TLS_Ext_SupportedVersions(TLS_Ext_Unknown):
576    name = "TLS Extension - Supported Versions (dummy class)"
577    fields_desc = [ShortEnumField("type", 0x2b, _tls_ext),
578                   ShortField("len", None)]
579
580
581class TLS_Ext_SupportedVersion_CH(TLS_Ext_Unknown):
582    name = "TLS Extension - Supported Versions (for ClientHello)"
583    fields_desc = [ShortEnumField("type", 0x2b, _tls_ext),
584                   ShortField("len", None),
585                   FieldLenField("versionslen", None, fmt='B',
586                                 length_of="versions"),
587                   FieldListField("versions", [],
588                                  ShortEnumField("version", None,
589                                                 _tls_version),
590                                  length_from=lambda pkt: pkt.versionslen)]
591
592
593class TLS_Ext_SupportedVersion_SH(TLS_Ext_Unknown):
594    name = "TLS Extension - Supported Versions (for ServerHello)"
595    fields_desc = [ShortEnumField("type", 0x2b, _tls_ext),
596                   ShortField("len", None),
597                   ShortEnumField("version", None, _tls_version)]
598
599
600_tls_ext_supported_version_cls = {1: TLS_Ext_SupportedVersion_CH,
601                                  2: TLS_Ext_SupportedVersion_SH}
602
603
604class TLS_Ext_Cookie(TLS_Ext_Unknown):
605    name = "TLS Extension - Cookie"
606    fields_desc = [ShortEnumField("type", 0x2c, _tls_ext),
607                   ShortField("len", None),
608                   FieldLenField("cookielen", None, length_of="cookie"),
609                   XStrLenField("cookie", "",
610                                length_from=lambda pkt: pkt.cookielen)]
611
612    def build(self):
613        fval = self.getfieldval("cookie")
614        if fval is None or fval == b"":
615            self.cookie = os.urandom(32)
616        return TLS_Ext_Unknown.build(self)
617
618
619_tls_psk_kx_modes = {0: "psk_ke", 1: "psk_dhe_ke"}
620
621
622class TLS_Ext_PSKKeyExchangeModes(TLS_Ext_Unknown):
623    name = "TLS Extension - PSK Key Exchange Modes"
624    fields_desc = [ShortEnumField("type", 0x2d, _tls_ext),
625                   ShortField("len", None),
626                   FieldLenField("kxmodeslen", None, fmt='B',
627                                 length_of="kxmodes"),
628                   FieldListField("kxmodes", [],
629                                  ByteEnumField("kxmode", None,
630                                                _tls_psk_kx_modes),
631                                  length_from=lambda pkt: pkt.kxmodeslen)]
632
633
634class TLS_Ext_TicketEarlyDataInfo(TLS_Ext_Unknown):
635    name = "TLS Extension - Ticket Early Data Info"
636    fields_desc = [ShortEnumField("type", 0x2e, _tls_ext),
637                   ShortField("len", None),
638                   IntField("max_early_data_size", 0)]
639
640
641class TLS_Ext_NPN(TLS_Ext_PrettyPacketList):
642    """
643    Defined in RFC-draft-agl-tls-nextprotoneg-03. Deprecated in favour of ALPN.
644    """
645    name = "TLS Extension - Next Protocol Negotiation"
646    fields_desc = [ShortEnumField("type", 0x3374, _tls_ext),
647                   FieldLenField("len", None, length_of="protocols"),
648                   ProtocolListField("protocols", [], ProtocolName,
649                                     length_from=lambda pkt:pkt.len)]
650
651
652class TLS_Ext_PostHandshakeAuth(TLS_Ext_Unknown):                   # RFC 8446
653    name = "TLS Extension - Post Handshake Auth"
654    fields_desc = [ShortEnumField("type", 0x31, _tls_ext),
655                   ShortField("len", None)]
656
657
658class TLS_Ext_SignatureAlgorithmsCert(TLS_Ext_Unknown):    # RFC 8446
659    name = "TLS Extension - Signature Algorithms Cert"
660    fields_desc = [ShortEnumField("type", 0x32, _tls_ext),
661                   ShortField("len", None),
662                   SigAndHashAlgsLenField("sig_algs_len", None,
663                                          length_of="sig_algs"),
664                   SigAndHashAlgsField("sig_algs", [],
665                                       EnumField("hash_sig", None,
666                                                 _tls_hash_sig),
667                                       length_from=lambda pkt: pkt.sig_algs_len)]  # noqa: E501
668
669
670class TLS_Ext_RenegotiationInfo(TLS_Ext_Unknown):                   # RFC 5746
671    name = "TLS Extension - Renegotiation Indication"
672    fields_desc = [ShortEnumField("type", 0xff01, _tls_ext),
673                   ShortField("len", None),
674                   FieldLenField("reneg_conn_len", None, fmt='B',
675                                 length_of="renegotiated_connection"),
676                   StrLenField("renegotiated_connection", "",
677                               length_from=lambda pkt: pkt.reneg_conn_len)]
678
679
680class TLS_Ext_RecordSizeLimit(TLS_Ext_Unknown):  # RFC 8449
681    name = "TLS Extension - Record Size Limit"
682    fields_desc = [ShortEnumField("type", 0x1c, _tls_ext),
683                   ShortField("len", None),
684                   ShortField("record_size_limit", None)]
685
686
687_tls_ext_cls = {0: TLS_Ext_ServerName,
688                1: TLS_Ext_MaxFragLen,
689                2: TLS_Ext_ClientCertURL,
690                3: TLS_Ext_TrustedCAInd,
691                4: TLS_Ext_TruncatedHMAC,
692                5: TLS_Ext_CSR,
693                6: TLS_Ext_UserMapping,
694                7: TLS_Ext_ClientAuthz,
695                8: TLS_Ext_ServerAuthz,
696                9: _TLS_Ext_CertTypeDispatcher,
697                # 10: TLS_Ext_SupportedEllipticCurves,
698                10: TLS_Ext_SupportedGroups,
699                11: TLS_Ext_SupportedPointFormat,
700                13: TLS_Ext_SignatureAlgorithms,
701                0x0f: TLS_Ext_Heartbeat,
702                0x10: TLS_Ext_ALPN,
703                0x15: TLS_Ext_Padding,
704                0x16: TLS_Ext_EncryptThenMAC,
705                0x17: TLS_Ext_ExtendedMasterSecret,
706                0x1c: TLS_Ext_RecordSizeLimit,
707                0x23: TLS_Ext_SessionTicket,
708                # 0x28: TLS_Ext_KeyShare,
709                0x29: TLS_Ext_PreSharedKey,
710                0x2a: TLS_Ext_EarlyDataIndication,
711                0x2b: TLS_Ext_SupportedVersions,
712                0x2c: TLS_Ext_Cookie,
713                0x2d: TLS_Ext_PSKKeyExchangeModes,
714                # 0x2e: TLS_Ext_TicketEarlyDataInfo,
715                0x31: TLS_Ext_PostHandshakeAuth,
716                0x32: TLS_Ext_SignatureAlgorithmsCert,
717                0x33: TLS_Ext_KeyShare,
718                # 0x2f: TLS_Ext_CertificateAuthorities,       #XXX
719                # 0x30: TLS_Ext_OIDFilters,                   #XXX
720                0x3374: TLS_Ext_NPN,
721                0xff01: TLS_Ext_RenegotiationInfo,
722                0xffce: TLS_Ext_EncryptedServerName
723                }
724
725
726class _ExtensionsLenField(FieldLenField):
727    def getfield(self, pkt, s):
728        """
729        We try to compute a length, usually from a msglen parsed earlier.
730        If this length is 0, we consider 'selection_present' (from RFC 5246)
731        to be False. This means that there should not be any length field.
732        However, with TLS 1.3, zero lengths are always explicit.
733        """
734        ext = pkt.get_field(self.length_of)
735        tmp_len = ext.length_from(pkt)
736        if tmp_len is None or tmp_len <= 0:
737            v = pkt.tls_session.tls_version
738            if v is None or v < 0x0304:
739                return s, None
740        return super(_ExtensionsLenField, self).getfield(pkt, s)
741
742    def addfield(self, pkt, s, i):
743        """
744        There is a hack with the _ExtensionsField.i2len. It works only because
745        we expect _ExtensionsField.i2m to return a string of the same size (if
746        not of the same value) upon successive calls (e.g. through i2len here,
747        then i2m when directly building the _ExtensionsField).
748
749        XXX A proper way to do this would be to keep the extensions built from
750        the i2len call here, instead of rebuilding them later on.
751        """
752        if i is None:
753            if self.length_of is not None:
754                fld, fval = pkt.getfield_and_val(self.length_of)
755
756                tmp = pkt.tls_session.frozen
757                pkt.tls_session.frozen = True
758                f = fld.i2len(pkt, fval)
759                pkt.tls_session.frozen = tmp
760
761                i = self.adjust(pkt, f)
762                if i == 0:  # for correct build if no ext and not explicitly 0
763                    v = pkt.tls_session.tls_version
764                    # With TLS 1.3, zero lengths are always explicit.
765                    if v is None or v < 0x0304:
766                        return s
767                    else:
768                        return s + struct.pack(self.fmt, i)
769        return s + struct.pack(self.fmt, i)
770
771
772class _ExtensionsField(StrLenField):
773    islist = 1
774    holds_packets = 1
775
776    def i2len(self, pkt, i):
777        if i is None:
778            return 0
779        return len(self.i2m(pkt, i))
780
781    def getfield(self, pkt, s):
782        tmp_len = self.length_from(pkt) or 0
783        if tmp_len <= 0:
784            return s, []
785        return s[tmp_len:], self.m2i(pkt, s[:tmp_len])
786
787    def i2m(self, pkt, i):
788        if i is None:
789            return b""
790        if isinstance(pkt, _GenericTLSSessionInheritance):
791            if not pkt.tls_session.frozen:
792                s = b""
793                for ext in i:
794                    if isinstance(ext, _GenericTLSSessionInheritance):
795                        ext.tls_session = pkt.tls_session
796                        s += ext.raw_stateful()
797                    else:
798                        s += raw(ext)
799                return s
800        return b"".join(map(raw, i))
801
802    def m2i(self, pkt, m):
803        res = []
804        while len(m) >= 4:
805            t = struct.unpack("!H", m[:2])[0]
806            tmp_len = struct.unpack("!H", m[2:4])[0]
807            cls = _tls_ext_cls.get(t, TLS_Ext_Unknown)
808            if cls is TLS_Ext_KeyShare:
809                # TLS_Ext_KeyShare can be :
810                #  - TLS_Ext_KeyShare_CH if the message is a ClientHello
811                #  - TLS_Ext_KeyShare_SH if the message is a ServerHello
812                #    and all parameters are accepted by the serveur
813                #  - TLS_Ext_KeyShare_HRR if message is a ServerHello and
814                #    the client has not provided a sufficient "key_share"
815                #    extension
816                from scapy.layers.tls.keyexchange_tls13 import (
817                    _tls_ext_keyshare_cls, _tls_ext_keyshare_hrr_cls)
818                # If SHA-256("HelloRetryRequest") == server_random,
819                # this message is a HelloRetryRequest
820                if pkt.random_bytes and \
821                        pkt.random_bytes == _tls_hello_retry_magic:
822                    cls = _tls_ext_keyshare_hrr_cls.get(pkt.msgtype, TLS_Ext_Unknown)  # noqa: E501
823                else:
824                    cls = _tls_ext_keyshare_cls.get(pkt.msgtype, TLS_Ext_Unknown)  # noqa: E501
825            elif cls is TLS_Ext_PreSharedKey:
826                from scapy.layers.tls.keyexchange_tls13 import _tls_ext_presharedkey_cls  # noqa: E501
827                cls = _tls_ext_presharedkey_cls.get(pkt.msgtype, TLS_Ext_Unknown)  # noqa: E501
828            elif cls is TLS_Ext_SupportedVersions:
829                cls = _tls_ext_supported_version_cls.get(pkt.msgtype, TLS_Ext_Unknown)  # noqa: E501
830            elif cls is TLS_Ext_EarlyDataIndication:
831                cls = _tls_ext_early_data_cls.get(pkt.msgtype, TLS_Ext_Unknown)
832            res.append(cls(m[:tmp_len + 4], tls_session=pkt.tls_session))
833            m = m[tmp_len + 4:]
834        return res
835