1# This file is part of Scapy
2# Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
3#               2015, 2016, 2017 Maxence Tury
4#               2019 Romain Perez
5# This program is published under a GPLv2 license
6
7"""
8TLS handshake fields & logic.
9
10This module covers the handshake TLS subprotocol, except for the key exchange
11mechanisms which are addressed with keyexchange.py.
12"""
13
14from __future__ import absolute_import
15import math
16import os
17import struct
18
19from scapy.error import log_runtime, warning
20from scapy.fields import (
21    ByteEnumField,
22    ByteField,
23    Field,
24    FieldLenField,
25    IntField,
26    PacketField,
27    PacketListField,
28    ShortEnumField,
29    ShortField,
30    StrFixedLenField,
31    StrLenField,
32    ThreeBytesField,
33    UTCTimeField,
34)
35
36from scapy.compat import hex_bytes, orb, raw
37from scapy.config import conf
38from scapy.modules import six
39from scapy.packet import Packet, Raw, Padding
40from scapy.utils import randstring, repr_hex
41from scapy.layers.x509 import OCSP_Response
42from scapy.layers.tls.cert import Cert
43from scapy.layers.tls.basefields import (_tls_version, _TLSVersionField,
44                                         _TLSClientVersionField)
45from scapy.layers.tls.extensions import (_ExtensionsLenField, _ExtensionsField,
46                                         _cert_status_type,
47                                         TLS_Ext_SupportedVersion_CH,
48                                         TLS_Ext_SignatureAlgorithms,
49                                         TLS_Ext_SupportedVersion_SH,
50                                         TLS_Ext_EarlyDataIndication,
51                                         _tls_hello_retry_magic,
52                                         TLS_Ext_ExtendedMasterSecret,
53                                         TLS_Ext_EncryptThenMAC)
54from scapy.layers.tls.keyexchange import (_TLSSignature, _TLSServerParamsField,
55                                          _TLSSignatureField, ServerRSAParams,
56                                          SigAndHashAlgsField, _tls_hash_sig,
57                                          SigAndHashAlgsLenField)
58from scapy.layers.tls.session import (_GenericTLSSessionInheritance,
59                                      readConnState, writeConnState)
60from scapy.layers.tls.keyexchange_tls13 import TLS_Ext_PreSharedKey_CH
61from scapy.layers.tls.crypto.compression import (_tls_compression_algs,
62                                                 _tls_compression_algs_cls,
63                                                 Comp_NULL, _GenericComp,
64                                                 _GenericCompMetaclass)
65from scapy.layers.tls.crypto.suites import (_tls_cipher_suites,
66                                            _tls_cipher_suites_cls,
67                                            _GenericCipherSuite,
68                                            _GenericCipherSuiteMetaclass)
69from scapy.layers.tls.crypto.hkdf import TLS13_HKDF
70
71if conf.crypto_valid:
72    from cryptography.hazmat.backends import default_backend
73    from cryptography.hazmat.primitives import hashes
74
75
76###############################################################################
77#   Generic TLS Handshake message                                             #
78###############################################################################
79
80_tls_handshake_type = {0: "hello_request", 1: "client_hello",
81                       2: "server_hello", 3: "hello_verify_request",
82                       4: "session_ticket", 6: "hello_retry_request",
83                       8: "encrypted_extensions", 11: "certificate",
84                       12: "server_key_exchange", 13: "certificate_request",
85                       14: "server_hello_done", 15: "certificate_verify",
86                       16: "client_key_exchange", 20: "finished",
87                       21: "certificate_url", 22: "certificate_status",
88                       23: "supplemental_data", 24: "key_update"}
89
90
91class _TLSHandshake(_GenericTLSSessionInheritance):
92    """
93    Inherited by other Handshake classes to get post_build().
94    Also used as a fallback for unknown TLS Handshake packets.
95    """
96    name = "TLS Handshake Generic message"
97    fields_desc = [ByteEnumField("msgtype", None, _tls_handshake_type),
98                   ThreeBytesField("msglen", None),
99                   StrLenField("msg", "",
100                               length_from=lambda pkt: pkt.msglen)]
101
102    def post_build(self, p, pay):
103        tmp_len = len(p)
104        if self.msglen is None:
105            l2 = tmp_len - 4
106            p = struct.pack("!I", (orb(p[0]) << 24) | l2) + p[4:]
107        return p + pay
108
109    def guess_payload_class(self, p):
110        return conf.padding_layer
111
112    def tls_session_update(self, msg_str):
113        """
114        Covers both post_build- and post_dissection- context updates.
115        """
116
117        self.tls_session.handshake_messages.append(msg_str)
118        self.tls_session.handshake_messages_parsed.append(self)
119
120
121###############################################################################
122#   HelloRequest                                                              #
123###############################################################################
124
125class TLSHelloRequest(_TLSHandshake):
126    name = "TLS Handshake - Hello Request"
127    fields_desc = [ByteEnumField("msgtype", 0, _tls_handshake_type),
128                   ThreeBytesField("msglen", None)]
129
130    def tls_session_update(self, msg_str):
131        """
132        Message should not be added to the list of handshake messages
133        that will be hashed in the finished and certificate verify messages.
134        """
135        return
136
137
138###############################################################################
139#   ClientHello fields                                                        #
140###############################################################################
141
142class _GMTUnixTimeField(UTCTimeField):
143    """
144    "The current time and date in standard UNIX 32-bit format (seconds since
145     the midnight starting Jan 1, 1970, GMT, ignoring leap seconds)."
146    """
147
148    def i2h(self, pkt, x):
149        if x is not None:
150            return x
151        return 0
152
153
154class _TLSRandomBytesField(StrFixedLenField):
155    def i2repr(self, pkt, x):
156        if x is None:
157            return repr(x)
158        return repr_hex(self.i2h(pkt, x))
159
160
161class _SessionIDField(StrLenField):
162    """
163    opaque SessionID<0..32>; section 7.4.1.2 of RFC 4346
164    """
165    pass
166
167
168class _CipherSuitesField(StrLenField):
169    __slots__ = ["itemfmt", "itemsize", "i2s", "s2i"]
170    islist = 1
171
172    def __init__(self, name, default, dico, length_from=None, itemfmt="!H"):
173        StrLenField.__init__(self, name, default, length_from=length_from)
174        self.itemfmt = itemfmt
175        self.itemsize = struct.calcsize(itemfmt)
176        i2s = self.i2s = {}
177        s2i = self.s2i = {}
178        for k in six.iterkeys(dico):
179            i2s[k] = dico[k]
180            s2i[dico[k]] = k
181
182    def any2i_one(self, pkt, x):
183        if isinstance(x, (_GenericCipherSuite, _GenericCipherSuiteMetaclass)):
184            x = x.val
185        if isinstance(x, bytes):
186            x = self.s2i[x]
187        return x
188
189    def i2repr_one(self, pkt, x):
190        fmt = "0x%%0%dx" % self.itemsize
191        return self.i2s.get(x, fmt % x)
192
193    def any2i(self, pkt, x):
194        if x is None:
195            return None
196        if not isinstance(x, list):
197            x = [x]
198        return [self.any2i_one(pkt, z) for z in x]
199
200    def i2repr(self, pkt, x):
201        if x is None:
202            return "None"
203        tmp_len = [self.i2repr_one(pkt, z) for z in x]
204        if len(tmp_len) == 1:
205            tmp_len = tmp_len[0]
206        else:
207            tmp_len = "[%s]" % ", ".join(tmp_len)
208        return tmp_len
209
210    def i2m(self, pkt, val):
211        if val is None:
212            val = []
213        return b"".join(struct.pack(self.itemfmt, x) for x in val)
214
215    def m2i(self, pkt, m):
216        res = []
217        itemlen = struct.calcsize(self.itemfmt)
218        while m:
219            res.append(struct.unpack(self.itemfmt, m[:itemlen])[0])
220            m = m[itemlen:]
221        return res
222
223    def i2len(self, pkt, i):
224        if i is None:
225            return 0
226        return len(i) * self.itemsize
227
228
229class _CompressionMethodsField(_CipherSuitesField):
230
231    def any2i_one(self, pkt, x):
232        if isinstance(x, (_GenericComp, _GenericCompMetaclass)):
233            x = x.val
234        if isinstance(x, str):
235            x = self.s2i[x]
236        return x
237
238
239###############################################################################
240#   ClientHello                                                               #
241###############################################################################
242
243class TLSClientHello(_TLSHandshake):
244    """
245    TLS ClientHello, with abilities to handle extensions.
246
247    The Random structure follows the RFC 5246: while it is 32-byte long,
248    many implementations use the first 4 bytes as a gmt_unix_time, and then
249    the remaining 28 byts should be completely random. This was designed in
250    order to (sort of) mitigate broken RNGs. If you prefer to show the full
251    32 random bytes without any GMT time, just comment in/out the lines below.
252    """
253    name = "TLS Handshake - Client Hello"
254    fields_desc = [ByteEnumField("msgtype", 1, _tls_handshake_type),
255                   ThreeBytesField("msglen", None),
256                   _TLSClientVersionField("version", None, _tls_version),
257
258                   # _TLSRandomBytesField("random_bytes", None, 32),
259                   _GMTUnixTimeField("gmt_unix_time", None),
260                   _TLSRandomBytesField("random_bytes", None, 28),
261
262                   FieldLenField("sidlen", None, fmt="B", length_of="sid"),
263                   _SessionIDField("sid", "",
264                                   length_from=lambda pkt: pkt.sidlen),
265
266                   FieldLenField("cipherslen", None, fmt="!H",
267                                 length_of="ciphers"),
268                   _CipherSuitesField("ciphers", None,
269                                      _tls_cipher_suites, itemfmt="!H",
270                                      length_from=lambda pkt: pkt.cipherslen),
271
272                   FieldLenField("complen", None, fmt="B", length_of="comp"),
273                   _CompressionMethodsField("comp", [0],
274                                            _tls_compression_algs,
275                                            itemfmt="B",
276                                            length_from=lambda pkt: pkt.complen),  # noqa: E501
277
278                   _ExtensionsLenField("extlen", None, length_of="ext"),
279                   _ExtensionsField("ext", None,
280                                    length_from=lambda pkt: (pkt.msglen -
281                                                             (pkt.sidlen or 0) -  # noqa: E501
282                                                             (pkt.cipherslen or 0) -  # noqa: E501
283                                                             (pkt.complen or 0) -  # noqa: E501
284                                                             40))]
285
286    def post_build(self, p, pay):
287        if self.random_bytes is None:
288            p = p[:10] + randstring(28) + p[10 + 28:]
289
290        # if no ciphersuites were provided, we add a few usual, supported
291        # ciphersuites along with the appropriate extensions
292        if self.ciphers is None:
293            cipherstart = 39 + (self.sidlen or 0)
294            s = b"001ac02bc023c02fc027009e0067009c003cc009c0130033002f000a"
295            p = p[:cipherstart] + hex_bytes(s) + p[cipherstart + 2:]
296            if self.ext is None:
297                ext_len = b'\x00\x2c'
298                ext_reneg = b'\xff\x01\x00\x01\x00'
299                ext_sn = b'\x00\x00\x00\x0f\x00\r\x00\x00\nsecdev.org'
300                ext_sigalg = b'\x00\r\x00\x08\x00\x06\x04\x03\x04\x01\x02\x01'
301                ext_supgroups = b'\x00\n\x00\x04\x00\x02\x00\x17'
302                p += ext_len + ext_reneg + ext_sn + ext_sigalg + ext_supgroups
303
304        return super(TLSClientHello, self).post_build(p, pay)
305
306    def tls_session_update(self, msg_str):
307        """
308        Either for parsing or building, we store the client_random
309        along with the raw string representing this handshake message.
310        """
311        super(TLSClientHello, self).tls_session_update(msg_str)
312        s = self.tls_session
313        s.advertised_tls_version = self.version
314        # This ClientHello could be a 1.3 one. Let's store the sid
315        # in all cases
316        if self.sidlen and self.sidlen > 0:
317            s.sid = self.sid
318        self.random_bytes = msg_str[10:38]
319        s.client_random = (struct.pack('!I', self.gmt_unix_time) +
320                           self.random_bytes)
321
322        # No distinction between a TLS 1.2 ClientHello and a TLS
323        # 1.3 ClientHello when dissecting : TLS 1.3 CH will be
324        # parsed as TLSClientHello
325        if self.ext:
326            for e in self.ext:
327                if isinstance(e, TLS_Ext_SupportedVersion_CH):
328                    for ver in sorted(e.versions, reverse=True):
329                        # RFC 8701: GREASE of TLS will send unknown versions
330                        # here. We have to ignore them
331                        if ver in _tls_version:
332                            s.advertised_tls_version = ver
333                            break
334                    if s.sid:
335                        s.middlebox_compatibility = True
336
337                if isinstance(e, TLS_Ext_SignatureAlgorithms):
338                    s.advertised_sig_algs = e.sig_algs
339
340
341class TLS13ClientHello(_TLSHandshake):
342    """
343    TLS 1.3 ClientHello, with abilities to handle extensions.
344
345    The Random structure is 32 random bytes without any GMT time
346    """
347    name = "TLS 1.3 Handshake - Client Hello"
348    fields_desc = [ByteEnumField("msgtype", 1, _tls_handshake_type),
349                   ThreeBytesField("msglen", None),
350                   _TLSClientVersionField("version", None, _tls_version),
351
352                   _TLSRandomBytesField("random_bytes", None, 32),
353
354                   FieldLenField("sidlen", None, fmt="B", length_of="sid"),
355                   _SessionIDField("sid", "",
356                                   length_from=lambda pkt: pkt.sidlen),
357
358                   FieldLenField("cipherslen", None, fmt="!H",
359                                 length_of="ciphers"),
360                   _CipherSuitesField("ciphers", None,
361                                      _tls_cipher_suites, itemfmt="!H",
362                                      length_from=lambda pkt: pkt.cipherslen),
363
364                   FieldLenField("complen", None, fmt="B", length_of="comp"),
365                   _CompressionMethodsField("comp", [0],
366                                            _tls_compression_algs,
367                                            itemfmt="B",
368                                            length_from=lambda pkt: pkt.complen),  # noqa: E501
369
370                   _ExtensionsLenField("extlen", None, length_of="ext"),
371                   _ExtensionsField("ext", None,
372                                    length_from=lambda pkt: (pkt.msglen -
373                                                             (pkt.sidlen or 0) -  # noqa: E501
374                                                             (pkt.cipherslen or 0) -  # noqa: E501
375                                                             (pkt.complen or 0) -  # noqa: E501
376                                                             40))]
377
378    def post_build(self, p, pay):
379        if self.random_bytes is None:
380            p = p[:6] + randstring(32) + p[6 + 32:]
381        # We don't call the post_build function from class _TLSHandshake
382        # to compute the message length because we need that value now
383        # for the HMAC in binder
384        tmp_len = len(p)
385        if self.msglen is None:
386            sz = tmp_len - 4
387            p = struct.pack("!I", (orb(p[0]) << 24) | sz) + p[4:]
388        s = self.tls_session
389        if self.ext:
390            for e in self.ext:
391                if isinstance(e, TLS_Ext_PreSharedKey_CH):
392                    if s.client_session_ticket:
393                        # For a resumed PSK, the hash function use
394                        # to compute the binder must be the same
395                        # as the one used to establish the original
396                        # conntection. For that, we assume that
397                        # the ciphersuite associate with the ticket
398                        # is given as argument to tlsSession
399                        # (see layers/tls/automaton_cli.py for an
400                        # example)
401                        res_suite = s.tls13_ticket_ciphersuite
402                        cs_cls = _tls_cipher_suites_cls[res_suite]
403                        hkdf = TLS13_HKDF(cs_cls.hash_alg.name.lower())
404                        hash_len = hkdf.hash.digest_size
405                        s.compute_tls13_early_secrets(external=False)
406                    else:
407                        # For out of band PSK, SHA-256 is used as default
408                        # hash functions for HKDF
409                        hkdf = TLS13_HKDF("sha256")
410                        hash_len = hkdf.hash.digest_size
411                        s.compute_tls13_early_secrets(external=True)
412
413                    # RFC8446 4.2.11.2
414                    # "Each entry in the binders list is computed as an HMAC
415                    # over a transcript hash (see Section 4.4.1) containing a
416                    # partial ClientHello up to and including the
417                    # PreSharedKeyExtension.identities field."
418                    # PSK Binders field is :
419                    #   - PSK Binders length (2 bytes)
420                    #   - First PSK Binder length (1 byte) +
421                    #         HMAC (hash_len bytes)
422                    # The PSK Binder is computed in the same way as the
423                    # Finished message with binder_key as BaseKey
424
425                    handshake_context = b""
426                    if s.tls13_retry:
427                        for m in s.handshake_messages:
428                            handshake_context += m
429                    handshake_context += p[:-hash_len - 3]
430
431                    binder_key = s.tls13_derived_secrets["binder_key"]
432                    psk_binder = hkdf.compute_verify_data(binder_key,
433                                                          handshake_context)
434
435                    # Here, we replaced the last 32 bytes of the packet by the
436                    # new HMAC values computed over the ClientHello (without
437                    # the binders)
438                    p = p[:-hash_len] + psk_binder
439
440        return p + pay
441
442    def tls_session_update(self, msg_str):
443        """
444        Either for parsing or building, we store the client_random
445        along with the raw string representing this handshake message.
446        """
447        super(TLS13ClientHello, self).tls_session_update(msg_str)
448        s = self.tls_session
449
450        if self.sidlen and self.sidlen > 0:
451            s.sid = self.sid
452            s.middlebox_compatibility = True
453
454        self.random_bytes = msg_str[10:38]
455        s.client_random = self.random_bytes
456        if self.ext:
457            for e in self.ext:
458                if isinstance(e, TLS_Ext_SupportedVersion_CH):
459                    for ver in sorted(e.versions, reverse=True):
460                        # RFC 8701: GREASE of TLS will send unknown versions
461                        # here. We have to ignore them
462                        if ver in _tls_version:
463                            self.tls_session.advertised_tls_version = ver
464                            break
465                if isinstance(e, TLS_Ext_SignatureAlgorithms):
466                    s.advertised_sig_algs = e.sig_algs
467
468
469###############################################################################
470#   ServerHello                                                               #
471###############################################################################
472
473
474class TLSServerHello(_TLSHandshake):
475    """
476    TLS ServerHello, with abilities to handle extensions.
477
478    The Random structure follows the RFC 5246: while it is 32-byte long,
479    many implementations use the first 4 bytes as a gmt_unix_time, and then
480    the remaining 28 byts should be completely random. This was designed in
481    order to (sort of) mitigate broken RNGs. If you prefer to show the full
482    32 random bytes without any GMT time, just comment in/out the lines below.
483    """
484    name = "TLS Handshake - Server Hello"
485    fields_desc = [ByteEnumField("msgtype", 2, _tls_handshake_type),
486                   ThreeBytesField("msglen", None),
487                   _TLSVersionField("version", None, _tls_version),
488
489                   # _TLSRandomBytesField("random_bytes", None, 32),
490                   _GMTUnixTimeField("gmt_unix_time", None),
491                   _TLSRandomBytesField("random_bytes", None, 28),
492
493                   FieldLenField("sidlen", None, length_of="sid", fmt="B"),
494                   _SessionIDField("sid", "",
495                                   length_from=lambda pkt: pkt.sidlen),
496
497                   ShortEnumField("cipher", None, _tls_cipher_suites),
498                   _CompressionMethodsField("comp", [0],
499                                            _tls_compression_algs,
500                                            itemfmt="B",
501                                            length_from=lambda pkt: 1),
502
503                   _ExtensionsLenField("extlen", None, length_of="ext"),
504                   _ExtensionsField("ext", None,
505                                    length_from=lambda pkt: (
506                                        pkt.msglen - (pkt.sidlen or 0) - 40
507                                    ))]
508
509    @classmethod
510    def dispatch_hook(cls, _pkt=None, *args, **kargs):
511        if _pkt and len(_pkt) >= 6:
512            version = struct.unpack("!H", _pkt[4:6])[0]
513            if version == 0x0304 or version > 0x7f00:
514                return TLS13ServerHello
515        return TLSServerHello
516
517    def post_build(self, p, pay):
518        if self.random_bytes is None:
519            p = p[:10] + randstring(28) + p[10 + 28:]
520        return super(TLSServerHello, self).post_build(p, pay)
521
522    def tls_session_update(self, msg_str):
523        """
524        Either for parsing or building, we store the server_random
525        along with the raw string representing this handshake message.
526        We also store the session_id, the cipher suite (if recognized),
527        the compression method, and finally we instantiate the pending write
528        and read connection states. Usually they get updated later on in the
529        negotiation when we learn the session keys, and eventually they
530        are committed once a ChangeCipherSpec has been sent/received.
531        """
532        super(TLSServerHello, self).tls_session_update(msg_str)
533
534        s = self.tls_session
535        s.tls_version = self.version
536        if hasattr(self, 'gmt_unix_time'):
537            self.random_bytes = msg_str[10:38]
538            s.server_random = (struct.pack('!I', self.gmt_unix_time) +
539                               self.random_bytes)
540        else:
541            s.server_random = self.random_bytes
542        s.sid = self.sid
543
544        if self.ext:
545            for e in self.ext:
546                if isinstance(e, TLS_Ext_ExtendedMasterSecret):
547                    self.tls_session.extms = True
548                if isinstance(e, TLS_Ext_EncryptThenMAC):
549                    self.tls_session.encrypt_then_mac = True
550
551        cs_cls = None
552        if self.cipher:
553            cs_val = self.cipher
554            if cs_val not in _tls_cipher_suites_cls:
555                warning("Unknown cipher suite %d from ServerHello", cs_val)
556                # we do not try to set a default nor stop the execution
557            else:
558                cs_cls = _tls_cipher_suites_cls[cs_val]
559
560        comp_cls = Comp_NULL
561        if self.comp:
562            comp_val = self.comp[0]
563            if comp_val not in _tls_compression_algs_cls:
564                err = "Unknown compression alg %d from ServerHello"
565                warning(err, comp_val)
566                comp_val = 0
567            comp_cls = _tls_compression_algs_cls[comp_val]
568
569        connection_end = s.connection_end
570        s.pwcs = writeConnState(ciphersuite=cs_cls,
571                                compression_alg=comp_cls,
572                                connection_end=connection_end,
573                                tls_version=self.version)
574        s.prcs = readConnState(ciphersuite=cs_cls,
575                               compression_alg=comp_cls,
576                               connection_end=connection_end,
577                               tls_version=self.version)
578
579
580_tls_13_server_hello_fields = [
581    ByteEnumField("msgtype", 2, _tls_handshake_type),
582    ThreeBytesField("msglen", None),
583    _TLSVersionField("version", 0x0303, _tls_version),
584    _TLSRandomBytesField("random_bytes", None, 32),
585    FieldLenField("sidlen", None, length_of="sid", fmt="B"),
586    _SessionIDField("sid", "",
587                    length_from=lambda pkt: pkt.sidlen),
588    ShortEnumField("cipher", None, _tls_cipher_suites),
589    _CompressionMethodsField("comp", [0],
590                             _tls_compression_algs,
591                             itemfmt="B",
592                             length_from=lambda pkt: 1),
593    _ExtensionsLenField("extlen", None, length_of="ext"),
594    _ExtensionsField("ext", None,
595                     length_from=lambda pkt: (pkt.msglen -
596                                              38))
597]
598
599
600class TLS13ServerHello(TLSServerHello):
601    """ TLS 1.3 ServerHello """
602    name = "TLS 1.3 Handshake - Server Hello"
603    fields_desc = _tls_13_server_hello_fields
604
605    # ServerHello and HelloRetryRequest has the same structure and the same
606    # msgId. We need to check the server_random value to determine which it is.
607    @classmethod
608    def dispatch_hook(cls, _pkt=None, *args, **kargs):
609        if _pkt and len(_pkt) >= 38:
610            # If SHA-256("HelloRetryRequest") == server_random,
611            # this message is a HelloRetryRequest
612            random_bytes = _pkt[6:38]
613            if random_bytes == _tls_hello_retry_magic:
614                return TLS13HelloRetryRequest
615        return TLS13ServerHello
616
617    def post_build(self, p, pay):
618        if self.random_bytes is None:
619            p = p[:6] + randstring(32) + p[6 + 32:]
620        return super(TLS13ServerHello, self).post_build(p, pay)
621
622    def tls_session_update(self, msg_str):
623        """
624        Either for parsing or building, we store the server_random along with
625        the raw string representing this handshake message. We also store the
626        cipher suite (if recognized), and finally we instantiate the write and
627        read connection states.
628        """
629        s = self.tls_session
630        s.server_random = self.random_bytes
631        s.ciphersuite = self.cipher
632        s.tls_version = self.version
633        # Check extensions
634        if self.ext:
635            for e in self.ext:
636                if isinstance(e, TLS_Ext_SupportedVersion_SH):
637                    s.tls_version = e.version
638                    break
639
640        if s.tls_version < 0x304:
641            # This means that the server does not support TLS 1.3 and ignored
642            # the initial TLS 1.3 ClientHello. tls_version has been updated
643            return TLSServerHello.tls_session_update(self, msg_str)
644        else:
645            _TLSHandshake.tls_session_update(self, msg_str)
646
647        cs_cls = None
648        if self.cipher:
649            cs_val = self.cipher
650            if cs_val not in _tls_cipher_suites_cls:
651                warning("Unknown cipher suite %d from ServerHello", cs_val)
652                # we do not try to set a default nor stop the execution
653            else:
654                cs_cls = _tls_cipher_suites_cls[cs_val]
655
656        connection_end = s.connection_end
657        if connection_end == "server":
658            s.pwcs = writeConnState(ciphersuite=cs_cls,
659                                    connection_end=connection_end,
660                                    tls_version=s.tls_version)
661
662            if not s.middlebox_compatibility:
663                s.triggered_pwcs_commit = True
664        elif connection_end == "client":
665
666            s.prcs = readConnState(ciphersuite=cs_cls,
667                                   connection_end=connection_end,
668                                   tls_version=s.tls_version)
669            if not s.middlebox_compatibility:
670                s.triggered_prcs_commit = True
671
672        if s.tls13_early_secret is None:
673            # In case the connState was not pre-initialized, we could not
674            # compute the early secrets at the ClientHello, so we do it here.
675            s.compute_tls13_early_secrets()
676        s.compute_tls13_handshake_secrets()
677        if connection_end == "server":
678            shts = s.tls13_derived_secrets["server_handshake_traffic_secret"]
679            s.pwcs.tls13_derive_keys(shts)
680        elif connection_end == "client":
681            shts = s.tls13_derived_secrets["server_handshake_traffic_secret"]
682            s.prcs.tls13_derive_keys(shts)
683
684
685###############################################################################
686#   HelloRetryRequest                                                         #
687###############################################################################
688
689class TLS13HelloRetryRequest(_TLSHandshake):
690    name = "TLS 1.3 Handshake - Hello Retry Request"
691
692    fields_desc = _tls_13_server_hello_fields
693
694    def build(self):
695        fval = self.getfieldval("random_bytes")
696        if fval is None:
697            self.random_bytes = _tls_hello_retry_magic
698        return _TLSHandshake.build(self)
699
700    def tls_session_update(self, msg_str):
701        s = self.tls_session
702        s.tls13_retry = True
703        s.tls13_client_pubshares = {}
704        # If the server responds to a ClientHello with a HelloRetryRequest
705        # The value of the first ClientHello is replaced by a message_hash
706        if s.client_session_ticket:
707            cs_cls = _tls_cipher_suites_cls[s.tls13_ticket_ciphersuite]
708            hkdf = TLS13_HKDF(cs_cls.hash_alg.name.lower())
709            hash_len = hkdf.hash.digest_size
710        else:
711            cs_cls = _tls_cipher_suites_cls[self.cipher]
712            hkdf = TLS13_HKDF(cs_cls.hash_alg.name.lower())
713            hash_len = hkdf.hash.digest_size
714
715        handshake_context = struct.pack("B", 254)
716        handshake_context += struct.pack("B", 0)
717        handshake_context += struct.pack("B", 0)
718        handshake_context += struct.pack("B", hash_len)
719        digest = hashes.Hash(hkdf.hash, backend=default_backend())
720        digest.update(s.handshake_messages[0])
721        handshake_context += digest.finalize()
722        s.handshake_messages[0] = handshake_context
723        super(TLS13HelloRetryRequest, self).tls_session_update(msg_str)
724
725
726###############################################################################
727#   EncryptedExtensions                                                       #
728###############################################################################
729
730
731class TLSEncryptedExtensions(_TLSHandshake):
732    name = "TLS 1.3 Handshake - Encrypted Extensions"
733    fields_desc = [ByteEnumField("msgtype", 8, _tls_handshake_type),
734                   ThreeBytesField("msglen", None),
735                   _ExtensionsLenField("extlen", None, length_of="ext"),
736                   _ExtensionsField("ext", None,
737                                    length_from=lambda pkt: pkt.msglen - 2)]
738
739    def post_build_tls_session_update(self, msg_str):
740        self.tls_session_update(msg_str)
741
742        s = self.tls_session
743        connection_end = s.connection_end
744
745        # Check if the server early_data extension is present in
746        # EncryptedExtensions message (if so, early data was accepted by the
747        # server)
748        early_data_accepted = False
749        if self.ext:
750            for e in self.ext:
751                if isinstance(e, TLS_Ext_EarlyDataIndication):
752                    early_data_accepted = True
753
754        # If the serveur did not accept early_data, we change prcs traffic
755        # encryption keys. Otherwise, the the keys will be updated after the
756        # EndOfEarlyData message
757        if connection_end == "server":
758            if not early_data_accepted:
759                s.prcs = readConnState(ciphersuite=type(s.wcs.ciphersuite),
760                                       connection_end=connection_end,
761                                       tls_version=s.tls_version)
762
763                chts = s.tls13_derived_secrets["client_handshake_traffic_secret"]  # noqa: E501
764                s.prcs.tls13_derive_keys(chts)
765
766                if not s.middlebox_compatibility:
767                    s.rcs = self.tls_session.prcs
768                    s.triggered_prcs_commit = False
769                else:
770                    s.triggered_prcs_commit = True
771
772    def post_dissection_tls_session_update(self, msg_str):
773        self.tls_session_update(msg_str)
774        s = self.tls_session
775        connection_end = s.connection_end
776
777        # Check if the server early_data extension is present in
778        # EncryptedExtensions message (if so, early data was accepted by the
779        # server)
780        early_data_accepted = False
781        if self.ext:
782            for e in self.ext:
783                if isinstance(e, TLS_Ext_EarlyDataIndication):
784                    early_data_accepted = True
785
786        # If the serveur did not accept early_data, we change pwcs traffic
787        # encryption key. Otherwise, the the keys will be updated after the
788        # EndOfEarlyData message
789        if connection_end == "client":
790            if not early_data_accepted:
791                s.pwcs = writeConnState(ciphersuite=type(s.rcs.ciphersuite),
792                                        connection_end=connection_end,
793                                        tls_version=s.tls_version)
794                chts = s.tls13_derived_secrets["client_handshake_traffic_secret"]  # noqa: E501
795                s.pwcs.tls13_derive_keys(chts)
796                if not s.middlebox_compatibility:
797                    s.wcs = self.tls_session.pwcs
798                    s.triggered_pwcs_commit = False
799                else:
800                    s.triggered_prcs_commit = True
801###############################################################################
802#   Certificate                                                               #
803###############################################################################
804
805# XXX It might be appropriate to rewrite this mess with basic 3-byte FieldLenField.  # noqa: E501
806
807
808class _ASN1CertLenField(FieldLenField):
809    """
810    This is mostly a 3-byte FieldLenField.
811    """
812    def __init__(self, name, default, length_of=None, adjust=lambda pkt, x: x):
813        self.length_of = length_of
814        self.adjust = adjust
815        Field.__init__(self, name, default, fmt="!I")
816
817    def i2m(self, pkt, x):
818        if x is None:
819            if self.length_of is not None:
820                fld, fval = pkt.getfield_and_val(self.length_of)
821                f = fld.i2len(pkt, fval)
822                x = self.adjust(pkt, f)
823        return x
824
825    def addfield(self, pkt, s, val):
826        return s + struct.pack(self.fmt, self.i2m(pkt, val))[1:4]
827
828    def getfield(self, pkt, s):
829        return s[3:], self.m2i(pkt, struct.unpack(self.fmt, b"\x00" + s[:3])[0])  # noqa: E501
830
831
832class _ASN1CertListField(StrLenField):
833    islist = 1
834
835    def i2len(self, pkt, i):
836        if i is None:
837            return 0
838        return len(self.i2m(pkt, i))
839
840    def getfield(self, pkt, s):
841        """
842        Extract Certs in a loop.
843        XXX We should provide safeguards when trying to parse a Cert.
844        """
845        tmp_len = None
846        if self.length_from is not None:
847            tmp_len = self.length_from(pkt)
848
849        lst = []
850        ret = b""
851        m = s
852        if tmp_len is not None:
853            m, ret = s[:tmp_len], s[tmp_len:]
854        while m:
855            clen = struct.unpack("!I", b'\x00' + m[:3])[0]
856            lst.append((clen, Cert(m[3:3 + clen])))
857            m = m[3 + clen:]
858        return m + ret, lst
859
860    def i2m(self, pkt, i):
861        def i2m_one(i):
862            if isinstance(i, str):
863                return i
864            if isinstance(i, Cert):
865                s = i.der
866                tmp_len = struct.pack("!I", len(s))[1:4]
867                return tmp_len + s
868
869            (tmp_len, s) = i
870            if isinstance(s, Cert):
871                s = s.der
872            return struct.pack("!I", tmp_len)[1:4] + s
873
874        if i is None:
875            return b""
876        if isinstance(i, str):
877            return i
878        if isinstance(i, Cert):
879            i = [i]
880        return b"".join(i2m_one(x) for x in i)
881
882    def any2i(self, pkt, x):
883        return x
884
885
886class _ASN1CertField(StrLenField):
887    def i2len(self, pkt, i):
888        if i is None:
889            return 0
890        return len(self.i2m(pkt, i))
891
892    def getfield(self, pkt, s):
893        tmp_len = None
894        if self.length_from is not None:
895            tmp_len = self.length_from(pkt)
896        ret = b""
897        m = s
898        if tmp_len is not None:
899            m, ret = s[:tmp_len], s[tmp_len:]
900        clen = struct.unpack("!I", b'\x00' + m[:3])[0]
901        len_cert = (clen, Cert(m[3:3 + clen]))
902        m = m[3 + clen:]
903        return m + ret, len_cert
904
905    def i2m(self, pkt, i):
906        def i2m_one(i):
907            if isinstance(i, str):
908                return i
909            if isinstance(i, Cert):
910                s = i.der
911                tmp_len = struct.pack("!I", len(s))[1:4]
912                return tmp_len + s
913
914            (tmp_len, s) = i
915            if isinstance(s, Cert):
916                s = s.der
917            return struct.pack("!I", tmp_len)[1:4] + s
918
919        if i is None:
920            return b""
921        return i2m_one(i)
922
923    def any2i(self, pkt, x):
924        return x
925
926
927class TLSCertificate(_TLSHandshake):
928    """
929    XXX We do not support RFC 5081, i.e. OpenPGP certificates.
930    """
931    name = "TLS Handshake - Certificate"
932    fields_desc = [ByteEnumField("msgtype", 11, _tls_handshake_type),
933                   ThreeBytesField("msglen", None),
934                   _ASN1CertLenField("certslen", None, length_of="certs"),
935                   _ASN1CertListField("certs", [],
936                                      length_from=lambda pkt: pkt.certslen)]
937
938    @classmethod
939    def dispatch_hook(cls, _pkt=None, *args, **kargs):
940        if _pkt:
941            tls_session = kargs.get("tls_session", None)
942            if tls_session and (tls_session.tls_version or 0) >= 0x0304:
943                return TLS13Certificate
944        return TLSCertificate
945
946    def post_dissection_tls_session_update(self, msg_str):
947        self.tls_session_update(msg_str)
948        connection_end = self.tls_session.connection_end
949        if connection_end == "client":
950            self.tls_session.server_certs = [x[1] for x in self.certs]
951        else:
952            self.tls_session.client_certs = [x[1] for x in self.certs]
953
954
955class _ASN1CertAndExt(_GenericTLSSessionInheritance):
956    name = "Certificate and Extensions"
957    fields_desc = [_ASN1CertField("cert", ""),
958                   FieldLenField("extlen", None, length_of="ext"),
959                   _ExtensionsField("ext", [],
960                                    length_from=lambda pkt: pkt.extlen)]
961
962    def extract_padding(self, s):
963        return b"", s
964
965
966class _ASN1CertAndExtListField(PacketListField):
967    def m2i(self, pkt, m):
968        return self.cls(m, tls_session=pkt.tls_session)
969
970
971class TLS13Certificate(_TLSHandshake):
972    name = "TLS 1.3 Handshake - Certificate"
973    fields_desc = [ByteEnumField("msgtype", 11, _tls_handshake_type),
974                   ThreeBytesField("msglen", None),
975                   FieldLenField("cert_req_ctxt_len", None, fmt="B",
976                                 length_of="cert_req_ctxt"),
977                   StrLenField("cert_req_ctxt", "",
978                               length_from=lambda pkt: pkt.cert_req_ctxt_len),
979                   _ASN1CertLenField("certslen", None, length_of="certs"),
980                   _ASN1CertAndExtListField("certs", [], _ASN1CertAndExt,
981                                            length_from=lambda pkt: pkt.certslen)]  # noqa: E501
982
983    def post_dissection_tls_session_update(self, msg_str):
984        self.tls_session_update(msg_str)
985        connection_end = self.tls_session.connection_end
986        if connection_end == "client":
987            if self.certs:
988                sc = [x.cert[1] for x in self.certs]
989                self.tls_session.server_certs = sc
990        else:
991            if self.certs:
992                cc = [x.cert[1] for x in self.certs]
993                self.tls_session.client_certs = cc
994
995
996###############################################################################
997#   ServerKeyExchange                                                         #
998###############################################################################
999
1000class TLSServerKeyExchange(_TLSHandshake):
1001    name = "TLS Handshake - Server Key Exchange"
1002    fields_desc = [ByteEnumField("msgtype", 12, _tls_handshake_type),
1003                   ThreeBytesField("msglen", None),
1004                   _TLSServerParamsField("params", None,
1005                                         length_from=lambda pkt: pkt.msglen),
1006                   _TLSSignatureField("sig", None,
1007                                      length_from=lambda pkt: pkt.msglen - len(pkt.params))]  # noqa: E501
1008
1009    def build(self, *args, **kargs):
1010        r"""
1011        We overload build() method in order to provide a valid default value
1012        for params based on TLS session if not provided. This cannot be done by
1013        overriding i2m() because the method is called on a copy of the packet.
1014
1015        The 'params' field is built according to key_exchange.server_kx_msg_cls
1016        which should have been set after receiving a cipher suite in a
1017        previous ServerHello. Usual cases are:
1018
1019        - None: for RSA encryption or fixed FF/ECDH. This should never happen,
1020          as no ServerKeyExchange should be generated in the first place.
1021        - ServerDHParams: for ephemeral FFDH. In that case, the parameter to
1022          server_kx_msg_cls does not matter.
1023        - ServerECDH\*Params: for ephemeral ECDH. There are actually three
1024          classes, which are dispatched by _tls_server_ecdh_cls_guess on
1025          the first byte retrieved. The default here is b"\03", which
1026          corresponds to ServerECDHNamedCurveParams (implicit curves).
1027
1028        When the Server\*DHParams are built via .fill_missing(), the session
1029        server_kx_privkey will be updated accordingly.
1030        """
1031        fval = self.getfieldval("params")
1032        if fval is None:
1033            s = self.tls_session
1034            if s.pwcs:
1035                if s.pwcs.key_exchange.export:
1036                    cls = ServerRSAParams(tls_session=s)
1037                else:
1038                    cls = s.pwcs.key_exchange.server_kx_msg_cls(b"\x03")
1039                    cls = cls(tls_session=s)
1040                try:
1041                    cls.fill_missing()
1042                except Exception:
1043                    if conf.debug_dissector:
1044                        raise
1045            else:
1046                cls = Raw()
1047            self.params = cls
1048
1049        fval = self.getfieldval("sig")
1050        if fval is None:
1051            s = self.tls_session
1052            if s.pwcs and s.client_random:
1053                if not s.pwcs.key_exchange.anonymous:
1054                    p = self.params
1055                    if p is None:
1056                        p = b""
1057                    m = s.client_random + s.server_random + raw(p)
1058                    cls = _TLSSignature(tls_session=s)
1059                    cls._update_sig(m, s.server_key)
1060                else:
1061                    cls = Raw()
1062            else:
1063                cls = Raw()
1064            self.sig = cls
1065
1066        return _TLSHandshake.build(self, *args, **kargs)
1067
1068    def post_dissection(self, pkt):
1069        """
1070        While previously dissecting Server*DHParams, the session
1071        server_kx_pubkey should have been updated.
1072
1073        XXX Add a 'fixed_dh' OR condition to the 'anonymous' test.
1074        """
1075        s = self.tls_session
1076        if s.prcs and s.prcs.key_exchange.no_ske:
1077            pkt_info = pkt.firstlayer().summary()
1078            log_runtime.info("TLS: useless ServerKeyExchange [%s]", pkt_info)
1079        if (s.prcs and
1080            not s.prcs.key_exchange.anonymous and
1081            s.client_random and s.server_random and
1082                s.server_certs and len(s.server_certs) > 0):
1083            m = s.client_random + s.server_random + raw(self.params)
1084            sig_test = self.sig._verify_sig(m, s.server_certs[0])
1085            if not sig_test:
1086                pkt_info = pkt.firstlayer().summary()
1087                log_runtime.info("TLS: invalid ServerKeyExchange signature [%s]", pkt_info)  # noqa: E501
1088
1089
1090###############################################################################
1091#   CertificateRequest                                                        #
1092###############################################################################
1093
1094_tls_client_certificate_types = {1: "rsa_sign",
1095                                 2: "dss_sign",
1096                                 3: "rsa_fixed_dh",
1097                                 4: "dss_fixed_dh",
1098                                 5: "rsa_ephemeral_dh_RESERVED",
1099                                 6: "dss_ephemeral_dh_RESERVED",
1100                                 20: "fortezza_dms_RESERVED",
1101                                 64: "ecdsa_sign",
1102                                 65: "rsa_fixed_ecdh",
1103                                 66: "ecdsa_fixed_ecdh"}
1104
1105
1106class _CertTypesField(_CipherSuitesField):
1107    pass
1108
1109
1110class _CertAuthoritiesField(StrLenField):
1111    """
1112    XXX Rework this with proper ASN.1 parsing.
1113    """
1114    islist = 1
1115
1116    def getfield(self, pkt, s):
1117        tmp_len = self.length_from(pkt)
1118        return s[tmp_len:], self.m2i(pkt, s[:tmp_len])
1119
1120    def m2i(self, pkt, m):
1121        res = []
1122        while len(m) > 1:
1123            tmp_len = struct.unpack("!H", m[:2])[0]
1124            if len(m) < tmp_len + 2:
1125                res.append((tmp_len, m[2:]))
1126                break
1127            dn = m[2:2 + tmp_len]
1128            res.append((tmp_len, dn))
1129            m = m[2 + tmp_len:]
1130        return res
1131
1132    def i2m(self, pkt, i):
1133        return b"".join(map(lambda x_y: struct.pack("!H", x_y[0]) + x_y[1], i))
1134
1135    def addfield(self, pkt, s, val):
1136        return s + self.i2m(pkt, val)
1137
1138    def i2len(self, pkt, val):
1139        if val is None:
1140            return 0
1141        else:
1142            return len(self.i2m(pkt, val))
1143
1144
1145class TLSCertificateRequest(_TLSHandshake):
1146    name = "TLS Handshake - Certificate Request"
1147    fields_desc = [ByteEnumField("msgtype", 13, _tls_handshake_type),
1148                   ThreeBytesField("msglen", None),
1149                   FieldLenField("ctypeslen", None, fmt="B",
1150                                 length_of="ctypes"),
1151                   _CertTypesField("ctypes", [1, 64],
1152                                   _tls_client_certificate_types,
1153                                   itemfmt="!B",
1154                                   length_from=lambda pkt: pkt.ctypeslen),
1155                   SigAndHashAlgsLenField("sig_algs_len", None,
1156                                          length_of="sig_algs"),
1157                   SigAndHashAlgsField("sig_algs", [0x0403, 0x0401, 0x0201],
1158                                       ShortEnumField("hash_sig", None, _tls_hash_sig),  # noqa: E501
1159                                       length_from=lambda pkt: pkt.sig_algs_len),  # noqa: E501
1160                   FieldLenField("certauthlen", None, fmt="!H",
1161                                 length_of="certauth"),
1162                   _CertAuthoritiesField("certauth", [],
1163                                         length_from=lambda pkt: pkt.certauthlen)]  # noqa: E501
1164
1165
1166class TLS13CertificateRequest(_TLSHandshake):
1167    name = "TLS 1.3 Handshake - Certificate Request"
1168    fields_desc = [ByteEnumField("msgtype", 13, _tls_handshake_type),
1169                   ThreeBytesField("msglen", None),
1170                   FieldLenField("cert_req_ctxt_len", None, fmt="B",
1171                                 length_of="cert_req_ctxt"),
1172                   StrLenField("cert_req_ctxt", "",
1173                               length_from=lambda pkt: pkt.cert_req_ctxt_len),
1174                   _ExtensionsLenField("extlen", None, length_of="ext"),
1175                   _ExtensionsField("ext", None,
1176                                    length_from=lambda pkt: pkt.msglen -
1177                                    pkt.cert_req_ctxt_len - 3)]
1178
1179###############################################################################
1180#   ServerHelloDone                                                           #
1181###############################################################################
1182
1183
1184class TLSServerHelloDone(_TLSHandshake):
1185    name = "TLS Handshake - Server Hello Done"
1186    fields_desc = [ByteEnumField("msgtype", 14, _tls_handshake_type),
1187                   ThreeBytesField("msglen", None)]
1188
1189
1190###############################################################################
1191#   CertificateVerify                                                         #
1192###############################################################################
1193
1194class TLSCertificateVerify(_TLSHandshake):
1195    name = "TLS Handshake - Certificate Verify"
1196    fields_desc = [ByteEnumField("msgtype", 15, _tls_handshake_type),
1197                   ThreeBytesField("msglen", None),
1198                   _TLSSignatureField("sig", None,
1199                                      length_from=lambda pkt: pkt.msglen)]
1200
1201    def build(self, *args, **kargs):
1202        sig = self.getfieldval("sig")
1203        if sig is None:
1204            s = self.tls_session
1205            m = b"".join(s.handshake_messages)
1206            tls_version = s.tls_version
1207            if tls_version is None:
1208                tls_version = s.advertised_tls_version
1209            if tls_version >= 0x0304:
1210                if s.connection_end == "client":
1211                    context_string = b"TLS 1.3, client CertificateVerify"
1212                elif s.connection_end == "server":
1213                    context_string = b"TLS 1.3, server CertificateVerify"
1214                m = b"\x20" * 64 + context_string + b"\x00" + s.wcs.hash.digest(m)  # noqa: E501
1215            self.sig = _TLSSignature(tls_session=s)
1216            if s.connection_end == "client":
1217                self.sig._update_sig(m, s.client_key)
1218            elif s.connection_end == "server":
1219                # should be TLS 1.3 only
1220                self.sig._update_sig(m, s.server_key)
1221        return _TLSHandshake.build(self, *args, **kargs)
1222
1223    def post_dissection(self, pkt):
1224        s = self.tls_session
1225        m = b"".join(s.handshake_messages)
1226        tls_version = s.tls_version
1227        if tls_version is None:
1228            tls_version = s.advertised_tls_version
1229        if tls_version >= 0x0304:
1230            if s.connection_end == "client":
1231                context_string = b"TLS 1.3, server CertificateVerify"
1232            elif s.connection_end == "server":
1233                context_string = b"TLS 1.3, client CertificateVerify"
1234            m = b"\x20" * 64 + context_string + b"\x00" + s.rcs.hash.digest(m)
1235
1236        if s.connection_end == "server":
1237            if s.client_certs and len(s.client_certs) > 0:
1238                sig_test = self.sig._verify_sig(m, s.client_certs[0])
1239                if not sig_test:
1240                    pkt_info = pkt.firstlayer().summary()
1241                    log_runtime.info("TLS: invalid CertificateVerify signature [%s]", pkt_info)  # noqa: E501
1242        elif s.connection_end == "client":
1243            # should be TLS 1.3 only
1244            if s.server_certs and len(s.server_certs) > 0:
1245                sig_test = self.sig._verify_sig(m, s.server_certs[0])
1246                if not sig_test:
1247                    pkt_info = pkt.firstlayer().summary()
1248                    log_runtime.info("TLS: invalid CertificateVerify signature [%s]", pkt_info)  # noqa: E501
1249
1250
1251###############################################################################
1252#   ClientKeyExchange                                                         #
1253###############################################################################
1254
1255class _TLSCKExchKeysField(PacketField):
1256    __slots__ = ["length_from"]
1257    holds_packet = 1
1258
1259    def __init__(self, name, length_from=None):
1260        self.length_from = length_from
1261        PacketField.__init__(self, name, None, None)
1262
1263    def m2i(self, pkt, m):
1264        """
1265        The client_kx_msg may be either None, EncryptedPreMasterSecret
1266        (for RSA encryption key exchange), ClientDiffieHellmanPublic,
1267        or ClientECDiffieHellmanPublic. When either one of them gets
1268        dissected, the session context is updated accordingly.
1269        """
1270        tmp_len = self.length_from(pkt)
1271        tbd, rem = m[:tmp_len], m[tmp_len:]
1272
1273        s = pkt.tls_session
1274        cls = None
1275
1276        if s.prcs and s.prcs.key_exchange:
1277            cls = s.prcs.key_exchange.client_kx_msg_cls
1278
1279        if cls is None:
1280            return Raw(tbd) / Padding(rem)
1281
1282        return cls(tbd, tls_session=s) / Padding(rem)
1283
1284
1285class TLSClientKeyExchange(_TLSHandshake):
1286    """
1287    This class mostly works like TLSServerKeyExchange and its 'params' field.
1288    """
1289    name = "TLS Handshake - Client Key Exchange"
1290    fields_desc = [ByteEnumField("msgtype", 16, _tls_handshake_type),
1291                   ThreeBytesField("msglen", None),
1292                   _TLSCKExchKeysField("exchkeys",
1293                                       length_from=lambda pkt: pkt.msglen)]
1294
1295    def build(self, *args, **kargs):
1296        fval = self.getfieldval("exchkeys")
1297        if fval is None:
1298            s = self.tls_session
1299            if s.prcs:
1300                cls = s.prcs.key_exchange.client_kx_msg_cls
1301                cls = cls(tls_session=s)
1302            else:
1303                cls = Raw()
1304            self.exchkeys = cls
1305        return _TLSHandshake.build(self, *args, **kargs)
1306
1307    def tls_session_update(self, msg_str):
1308        """
1309        Finalize the EXTMS messages and compute the hash
1310        """
1311        super(TLSClientKeyExchange, self).tls_session_update(msg_str)
1312
1313        if self.tls_session.extms:
1314            to_hash = b''.join(self.tls_session.handshake_messages)
1315            # https://tools.ietf.org/html/rfc7627#section-3
1316            if self.tls_session.tls_version >= 0x303:
1317                # TLS 1.2 uses the same Hash as the PRF
1318                from scapy.layers.tls.crypto.hash import _tls_hash_algs
1319                hash_object = _tls_hash_algs.get(
1320                    self.tls_session.prcs.prf.hash_name
1321                )()
1322                self.tls_session.session_hash = hash_object.digest(to_hash)
1323            else:
1324                # Previous TLS version use concatenation of MD5 & SHA1
1325                from scapy.layers.tls.crypto.hash import Hash_MD5, Hash_SHA
1326                self.tls_session.session_hash = (
1327                    Hash_MD5().digest(to_hash) + Hash_SHA().digest(to_hash)
1328                )
1329            self.tls_session.compute_ms_and_derive_keys()
1330
1331
1332###############################################################################
1333#   Finished                                                                  #
1334###############################################################################
1335
1336class _VerifyDataField(StrLenField):
1337    def getfield(self, pkt, s):
1338        if pkt.tls_session.tls_version == 0x0300:
1339            sep = 36
1340        elif pkt.tls_session.tls_version >= 0x0304:
1341            sep = pkt.tls_session.rcs.hash.hash_len
1342        else:
1343            sep = 12
1344        return s[sep:], s[:sep]
1345
1346
1347class TLSFinished(_TLSHandshake):
1348    name = "TLS Handshake - Finished"
1349    fields_desc = [ByteEnumField("msgtype", 20, _tls_handshake_type),
1350                   ThreeBytesField("msglen", None),
1351                   _VerifyDataField("vdata", None)]
1352
1353    def build(self, *args, **kargs):
1354        fval = self.getfieldval("vdata")
1355        if fval is None:
1356            s = self.tls_session
1357            handshake_msg = b"".join(s.handshake_messages)
1358            con_end = s.connection_end
1359            tls_version = s.tls_version
1360            if tls_version is None:
1361                tls_version = s.advertised_tls_version
1362            if tls_version < 0x0304:
1363                ms = s.master_secret
1364                self.vdata = s.wcs.prf.compute_verify_data(con_end, "write",
1365                                                           handshake_msg, ms)
1366            else:
1367                self.vdata = s.compute_tls13_verify_data(con_end, "write")
1368        return _TLSHandshake.build(self, *args, **kargs)
1369
1370    def post_dissection(self, pkt):
1371        s = self.tls_session
1372        if not s.frozen:
1373            handshake_msg = b"".join(s.handshake_messages)
1374            tls_version = s.tls_version
1375            if tls_version is None:
1376                tls_version = s.advertised_tls_version
1377            if tls_version < 0x0304 and s.master_secret is not None:
1378                ms = s.master_secret
1379                con_end = s.connection_end
1380                verify_data = s.rcs.prf.compute_verify_data(con_end, "read",
1381                                                            handshake_msg, ms)
1382                if self.vdata != verify_data:
1383                    pkt_info = pkt.firstlayer().summary()
1384                    log_runtime.info("TLS: invalid Finished received [%s]", pkt_info)  # noqa: E501
1385            elif tls_version >= 0x0304:
1386                con_end = s.connection_end
1387                verify_data = s.compute_tls13_verify_data(con_end, "read")
1388                if self.vdata != verify_data:
1389                    pkt_info = pkt.firstlayer().summary()
1390                    log_runtime.info("TLS: invalid Finished received [%s]", pkt_info)  # noqa: E501
1391
1392    def post_build_tls_session_update(self, msg_str):
1393        self.tls_session_update(msg_str)
1394        s = self.tls_session
1395        tls_version = s.tls_version
1396        if tls_version is None:
1397            tls_version = s.advertised_tls_version
1398        if tls_version >= 0x0304:
1399            s.pwcs = writeConnState(ciphersuite=type(s.wcs.ciphersuite),
1400                                    connection_end=s.connection_end,
1401                                    tls_version=s.tls_version)
1402            s.triggered_pwcs_commit = True
1403            if s.connection_end == "server":
1404                s.compute_tls13_traffic_secrets()
1405            elif s.connection_end == "client":
1406                s.compute_tls13_traffic_secrets_end()
1407                s.compute_tls13_resumption_secret()
1408
1409    def post_dissection_tls_session_update(self, msg_str):
1410        self.tls_session_update(msg_str)
1411        s = self.tls_session
1412        tls_version = s.tls_version
1413        if tls_version is None:
1414            tls_version = s.advertised_tls_version
1415        if tls_version >= 0x0304:
1416            s.prcs = readConnState(ciphersuite=type(s.rcs.ciphersuite),
1417                                   connection_end=s.connection_end,
1418                                   tls_version=s.tls_version)
1419            s.triggered_prcs_commit = True
1420            if s.connection_end == "client":
1421                s.compute_tls13_traffic_secrets()
1422            elif s.connection_end == "server":
1423                s.compute_tls13_traffic_secrets_end()
1424                s.compute_tls13_resumption_secret()
1425
1426
1427# Additional handshake messages
1428
1429###############################################################################
1430#   HelloVerifyRequest                                                        #
1431###############################################################################
1432
1433class TLSHelloVerifyRequest(_TLSHandshake):
1434    """
1435    Defined for DTLS, see RFC 6347.
1436    """
1437    name = "TLS Handshake - Hello Verify Request"
1438    fields_desc = [ByteEnumField("msgtype", 21, _tls_handshake_type),
1439                   ThreeBytesField("msglen", None),
1440                   FieldLenField("cookielen", None,
1441                                 fmt="B", length_of="cookie"),
1442                   StrLenField("cookie", "",
1443                               length_from=lambda pkt: pkt.cookielen)]
1444
1445
1446###############################################################################
1447#   CertificateURL                                                            #
1448###############################################################################
1449
1450_tls_cert_chain_types = {0: "individual_certs",
1451                         1: "pkipath"}
1452
1453
1454class URLAndOptionalHash(Packet):
1455    name = "URLAndOptionHash structure for TLSCertificateURL"
1456    fields_desc = [FieldLenField("urllen", None, length_of="url"),
1457                   StrLenField("url", "",
1458                               length_from=lambda pkt: pkt.urllen),
1459                   FieldLenField("hash_present", None,
1460                                 fmt="B", length_of="hash",
1461                                 adjust=lambda pkt, x: int(math.ceil(x / 20.))),  # noqa: E501
1462                   StrLenField("hash", "",
1463                               length_from=lambda pkt: 20 * pkt.hash_present)]
1464
1465    def guess_payload_class(self, p):
1466        return Padding
1467
1468
1469class TLSCertificateURL(_TLSHandshake):
1470    """
1471    Defined in RFC 4366. PkiPath structure of section 8 is not implemented yet.
1472    """
1473    name = "TLS Handshake - Certificate URL"
1474    fields_desc = [ByteEnumField("msgtype", 21, _tls_handshake_type),
1475                   ThreeBytesField("msglen", None),
1476                   ByteEnumField("certchaintype", None, _tls_cert_chain_types),
1477                   FieldLenField("uahlen", None, length_of="uah"),
1478                   PacketListField("uah", [], URLAndOptionalHash,
1479                                   length_from=lambda pkt: pkt.uahlen)]
1480
1481
1482###############################################################################
1483#   CertificateStatus                                                         #
1484###############################################################################
1485
1486class ThreeBytesLenField(FieldLenField):
1487    def __init__(self, name, default, length_of=None, adjust=lambda pkt, x: x):
1488        FieldLenField.__init__(self, name, default, length_of=length_of,
1489                               fmt='!I', adjust=adjust)
1490
1491    def i2repr(self, pkt, x):
1492        if x is None:
1493            return 0
1494        return repr(self.i2h(pkt, x))
1495
1496    def addfield(self, pkt, s, val):
1497        return s + struct.pack(self.fmt, self.i2m(pkt, val))[1:4]
1498
1499    def getfield(self, pkt, s):
1500        return s[3:], self.m2i(pkt, struct.unpack(self.fmt, b"\x00" + s[:3])[0])  # noqa: E501
1501
1502
1503_cert_status_cls = {1: OCSP_Response}
1504
1505
1506class _StatusField(PacketField):
1507    def m2i(self, pkt, m):
1508        idtype = pkt.status_type
1509        cls = self.cls
1510        if idtype in _cert_status_cls:
1511            cls = _cert_status_cls[idtype]
1512        return cls(m)
1513
1514
1515class TLSCertificateStatus(_TLSHandshake):
1516    name = "TLS Handshake - Certificate Status"
1517    fields_desc = [ByteEnumField("msgtype", 22, _tls_handshake_type),
1518                   ThreeBytesField("msglen", None),
1519                   ByteEnumField("status_type", 1, _cert_status_type),
1520                   ThreeBytesLenField("responselen", None,
1521                                      length_of="response"),
1522                   _StatusField("response", None, Raw)]
1523
1524
1525###############################################################################
1526#   SupplementalData                                                          #
1527###############################################################################
1528
1529class SupDataEntry(Packet):
1530    name = "Supplemental Data Entry - Generic"
1531    fields_desc = [ShortField("sdtype", None),
1532                   FieldLenField("len", None, length_of="data"),
1533                   StrLenField("data", "",
1534                               length_from=lambda pkt:pkt.len)]
1535
1536    def guess_payload_class(self, p):
1537        return Padding
1538
1539
1540class UserMappingData(Packet):
1541    name = "User Mapping Data"
1542    fields_desc = [ByteField("version", None),
1543                   FieldLenField("len", None, length_of="data"),
1544                   StrLenField("data", "",
1545                               length_from=lambda pkt: pkt.len)]
1546
1547    def guess_payload_class(self, p):
1548        return Padding
1549
1550
1551class SupDataEntryUM(Packet):
1552    name = "Supplemental Data Entry - User Mapping"
1553    fields_desc = [ShortField("sdtype", None),
1554                   FieldLenField("len", None, length_of="data",
1555                                 adjust=lambda pkt, x: x + 2),
1556                   FieldLenField("dlen", None, length_of="data"),
1557                   PacketListField("data", [], UserMappingData,
1558                                   length_from=lambda pkt:pkt.dlen)]
1559
1560    def guess_payload_class(self, p):
1561        return Padding
1562
1563
1564class TLSSupplementalData(_TLSHandshake):
1565    name = "TLS Handshake - Supplemental Data"
1566    fields_desc = [ByteEnumField("msgtype", 23, _tls_handshake_type),
1567                   ThreeBytesField("msglen", None),
1568                   ThreeBytesLenField("sdatalen", None, length_of="sdata"),
1569                   PacketListField("sdata", [], SupDataEntry,
1570                                   length_from=lambda pkt: pkt.sdatalen)]
1571
1572
1573###############################################################################
1574#   NewSessionTicket                                                          #
1575###############################################################################
1576
1577class TLSNewSessionTicket(_TLSHandshake):
1578    """
1579    XXX When knowing the right secret, we should be able to read the ticket.
1580    """
1581    name = "TLS Handshake - New Session Ticket"
1582    fields_desc = [ByteEnumField("msgtype", 4, _tls_handshake_type),
1583                   ThreeBytesField("msglen", None),
1584                   IntField("lifetime", 0xffffffff),
1585                   FieldLenField("ticketlen", None, length_of="ticket"),
1586                   StrLenField("ticket", "",
1587                               length_from=lambda pkt: pkt.ticketlen)]
1588
1589    @classmethod
1590    def dispatch_hook(cls, _pkt=None, *args, **kargs):
1591        s = kargs.get("tls_session", None)
1592        if s and s.tls_version and s.tls_version >= 0x0304:
1593            return TLS13NewSessionTicket
1594        return TLSNewSessionTicket
1595
1596    def post_dissection_tls_session_update(self, msg_str):
1597        self.tls_session_update(msg_str)
1598        if self.tls_session.connection_end == "client":
1599            self.tls_session.client_session_ticket = self.ticket
1600
1601
1602class TLS13NewSessionTicket(_TLSHandshake):
1603    """
1604    Uncomment the TicketField line for parsing a RFC 5077 ticket.
1605    """
1606    name = "TLS 1.3 Handshake - New Session Ticket"
1607    fields_desc = [ByteEnumField("msgtype", 4, _tls_handshake_type),
1608                   ThreeBytesField("msglen", None),
1609                   IntField("ticket_lifetime", 0xffffffff),
1610                   IntField("ticket_age_add", 0),
1611                   FieldLenField("noncelen", None, fmt="B",
1612                                 length_of="ticket_nonce"),
1613                   StrLenField("ticket_nonce", "",
1614                               length_from=lambda pkt: pkt.noncelen),
1615                   FieldLenField("ticketlen", None, length_of="ticket"),
1616                   # TicketField("ticket", "",
1617                   StrLenField("ticket", "",
1618                               length_from=lambda pkt: pkt.ticketlen),
1619                   _ExtensionsLenField("extlen", None, length_of="ext"),
1620                   _ExtensionsField("ext", None,
1621                                    length_from=lambda pkt: (pkt.msglen -
1622                                                             (pkt.ticketlen or 0) -  # noqa: E501
1623                                                             pkt.noncelen or 0) - 13)]  # noqa: E501
1624
1625    def build(self):
1626        fval = self.getfieldval("ticket")
1627        if fval == b"":
1628            # Here, the ticket is just a random 48-byte label
1629            # The ticket may also be a self-encrypted and self-authenticated
1630            # value
1631            self.ticket = os.urandom(48)
1632
1633        fval = self.getfieldval("ticket_nonce")
1634        if fval == b"":
1635            # Nonce is randomly chosen
1636            self.ticket_nonce = os.urandom(32)
1637
1638        fval = self.getfieldval("ticket_lifetime")
1639        if fval == 0xffffffff:
1640            # ticket_lifetime is set to 12 hours
1641            self.ticket_lifetime = 43200
1642
1643        fval = self.getfieldval("ticket_age_add")
1644        if fval == 0:
1645            # ticket_age_add is a random 32-bit value
1646            self.ticket_age_add = struct.unpack("!I", os.urandom(4))[0]
1647
1648        return _TLSHandshake.build(self)
1649
1650    def post_dissection_tls_session_update(self, msg_str):
1651        self.tls_session_update(msg_str)
1652        if self.tls_session.connection_end == "client":
1653            self.tls_session.client_session_ticket = self.ticket
1654
1655
1656###############################################################################
1657#   EndOfEarlyData                                                            #
1658###############################################################################
1659
1660class TLS13EndOfEarlyData(_TLSHandshake):
1661    name = "TLS 1.3 Handshake - End Of Early Data"
1662
1663    fields_desc = [ByteEnumField("msgtype", 5, _tls_handshake_type),
1664                   ThreeBytesField("msglen", None)]
1665
1666
1667###############################################################################
1668#   KeyUpdate                                                                 #
1669###############################################################################
1670_key_update_request = {0: "update_not_requested", 1: "update_requested"}
1671
1672
1673class TLS13KeyUpdate(_TLSHandshake):
1674    name = "TLS 1.3 Handshake - Key Update"
1675    fields_desc = [ByteEnumField("msgtype", 24, _tls_handshake_type),
1676                   ThreeBytesField("msglen", None),
1677                   ByteEnumField("request_update", 0, _key_update_request)]
1678
1679    def post_build_tls_session_update(self, msg_str):
1680        s = self.tls_session
1681        s.pwcs = writeConnState(ciphersuite=type(s.wcs.ciphersuite),
1682                                connection_end=s.connection_end,
1683                                tls_version=s.tls_version)
1684        s.triggered_pwcs_commit = True
1685        s.compute_tls13_next_traffic_secrets(s.connection_end, "write")
1686
1687    def post_dissection_tls_session_update(self, msg_str):
1688        s = self.tls_session
1689        s.prcs = writeConnState(ciphersuite=type(s.rcs.ciphersuite),
1690                                connection_end=s.connection_end,
1691                                tls_version=s.tls_version)
1692        s.triggered_prcs_commit = True
1693        if s.connection_end == "server":
1694            s.compute_tls13_next_traffic_secrets("client", "read")
1695        elif s.connection_end == "client":
1696            s.compute_tls13_next_traffic_secrets("server", "read")
1697
1698
1699###############################################################################
1700#   All handshake messages defined in this module                             #
1701###############################################################################
1702
1703_tls_handshake_cls = {0: TLSHelloRequest, 1: TLSClientHello,
1704                      2: TLSServerHello, 3: TLSHelloVerifyRequest,
1705                      4: TLSNewSessionTicket,
1706                      8: TLSEncryptedExtensions, 11: TLSCertificate,
1707                      12: TLSServerKeyExchange, 13: TLSCertificateRequest,
1708                      14: TLSServerHelloDone, 15: TLSCertificateVerify,
1709                      16: TLSClientKeyExchange, 20: TLSFinished,
1710                      21: TLSCertificateURL, 22: TLSCertificateStatus,
1711                      23: TLSSupplementalData}
1712
1713_tls13_handshake_cls = {1: TLS13ClientHello, 2: TLS13ServerHello,
1714                        4: TLS13NewSessionTicket, 5: TLS13EndOfEarlyData,
1715                        8: TLSEncryptedExtensions, 11: TLS13Certificate,
1716                        13: TLS13CertificateRequest, 15: TLSCertificateVerify,
1717                        20: TLSFinished, 24: TLS13KeyUpdate}
1718