1# -*- coding: utf-8 -*-
2# Copyright: (c) 2019, Jordan Borean (@jborean93) <jborean93@gmail.com>
3# MIT License (see LICENSE or https://opensource.org/licenses/MIT)
4
5import binascii
6import hashlib
7import hmac
8import logging
9import os
10import struct
11import time
12import threading
13
14from collections import (
15    OrderedDict,
16)
17
18from cryptography.hazmat.backends import (
19    default_backend,
20)
21
22from cryptography.hazmat.primitives import (
23    cmac,
24)
25
26from cryptography.hazmat.primitives.ciphers import (
27    aead,
28    algorithms,
29)
30
31from datetime import (
32    datetime,
33)
34
35from threading import (
36    Lock,
37)
38
39from smbprotocol import (
40    Dialects,
41    MAX_PAYLOAD_SIZE,
42)
43
44from smbprotocol._text import (
45    to_text,
46)
47
48from smbprotocol.exceptions import (
49    SMBConnectionClosed,
50    SMB2SymbolicLinkErrorResponse,
51    SMBException,
52    SMBResponseException,
53)
54
55from smbprotocol.header import (
56    Commands,
57    NtStatus,
58    Smb2Flags,
59    SMB2HeaderAsync,
60    SMB2HeaderRequest,
61    SMB2HeaderResponse,
62)
63
64from smbprotocol.open import (
65    Open,
66)
67
68from smbprotocol.structure import (
69    BytesField,
70    DateTimeField,
71    EnumField,
72    FlagField,
73    IntField,
74    ListField,
75    TextField,
76    Structure,
77    StructureField,
78    UuidField,
79)
80
81from smbprotocol.transport import (
82    Tcp,
83)
84
85try:
86    from queue import Queue, Empty
87except ImportError:  # pragma: no cover
88    from Queue import Queue, Empty
89
90log = logging.getLogger(__name__)
91
92
93class SecurityMode(object):
94    """
95    [MS-SMB2] v53.0 2017-09-15
96
97    2.2.3 SMB2 NEGOTIATE Request SecurityMode
98    Indicates whether SMB signing is enabled or required by the client.
99    """
100    SMB2_NEGOTIATE_SIGNING_ENABLED = 0x0001
101    SMB2_NEGOTIATE_SIGNING_REQUIRED = 0x0002
102
103
104class Capabilities(object):
105    """
106    [MS-SMB2] v53.0 2017-09-15
107
108    2.2.3 SMB2 NEGOTIATE Request Capabilities
109    Used in SMB3.x and above, used to specify the capabilities supported.
110    """
111    SMB2_GLOBAL_CAP_DFS = 0x00000001
112    SMB2_GLOBAL_CAP_LEASING = 0x00000002
113    SMB2_GLOBAL_CAP_LARGE_MTU = 0x00000004
114    SMB2_GLOBAL_CAP_MULTI_CHANNEL = 0x00000008
115    SMB2_GLOBAL_CAP_PERSISTENT_HANDLES = 0x00000010
116    SMB2_GLOBAL_CAP_DIRECTORY_LEASING = 0x00000020
117    SMB2_GLOBAL_CAP_ENCRYPTION = 0x00000040
118
119
120class NegotiateContextType(object):
121    """
122    [MS-SMB2] v53.0 2017-09-15
123
124    2.2.3.1 SMB2 NEGOTIATE_CONTENT Request ContextType
125    Specifies the type of context in an SMB2 NEGOTIATE_CONTEXT message.
126    """
127    SMB2_PREAUTH_INTEGRITY_CAPABILITIES = 0x0001
128    SMB2_ENCRYPTION_CAPABILITIES = 0x0002
129    SMB2_COMPRESSION_CAPABILITIES = 0x0003
130    SMB2_NETNAME_NEGOTIATE_CONTEXT_ID = 0x0005
131    SMB2_TRANSPORT_CAPABILITIES = 0x0006
132    SMB2_RDMA_TRANSFORM_CAPABILITIES = 0x0007
133    SMB2_SIGNING_CAPABILITIES = 0x0008
134
135
136class HashAlgorithms(object):
137    """
138    [MS-SMB2] v53.0 2017-09-15
139
140    2.2.3.1.1 SMB2_PREAUTH_INTEGRITY_CAPABILITIES
141    16-bit integer IDs that specify the integrity hash algorithm supported
142    """
143    SHA_512 = 0x0001
144
145
146class Ciphers(object):
147    """
148    [MS-SMB2] v53.0 2017-09-15
149
150    2.2.3.1.2 SMB2_ENCRYPTION_CAPABILITIES
151    16-bit integer IDs that specify the supported encryption algorithms.
152    """
153    AES_128_CCM = 0x0001
154    AES_128_GCM = 0x0002
155    AES_256_CCM = 0x0003
156    AES_256_GCM = 0x0004
157
158
159class SigningAlgorithms:
160    """
161    [MS-SMB2] 2.2.3.1.7 SMB2_SIGNING_CAPABILITIES
162
163    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/cb9b5d66-b6be-4d18-aa66-8784a871cc10
164    16-bit integer IDs that specify the supported signing algorithms.
165    """
166    HMAC_SHA256 = 0x0000
167    AES_CMAC = 0x0001
168    AES_GMAC = 0x0002
169
170
171class SMB2NegotiateRequest(Structure):
172    """
173    [MS-SMB2] v53.0 2017-09-15
174
175    2.2.3 SMB2 Negotiate Request
176    The SMB2 NEGOTIATE Request packet is used by the client to notify the
177    server what dialects of the SMB2 Protocol the client understands. This is
178    only used if the client explicitly sets the Dialect to use to a version
179    less than 3.1.1. Dialect 3.1.1 added support for negotiate_context and
180    SMB3NegotiateRequest should be used to support that.
181    """
182    COMMAND = Commands.SMB2_NEGOTIATE
183
184    def __init__(self):
185        self.fields = OrderedDict([
186            ('structure_size', IntField(
187                size=2,
188                default=36,
189            )),
190            ('dialect_count', IntField(
191                size=2,
192                default=lambda s: len(s['dialects'].get_value()),
193            )),
194            ('security_mode', FlagField(
195                size=2,
196                flag_type=SecurityMode
197            )),
198            ('reserved', IntField(size=2)),
199            ('capabilities', FlagField(
200                size=4,
201                flag_type=Capabilities,
202            )),
203            ('client_guid', UuidField()),
204            ('client_start_time', IntField(size=8)),
205            ('dialects', ListField(
206                size=lambda s: s['dialect_count'].get_value() * 2,
207                list_count=lambda s: s['dialect_count'].get_value(),
208                list_type=EnumField(size=2, enum_type=Dialects),
209            )),
210        ])
211
212        super(SMB2NegotiateRequest, self).__init__()
213
214
215class SMB3NegotiateRequest(Structure):
216    """
217    [MS-SMB2] v53.0 2017-09-15
218
219    2.2.3 SMB2 Negotiate Request
220    Like SMB2NegotiateRequest but with support for setting a list of
221    Negotiate Context values. This is used by default and is for Dialects 3.1.1
222    or greater.
223    """
224    COMMAND = Commands.SMB2_NEGOTIATE
225
226    def __init__(self):
227        self.fields = OrderedDict([
228            ('structure_size', IntField(
229                size=2,
230                default=36,
231            )),
232            ('dialect_count', IntField(
233                size=2,
234                default=lambda s: len(s['dialects'].get_value()),
235            )),
236            ('security_mode', FlagField(
237                size=2,
238                flag_type=SecurityMode,
239            )),
240            ('reserved', IntField(size=2)),
241            ('capabilities', FlagField(
242                size=4,
243                flag_type=Capabilities,
244            )),
245            ('client_guid', UuidField()),
246            ('negotiate_context_offset', IntField(
247                size=4,
248                default=lambda s: self._negotiate_context_offset_value(s),
249            )),
250            ('negotiate_context_count', IntField(
251                size=2,
252                default=lambda s: len(s['negotiate_context_list'].get_value()),
253            )),
254            ('reserved2', IntField(size=2)),
255            ('dialects', ListField(
256                size=lambda s: s['dialect_count'].get_value() * 2,
257                list_count=lambda s: s['dialect_count'].get_value(),
258                list_type=EnumField(size=2, enum_type=Dialects),
259            )),
260            ('padding', BytesField(
261                size=lambda s: self._padding_size(s),
262                default=lambda s: b"\x00" * self._padding_size(s),
263            )),
264            ('negotiate_context_list', ListField(
265                list_count=lambda s: s['negotiate_context_count'].get_value(),
266                unpack_func=lambda s, d: self._negotiate_context_list(s, d),
267            )),
268        ])
269        super(SMB3NegotiateRequest, self).__init__()
270
271    def _negotiate_context_offset_value(self, structure):
272        # The offset from the beginning of the SMB2 header to the first, 8-byte
273        # aligned, negotiate context
274        header_size = 64
275        negotiate_size = structure['structure_size'].get_value()
276        dialect_size = len(structure['dialects'])
277        padding_size = self._padding_size(structure)
278        return header_size + negotiate_size + dialect_size + padding_size
279
280    def _padding_size(self, structure):
281        # Padding between the end of the buffer value and the first Negotiate
282        # context value so that the first value is 8-byte aligned. Padding is
283        # 4 is there are no dialects specified
284        mod = (structure['dialect_count'].get_value() * 2) % 8
285        return 0 if mod == 0 else mod
286
287    def _negotiate_context_list(self, structure, data):
288        context_count = structure['negotiate_context_count'].get_value()
289        context_list = []
290        for idx in range(0, context_count):
291            field, data = self._parse_negotiate_context_entry(data)
292            context_list.append(field)
293
294        return context_list
295
296    def _parse_negotiate_context_entry(self, data):
297        data_length = struct.unpack("<H", data[2:4])[0]
298        negotiate_context = SMB2NegotiateContextRequest()
299        negotiate_context.unpack(data[:data_length + 8])
300        padded_size = data_length % 8
301        if padded_size != 0:
302            padded_size = 8 - padded_size
303
304        return negotiate_context, data[8 + data_length + padded_size:]
305
306
307class SMB2NegotiateContextRequest(Structure):
308    """
309    [MS-SMB2] v53.0 2017-09-15
310
311    2.2.3.1 SMB2 NEGOTIATE_CONTEXT Request Values
312    The SMB2_NEGOTIATE_CONTEXT structure is used by the SMB2 NEGOTIATE Request
313    and the SMB2 NEGOTIATE Response to encode additional properties.
314    """
315    COMMAND = Commands.SMB2_NEGOTIATE
316
317    def __init__(self):
318        self.fields = OrderedDict([
319            ('context_type', EnumField(
320                size=2,
321                enum_type=NegotiateContextType,
322            )),
323            ('data_length', IntField(
324                size=2,
325                default=lambda s: len(s['data'].get_value()),
326            )),
327            ('reserved', IntField(size=4)),
328            ('data', StructureField(
329                size=lambda s: s['data_length'].get_value(),
330                structure_type=lambda s: self._data_structure_type(s)
331            )),
332            # not actually a field but each list entry must start at the 8 byte
333            # alignment
334            ('padding', BytesField(
335                size=lambda s: self._padding_size(s),
336                default=lambda s: b"\x00" * self._padding_size(s),
337            ))
338        ])
339        super(SMB2NegotiateContextRequest, self).__init__()
340
341    def _data_structure_type(self, structure):
342        con_type = structure['context_type'].get_value()
343        if con_type == NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
344            return SMB2PreauthIntegrityCapabilities
345        elif con_type == NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
346            return SMB2EncryptionCapabilities
347        elif con_type == NegotiateContextType.SMB2_NETNAME_NEGOTIATE_CONTEXT_ID:
348            return SMB2NetnameNegotiateContextId
349        elif con_type == NegotiateContextType.SMB2_SIGNING_CAPABILITIES:
350            return SMB2SigningCapabilities
351
352    def _padding_size(self, structure):
353        data_size = len(structure['data'])
354        return 8 - data_size if data_size <= 8 else 8 - (data_size % 8)
355
356
357class SMB2PreauthIntegrityCapabilities(Structure):
358    """
359    [MS-SMB2] v53.0 2017-09-15
360
361    2.2.3.1.1 SMB2_PREAUTH_INTEGRITY_CAPABILITIES
362    The SMB2_PREAUTH_INTEGRITY_CAPABILITIES context is specified in an SMB2
363    NEGOTIATE request by the client to indicate which preauthentication
364    integrity hash algorithms it supports and to optionally supply a
365    preauthentication integrity hash salt value.
366    """
367
368    def __init__(self):
369        self.fields = OrderedDict([
370            ('hash_algorithm_count', IntField(
371                size=2,
372                default=lambda s: len(s['hash_algorithms'].get_value()),
373            )),
374            ('salt_length', IntField(
375                size=2,
376                default=lambda s: len(s['salt']),
377            )),
378            ('hash_algorithms', ListField(
379                size=lambda s: s['hash_algorithm_count'].get_value() * 2,
380                list_count=lambda s: s['hash_algorithm_count'].get_value(),
381                list_type=EnumField(size=2, enum_type=HashAlgorithms),
382            )),
383            ('salt', BytesField(
384                size=lambda s: s['salt_length'].get_value(),
385            )),
386        ])
387        super(SMB2PreauthIntegrityCapabilities, self).__init__()
388
389
390class SMB2EncryptionCapabilities(Structure):
391    """
392    [MS-SMB2] v53.0 2017-09-15
393
394    2.2.3.1.2 SMB2_ENCRYPTION_CAPABILITIES
395    The SMB2_ENCRYPTION_CAPABILITIES context is specified in an SMB2 NEGOTIATE
396    request by the client to indicate which encryption algorithms the client
397    supports.
398    """
399
400    def __init__(self):
401        self.fields = OrderedDict([
402            ('cipher_count', IntField(
403                size=2,
404                default=lambda s: len(s['ciphers'].get_value()),
405            )),
406            ('ciphers', ListField(
407                size=lambda s: s['cipher_count'].get_value() * 2,
408                list_count=lambda s: s['cipher_count'].get_value(),
409                list_type=EnumField(size=2, enum_type=Ciphers),
410            )),
411        ])
412        super(SMB2EncryptionCapabilities, self).__init__()
413
414
415class SMB2NetnameNegotiateContextId(Structure):
416    """
417    [MS-SMB2] 2.2.3.1.4 SMB2_NETNAME_NEGOTIATE_CONTEXT_ID
418
419    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/ca6726bd-b9cf-43d9-b0bc-d127d3c993b3
420
421    The SMB2_NETNAME_NEGOTIATE_CONTEXT_ID context is specified in an SMB2
422    NEGOTIATE request to indicate the server name the client connects to.
423    """
424
425    def __init__(self):
426        self.fields = OrderedDict([
427            ('net_name', TextField()),
428        ])
429        super().__init__()
430
431
432class SMB2SigningCapabilities(Structure):
433    """
434    [MS-SMB2] 2.2.3.1.7 SMB2_SIGNING_CAPABILITIES
435
436    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/cb9b5d66-b6be-4d18-aa66-8784a871cc10
437
438    The SMB2_SIGNING_CAPABILITIES context is specified in an SMB2 NEGOTIATE
439    request by the client to indicate which signing algorithms the client supports.
440    """
441
442    def __init__(self):
443        self.fields = OrderedDict([
444            ('signing_algorithm_count', IntField(
445                size=2,
446                default=lambda s: len(s['signing_algorithms'].get_value()),
447            )),
448            ('signing_algorithms', ListField(
449                size=lambda s: s['signing_algorithm_count'].get_value() * 2,
450                list_count=lambda s: s['signing_algorithm_count'].get_value(),
451                list_type=EnumField(size=2, enum_type=SigningAlgorithms),
452            )),
453        ])
454        super().__init__()
455
456
457class SMB2NegotiateResponse(Structure):
458    """
459    [MS-SMB2] v53.0 2017-09-15
460
461    2.2.4 SMB2 NEGOTIATE Response
462    The SMB2 NEGOTIATE Response packet is sent by the server to notify the
463    client of the preferred common dialect.
464    """
465    COMMAND = Commands.SMB2_NEGOTIATE
466
467    def __init__(self):
468        self.fields = OrderedDict([
469            ('structure_size', IntField(
470                size=2,
471                default=65,
472            )),
473            ('security_mode', FlagField(
474                size=2,
475                flag_type=SecurityMode,
476            )),
477            ('dialect_revision', EnumField(
478                size=2,
479                enum_type=Dialects,
480            )),
481            ('negotiate_context_count', IntField(
482                size=2,
483                default=lambda s: self._negotiate_context_count_value(s),
484            )),
485            ('server_guid', UuidField()),
486            ('capabilities', FlagField(
487                size=4,
488                flag_type=Capabilities
489            )),
490            ('max_transact_size', IntField(size=4)),
491            ('max_read_size', IntField(size=4)),
492            ('max_write_size', IntField(size=4)),
493            ('system_time', DateTimeField()),
494            ('server_start_time', DateTimeField()),
495            ('security_buffer_offset', IntField(
496                size=2,
497                default=128,  # (header size 64) + (structure size 64)
498            )),
499            ('security_buffer_length', IntField(
500                size=2,
501                default=lambda s: len(s['buffer'].get_value()),
502            )),
503            ('negotiate_context_offset', IntField(
504                size=4,
505                default=lambda s: self._negotiate_context_offset_value(s),
506            )),
507            ('buffer', BytesField(
508                size=lambda s: s['security_buffer_length'].get_value(),
509            )),
510            ('padding', BytesField(
511                size=lambda s: self._padding_size(s),
512                default=lambda s: b"\x00" * self._padding_size(s),
513            )),
514            ('negotiate_context_list', ListField(
515                list_count=lambda s: s['negotiate_context_count'].get_value(),
516                unpack_func=lambda s, d:
517                self._negotiate_context_list(s, d),
518            )),
519        ])
520        super(SMB2NegotiateResponse, self).__init__()
521
522    def _negotiate_context_count_value(self, structure):
523        # If the dialect_revision is SMBv3.1.1, this field specifies the
524        # number of negotiate contexts in negotiate_context_list; otherwise
525        # this field must not be used and must be reserved (0).
526        if structure['dialect_revision'].get_value() == Dialects.SMB_3_1_1:
527            return len(structure['negotiate_context_list'].get_value())
528        else:
529            return None
530
531    def _negotiate_context_offset_value(self, structure):
532        # If the dialect_revision is SMBv3.1.1, this field specifies the offset
533        # from the beginning of the SMB2 header to the first 8-byte
534        # aligned negotiate context entry in negotiate_context_list; otherwise
535        # this field must not be used and must be reserved (0).
536        if structure['dialect_revision'].get_value() == Dialects.SMB_3_1_1:
537            buffer_offset = structure['security_buffer_offset'].get_value()
538            buffer_size = structure['security_buffer_length'].get_value()
539            padding_size = self._padding_size(structure)
540            return buffer_offset + buffer_size + padding_size
541        else:
542            return None
543
544    def _padding_size(self, structure):
545        # Padding between the end of the buffer value and the first Negotiate
546        # context value so that the first value is 8-byte aligned. Padding is
547        # not required if there are not negotiate contexts
548        if structure['negotiate_context_count'].get_value() == 0:
549            return 0
550
551        mod = structure['security_buffer_length'].get_value() % 8
552        return 0 if mod == 0 else 8 - mod
553
554    def _negotiate_context_list(self, structure, data):
555        context_count = structure['negotiate_context_count'].get_value()
556        context_list = []
557        for idx in range(0, context_count):
558            field, data = self._parse_negotiate_context_entry(data)
559            context_list.append(field)
560
561        return context_list
562
563    def _parse_negotiate_context_entry(self, data):
564        data_length = struct.unpack("<H", data[2:4])[0]
565        negotiate_context = SMB2NegotiateContextRequest()
566        negotiate_context.unpack(data[:data_length + 8])
567        padded_size = data_length % 8
568        if padded_size != 0:
569            padded_size = 8 - padded_size
570
571        return negotiate_context, data[8 + data_length + padded_size:]
572
573
574class SMB2Echo(Structure):
575    """
576    [MS-SMB2] v53.0 2017-09-15
577
578    2.2.28 SMB2 Echo Request/Response
579    Request and response for an SMB ECHO message.
580    """
581    COMMAND = Commands.SMB2_ECHO
582
583    def __init__(self):
584        self.fields = OrderedDict([
585            ('structure_size', IntField(
586                size=2,
587                default=4
588            )),
589            ('reserved', IntField(size=2))
590        ])
591        super(SMB2Echo, self).__init__()
592
593
594class SMB2CancelRequest(Structure):
595    """
596    [MS-SMB2] 2.2.30 - SMB2 CANCEL Request
597    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/91913fc6-4ec9-4a83-961b-370070067e63
598
599    The SMB2 CANCEL Request packet is sent by the client to cancel a previously sent message on the same SMB2 transport
600    connection.
601    """
602    COMMAND = Commands.SMB2_CANCEL
603
604    def __init__(self):
605        self.fields = OrderedDict([
606            ('structure_size', IntField(
607                size=2,
608                default=4,
609            )),
610            ('reserved', IntField(size=2)),
611        ])
612        super(SMB2CancelRequest, self).__init__()
613
614
615class SMB2TransformHeader(Structure):
616    """
617    [MS-SMB2] v53.0 2017-09-15
618
619    2.2.41 SMB2 TRANSFORM_HEADER
620    The SMB2 Transform Header is used by the client or server when sending
621    encrypted message. This is only valid for the SMB.x dialect family.
622    """
623
624    def __init__(self):
625        self.fields = OrderedDict([
626            ('protocol_id', BytesField(
627                size=4,
628                default=b"\xfdSMB"
629            )),
630            ('signature', BytesField(
631                size=16,
632                default=b"\x00" * 16
633            )),
634            ('nonce', BytesField(size=16)),
635            ('original_message_size', IntField(size=4)),
636            ('reserved', IntField(size=2, default=0)),
637            ('flags', IntField(
638                size=2,
639                default=1
640            )),
641            ('session_id', IntField(size=8)),
642            ('data', BytesField())  # not in spec
643        ])
644        super(SMB2TransformHeader, self).__init__()
645
646
647class Connection(object):
648
649    def __init__(self, guid, server_name, port=445, require_signing=True):
650        """
651        [MS-SMB2] v53.0 2017-09-15
652
653        3.2.1.2 Per SMB2 Transport Connection
654        Used as the transport interface for a server. Some values have been
655        omitted as they can be retrieved by the Server object stored in
656        self.server
657
658        :param guid: A unique guid that represents the client
659        :param server_name: The server to start the connection
660        :param port: The port to use for the transport, default is 445
661        :param require_signing: Whether signing is required on SMB messages
662            sent over this connection
663        """
664        log.info("Initialising connection, guid: %s, require_signing: %s, "
665                 "server_name: %s, port: %d"
666                 % (guid, require_signing, server_name, port))
667        self.server_name = server_name
668        self.port = port
669        self.transport = None  # Instanciated in .connect()
670
671        # Table of Session entries, the order is important for smbclient.
672        self.session_table = OrderedDict()
673
674        # Table of sessions that have not completed authentication, indexed by
675        # session_id
676        self.preauth_session_table = {}
677
678        # Table of Requests that have yet to be picked up by the application,
679        # it MAY contain a response from the server as well
680        self.outstanding_requests = dict()
681
682        # Table of available sequence numbers
683        self.sequence_window = dict(
684            low=0,
685            high=1
686        )
687        self.sequence_lock = Lock()
688
689        # Byte array containing the negotiate token and remembered for
690        # authentication
691        self.gss_negotiate_token = None
692
693        self.server_guid = None
694        self.max_transact_size = None
695        self.max_read_size = None
696        self.max_write_size = None
697        self.require_signing = require_signing
698
699        # SMB 2.1+
700        self.dialect = None
701        self.supports_file_leasing = None
702        # just go with False as a default for Dialect 2.0.2
703        self.supports_multi_credit = False
704        self.client_guid = guid
705
706        # SMB 3.x+
707        self.salt = None
708        self.supports_directory_leasing = None
709        self.supports_multi_channel = None
710        self.supports_persistent_handles = None
711        self.supports_encryption = None
712
713        # used for SMB 3.x for secure negotiate verification on tree connect
714        self.negotiated_dialects = []
715        self.client_capabilities = Capabilities.SMB2_GLOBAL_CAP_LARGE_MTU | \
716            Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION | Capabilities.SMB2_GLOBAL_CAP_DFS
717
718        self.client_security_mode = \
719            SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED if \
720            require_signing else SecurityMode.SMB2_NEGOTIATE_SIGNING_ENABLED
721        self.server_security_mode = None
722        self.server_capabilities = None
723
724        # SMB 3.1.1+
725        # The hashing algorithm object that was negotiated
726        self.preauth_integrity_hash_id = None
727
728        # Preauth integrity hash value computed for the SMB2 NEGOTIATE request
729        # contains the messages used to compute the hash
730        self.preauth_integrity_hash_value = []
731
732        # The cipher object that was negotiated
733        self.cipher_id = None
734
735        # The signing algorithm that was negotiated
736        self.signing_algorithm_id = None
737
738        # Keep track of the message processing thread's potential traceback that it may raise.
739        self._t_exc = None
740
741    def connect(self, dialect=None, timeout=60, preferred_encryption_algos=None, preferred_signing_algos=None):
742        """
743        Will connect to the target server and negotiate the capabilities
744        with the client. Once setup, the client MUST call the disconnect()
745        function to close the listener thread. This function will populate
746        various connection properties that denote the capabilities of the
747        server.
748
749        If no preferred encryption or signing algorithms are specified then
750        all algorithms are offered during negotiation. Older dialects may not
751        be offered if a custom encryption or signing algorithm list is
752        specified without the algorithm required by that dialect.
753
754        By default the following encryption algorithms are used:
755
756            AES_128_GCM
757            AES_128_CCM (required for SMB 3.0.x)
758            AES_256_GCM
759            AES_256_CCM
760
761        By default the following signing algorithms are used:
762
763            AES_GMAC
764            AES_CMAC (required for SMB 3.0.x)
765            HMAC_SHA256 (required for SMB 2.x)
766
767        :param dialect: If specified, forces the dialect that is negotiated
768            with the server, if not set, then the newest dialect supported by
769            the server is used up to SMB 3.1.1
770        :param timeout: The timeout in seconds to wait for the initial
771            negotiation process to complete
772        :param preferred_encryption_algos: A list of encryption algorithm ids
773            in priority order from highest to lowest. See :class:`Ciphers` for
774            a list of known identifiers.
775        :param preferred_signing_algos: A list of signing algorithm ids in
776            priority order from highest to lowest.
777            See :class:`SigningAlgorithms` for a list of known identifiers.
778        """
779        log.info("Setting up transport connection")
780        self.transport = Tcp(self.server_name, self.port, timeout)
781        self.transport.connect()
782        t_worker = threading.Thread(target=self._process_message_thread,
783                                    name="msg_worker-%s:%s" % (self.server_name, self.port))
784        t_worker.daemon = True
785        t_worker.start()
786
787        log.info("Starting negotiation with SMB server")
788        enc_algos = preferred_encryption_algos or [
789            Ciphers.AES_128_GCM,
790            Ciphers.AES_128_CCM,
791            Ciphers.AES_256_GCM,
792            Ciphers.AES_256_CCM,
793        ]
794        sign_algos = preferred_signing_algos or [
795            SigningAlgorithms.AES_GMAC,
796            SigningAlgorithms.AES_CMAC,
797            SigningAlgorithms.HMAC_SHA256,
798        ]
799        smb_response = self._send_smb2_negotiate(dialect, timeout, enc_algos, sign_algos)
800        log.info("Negotiated dialect: %s"
801                 % str(smb_response['dialect_revision']))
802        self.dialect = smb_response['dialect_revision'].get_value()
803        self.max_transact_size = smb_response['max_transact_size'].get_value()
804        self.max_read_size = smb_response['max_read_size'].get_value()
805        self.max_write_size = smb_response['max_write_size'].get_value()
806        self.server_guid = smb_response['server_guid'].get_value()
807        self.gss_negotiate_token = smb_response['buffer'].get_value()
808
809        if not self.require_signing and \
810                smb_response['security_mode'].has_flag(
811                    SecurityMode.SMB2_NEGOTIATE_SIGNING_REQUIRED):
812            self.require_signing = True
813        log.info("Connection require signing: %s" % self.require_signing)
814
815        capabilities = smb_response['capabilities']
816        self.server_capabilities = capabilities
817        self.server_security_mode = smb_response['security_mode'].get_value()
818
819        # SMB 2.1
820        if self.dialect >= Dialects.SMB_2_1_0:
821            self.supports_file_leasing = \
822                capabilities.has_flag(Capabilities.SMB2_GLOBAL_CAP_LEASING)
823            self.supports_multi_credit = \
824                capabilities.has_flag(Capabilities.SMB2_GLOBAL_CAP_LARGE_MTU)
825
826        # SMB 3.x
827        if self.dialect >= Dialects.SMB_3_0_0:
828            self.supports_directory_leasing = capabilities.has_flag(
829                Capabilities.SMB2_GLOBAL_CAP_DIRECTORY_LEASING)
830            self.supports_multi_channel = capabilities.has_flag(
831                Capabilities.SMB2_GLOBAL_CAP_MULTI_CHANNEL)
832
833            # TODO: SMB2_GLOBAL_CAP_PERSISTENT_HANDLES
834            self.supports_persistent_handles = False
835            self.supports_encryption = capabilities.has_flag(
836                Capabilities.SMB2_GLOBAL_CAP_ENCRYPTION) \
837                and self.dialect < Dialects.SMB_3_1_1
838
839            # TODO: Check/add server to server_list in Client Page 203
840
841        # SMB 3.1
842        if self.dialect >= Dialects.SMB_3_1_1:
843            for context in smb_response['negotiate_context_list']:
844                context_type = context["context_type"].get_value()
845
846                if context_type == NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
847                    self.cipher_id = context['data']['ciphers'][0]
848                    self.supports_encryption = self.cipher_id != 0
849
850                elif context_type == NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
851                    self.preauth_integrity_hash_id = context['data']['hash_algorithms'][0]
852
853                elif context_type == NegotiateContextType.SMB2_SIGNING_CAPABILITIES:
854                    self.signing_algorithm_id = context['data']['signing_algorithms'][0]
855
856    def disconnect(self, close=True):
857        """
858        Closes the connection as well as logs off any of the
859        Disconnects the TCP connection and shuts down the socket listener
860        running in a thread.
861
862        :param close: Will close all sessions in the connection as well as the
863            tree connections of each session.
864        """
865        # We cannot close the session or tree if the socket has been closed.
866        if close and self.transport.connected:
867            for session in list(self.session_table.values()):
868                session.disconnect(True)
869
870        log.info("Disconnecting transport connection")
871        self.transport.close()
872
873    def send(self, message, sid=None, tid=None, credit_request=None, message_id=None, async_id=None,
874             force_signature=False):
875        """
876        Will send a message to the server that is passed in. The final unencrypted header is returned to the function
877        that called this.
878
879        :param message: An SMB message structure to send.
880        :param sid: A session_id that the message is sent for.
881        :param tid: A tree_id object that the message is sent for.
882        :param credit_request: Specifies extra credits to be requested with the SMB header.
883        :param message_id: The message_id for the header, only useful for a cancel request.
884        :param async_id: The async_id for the header, only useful for a cancel request.
885        :param force_signature: Force signing the SMB request even if not requested by the client/server.
886        :return: Request of the message that was sent.
887        """
888        return self._send([message], session_id=sid, tree_id=tid, message_id=message_id, credit_request=credit_request,
889                          async_id=async_id, force_signature=force_signature)[0]
890
891    def send_compound(self, messages, sid, tid, related=False):
892        """
893        Sends multiple messages within 1 TCP request, will fail if the size of the total length exceeds the maximum of
894        the transport max.
895
896        :param messages: A list of messages to send to the server.
897        :param sid: The session_id that the request is sent for.
898        :param tid: A tree_id object that the message is sent for.
899        :param related: Whether each message is related to each other, sets the Session, Tree, and File Id to the same
900            value as the first message.
901        :return: List<Request> for each request that was sent, each entry in the list is in the same order of the
902            message list that was passed in.
903        """
904        return self._send(messages, session_id=sid, tree_id=tid, related=related)
905
906    def receive(self, request, wait=True, timeout=None, resolve_symlinks=True):
907        """
908        Polls the message buffer of the TCP connection and waits until a valid
909        message is received based on the message_id passed in.
910
911        :param request: The Request object to wait get the response for
912        :param wait: Wait for the final response in the case of a STATUS_PENDING response, the pending response is
913            returned in the case of wait=False
914        :param timeout: Set a timeout used while waiting for the final response from the server.
915        :param resolve_symlinks: Set to automatically resolve symlinks in the path when opening a file or directory.
916        :return: SMB2HeaderResponse of the received message
917        """
918        # Make sure the receiver is still active, if not this raises an exception.
919        self._check_worker_running()
920
921        start_time = time.time()
922        while True:
923            iter_timeout = int(max(timeout - (time.time() - start_time), 1)) if timeout is not None else None
924            if not request.response_event.wait(timeout=iter_timeout):
925                raise SMBException("Connection timeout of %d seconds exceeded while waiting for a message id %s "
926                                   "response from the server" % (timeout, request.message['message_id'].get_value()))
927
928            # Use a lock on the request so that in the case of a pending response we have exclusive lock on the event
929            # flag and can clear it without the future pending response taking it over before we first clear the flag.
930            with request.response_event_lock:
931                self._check_worker_running()  # The worker may have failed while waiting for the response, check again
932
933                response = request.response
934                status = response['status'].get_value()
935                if status == NtStatus.STATUS_PENDING and wait:
936                    # Received a pending message, clear the response_event flag and wait again.
937                    request.response_event.clear()
938                    continue
939                elif status == NtStatus.STATUS_STOPPED_ON_SYMLINK and resolve_symlinks:
940                    # Received when we do an Open on a path that contains a symlink. Need to capture all related
941                    # requests and resend the Open + others with the redirected path. First we need to resolve the
942                    # symlink path. This will fail if the symlink is pointing to a location that is not in the same
943                    # tree/share as the original request.
944
945                    # First wait for the other remaining requests to be processed. Their status will also fail and we
946                    # need to make sure we update the old request with the new one properly.
947                    related_requests = [self.outstanding_requests[i] for i in request.related_ids]
948                    [r.response_event.wait() for r in related_requests]
949
950                    # Now create a new request with the new path the symlink points to.
951                    session = self.session_table[request.session_id]
952                    tree = session.tree_connect_table[request.message['tree_id'].get_value()]
953
954                    old_create = request.get_message_data()
955                    tree_share_name = tree.share_name + u'\\'
956                    original_path = tree_share_name + to_text(old_create['buffer_path'].get_value(),
957                                                              encoding='utf-16-le')
958
959                    exp = SMBResponseException(response)
960                    reparse_buffer = next((e for e in exp.error_details
961                                           if isinstance(e, SMB2SymbolicLinkErrorResponse)))
962                    new_path = reparse_buffer.resolve_path(original_path)[len(tree_share_name):]
963
964                    new_open = Open(tree, new_path)
965                    create_req = new_open.create(
966                        old_create['impersonation_level'].get_value(),
967                        old_create['desired_access'].get_value(),
968                        old_create['file_attributes'].get_value(),
969                        old_create['share_access'].get_value(),
970                        old_create['create_disposition'].get_value(),
971                        old_create['create_options'].get_value(),
972                        create_contexts=old_create['buffer_contexts'].get_value(),
973                        send=False
974                    )[0]
975
976                    # Now add all the related requests (if any) to send as a compound request.
977                    new_msgs = [create_req] + [r.get_message_data() for r in related_requests]
978                    new_requests = self.send_compound(new_msgs, session.session_id, tree.tree_connect_id, related=True)
979
980                    # Verify that the first request was successful before updating the related requests with the new
981                    # info.
982                    error = None
983                    try:
984                        new_response = self.receive(new_requests[0], wait=wait, timeout=timeout, resolve_symlinks=True)
985                    except SMBResponseException as exc:
986                        # We need to make sure we fix up the remaining responses before throwing this.
987                        error = exc
988                    [r.response_event.wait() for r in new_requests]
989
990                    # Update the old requests with the new response information
991                    for i, old_request in enumerate([request] + related_requests):
992                        del self.outstanding_requests[old_request.message['message_id'].get_value()]
993                        old_request.update_request(new_requests[i])
994
995                    if error:
996                        raise error
997
998                    return new_response
999                else:
1000                    # now we have a retrieval request for the response, we can delete
1001                    # the request from the outstanding requests
1002                    message_id = request.message['message_id'].get_value()
1003                    self.outstanding_requests.pop(message_id, None)
1004
1005                    if status not in [NtStatus.STATUS_SUCCESS, NtStatus.STATUS_PENDING]:
1006                        raise SMBResponseException(response)
1007
1008                    break
1009
1010        return response
1011
1012    def echo(self, sid=0, timeout=60, credit_request=1):
1013        """
1014        Sends an SMB2 Echo request to the server. This can be used to request
1015        more credits from the server with the credit_request param.
1016
1017        On a Samba server, the sid can be 0 but for a Windows SMB Server, the
1018        sid of an authenticated session must be passed into this function or
1019        else the socket will close.
1020
1021        :param sid: When talking to a Windows host this must be populated with
1022            a valid session_id from a negotiated session
1023        :param timeout: The timeout in seconds to wait for the Echo Response
1024        :param credit_request: The number of credits to request
1025        :return: the credits that were granted by the server
1026        """
1027        log.info("Sending Echo request with a timeout of %d and credit "
1028                 "request of %d" % (timeout, credit_request))
1029
1030        echo_msg = SMB2Echo()
1031        log.debug(echo_msg)
1032        req = self.send(echo_msg, sid=sid, credit_request=credit_request)
1033
1034        log.info("Receiving Echo response")
1035        response = self.receive(req, timeout=timeout)
1036        log.info("Credits granted from the server echo response: %d"
1037                 % response['credit_response'].get_value())
1038        echo_resp = SMB2Echo()
1039        echo_resp.unpack(response['data'].get_value())
1040        log.debug(echo_resp)
1041
1042        return response['credit_response'].get_value()
1043
1044    def verify_signature(self, header, session_id, force=False):
1045        """
1046        Verifies the SMB2 Header request/response signature.
1047
1048        :param header: The SMB2Header that will have its signature verified against the signing key specified.
1049        :param session_id: The Session Id to denote what session security verifies the message.
1050        :param force: Force verification of the header even if it does not match the criteria required in normal
1051            scenarios.
1052        """
1053        message_id = header['message_id'].get_value()
1054        flags = header['flags']
1055        status = header['status'].get_value()
1056        command = header['command'].get_value()
1057
1058        if not force and (message_id == 0xFFFFFFFFFFFFFFFF or
1059                          not flags.has_flag(Smb2Flags.SMB2_FLAGS_SIGNED) or
1060                          status == NtStatus.STATUS_PENDING or
1061                          command == Commands.SMB2_SESSION_SETUP):
1062            return
1063
1064        session = self.session_table.get(session_id, None)
1065        if session is None:
1066            raise SMBException("Failed to find session %s for message verification" % session_id)
1067
1068        expected = self._generate_signature(header.pack(), session.signing_key, message_id,
1069                                            flags.has_flag(Smb2Flags.SMB2_FLAGS_SERVER_TO_REDIR), command)
1070        actual = header['signature'].get_value()
1071        if actual != expected:
1072            raise SMBException("Server message signature could not be verified: %s != %s"
1073                               % (binascii.hexlify(actual).decode(), binascii.hexlify(expected).decode()))
1074
1075    def _check_worker_running(self):
1076        """ Checks that the message worker thread is still running and raises it's exception if it has failed. """
1077        if self._t_exc is not None:
1078            self.disconnect(False)
1079            raise self._t_exc
1080
1081        elif not self.transport.connected:
1082            raise SMBConnectionClosed('SMB socket was closed, cannot send or receive any more data')
1083
1084    def _send(self, messages, session_id=None, tree_id=None, message_id=None, credit_request=None, related=False,
1085              async_id=None, force_signature=False):
1086        send_data = b""
1087        requests = []
1088        session = self.session_table.get(session_id, None)
1089        tree = None
1090        if tree_id and session:
1091            if tree_id not in session.tree_connect_table:
1092                raise SMBException("Cannot find Tree with the ID %d in the session tree table" % tree_id)
1093            tree = session.tree_connect_table[tree_id]
1094
1095        total_requests = len(messages)
1096        for i, message in enumerate(messages):
1097            if i == total_requests - 1:
1098                next_command = 0
1099                padding = b""
1100            else:
1101                # each compound message must start at the 8-byte boundary
1102                msg_length = 64 + len(message)
1103                mod = msg_length % 8
1104                padding_length = 8 - mod if mod > 0 else 0
1105                next_command = msg_length + padding_length
1106                padding = b"\x00" * padding_length
1107
1108            # When running with multiple threads we need to ensure that getting the message id and adjusting the
1109            # sequence windows is done in a thread safe manner so we use a lock to ensure only 1 thread accesses the
1110            # sequence window at a time.
1111            with self.sequence_lock:
1112                sequence_window_low = self.sequence_window['low']
1113                sequence_window_high = self.sequence_window['high']
1114                credit_charge = self._calculate_credit_charge(message)
1115                credits_available = sequence_window_high - sequence_window_low
1116                if credit_charge > credits_available:
1117                    raise SMBException("Request requires %d credits but only %d credits are available"
1118                                       % (credit_charge, credits_available))
1119
1120                current_id = message_id or sequence_window_low
1121                if message.COMMAND != Commands.SMB2_CANCEL:
1122                    self.sequence_window['low'] += credit_charge if credit_charge > 0 else 1
1123
1124            if async_id is None:
1125                header = SMB2HeaderRequest()
1126                header['tree_id'] = tree_id or 0
1127            else:
1128                header = SMB2HeaderAsync()
1129                header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_ASYNC_COMMAND)
1130                header['async_id'] = async_id
1131
1132            header['credit_charge'] = credit_charge
1133            header['command'] = message.COMMAND
1134            header['credit_request'] = credit_request if credit_request else credit_charge
1135            header['message_id'] = current_id
1136            header['session_id'] = session_id if session_id and session_id > 0 else 0
1137            header['data'] = message.pack()
1138            header['next_command'] = next_command
1139
1140            if i != 0 and related:
1141                header['session_id'] = b"\xff" * 8
1142                header['tree_id'] = b"\xff" * 4
1143                header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_RELATED_OPERATIONS)
1144
1145            if force_signature or (session and session.signing_required and session.signing_key):
1146                header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_SIGNED)
1147                b_header = header.pack() + padding
1148                signature = self._generate_signature(b_header, session.signing_key, current_id, False, message.COMMAND)
1149
1150                # To save on unpacking and re-packing, manually adjust the signature and update the request object for
1151                # back-referencing.
1152                b_header = b_header[:48] + signature + b_header[64:]
1153                header['signature'] = signature
1154            else:
1155                b_header = header.pack() + padding
1156
1157            send_data += b_header
1158
1159            if message.COMMAND == Commands.SMB2_CANCEL:
1160                request = self.outstanding_requests[header['message_id'].get_value()]
1161            else:
1162                request = Request(header, type(message), self, session_id=session_id)
1163                self.outstanding_requests[header['message_id'].get_value()] = request
1164
1165            # Make sure the preauth integrity values are updated for a negotiate or session setup message.
1166            if message.COMMAND == Commands.SMB2_NEGOTIATE:
1167                self.preauth_integrity_hash_value.append(b_header)
1168
1169            elif message.COMMAND == Commands.SMB2_SESSION_SETUP:
1170                self.preauth_session_table[session_id].preauth_integrity_hash_value.append(b_header)
1171
1172            requests.append(request)
1173
1174        if related:
1175            requests[0].related_ids = [r.message['message_id'].get_value() for r in requests][1:]
1176
1177        if session and session.encrypt_data or tree and tree.encrypt_data:
1178            send_data = self._encrypt(send_data, session)
1179
1180        self._check_worker_running()
1181        self.transport.send(send_data)
1182        return requests
1183
1184    def _process_message_thread(self):
1185        try:
1186            while True:
1187                # Wait for a max of 10 minutes before sending an echo that tells the SMB server the client is still
1188                # available. This stops the server from closing the connection and the associated sessions on a long
1189                # lived connection. A brief test shows Windows kills a connection at ~16 minutes so 10 minutes is a
1190                # safe choice.
1191                # https://github.com/jborean93/smbprotocol/issues/31
1192                try:
1193                    b_msg = self.transport.recv(600)
1194                except TimeoutError as ex:
1195                    # Check if the connection has unanswered keepalive echo requests with the reserved field set.
1196                    # When unanswered keep alive echo exists, the server did not respond withing two times the timeout.
1197                    # We assume that the server connection is dead and close it.
1198                    for r in self.outstanding_requests.values():
1199                        if r.response is None and \
1200                                r.message['command'].get_value() == Commands.SMB2_ECHO and \
1201                                r.message['reserved'].get_value() == 1:
1202                            # connection will be closed in finally block
1203                            raise SMBConnectionClosed('Connection timed out. Server did not respond within timeout.') \
1204                                from ex
1205
1206                    log.debug("Sending SMB2 Echo to keep connection alive")
1207                    for sid in self.session_table.keys():
1208                        req = self.send(SMB2Echo(), sid=sid)
1209                        # Set this reserved field to 1 as we use that internally to check whether the outstanding
1210                        # requests queue should be cleared in this thread or not.
1211                        req.message['reserved'] = 1
1212                    continue
1213
1214                # If recv didn't return any data then the socket is considered to be closed.
1215                if not b_msg:
1216                    return
1217
1218                is_encrypted = b_msg[:4] == b"\xfdSMB"
1219                if is_encrypted:
1220                    msg = SMB2TransformHeader()
1221                    msg.unpack(b_msg)
1222                    b_msg = self._decrypt(msg)
1223
1224                next_command = -1
1225                while next_command != 0:
1226                    next_command = struct.unpack("<L", b_msg[20:24])[0]
1227                    header_length = next_command if next_command != 0 else len(b_msg)
1228                    b_header = b_msg[:header_length]
1229                    b_msg = b_msg[header_length:]
1230
1231                    header = SMB2HeaderResponse()
1232                    header.unpack(b_header)
1233
1234                    message_id = header['message_id'].get_value()
1235                    request = self.outstanding_requests[message_id]
1236
1237                    # Typically you want to get the Session Id from the first message in a compound request but that is
1238                    # unreliable for async responses. Instead get the Session Id from the original request object if
1239                    # the Session Id is 0xFFFFFFFFFFFFFFFF.
1240                    # https://social.msdn.microsoft.com/Forums/en-US/a580f7bc-6746-4876-83db-6ac209b202c4/mssmb2-change-notify-response-sessionid?forum=os_fileservices
1241                    session_id = header['session_id'].get_value()
1242                    if session_id == 0xFFFFFFFFFFFFFFFF:
1243                        session_id = request.session_id
1244
1245                    # No need to waste CPU cycles to verify the signature if we already decrypted the header.
1246                    if not is_encrypted:
1247                        self.verify_signature(header, session_id)
1248
1249                    credit_response = header['credit_response'].get_value()
1250                    if credit_response == 0 and not self.supports_multi_credit:
1251                        # If the dialect does not support credits we still need to adjust our sequence window.
1252                        # Otherwise the credit response may be 0 in the case of compound responses and the last
1253                        # response contains the credits that were granted.
1254                        credit_response += 1
1255
1256                    with self.sequence_lock:
1257                        self.sequence_window['high'] += credit_response
1258
1259                    command = header['command'].get_value()
1260                    status = header['status'].get_value()
1261                    if command == Commands.SMB2_NEGOTIATE:
1262                        self.preauth_integrity_hash_value.append(b_header)
1263
1264                    elif command == Commands.SMB2_SESSION_SETUP and status == NtStatus.STATUS_MORE_PROCESSING_REQUIRED:
1265                        self.preauth_session_table[message_id] = b_header
1266
1267                    with request.response_event_lock:
1268                        if header['flags'].has_flag(Smb2Flags.SMB2_FLAGS_ASYNC_COMMAND):
1269                            request.async_id = b_header[32:40]
1270
1271                        request.response = header
1272                        request.response_event.set()
1273
1274                        # When we send a ping in this thread we want to make sure it doesn't linger in the outstanding
1275                        # request queue.
1276                        if request.message['reserved'].get_value() == 1:
1277                            self.outstanding_requests.pop(message_id, None)
1278        except Exception as exc:
1279            # The exception is raised in _check_worker_running by the main thread when send/receive is called next.
1280            self._t_exc = exc
1281
1282            # While a caller of send/receive could theoretically catch this exception, we consider any failures
1283            # here a fatal errors and the connection should be closed so we exit the worker thread.
1284            self.disconnect(False)
1285
1286        finally:
1287            # Make sure we fire all the request events to ensure the main thread isn't waiting on a receive.
1288            for request in self.outstanding_requests.values():
1289                request.response_event.set()
1290
1291    def _generate_signature(self, b_header, signing_key, message_id, response, command):
1292        b_header = b_header[:48] + (b"\x00" * 16) + b_header[64:]
1293
1294        if self.dialect >= Dialects.SMB_3_1_1 and self.signing_algorithm_id is not None:
1295            sign_id = self.signing_algorithm_id
1296
1297        elif self.dialect >= Dialects.SMB_3_0_0:
1298            sign_id = SigningAlgorithms.AES_CMAC
1299
1300        else:
1301            sign_id = SigningAlgorithms.HMAC_SHA256
1302
1303        if sign_id == SigningAlgorithms.AES_GMAC:
1304            message_info = 0
1305            if response:
1306                message_info |= 1
1307
1308            if command == Commands.SMB2_CANCEL:
1309                message_info |= 2
1310
1311            nonce = b"".join([
1312                message_id.to_bytes(8, byteorder="little"),
1313                message_info.to_bytes(4, byteorder="little"),
1314            ])
1315            signature = aead.AESGCM(signing_key).encrypt(nonce, b"", b_header)
1316
1317        elif sign_id == SigningAlgorithms.AES_CMAC:
1318            c = cmac.CMAC(algorithms.AES(signing_key), backend=default_backend())
1319            c.update(b_header)
1320            signature = c.finalize()
1321
1322        else:
1323            hmac_algo = hmac.new(signing_key, msg=b_header, digestmod=hashlib.sha256)
1324            signature = hmac_algo.digest()[:16]
1325
1326        return signature
1327
1328    def _encrypt(self, b_data, session):
1329        header = SMB2TransformHeader()
1330        header['original_message_size'] = len(b_data)
1331        header['session_id'] = session.session_id
1332
1333        encryption_key = session.encryption_key
1334        if self.dialect >= Dialects.SMB_3_1_1:
1335            cipher_id = self.cipher_id
1336        else:
1337            cipher_id = Ciphers.AES_128_CCM
1338
1339        if cipher_id in [Ciphers.AES_128_GCM, Ciphers.AES_256_GCM]:
1340            cipher = aead.AESGCM
1341            nonce = os.urandom(12)
1342            header['nonce'] = nonce + (b"\x00" * 4)
1343        else:
1344            cipher = aead.AESCCM
1345            nonce = os.urandom(11)
1346            header['nonce'] = nonce + (b"\x00" * 5)
1347
1348        cipher_text = cipher(encryption_key).encrypt(nonce, b_data, header.pack()[20:])
1349        signature = cipher_text[-16:]
1350        enc_message = cipher_text[:-16]
1351
1352        header['signature'] = signature
1353        header['data'] = enc_message
1354
1355        return header
1356
1357    def _decrypt(self, message):
1358        if message['flags'].get_value() != 0x0001:
1359            error_msg = "Expecting flag of 0x0001 but got %s in the SMB Transform Header Response" \
1360                        % format(message['flags'].get_value(), 'x')
1361            raise SMBException(error_msg)
1362
1363        session_id = message['session_id'].get_value()
1364        session = self.session_table.get(session_id, None)
1365        if session is None:
1366            error_msg = "Failed to find valid session %s for message decryption" % session_id
1367            raise SMBException(error_msg)
1368
1369        if self.dialect >= Dialects.SMB_3_1_1:
1370            cipher_id = self.cipher_id
1371        else:
1372            cipher_id = Ciphers.AES_128_CCM
1373
1374        if cipher_id in [Ciphers.AES_128_GCM, Ciphers.AES_256_GCM]:
1375            cipher = aead.AESGCM
1376            nonce_length = 12
1377        else:
1378            cipher = aead.AESCCM
1379            nonce_length = 11
1380
1381        nonce = message['nonce'].get_value()[:nonce_length]
1382        signature = message['signature'].get_value()
1383        enc_message = message['data'].get_value() + signature
1384
1385        c = cipher(session.decryption_key)
1386        dec_message = c.decrypt(nonce, enc_message, message.pack()[20:52])
1387        return dec_message
1388
1389    def _send_smb2_negotiate(self, dialect, timeout, encryption_algorithms, signing_algorithms):
1390        self.salt = os.urandom(32)
1391
1392        if dialect is None:
1393            neg_req = SMB3NegotiateRequest()
1394            negotiated_dialects = [
1395                Dialects.SMB_2_0_2,
1396                Dialects.SMB_2_1_0,
1397                Dialects.SMB_3_0_0,
1398                Dialects.SMB_3_0_2,
1399                Dialects.SMB_3_1_1
1400            ]
1401
1402            if SigningAlgorithms.HMAC_SHA256 not in signing_algorithms:
1403                if Dialects.SMB_2_0_2 in negotiated_dialects:
1404                    negotiated_dialects.remove(Dialects.SMB_2_0_2)
1405                if Dialects.SMB_2_1_0 in negotiated_dialects:
1406                    negotiated_dialects.remove(Dialects.SMB_2_1_0)
1407
1408            if (
1409                SigningAlgorithms.AES_CMAC not in signing_algorithms or
1410                Ciphers.AES_128_CCM not in encryption_algorithms
1411            ):
1412                if Dialects.SMB_3_0_0 in negotiated_dialects:
1413                    negotiated_dialects.remove(Dialects.SMB_3_0_0)
1414                if Dialects.SMB_3_0_2 in negotiated_dialects:
1415                    negotiated_dialects.remove(Dialects.SMB_3_0_2)
1416        else:
1417            if dialect >= Dialects.SMB_3_1_1:
1418                neg_req = SMB3NegotiateRequest()
1419            else:
1420                neg_req = SMB2NegotiateRequest()
1421            negotiated_dialects = [dialect]
1422
1423        highest_dialect = sorted(negotiated_dialects)[-1]
1424        self.negotiated_dialects = neg_req['dialects'] = negotiated_dialects
1425        log.info("Negotiating with SMB2 protocol with highest client dialect "
1426                 "of: %s" % [dialect for dialect, v in vars(Dialects).items()
1427                             if v == highest_dialect][0])
1428
1429        neg_req['security_mode'] = self.client_security_mode
1430
1431        if highest_dialect >= Dialects.SMB_2_1_0:
1432            log.debug("Adding client guid %s to negotiate request"
1433                      % self.client_guid)
1434            neg_req['client_guid'] = self.client_guid
1435
1436        else:
1437            # Must be None, this value is used to verify the negotiation info.
1438            self.client_guid = None
1439
1440        if highest_dialect >= Dialects.SMB_3_0_0:
1441            log.debug("Adding client capabilities %d to negotiate request"
1442                      % self.client_capabilities)
1443            neg_req['capabilities'] = self.client_capabilities
1444
1445        else:
1446            # Must be 0, this value is used to verify the negotiation info.
1447            self.client_capabilities = 0
1448
1449        if highest_dialect >= Dialects.SMB_3_1_1:
1450            int_cap = SMB2NegotiateContextRequest()
1451            int_cap['context_type'] = \
1452                NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES
1453            int_cap['data'] = SMB2PreauthIntegrityCapabilities()
1454            int_cap['data']['hash_algorithms'] = [
1455                HashAlgorithms.SHA_512
1456            ]
1457            int_cap['data']['salt'] = self.salt
1458            log.debug("Adding preauth integrity capabilities of hash SHA512 "
1459                      "and salt %s to negotiate request" % self.salt)
1460
1461            enc_cap = SMB2NegotiateContextRequest()
1462            enc_cap['context_type'] = \
1463                NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES
1464            enc_cap['data'] = SMB2EncryptionCapabilities()
1465            supported_ciphers = encryption_algorithms
1466            enc_cap['data']['ciphers'] = supported_ciphers
1467            log.debug("Adding encryption capabilities of AES128|256 GCM and "
1468                      "AES128|256 CCM to negotiate request")
1469
1470            netname_id = SMB2NegotiateContextRequest()
1471            netname_id['context_type'] = NegotiateContextType.SMB2_NETNAME_NEGOTIATE_CONTEXT_ID
1472            netname_id['data'] = SMB2NetnameNegotiateContextId()
1473            netname_id['data']['net_name'] = self.server_name
1474            log.debug(f"Adding netname context id of {self.server_name} to negotiate request")
1475
1476            signing_cap = SMB2NegotiateContextRequest()
1477            signing_cap['context_type'] = NegotiateContextType.SMB2_SIGNING_CAPABILITIES
1478            signing_cap['data'] = SMB2SigningCapabilities()
1479            signing_cap['data']['signing_algorithms'] = signing_algorithms
1480            log.debug("Adding signing algorithms AES_GMAC, AES_CMAC, and HMAC_SHA256 to negotiate request")
1481
1482            # remove extra padding for last list entry
1483            signing_cap['padding'].size = 0
1484            signing_cap['padding'] = b""
1485
1486            neg_req['negotiate_context_list'] = [
1487                int_cap,
1488                enc_cap,
1489                netname_id,
1490                signing_cap,
1491            ]
1492
1493        log.info("Sending SMB2 Negotiate message")
1494        log.debug(neg_req)
1495        request = self.send(neg_req)
1496
1497        response = self.receive(request, timeout=timeout)
1498        log.info("Receiving SMB2 Negotiate response")
1499        log.debug(response)
1500
1501        smb_response = SMB2NegotiateResponse()
1502        smb_response.unpack(response['data'].get_value())
1503
1504        return smb_response
1505
1506    def _calculate_credit_charge(self, message):
1507        """
1508        Calculates the credit charge for a request based on the command. If
1509        connection.supports_multi_credit is not True then the credit charge
1510        isn't valid so it returns 0.
1511
1512        The credit charge is the number of credits that are required for
1513        sending/receiving data over 64 kilobytes, in the existing messages only
1514        the Read, Write, Query Directory or IOCTL commands will end in this
1515        scenario and each require their own calculation to get the proper
1516        value. The generic formula for calculating the credit charge is
1517
1518        https://msdn.microsoft.com/en-us/library/dn529312.aspx
1519        (max(SendPayloadSize, Expected ResponsePayloadSize) - 1) / 65536 + 1
1520
1521        :param message: The message being sent
1522        :return: The credit charge to set on the header
1523        """
1524        if (not self.supports_multi_credit) or (message.COMMAND == Commands.SMB2_CANCEL):
1525            return 0
1526
1527        elif message.COMMAND == Commands.SMB2_READ:
1528            payload_size = message['length'].get_value() + message['read_channel_info_length'].get_value()
1529
1530        elif message.COMMAND == Commands.SMB2_WRITE:
1531            payload_size = message['length'].get_value() + message['write_channel_info_length'].get_value()
1532
1533        elif message.COMMAND == Commands.SMB2_IOCTL:
1534            max_in_size = len(message['buffer'])
1535            max_out_size = message['max_output_response'].get_value()
1536            payload_size = max(max_in_size, max_out_size)
1537
1538        elif message.COMMAND == Commands.SMB2_QUERY_DIRECTORY:
1539            max_in_size = len(message['buffer'])
1540            max_out_size = message['output_buffer_length'].get_value()
1541            payload_size = max(max_in_size, max_out_size)
1542
1543        else:
1544            payload_size = 1
1545
1546        credit_charge = (max(0, payload_size - 1) // MAX_PAYLOAD_SIZE) + 1
1547        return credit_charge
1548
1549
1550class Request(object):
1551
1552    def __init__(self, message, message_type, connection, session_id=None):
1553        """
1554        [MS-SMB2] v53.0 2017-09-15
1555
1556        3.2.1.7 Per Pending Request
1557        For each request that was sent to the server and is await a response
1558        :param message: The message to be sent in the request
1559        :param message_type: The type of message that is set in the header's data field.
1560        :param connection: The Connection the request was sent under.
1561        :param session_id: The Session Id the request was for.
1562        """
1563        self.async_id = None
1564        self.message = message
1565        self.timestamp = datetime.now()
1566        self.cancelled = False
1567
1568        # Used to contain the corresponding response from the server as the receiving in done in a separate thread.
1569        self.response = None
1570
1571        # Used by the recv processing thread to say the response has been received and is ready for consumption.
1572        self.response_event = threading.Event()
1573
1574        # Used to lock the request when the main thread is processing the PENDING result in case the background thread
1575        # receives the final result and fires the event before main clears it.
1576        self.response_event_lock = threading.Lock()
1577
1578        # Stores the message_ids of related messages that are sent in a compound request. This is only set on the 1st
1579        # message in the request. Used when STATUS_STOPPED_ON_SYMLINK is set and we need to send the whole compound
1580        # request again with the new path.
1581        self.related_ids = []
1582
1583        # Cannot rely on the message values as it could be a related compound msg which does not set these values.
1584        self.session_id = session_id
1585
1586        self._connection = connection
1587        self._message_type = message_type  # Used to rehydrate the message data in case it's needed again.
1588
1589    def cancel(self):
1590        if self.cancelled is True:
1591            return
1592
1593        message_id = self.message['message_id'].get_value()
1594        log.info("Cancelling message %s" % message_id)
1595        self._connection.send(SMB2CancelRequest(), sid=self.session_id, credit_request=0, message_id=message_id,
1596                              async_id=self.async_id)
1597        self.cancelled = True
1598
1599    def get_message_data(self):
1600        message_obj = self._message_type()
1601        message_obj.unpack(self.message['data'].get_value())
1602        return message_obj
1603
1604    def update_request(self, new_request):
1605        self.async_id = new_request.async_id
1606        self.message = new_request.message
1607        self.timestamp = new_request.timestamp
1608        self.response = new_request.response
1609        self.response_event = new_request.response_event
1610        self.response_event_lock = new_request.response_event_lock
1611        self.related_ids = new_request.related_ids
1612