1# This file is part of Scapy
2# Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
3#               2015, 2016, 2017 Maxence Tury
4#               2019 Romain Perez
5# This program is published under a GPLv2 license
6
7"""
8TLS session handler.
9"""
10
11import socket
12import struct
13
14from scapy.config import conf
15from scapy.compat import raw
16import scapy.modules.six as six
17from scapy.error import log_runtime, warning
18from scapy.packet import Packet
19from scapy.pton_ntop import inet_pton
20from scapy.sessions import DefaultSession
21from scapy.utils import repr_hex, strxor
22from scapy.layers.inet import TCP
23from scapy.layers.tls.crypto.compression import Comp_NULL
24from scapy.layers.tls.crypto.hkdf import TLS13_HKDF
25from scapy.layers.tls.crypto.prf import PRF
26
27# Note the following import may happen inside connState.__init__()
28# in order to avoid to avoid cyclical dependencies.
29# from scapy.layers.tls.crypto.suites import TLS_NULL_WITH_NULL_NULL
30
31
32###############################################################################
33#   Connection states                                                         #
34###############################################################################
35
36class connState(object):
37    """
38    From RFC 5246, section 6.1:
39    A TLS connection state is the operating environment of the TLS Record
40    Protocol.  It specifies a compression algorithm, an encryption
41    algorithm, and a MAC algorithm.  In addition, the parameters for
42    these algorithms are known: the MAC key and the bulk encryption keys
43    for the connection in both the read and the write directions.
44    Logically, there are always four connection states outstanding: the
45    current read and write states, and the pending read and write states.
46    All records are processed under the current read and write states.
47    The security parameters for the pending states can be set by the TLS
48    Handshake Protocol, and the ChangeCipherSpec can selectively make
49    either of the pending states current, in which case the appropriate
50    current state is disposed of and replaced with the pending state; the
51    pending state is then reinitialized to an empty state.  It is illegal
52    to make a state that has not been initialized with security
53    parameters a current state.  The initial current state always
54    specifies that no encryption, compression, or MAC will be used.
55
56    (For practical reasons, Scapy scraps these two last lines, through the
57    implementation of dummy ciphers and MAC with TLS_NULL_WITH_NULL_NULL.)
58
59    These attributes and behaviours are mostly mapped in this class.
60    Also, note that Scapy may make a current state out of a pending state
61    which has been initialized with dummy security parameters. We need
62    this in order to know when the content of a TLS message is encrypted,
63    whether we possess the right keys to decipher/verify it or not.
64    For instance, when Scapy parses a CKE without knowledge of any secret,
65    and then a CCS, it needs to know that the following Finished
66    is encrypted and signed according to a new cipher suite, even though
67    it cannot decipher the message nor verify its integrity.
68    """
69
70    def __init__(self,
71                 connection_end="server",
72                 read_or_write="read",
73                 seq_num=0,
74                 compression_alg=Comp_NULL,
75                 ciphersuite=None,
76                 tls_version=0x0303):
77
78        self.tls_version = tls_version
79
80        # It is the user's responsibility to keep the record seq_num
81        # under 2**64-1. If this value gets maxed out, the TLS class in
82        # record.py will crash when trying to encode it with struct.pack().
83        self.seq_num = seq_num
84
85        self.connection_end = connection_end
86        self.row = read_or_write
87
88        if ciphersuite is None:
89            from scapy.layers.tls.crypto.suites import TLS_NULL_WITH_NULL_NULL
90            ciphersuite = TLS_NULL_WITH_NULL_NULL
91        self.ciphersuite = ciphersuite(tls_version=tls_version)
92
93        if not self.ciphersuite.usable:
94            warning("TLS cipher suite not usable. "
95                    "Is the cryptography Python module installed?")
96            return
97
98        self.compression = compression_alg()
99        self.key_exchange = ciphersuite.kx_alg()
100        self.cipher = ciphersuite.cipher_alg()
101        self.hash = ciphersuite.hash_alg()
102
103        if tls_version > 0x0200:
104            if ciphersuite.cipher_alg.type == "aead":
105                self.hmac = None
106                self.mac_len = self.cipher.tag_len
107            else:
108                self.hmac = ciphersuite.hmac_alg()
109                self.mac_len = self.hmac.hmac_len
110        else:
111            self.hmac = ciphersuite.hmac_alg()          # should be Hmac_NULL
112            self.mac_len = self.hash.hash_len
113
114        if tls_version >= 0x0304:
115            self.hkdf = TLS13_HKDF(self.hash.name.lower())
116        else:
117            self.prf = PRF(ciphersuite.hash_alg.name, tls_version)
118
119    def debug_repr(self, name, secret):
120        if conf.debug_tls and secret:
121            log_runtime.debug("TLS: %s %s %s: %s",
122                              self.connection_end,
123                              self.row,
124                              name,
125                              repr_hex(secret))
126
127    def derive_keys(self,
128                    client_random=b"",
129                    server_random=b"",
130                    master_secret=b""):
131        # XXX Can this be called over a non-usable suite? What happens then?
132
133        cs = self.ciphersuite
134
135        # Derive the keys according to the cipher type and protocol version
136        key_block = self.prf.derive_key_block(master_secret,
137                                              server_random,
138                                              client_random,
139                                              cs.key_block_len)
140
141        # When slicing the key_block, keep the right half of the material
142        skip_first = False
143        if ((self.connection_end == "client" and self.row == "read") or
144                (self.connection_end == "server" and self.row == "write")):
145            skip_first = True
146
147        pos = 0
148        cipher_alg = cs.cipher_alg
149
150        # MAC secret (for block and stream ciphers)
151        if (cipher_alg.type == "stream") or (cipher_alg.type == "block"):
152            start = pos
153            if skip_first:
154                start += cs.hmac_alg.key_len
155            end = start + cs.hmac_alg.key_len
156            mac_secret = key_block[start:end]
157            self.debug_repr("mac_secret", mac_secret)
158            pos += 2 * cs.hmac_alg.key_len
159        else:
160            mac_secret = None
161
162        # Cipher secret
163        start = pos
164        if skip_first:
165            start += cipher_alg.key_len
166        end = start + cipher_alg.key_len
167        cipher_secret = key_block[start:end]
168        if cs.kx_alg.export:
169            reqLen = cipher_alg.expanded_key_len
170            cipher_secret = self.prf.postprocess_key_for_export(cipher_secret,
171                                                                client_random,
172                                                                server_random,
173                                                                self.connection_end,  # noqa: E501
174                                                                self.row,
175                                                                reqLen)
176        self.debug_repr("cipher_secret", cipher_secret)
177        pos += 2 * cipher_alg.key_len
178
179        # Implicit IV (for block and AEAD ciphers)
180        start = pos
181        if cipher_alg.type == "block":
182            if skip_first:
183                start += cipher_alg.block_size
184            end = start + cipher_alg.block_size
185        elif cipher_alg.type == "aead":
186            if skip_first:
187                start += cipher_alg.fixed_iv_len
188            end = start + cipher_alg.fixed_iv_len
189
190        # Now we have the secrets, we can instantiate the algorithms
191        if cs.hmac_alg is None:         # AEAD
192            self.hmac = None
193            self.mac_len = cipher_alg.tag_len
194        else:
195            self.hmac = cs.hmac_alg(mac_secret)
196            self.mac_len = self.hmac.hmac_len
197
198        if cipher_alg.type == "stream":
199            cipher = cipher_alg(cipher_secret)
200        elif cipher_alg.type == "block":
201            # We set an IV every time, even though it does not matter for
202            # TLS 1.1+ as it requires an explicit IV. Indeed the cipher.iv
203            # would get updated in TLS.post_build() or TLS.pre_dissect().
204            iv = key_block[start:end]
205            if cs.kx_alg.export:
206                reqLen = cipher_alg.block_size
207                iv = self.prf.generate_iv_for_export(client_random,
208                                                     server_random,
209                                                     self.connection_end,
210                                                     self.row,
211                                                     reqLen)
212            cipher = cipher_alg(cipher_secret, iv)
213            self.debug_repr("block iv", iv)
214        elif cipher_alg.type == "aead":
215            fixed_iv = key_block[start:end]
216            nonce_explicit_init = 0
217            # If you ever wanted to set a random nonce_explicit, use this:
218            # exp_bit_len = cipher_alg.nonce_explicit_len * 8
219            # nonce_explicit_init = random.randint(0, 2**exp_bit_len - 1)
220            cipher = cipher_alg(cipher_secret, fixed_iv, nonce_explicit_init)
221            self.debug_repr("aead fixed iv", fixed_iv)
222        self.cipher = cipher
223
224    def sslv2_derive_keys(self, key_material):
225        """
226        There is actually only one key, the CLIENT-READ-KEY or -WRITE-KEY.
227
228        Note that skip_first is opposite from the one with SSLv3 derivation.
229
230        Also, if needed, the IV should be set elsewhere.
231        """
232        skip_first = True
233        if ((self.connection_end == "client" and self.row == "read") or
234                (self.connection_end == "server" and self.row == "write")):
235            skip_first = False
236
237        cipher_alg = self.ciphersuite.cipher_alg
238
239        start = 0
240        if skip_first:
241            start += cipher_alg.key_len
242        end = start + cipher_alg.key_len
243        cipher_secret = key_material[start:end]
244        self.cipher = cipher_alg(cipher_secret)
245        self.debug_repr("cipher_secret", cipher_secret)
246
247    def tls13_derive_keys(self, key_material):
248        cipher_alg = self.ciphersuite.cipher_alg
249        key_len = cipher_alg.key_len
250        iv_len = cipher_alg.fixed_iv_len
251        write_key = self.hkdf.expand_label(key_material, b"key", b"", key_len)
252        write_iv = self.hkdf.expand_label(key_material, b"iv", b"", iv_len)
253        self.cipher = cipher_alg(write_key, write_iv)
254
255    def snapshot(self):
256        """
257        This is used mostly as a way to keep the cipher state and the seq_num.
258        """
259        snap = connState(connection_end=self.connection_end,
260                         read_or_write=self.row,
261                         seq_num=self.seq_num,
262                         compression_alg=type(self.compression),
263                         ciphersuite=type(self.ciphersuite),
264                         tls_version=self.tls_version)
265        snap.cipher = self.cipher.snapshot()
266        if self.hmac:
267            snap.hmac.key = self.hmac.key
268        return snap
269
270    def __repr__(self):
271        res = "Connection end : %s\n" % self.connection_end.upper()
272        res += "Cipher suite   : %s (0x%04x)\n" % (self.ciphersuite.name,
273                                                   self.ciphersuite.val)
274        res += "Compression    : %s (0x%02x)\n" % (self.compression.name,
275                                                   self.compression.val)
276        tabsize = 4
277        return res.expandtabs(tabsize)
278
279
280class readConnState(connState):
281    def __init__(self, **kargs):
282        connState.__init__(self, read_or_write="read", **kargs)
283
284
285class writeConnState(connState):
286    def __init__(self, **kargs):
287        connState.__init__(self, read_or_write="write", **kargs)
288
289
290###############################################################################
291#   TLS session                                                               #
292###############################################################################
293
294class tlsSession(object):
295    """
296    This is our TLS context, which gathers information from both sides of the
297    TLS connection. These sides are represented by a readConnState instance and
298    a writeConnState instance. Along with overarching network attributes, a
299    tlsSession object also holds negotiated, shared information, such as the
300    key exchange parameters and the master secret (when available).
301
302    The default connection_end is "server". This corresponds to the expected
303    behaviour for static exchange analysis (with a ClientHello parsed first).
304    """
305
306    def __init__(self,
307                 ipsrc=None, ipdst=None,
308                 sport=None, dport=None, sid=None,
309                 connection_end="server",
310                 wcs=None, rcs=None):
311
312        # Use this switch to prevent additions to the 'handshake_messages'.
313        self.frozen = False
314
315        # Network settings
316        self.ipsrc = ipsrc
317        self.ipdst = ipdst
318        self.sport = sport
319        self.dport = dport
320        self.sid = sid
321
322        # Our TCP socket. None until we send (or receive) a packet.
323        self.sock = None
324
325        # Connection states
326        self.connection_end = connection_end
327
328        if wcs is None:
329            # Instantiate wcs with dummy values.
330            self.wcs = writeConnState(connection_end=connection_end)
331            self.wcs.derive_keys()
332        else:
333            self.wcs = wcs
334
335        if rcs is None:
336            # Instantiate rcs with dummy values.
337            self.rcs = readConnState(connection_end=connection_end)
338            self.rcs.derive_keys()
339        else:
340            self.rcs = rcs
341
342        # The pending write/read states are updated by the building/parsing
343        # of various TLS packets. They get committed to self.wcs/self.rcs
344        # once Scapy builds/parses a ChangeCipherSpec message, or for certain
345        # other messages in case of TLS 1.3.
346        self.pwcs = None
347        self.triggered_pwcs_commit = False
348        self.prcs = None
349        self.triggered_prcs_commit = False
350
351        # Certificates and private keys
352
353        # The server certificate chain, as a list of Cert instances.
354        # Either we act as server and it has to be provided, or it is expected
355        # to be sent by the server through a Certificate message.
356        # The server certificate should be self.server_certs[0].
357        self.server_certs = []
358
359        # The server private key, as a PrivKey instance, when acting as server.
360        # XXX It would be nice to be able to provide both an RSA and an ECDSA
361        # key in order for the same Scapy server to support both families of
362        # cipher suites. See INIT_TLS_SESSION() in automaton_srv.py.
363        # (For now server_key holds either one of both types for DHE
364        # authentication, while server_rsa_key is used only for RSAkx.)
365        self.server_key = None
366        self.server_rsa_key = None
367        # self.server_ecdsa_key = None
368
369        # Back in the dreadful EXPORT days, US servers were forbidden to use
370        # RSA keys longer than 512 bits for RSAkx. When their usual RSA key
371        # was longer than this, they had to create a new key and send it via
372        # a ServerRSAParams message. When receiving such a message,
373        # Scapy stores this key in server_tmp_rsa_key as a PubKey instance.
374        self.server_tmp_rsa_key = None
375
376        # When client authentication is performed, we need at least a
377        # client certificate chain. If we act as client, we also have
378        # to provide the key associated with the first certificate.
379        self.client_certs = []
380        self.client_key = None
381
382        # Ephemeral key exchange parameters
383
384        # These are the group/curve parameters, needed to hold the information
385        # e.g. from receiving an SKE to sending a CKE. Usually, only one of
386        # these attributes will be different from None.
387        self.client_kx_ffdh_params = None
388        self.client_kx_ecdh_params = None
389
390        # These are PrivateKeys and PublicKeys from the appropriate FFDH/ECDH
391        # cryptography module, i.e. these are not raw bytes. Usually, only one
392        # in two will be different from None, e.g. when being a TLS client you
393        # will need the client_kx_privkey (the serialized public key is not
394        # actually registered) and you will receive a server_kx_pubkey.
395        self.client_kx_privkey = None
396        self.client_kx_pubkey = None
397        self.server_kx_privkey = None
398        self.server_kx_pubkey = None
399
400        # When using TLS 1.3, the tls13_client_pubshares will contain every
401        # potential key share (equate the 'client_kx_pubkey' before) the client
402        # offered, indexed by the id of the FFDH/ECDH group. These dicts
403        # effectively replace the four previous attributes.
404        self.tls13_client_privshares = {}
405        self.tls13_client_pubshares = {}
406        self.tls13_server_privshare = {}
407        self.tls13_server_pubshare = {}
408
409        # Negotiated session parameters
410
411        # The advertised TLS version found in the ClientHello (and
412        # EncryptedPreMasterSecret if used). If acting as server, it is set to
413        # the value advertised by the client in its ClientHello.
414        # The default value corresponds to TLS 1.2 (and TLS 1.3, incidentally).
415        self.advertised_tls_version = 0x0303
416
417        # The agreed-upon TLS version found in the ServerHello.
418        self.tls_version = None
419
420        # These attributes should eventually be known to both sides (SSLv3-TLS 1.2).  # noqa: E501
421        self.client_random = None
422        self.server_random = None
423        self.pre_master_secret = None
424        self.master_secret = None
425
426        # The agreed-upon signature algorithm (for TLS 1.2-TLS 1.3 only)
427        self.selected_sig_alg = None
428
429        # A session ticket received by the client.
430        self.client_session_ticket = None
431
432        # These attributes should only be used with SSLv2 connections.
433        # We need to keep the KEY-MATERIAL here because it may be reused.
434        self.sslv2_common_cs = []
435        self.sslv2_connection_id = None
436        self.sslv2_challenge = None
437        self.sslv2_challenge_clientcert = None
438        self.sslv2_key_material = None
439
440        # These attributes should only be used with TLS 1.3 connections.
441        self.tls13_psk_secret = None
442        self.tls13_early_secret = None
443        self.tls13_dhe_secret = None
444        self.tls13_handshake_secret = None
445        self.tls13_master_secret = None
446        self.tls13_derived_secrets = {}
447        self.post_handshake_auth = False
448        self.tls13_ticket_ciphersuite = None
449        self.tls13_retry = False
450        self.middlebox_compatibility = False
451
452        # Handshake messages needed for Finished computation/validation.
453        # No record layer headers, no HelloRequests, no ChangeCipherSpecs.
454        self.handshake_messages = []
455        self.handshake_messages_parsed = []
456
457        # Flag, whether we derive the secret as Extended MS or not
458        self.extms = False
459        self.session_hash = None
460
461        self.encrypt_then_mac = False
462
463        # All exchanged TLS packets.
464        # XXX no support for now
465        # self.exchanged_pkts = []
466
467    def __setattr__(self, name, val):
468        if name == "connection_end":
469            if hasattr(self, "rcs") and self.rcs:
470                self.rcs.connection_end = val
471            if hasattr(self, "wcs") and self.wcs:
472                self.wcs.connection_end = val
473            if hasattr(self, "prcs") and self.prcs:
474                self.prcs.connection_end = val
475            if hasattr(self, "pwcs") and self.pwcs:
476                self.pwcs.connection_end = val
477        super(tlsSession, self).__setattr__(name, val)
478
479    # Mirroring
480
481    def mirror(self):
482        """
483        This function takes a tlsSession object and swaps the IP addresses,
484        ports, connection ends and connection states. The triggered_commit are
485        also swapped (though it is probably overkill, it is cleaner this way).
486
487        It is useful for static analysis of a series of messages from both the
488        client and the server. In such a situation, it should be used every
489        time the message being read comes from a different side than the one
490        read right before, as the reading state becomes the writing state, and
491        vice versa. For instance you could do:
492
493        client_hello = open('client_hello.raw').read()
494        <read other messages>
495
496        m1 = TLS(client_hello)
497        m2 = TLS(server_hello, tls_session=m1.tls_session.mirror())
498        m3 = TLS(server_cert, tls_session=m2.tls_session)
499        m4 = TLS(client_keyexchange, tls_session=m3.tls_session.mirror())
500        """
501
502        self.ipdst, self.ipsrc = self.ipsrc, self.ipdst
503        self.dport, self.sport = self.sport, self.dport
504
505        self.rcs, self.wcs = self.wcs, self.rcs
506        if self.rcs:
507            self.rcs.row = "read"
508        if self.wcs:
509            self.wcs.row = "write"
510
511        self.prcs, self.pwcs = self.pwcs, self.prcs
512        if self.prcs:
513            self.prcs.row = "read"
514        if self.pwcs:
515            self.pwcs.row = "write"
516
517        self.triggered_prcs_commit, self.triggered_pwcs_commit = \
518            self.triggered_pwcs_commit, self.triggered_prcs_commit
519
520        if self.connection_end == "client":
521            self.connection_end = "server"
522        elif self.connection_end == "server":
523            self.connection_end = "client"
524
525        return self
526
527    # Secrets management for SSLv3 to TLS 1.2
528
529    def compute_master_secret(self):
530        if self.pre_master_secret is None:
531            warning("Missing pre_master_secret while computing master_secret!")
532        if self.client_random is None:
533            warning("Missing client_random while computing master_secret!")
534        if self.server_random is None:
535            warning("Missing server_random while computing master_secret!")
536        if self.extms and self.session_hash is None:
537            warning("Missing session hash while computing master secret!")
538
539        ms = self.pwcs.prf.compute_master_secret(self.pre_master_secret,
540                                                 self.client_random,
541                                                 self.server_random,
542                                                 self.extms,
543                                                 self.session_hash)
544        self.master_secret = ms
545        if conf.debug_tls:
546            log_runtime.debug("TLS: master secret: %s", repr_hex(ms))
547
548    def compute_ms_and_derive_keys(self):
549        self.compute_master_secret()
550        self.prcs.derive_keys(client_random=self.client_random,
551                              server_random=self.server_random,
552                              master_secret=self.master_secret)
553        self.pwcs.derive_keys(client_random=self.client_random,
554                              server_random=self.server_random,
555                              master_secret=self.master_secret)
556
557    # Secrets management for SSLv2
558
559    def compute_sslv2_key_material(self):
560        if self.master_secret is None:
561            warning("Missing master_secret while computing key_material!")
562        if self.sslv2_challenge is None:
563            warning("Missing challenge while computing key_material!")
564        if self.sslv2_connection_id is None:
565            warning("Missing connection_id while computing key_material!")
566
567        km = self.pwcs.prf.derive_key_block(self.master_secret,
568                                            self.sslv2_challenge,
569                                            self.sslv2_connection_id,
570                                            2 * self.pwcs.cipher.key_len)
571        self.sslv2_key_material = km
572        if conf.debug_tls:
573            log_runtime.debug("TLS: master secret: %s", repr_hex(self.master_secret))  # noqa: E501
574            log_runtime.debug("TLS: key material: %s", repr_hex(km))
575
576    def compute_sslv2_km_and_derive_keys(self):
577        self.compute_sslv2_key_material()
578        self.prcs.sslv2_derive_keys(key_material=self.sslv2_key_material)
579        self.pwcs.sslv2_derive_keys(key_material=self.sslv2_key_material)
580
581    # Secrets management for TLS 1.3
582
583    def compute_tls13_early_secrets(self, external=False):
584        """
585        This function computes the Early Secret, the binder_key,
586        the client_early_traffic_secret and the
587        early_exporter_master_secret (See RFC8446, section 7.1).
588
589        The parameter external is used for the computation of the
590        binder_key:
591
592        - For external PSK provisioned outside out of TLS, the parameter
593          external must be True.
594        - For resumption PSK, the parameter external must be False.
595
596        If no argument is specified, the label "res binder" will be
597        used by default.
598
599        Ciphers key and IV are updated accordingly for 0-RTT data.
600        self.handshake_messages should be ClientHello only.
601        """
602
603        # if no hash algorithm is set, default to SHA-256
604        if self.prcs and self.prcs.hkdf:
605            hkdf = self.prcs.hkdf
606        elif self.pwcs and self.pwcs.hkdf:
607            hkdf = self.pwcs.hkdf
608        else:
609            hkdf = TLS13_HKDF("sha256")
610
611        if self.tls13_early_secret is None:
612            self.tls13_early_secret = hkdf.extract(None,
613                                                   self.tls13_psk_secret)
614
615        if "binder_key" not in self.tls13_derived_secrets:
616            if external:
617                bk = hkdf.derive_secret(self.tls13_early_secret,
618                                        b"ext binder",
619                                        b"")
620            else:
621                bk = hkdf.derive_secret(self.tls13_early_secret,
622                                        b"res binder",
623                                        b"")
624
625            self.tls13_derived_secrets["binder_key"] = bk
626
627        cets = hkdf.derive_secret(self.tls13_early_secret,
628                                  b"c e traffic",
629                                  b"".join(self.handshake_messages))
630
631        self.tls13_derived_secrets["client_early_traffic_secret"] = cets
632        ees = hkdf.derive_secret(self.tls13_early_secret,
633                                 b"e exp master",
634                                 b"".join(self.handshake_messages))
635        self.tls13_derived_secrets["early_exporter_secret"] = ees
636
637        if self.connection_end == "server":
638            if self.prcs:
639                self.prcs.tls13_derive_keys(cets)
640        elif self.connection_end == "client":
641            if self.pwcs:
642                self.pwcs.tls13_derive_keys(cets)
643
644    def compute_tls13_handshake_secrets(self):
645        """
646        Ciphers key and IV are updated accordingly for Handshake data.
647        self.handshake_messages should be ClientHello...ServerHello.
648        """
649        if self.prcs:
650            hkdf = self.prcs.hkdf
651        elif self.pwcs:
652            hkdf = self.pwcs.hkdf
653        else:
654            warning("No HKDF. This is abnormal.")
655            return
656
657        if self.tls13_early_secret is None:
658            self.tls13_early_secret = hkdf.extract(None,
659                                                   self.tls13_psk_secret)
660
661        secret = hkdf.derive_secret(self.tls13_early_secret, b"derived", b"")
662        self.tls13_handshake_secret = hkdf.extract(secret, self.tls13_dhe_secret)  # noqa: E501
663
664        chts = hkdf.derive_secret(self.tls13_handshake_secret,
665                                  b"c hs traffic",
666                                  b"".join(self.handshake_messages))
667        self.tls13_derived_secrets["client_handshake_traffic_secret"] = chts
668
669        shts = hkdf.derive_secret(self.tls13_handshake_secret,
670                                  b"s hs traffic",
671                                  b"".join(self.handshake_messages))
672        self.tls13_derived_secrets["server_handshake_traffic_secret"] = shts
673
674    def compute_tls13_traffic_secrets(self):
675        """
676        Ciphers key and IV are updated accordingly for Application data.
677        self.handshake_messages should be ClientHello...ServerFinished.
678        """
679        if self.prcs and self.prcs.hkdf:
680            hkdf = self.prcs.hkdf
681        elif self.pwcs and self.pwcs.hkdf:
682            hkdf = self.pwcs.hkdf
683        else:
684            warning("No HKDF. This is abnormal.")
685            return
686
687        tmp = hkdf.derive_secret(self.tls13_handshake_secret,
688                                 b"derived",
689                                 b"")
690        self.tls13_master_secret = hkdf.extract(tmp, None)
691
692        cts0 = hkdf.derive_secret(self.tls13_master_secret,
693                                  b"c ap traffic",
694                                  b"".join(self.handshake_messages))
695        self.tls13_derived_secrets["client_traffic_secrets"] = [cts0]
696
697        sts0 = hkdf.derive_secret(self.tls13_master_secret,
698                                  b"s ap traffic",
699                                  b"".join(self.handshake_messages))
700        self.tls13_derived_secrets["server_traffic_secrets"] = [sts0]
701
702        es = hkdf.derive_secret(self.tls13_master_secret,
703                                b"exp master",
704                                b"".join(self.handshake_messages))
705        self.tls13_derived_secrets["exporter_secret"] = es
706
707        if self.connection_end == "server":
708            # self.prcs.tls13_derive_keys(cts0)
709            self.pwcs.tls13_derive_keys(sts0)
710        elif self.connection_end == "client":
711            # self.pwcs.tls13_derive_keys(cts0)
712            self.prcs.tls13_derive_keys(sts0)
713
714    def compute_tls13_traffic_secrets_end(self):
715        cts0 = self.tls13_derived_secrets["client_traffic_secrets"][0]
716        if self.connection_end == "server":
717            self.prcs.tls13_derive_keys(cts0)
718        elif self.connection_end == "client":
719            self.pwcs.tls13_derive_keys(cts0)
720
721    def compute_tls13_verify_data(self, connection_end, read_or_write):
722        shts = "server_handshake_traffic_secret"
723        chts = "client_handshake_traffic_secret"
724        if read_or_write == "read":
725            hkdf = self.rcs.hkdf
726            if connection_end == "client":
727                basekey = self.tls13_derived_secrets[shts]
728            elif connection_end == "server":
729                basekey = self.tls13_derived_secrets[chts]
730        elif read_or_write == "write":
731            hkdf = self.wcs.hkdf
732            if connection_end == "client":
733                basekey = self.tls13_derived_secrets[chts]
734            elif connection_end == "server":
735                basekey = self.tls13_derived_secrets[shts]
736
737        if not hkdf or not basekey:
738            warning("Missing arguments for verify_data computation!")
739            return None
740        # XXX this join() works in standard cases, but does it in all of them?
741        handshake_context = b"".join(self.handshake_messages)
742        return hkdf.compute_verify_data(basekey, handshake_context)
743
744    def compute_tls13_resumption_secret(self):
745        """
746        self.handshake_messages should be ClientHello...ClientFinished.
747        """
748        if self.connection_end == "server":
749            hkdf = self.prcs.hkdf
750        elif self.connection_end == "client":
751            hkdf = self.pwcs.hkdf
752        rs = hkdf.derive_secret(self.tls13_master_secret,
753                                b"res master",
754                                b"".join(self.handshake_messages))
755        self.tls13_derived_secrets["resumption_secret"] = rs
756
757    def compute_tls13_next_traffic_secrets(self, connection_end, read_or_write):  # noqa : E501
758        """
759        Ciphers key and IV are updated accordingly.
760        """
761        if self.rcs.hkdf:
762            hkdf = self.rcs.hkdf
763            hl = hkdf.hash.digest_size
764        elif self.wcs.hkdf:
765            hkdf = self.wcs.hkdf
766            hl = hkdf.hash.digest_size
767
768        if read_or_write == "read":
769            if connection_end == "client":
770                cts = self.tls13_derived_secrets["client_traffic_secrets"]
771                ctsN = cts[-1]
772                ctsN_1 = hkdf.expand_label(ctsN, b"traffic upd", b"", hl)
773                cts.append(ctsN_1)
774                self.prcs.tls13_derive_keys(ctsN_1)
775            elif connection_end == "server":
776                sts = self.tls13_derived_secrets["server_traffic_secrets"]
777                stsN = sts[-1]
778                stsN_1 = hkdf.expand_label(stsN, b"traffic upd", b"", hl)
779                sts.append(stsN_1)
780
781                self.prcs.tls13_derive_keys(stsN_1)
782
783        elif read_or_write == "write":
784            if connection_end == "client":
785                cts = self.tls13_derived_secrets["client_traffic_secrets"]
786                ctsN = cts[-1]
787                ctsN_1 = hkdf.expand_label(ctsN, b"traffic upd", b"", hl)
788                cts.append(ctsN_1)
789                self.pwcs.tls13_derive_keys(ctsN_1)
790            elif connection_end == "server":
791                sts = self.tls13_derived_secrets["server_traffic_secrets"]
792                stsN = sts[-1]
793                stsN_1 = hkdf.expand_label(stsN, b"traffic upd", b"", hl)
794                sts.append(stsN_1)
795
796                self.pwcs.tls13_derive_keys(stsN_1)
797
798    # Tests for record building/parsing
799
800    def consider_read_padding(self):
801        # Return True if padding is needed. Used by TLSPadField.
802        return (self.rcs.cipher.type == "block" and
803                not (False in six.itervalues(self.rcs.cipher.ready)))
804
805    def consider_write_padding(self):
806        # Return True if padding is needed. Used by TLSPadField.
807        return self.wcs.cipher.type == "block"
808
809    def use_explicit_iv(self, version, cipher_type):
810        # Return True if an explicit IV is needed. Required for TLS 1.1+
811        # when either a block or an AEAD cipher is used.
812        if cipher_type == "stream":
813            return False
814        return version >= 0x0302
815
816    # Python object management
817
818    def hash(self):
819        s1 = struct.pack("!H", self.sport)
820        s2 = struct.pack("!H", self.dport)
821        family = socket.AF_INET
822        if ':' in self.ipsrc:
823            family = socket.AF_INET6
824        s1 += inet_pton(family, self.ipsrc)
825        s2 += inet_pton(family, self.ipdst)
826        return strxor(s1, s2)
827
828    def eq(self, other):
829        ok = False
830        if (self.sport == other.sport and self.dport == other.dport and
831                self.ipsrc == other.ipsrc and self.ipdst == other.ipdst):
832            ok = True
833
834        if (not ok and
835            self.dport == other.sport and self.sport == other.dport and
836                self.ipdst == other.ipsrc and self.ipsrc == other.ipdst):
837            ok = True
838
839        if ok:
840            if self.sid and other.sid:
841                return self.sid == other.sid
842            return True
843
844        return False
845
846    def __repr__(self):
847        sid = repr(self.sid)
848        if len(sid) > 12:
849            sid = sid[:11] + "..."
850        return "%s:%s > %s:%s" % (self.ipsrc, str(self.sport),
851                                  self.ipdst, str(self.dport))
852
853###############################################################################
854#   Session singleton                                                         #
855###############################################################################
856
857
858class _GenericTLSSessionInheritance(Packet):
859    """
860    Many classes inside the TLS module need to get access to session-related
861    information. For instance, an encrypted TLS record cannot be parsed without
862    some knowledge of the cipher suite being used and the secrets which have
863    been negotiated. Passing information is also essential to the handshake.
864    To this end, various TLS objects inherit from the present class.
865    """
866    __slots__ = ["tls_session", "rcs_snap_init", "wcs_snap_init"]
867    name = "Dummy Generic TLS Packet"
868    fields_desc = []
869
870    def __init__(self, _pkt="", post_transform=None, _internal=0,
871                 _underlayer=None, tls_session=None, **fields):
872        try:
873            setme = self.tls_session is None
874        except Exception:
875            setme = True
876
877        newses = False
878        if setme:
879            if tls_session is None:
880                newses = True
881                self.tls_session = tlsSession()
882            else:
883                self.tls_session = tls_session
884
885        self.rcs_snap_init = self.tls_session.rcs.snapshot()
886        self.wcs_snap_init = self.tls_session.wcs.snapshot()
887
888        if isinstance(_underlayer, TCP):
889            tcp = _underlayer
890            self.tls_session.sport = tcp.sport
891            self.tls_session.dport = tcp.dport
892            try:
893                self.tls_session.ipsrc = tcp.underlayer.src
894                self.tls_session.ipdst = tcp.underlayer.dst
895            except AttributeError:
896                pass
897            if conf.tls_session_enable:
898                if newses:
899                    s = conf.tls_sessions.find(self.tls_session)
900                    if s:
901                        if s.dport == self.tls_session.dport:
902                            self.tls_session = s
903                        else:
904                            self.tls_session = s.mirror()
905                    else:
906                        conf.tls_sessions.add(self.tls_session)
907            if self.tls_session.connection_end == "server":
908                srk = conf.tls_sessions.server_rsa_key
909                if not self.tls_session.server_rsa_key and \
910                        srk:
911                    self.tls_session.server_rsa_key = srk
912
913        Packet.__init__(self, _pkt=_pkt, post_transform=post_transform,
914                        _internal=_internal, _underlayer=_underlayer,
915                        **fields)
916
917    def __getattr__(self, attr):
918        """
919        The tls_session should be found only through the normal mechanism.
920        """
921        if attr == "tls_session":
922            return None
923        return super(_GenericTLSSessionInheritance, self).__getattr__(attr)
924
925    def tls_session_update(self, msg_str):
926        """
927        post_{build, dissection}_tls_session_update() are used to update the
928        tlsSession context. The default definitions below, along with
929        tls_session_update(), may prevent code duplication in some cases.
930        """
931        pass
932
933    def post_build_tls_session_update(self, msg_str):
934        self.tls_session_update(msg_str)
935
936    def post_dissection_tls_session_update(self, msg_str):
937        self.tls_session_update(msg_str)
938
939    def copy(self):
940        pkt = Packet.copy(self)
941        pkt.tls_session = self.tls_session
942        return pkt
943
944    def clone_with(self, payload=None, **kargs):
945        pkt = Packet.clone_with(self, payload=payload, **kargs)
946        pkt.tls_session = self.tls_session
947        return pkt
948
949    def raw_stateful(self):
950        return super(_GenericTLSSessionInheritance, self).__bytes__()
951
952    def str_stateful(self):
953        return self.raw_stateful()
954
955    def __bytes__(self):
956        """
957        The __bytes__ call has to leave the connection states unchanged.
958        We also have to delete raw_packet_cache in order to access post_build.
959
960        For performance, the pending connStates are not snapshotted.
961        This should not be an issue, but maybe pay attention to this.
962
963        The previous_freeze_state prevents issues with calling a raw() calling
964        in turn another raw(), which would unfreeze the session too soon.
965        """
966        s = self.tls_session
967        rcs_snap = s.rcs.snapshot()
968        wcs_snap = s.wcs.snapshot()
969        rpc_snap = self.raw_packet_cache
970        rpcf_snap = self.raw_packet_cache_fields
971
972        s.wcs = self.rcs_snap_init
973
974        self.raw_packet_cache = None
975        previous_freeze_state = s.frozen
976        s.frozen = True
977        built_packet = super(_GenericTLSSessionInheritance, self).__bytes__()
978        s.frozen = previous_freeze_state
979
980        s.rcs = rcs_snap
981        s.wcs = wcs_snap
982        self.raw_packet_cache = rpc_snap
983        self.raw_packet_cache_fields = rpcf_snap
984
985        return built_packet
986    __str__ = __bytes__
987
988    def show2(self):
989        """
990        Rebuild the TLS packet with the same context, and then .show() it.
991        We need self.__class__ to call the subclass in a dynamic way.
992
993        Howether we do not want the tls_session.{r,w}cs.seq_num to be updated.
994        We have to bring back the init states (it's possible the cipher context
995        has been updated because of parsing) but also to keep the current state
996        and restore it afterwards (the raw() call may also update the states).
997        """
998        s = self.tls_session
999        rcs_snap = s.rcs.snapshot()
1000        wcs_snap = s.wcs.snapshot()
1001
1002        s.rcs = self.rcs_snap_init
1003
1004        built_packet = raw(self)
1005        s.frozen = True
1006        self.__class__(built_packet, tls_session=s).show()
1007        s.frozen = False
1008
1009        s.rcs = rcs_snap
1010        s.wcs = wcs_snap
1011
1012    def mysummary(self):
1013        return "TLS %s / %s" % (repr(self.tls_session),
1014                                getattr(self, "_name", self.name))
1015
1016    @classmethod
1017    def tcp_reassemble(cls, data, metadata):
1018        # Used with TLSSession
1019        from scapy.layers.tls.record import TLS
1020        from scapy.layers.tls.record_tls13 import TLS13
1021        if cls in (TLS, TLS13):
1022            length = struct.unpack("!H", data[3:5])[0] + 5
1023            if len(data) == length:
1024                return cls(data)
1025            elif len(data) > length:
1026                pkt = cls(data)
1027                if hasattr(pkt.payload, "tcp_reassemble"):
1028                    if pkt.payload.tcp_reassemble(data[length:], metadata):
1029                        return pkt
1030                else:
1031                    return pkt
1032        else:
1033            return cls(data)
1034
1035
1036###############################################################################
1037#   Multiple TLS sessions                                                     #
1038###############################################################################
1039
1040class _tls_sessions(object):
1041    def __init__(self):
1042        self.sessions = {}
1043        self.server_rsa_key = None
1044
1045    def add(self, session):
1046        s = self.find(session)
1047        if s:
1048            log_runtime.info("TLS: previous session shall not be overwritten")
1049            return
1050
1051        h = session.hash()
1052        if h in self.sessions:
1053            self.sessions[h].append(session)
1054        else:
1055            self.sessions[h] = [session]
1056
1057    def rem(self, session):
1058        s = self.find(session)
1059        if s:
1060            log_runtime.info("TLS: previous session shall not be overwritten")
1061            return
1062
1063        h = session.hash()
1064        self.sessions[h].remove(session)
1065
1066    def find(self, session):
1067        try:
1068            h = session.hash()
1069        except Exception:
1070            return None
1071        if h in self.sessions:
1072            for k in self.sessions[h]:
1073                if k.eq(session):
1074                    if conf.tls_verbose:
1075                        log_runtime.info("TLS: found session matching %s", k)
1076                    return k
1077        if conf.tls_verbose:
1078            log_runtime.info("TLS: did not find session matching %s", session)
1079        return None
1080
1081    def __repr__(self):
1082        res = [("First endpoint", "Second endpoint", "Session ID")]
1083        for li in six.itervalues(self.sessions):
1084            for s in li:
1085                src = "%s[%d]" % (s.ipsrc, s.sport)
1086                dst = "%s[%d]" % (s.ipdst, s.dport)
1087                sid = repr(s.sid)
1088                if len(sid) > 12:
1089                    sid = sid[:11] + "..."
1090                res.append((src, dst, sid))
1091        colwidth = (max(len(y) for y in x) for x in zip(*res))
1092        fmt = "  ".join(map(lambda x: "%%-%ds" % x, colwidth))
1093        return "\n".join(map(lambda x: fmt % x, res))
1094
1095
1096class TLSSession(DefaultSession):
1097    def __init__(self, *args, **kwargs):
1098        server_rsa_key = kwargs.pop("server_rsa_key", None)
1099        super(TLSSession, self).__init__(*args, **kwargs)
1100        self._old_conf_status = conf.tls_session_enable
1101        conf.tls_session_enable = True
1102        if server_rsa_key:
1103            conf.tls_sessions.server_rsa_key = server_rsa_key
1104
1105    def toPacketList(self):
1106        conf.tls_session_enable = self._old_conf_status
1107        return super(TLSSession, self).toPacketList()
1108
1109
1110conf.tls_sessions = _tls_sessions()
1111conf.tls_session_enable = False
1112conf.tls_verbose = False
1113