1"""
2Utility functions for SMB connections
3
4:depends: impacket
5"""
6
7
8import logging
9import socket
10import uuid
11
12import salt.utils.files
13import salt.utils.stringutils
14import salt.utils.versions
15from salt.exceptions import MissingSmb
16
17log = logging.getLogger(__name__)
18
19
20try:
21    from smbprotocol.connection import Connection
22    from smbprotocol.session import Session
23    from smbprotocol.tree import TreeConnect
24    from smbprotocol.open import (
25        Open,
26        ImpersonationLevel,
27        FilePipePrinterAccessMask,
28        FileAttributes,
29        CreateDisposition,
30        CreateOptions,
31        ShareAccess,
32        DirectoryAccessMask,
33        FileInformationClass,
34    )
35    from smbprotocol.create_contexts import (
36        CreateContextName,
37        SMB2CreateContextRequest,
38        SMB2CreateQueryMaximalAccessRequest,
39    )
40    from smbprotocol.security_descriptor import (
41        AccessAllowedAce,
42        AccessMask,
43        AclPacket,
44        SDControl,
45        SIDPacket,
46        SMB2CreateSDBuffer,
47    )
48
49    logging.getLogger("smbprotocol").setLevel(logging.WARNING)
50    HAS_SMBPROTOCOL = True
51except ImportError:
52    HAS_SMBPROTOCOL = False
53
54
55class SMBProto:
56    def __init__(self, server, username, password, port=445):
57        connection_id = uuid.uuid4()
58        addr = socket.getaddrinfo(server, None, 0, 0, socket.IPPROTO_TCP)[0][4][0]
59        self.server = server
60        connection = Connection(connection_id, addr, port, require_signing=True)
61        self.session = Session(connection, username, password, require_encryption=False)
62
63    def connect(self):
64        self.connection.connect()
65        self.session.connect()
66
67    def close(self):
68        self.session.connection.disconnect(True)
69
70    @property
71    def connection(self):
72        return self.session.connection
73
74    def tree_connect(self, share):
75        if share.endswith("$"):
76            share = r"\\{}\{}".format(self.server, share)
77        tree = TreeConnect(self.session, share)
78        tree.connect()
79        return tree
80
81    @staticmethod
82    def normalize_filename(file):
83        return file.lstrip("\\")
84
85    @classmethod
86    def open_file(cls, tree, file):
87        file = cls.normalize_filename(file)
88        # ensure file is created, get maximal access, and set everybody read access
89        max_req = SMB2CreateContextRequest()
90        max_req[
91            "buffer_name"
92        ] = CreateContextName.SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST
93        max_req["buffer_data"] = SMB2CreateQueryMaximalAccessRequest()
94
95        # create security buffer that sets the ACL for everyone to have read access
96        everyone_sid = SIDPacket()
97        everyone_sid.from_string("S-1-1-0")
98        ace = AccessAllowedAce()
99        ace["mask"] = AccessMask.GENERIC_ALL
100        ace["sid"] = everyone_sid
101        acl = AclPacket()
102        acl["aces"] = [ace]
103        sec_desc = SMB2CreateSDBuffer()
104        sec_desc["control"].set_flag(SDControl.SELF_RELATIVE)
105        sec_desc.set_dacl(acl)
106        sd_buffer = SMB2CreateContextRequest()
107        sd_buffer["buffer_name"] = CreateContextName.SMB2_CREATE_SD_BUFFER
108        sd_buffer["buffer_data"] = sec_desc
109
110        create_contexts = [max_req, sd_buffer]
111        file_open = Open(tree, file)
112        open_info = file_open.create(
113            ImpersonationLevel.Impersonation,
114            FilePipePrinterAccessMask.GENERIC_READ
115            | FilePipePrinterAccessMask.GENERIC_WRITE,
116            FileAttributes.FILE_ATTRIBUTE_NORMAL,
117            ShareAccess.FILE_SHARE_READ | ShareAccess.FILE_SHARE_WRITE,
118            CreateDisposition.FILE_OVERWRITE_IF,
119            CreateOptions.FILE_NON_DIRECTORY_FILE,
120        )
121        return file_open
122
123    @staticmethod
124    def open_directory(tree, name, create=False):
125        # ensure directory is created
126        dir_open = Open(tree, name)
127        if create:
128            dir_open.create(
129                ImpersonationLevel.Impersonation,
130                DirectoryAccessMask.GENERIC_READ | DirectoryAccessMask.GENERIC_WRITE,
131                FileAttributes.FILE_ATTRIBUTE_DIRECTORY,
132                ShareAccess.FILE_SHARE_READ | ShareAccess.FILE_SHARE_WRITE,
133                CreateDisposition.FILE_OPEN_IF,
134                CreateOptions.FILE_DIRECTORY_FILE,
135            )
136        return dir_open
137
138
139def _get_conn_smbprotocol(host="", username="", password="", client_name="", port=445):
140    conn = SMBProto(host, username, password, port)
141    conn.connect()
142    return conn
143
144
145def get_conn(host="", username=None, password=None, port=445):
146    """
147    Get an SMB connection
148    """
149    if HAS_SMBPROTOCOL:
150        log.info("Get connection smbprotocol")
151        return _get_conn_smbprotocol(host, username, password, port=port)
152    else:
153        return False
154
155
156def _mkdirs_smbprotocol(
157    path, share="C$", conn=None, host=None, username=None, password=None
158):
159    if conn is None:
160        conn = get_conn(host, username, password)
161
162    if conn is False:
163        return False
164
165    tree = conn.tree_connect(share)
166    comps = path.split("/")
167    pos = 1
168    for comp in comps:
169        cwd = "\\".join(comps[0:pos])
170        dir_open = conn.open_directory(tree, cwd, create=True)
171        compound_messages = [
172            dir_open.query_directory(
173                "*", FileInformationClass.FILE_NAMES_INFORMATION, send=False
174            ),
175            dir_open.close(False, send=False),
176        ]
177        requests = conn.session.connection.send_compound(
178            [x[0] for x in compound_messages],
179            conn.session.session_id,
180            tree.tree_connect_id,
181        )
182        for i, request in enumerate(requests):
183            response = compound_messages[i][1](request)
184        pos += 1
185
186
187def mkdirs(path, share="C$", conn=None, host=None, username=None, password=None):
188    if HAS_SMBPROTOCOL:
189        return _mkdirs_smbprotocol(
190            path, share, conn=conn, host=host, username=username, password=password
191        )
192    raise MissingSmb("SMB library required (impacket or smbprotocol)")
193
194
195def _put_str_smbprotocol(
196    content, path, share="C$", conn=None, host=None, username=None, password=None
197):
198    if conn is None:
199        conn = get_conn(host, username, password)
200    if conn is False:
201        return False
202    tree = conn.tree_connect(share)
203    try:
204        file_open = conn.open_file(tree, path)
205        file_open.write(salt.utils.stringutils.to_bytes(content), 0)
206    finally:
207        file_open.close()
208
209
210def put_str(
211    content, path, share="C$", conn=None, host=None, username=None, password=None
212):
213    """
214    Wrapper around impacket.smbconnection.putFile() that allows a string to be
215    uploaded, without first writing it as a local file
216    """
217    if HAS_SMBPROTOCOL:
218        return _put_str_smbprotocol(
219            content,
220            path,
221            share,
222            conn=conn,
223            host=host,
224            username=username,
225            password=password,
226        )
227    raise MissingSmb("SMB library required (impacket or smbprotocol)")
228
229
230def _put_file_smbprotocol(
231    local_path,
232    path,
233    share="C$",
234    conn=None,
235    host=None,
236    username=None,
237    password=None,
238    chunk_size=1024 * 1024,
239):
240    if conn is None:
241        conn = get_conn(host, username, password)
242    if conn is False:
243        return False
244
245    tree = conn.tree_connect(share)
246    file_open = conn.open_file(tree, path)
247    with salt.utils.files.fopen(local_path, "rb") as fh_:
248        try:
249            position = 0
250            while True:
251                chunk = fh_.read(chunk_size)
252                if not chunk:
253                    break
254                file_open.write(chunk, position)
255                position += len(chunk)
256        finally:
257            file_open.close(False)
258
259
260def put_file(
261    local_path, path, share="C$", conn=None, host=None, username=None, password=None
262):
263    """
264    Wrapper around impacket.smbconnection.putFile() that allows a file to be
265    uploaded
266
267    Example usage:
268
269        import salt.utils.smb
270        smb_conn = salt.utils.smb.get_conn('10.0.0.45', 'vagrant', 'vagrant')
271        salt.utils.smb.put_file('/root/test.pdf', 'temp\\myfiles\\test1.pdf', conn=smb_conn)
272    """
273    if HAS_SMBPROTOCOL:
274        return _put_file_smbprotocol(
275            local_path,
276            path,
277            share,
278            conn=conn,
279            host=host,
280            username=username,
281            password=password,
282        )
283    raise MissingSmb("SMB library required (impacket or smbprotocol)")
284
285
286def _delete_file_smbprotocol(
287    path, share="C$", conn=None, host=None, username=None, password=None
288):
289    if conn is None:
290        conn = get_conn(host, username, password)
291    if conn is False:
292        return False
293    tree = conn.tree_connect(share)
294    file_open = Open(tree, path)
295    delete_msgs = [
296        file_open.create(
297            ImpersonationLevel.Impersonation,
298            FilePipePrinterAccessMask.GENERIC_READ | FilePipePrinterAccessMask.DELETE,
299            FileAttributes.FILE_ATTRIBUTE_NORMAL,
300            ShareAccess.FILE_SHARE_READ | ShareAccess.FILE_SHARE_WRITE,
301            CreateDisposition.FILE_OPEN,
302            CreateOptions.FILE_NON_DIRECTORY_FILE | CreateOptions.FILE_DELETE_ON_CLOSE,
303            send=False,
304        ),
305        file_open.close(False, send=False),
306    ]
307    requests = conn.connection.send_compound(
308        [x[0] for x in delete_msgs],
309        conn.session.session_id,
310        tree.tree_connect_id,
311        related=True,
312    )
313    responses = []
314    for i, request in enumerate(requests):
315        # A SMBResponseException will be raised if something went wrong
316        response = delete_msgs[i][1](request)
317        responses.append(response)
318
319
320def delete_file(path, share="C$", conn=None, host=None, username=None, password=None):
321    if HAS_SMBPROTOCOL:
322        return _delete_file_smbprotocol(
323            path, share, conn=conn, host=host, username=username, password=password
324        )
325    raise MissingSmb("SMB library required (impacket or smbprotocol)")
326
327
328def _delete_directory_smbprotocol(
329    path, share="C$", conn=None, host=None, username=None, password=None
330):
331    if conn is None:
332        conn = get_conn(host, username, password)
333    if conn is False:
334        return False
335    log.debug("_delete_directory_smbprotocol - share: %s, path: %s", share, path)
336    tree = conn.tree_connect(share)
337
338    dir_open = Open(tree, path)
339    delete_msgs = [
340        dir_open.create(
341            ImpersonationLevel.Impersonation,
342            DirectoryAccessMask.DELETE,
343            FileAttributes.FILE_ATTRIBUTE_DIRECTORY,
344            0,
345            CreateDisposition.FILE_OPEN,
346            CreateOptions.FILE_DIRECTORY_FILE | CreateOptions.FILE_DELETE_ON_CLOSE,
347            send=False,
348        ),
349        dir_open.close(False, send=False),
350    ]
351    delete_reqs = conn.connection.send_compound(
352        [x[0] for x in delete_msgs],
353        sid=conn.session.session_id,
354        tid=tree.tree_connect_id,
355        related=True,
356    )
357    for i, request in enumerate(delete_reqs):
358        # A SMBResponseException will be raised if something went wrong
359        response = delete_msgs[i][1](request)
360
361
362def delete_directory(
363    path, share="C$", conn=None, host=None, username=None, password=None
364):
365    if HAS_SMBPROTOCOL:
366        return _delete_directory_smbprotocol(
367            path, share, conn=conn, host=host, username=username, password=password
368        )
369    raise MissingSmb("SMB library required (impacket or smbprotocol)")
370