1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) 2017 - 2021, Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22"""Server for testing SMB"""
23
24from __future__ import (absolute_import, division, print_function,
25                        unicode_literals)
26
27import argparse
28import logging
29import os
30import sys
31import tempfile
32
33# Import our curl test data helper
34from util import ClosingFileHandler, TestData
35
36if sys.version_info.major >= 3:
37    import configparser
38else:
39    import ConfigParser as configparser
40
41# impacket needs to be installed in the Python environment
42try:
43    import impacket
44except ImportError:
45    sys.stderr.write('Python package impacket needs to be installed!\n')
46    sys.stderr.write('Use pip or your package manager to install it.\n')
47    sys.exit(1)
48from impacket import smb as imp_smb
49from impacket import smbserver as imp_smbserver
50from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
51                                STATUS_SUCCESS)
52
53log = logging.getLogger(__name__)
54SERVER_MAGIC = "SERVER_MAGIC"
55TESTS_MAGIC = "TESTS_MAGIC"
56VERIFIED_REQ = "verifiedserver"
57VERIFIED_RSP = "WE ROOLZ: {pid}\n"
58
59
60def smbserver(options):
61    """Start up a TCP SMB server that serves forever
62
63    """
64    if options.pidfile:
65        pid = os.getpid()
66        # see tests/server/util.c function write_pidfile
67        if os.name == "nt":
68            pid += 65536
69        with open(options.pidfile, "w") as f:
70            f.write(str(pid))
71
72    # Here we write a mini config for the server
73    smb_config = configparser.ConfigParser()
74    smb_config.add_section("global")
75    smb_config.set("global", "server_name", "SERVICE")
76    smb_config.set("global", "server_os", "UNIX")
77    smb_config.set("global", "server_domain", "WORKGROUP")
78    smb_config.set("global", "log_file", "")
79    smb_config.set("global", "credentials_file", "")
80
81    # We need a share which allows us to test that the server is running
82    smb_config.add_section("SERVER")
83    smb_config.set("SERVER", "comment", "server function")
84    smb_config.set("SERVER", "read only", "yes")
85    smb_config.set("SERVER", "share type", "0")
86    smb_config.set("SERVER", "path", SERVER_MAGIC)
87
88    # Have a share for tests.  These files will be autogenerated from the
89    # test input.
90    smb_config.add_section("TESTS")
91    smb_config.set("TESTS", "comment", "tests")
92    smb_config.set("TESTS", "read only", "yes")
93    smb_config.set("TESTS", "share type", "0")
94    smb_config.set("TESTS", "path", TESTS_MAGIC)
95
96    if not options.srcdir or not os.path.isdir(options.srcdir):
97        raise ScriptException("--srcdir is mandatory")
98
99    test_data_dir = os.path.join(options.srcdir, "data")
100
101    smb_server = TestSmbServer((options.host, options.port),
102                               config_parser=smb_config,
103                               test_data_directory=test_data_dir)
104    log.info("[SMB] setting up SMB server on port %s", options.port)
105    smb_server.processConfigFile()
106    smb_server.serve_forever()
107    return 0
108
109
110class TestSmbServer(imp_smbserver.SMBSERVER):
111    """
112    Test server for SMB which subclasses the impacket SMBSERVER and provides
113    test functionality.
114    """
115
116    def __init__(self,
117                 address,
118                 config_parser=None,
119                 test_data_directory=None):
120        imp_smbserver.SMBSERVER.__init__(self,
121                                         address,
122                                         config_parser=config_parser)
123
124        # Set up a test data object so we can get test data later.
125        self.ctd = TestData(test_data_directory)
126
127        # Override smbComNtCreateAndX so we can pretend to have files which
128        # don't exist.
129        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
130                            self.create_and_x)
131
132    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
133        """
134        Our version of smbComNtCreateAndX looks for special test files and
135        fools the rest of the framework into opening them as if they were
136        normal files.
137        """
138        conn_data = smb_server.getConnectionData(conn_id)
139
140        # Wrap processing in a try block which allows us to throw SmbException
141        # to control the flow.
142        try:
143            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
144                smb_command["Parameters"])
145
146            path = self.get_share_path(conn_data,
147                                       ncax_parms["RootFid"],
148                                       recv_packet["Tid"])
149            log.info("[SMB] Requested share path: %s", path)
150
151            disposition = ncax_parms["Disposition"]
152            log.debug("[SMB] Requested disposition: %s", disposition)
153
154            # Currently we only support reading files.
155            if disposition != imp_smb.FILE_OPEN:
156                raise SmbException(STATUS_ACCESS_DENIED,
157                                   "Only support reading files")
158
159            # Check to see if the path we were given is actually a
160            # magic path which needs generating on the fly.
161            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
162                # Pass the command onto the original handler.
163                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
164                                                                    smb_server,
165                                                                    smb_command,
166                                                                    recv_packet)
167
168            flags2 = recv_packet["Flags2"]
169            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
170                                                     data=smb_command[
171                                                         "Data"])
172            requested_file = imp_smbserver.decodeSMBString(
173                flags2,
174                ncax_data["FileName"])
175            log.debug("[SMB] User requested file '%s'", requested_file)
176
177            if path == SERVER_MAGIC:
178                fid, full_path = self.get_server_path(requested_file)
179            else:
180                assert (path == TESTS_MAGIC)
181                fid, full_path = self.get_test_path(requested_file)
182
183            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
184            resp_data = ""
185
186            # Simple way to generate a fid
187            if len(conn_data["OpenedFiles"]) == 0:
188                fakefid = 1
189            else:
190                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
191            resp_parms["Fid"] = fakefid
192            resp_parms["CreateAction"] = disposition
193
194            if os.path.isdir(path):
195                resp_parms[
196                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
197                resp_parms["IsDirectory"] = 1
198            else:
199                resp_parms["IsDirectory"] = 0
200                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
201
202            # Get this file's information
203            resp_info, error_code = imp_smbserver.queryPathInformation(
204                os.path.dirname(full_path), os.path.basename(full_path),
205                level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
206
207            if error_code != STATUS_SUCCESS:
208                raise SmbException(error_code, "Failed to query path info")
209
210            resp_parms["CreateTime"] = resp_info["CreationTime"]
211            resp_parms["LastAccessTime"] = resp_info[
212                "LastAccessTime"]
213            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
214            resp_parms["LastChangeTime"] = resp_info[
215                "LastChangeTime"]
216            resp_parms["FileAttributes"] = resp_info[
217                "ExtFileAttributes"]
218            resp_parms["AllocationSize"] = resp_info[
219                "AllocationSize"]
220            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
221
222            # Let's store the fid for the connection
223            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
224            conn_data["OpenedFiles"][fakefid] = {}
225            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
226            conn_data["OpenedFiles"][fakefid]["FileName"] = path
227            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
228
229        except SmbException as s:
230            log.debug("[SMB] SmbException hit: %s", s)
231            error_code = s.error_code
232            resp_parms = ""
233            resp_data = ""
234
235        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
236        resp_cmd["Parameters"] = resp_parms
237        resp_cmd["Data"] = resp_data
238        smb_server.setConnectionData(conn_id, conn_data)
239
240        return [resp_cmd], None, error_code
241
242    def get_share_path(self, conn_data, root_fid, tid):
243        conn_shares = conn_data["ConnectedShares"]
244
245        if tid in conn_shares:
246            if root_fid > 0:
247                # If we have a rootFid, the path is relative to that fid
248                path = conn_data["OpenedFiles"][root_fid]["FileName"]
249                log.debug("RootFid present %s!" % path)
250            else:
251                if "path" in conn_shares[tid]:
252                    path = conn_shares[tid]["path"]
253                else:
254                    raise SmbException(STATUS_ACCESS_DENIED,
255                                       "Connection share had no path")
256        else:
257            raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
258                               "TID was invalid")
259
260        return path
261
262    def get_server_path(self, requested_filename):
263        log.debug("[SMB] Get server path '%s'", requested_filename)
264
265        if requested_filename not in [VERIFIED_REQ]:
266            raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
267
268        fid, filename = tempfile.mkstemp()
269        log.debug("[SMB] Created %s (%d) for storing '%s'",
270                  filename, fid, requested_filename)
271
272        contents = ""
273
274        if requested_filename == VERIFIED_REQ:
275            log.debug("[SMB] Verifying server is alive")
276            pid = os.getpid()
277            # see tests/server/util.c function write_pidfile
278            if os.name == "nt":
279                pid += 65536
280            contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
281
282        self.write_to_fid(fid, contents)
283        return fid, filename
284
285    def write_to_fid(self, fid, contents):
286        # Write the contents to file descriptor
287        os.write(fid, contents)
288        os.fsync(fid)
289
290        # Rewind the file to the beginning so a read gets us the contents
291        os.lseek(fid, 0, os.SEEK_SET)
292
293    def get_test_path(self, requested_filename):
294        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
295
296        fid, filename = tempfile.mkstemp()
297        log.debug("[SMB] Created %s (%d) for storing test '%s'",
298                  filename, fid, requested_filename)
299
300        try:
301            contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
302            self.write_to_fid(fid, contents)
303            return fid, filename
304
305        except Exception:
306            log.exception("Failed to make test file")
307            raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
308
309
310class SmbException(Exception):
311    def __init__(self, error_code, error_message):
312        super(SmbException, self).__init__(error_message)
313        self.error_code = error_code
314
315
316class ScriptRC(object):
317    """Enum for script return codes"""
318    SUCCESS = 0
319    FAILURE = 1
320    EXCEPTION = 2
321
322
323class ScriptException(Exception):
324    pass
325
326
327def get_options():
328    parser = argparse.ArgumentParser()
329
330    parser.add_argument("--port", action="store", default=9017,
331                      type=int, help="port to listen on")
332    parser.add_argument("--host", action="store", default="127.0.0.1",
333                      help="host to listen on")
334    parser.add_argument("--verbose", action="store", type=int, default=0,
335                        help="verbose output")
336    parser.add_argument("--pidfile", action="store",
337                        help="file name for the PID")
338    parser.add_argument("--logfile", action="store",
339                        help="file name for the log")
340    parser.add_argument("--srcdir", action="store", help="test directory")
341    parser.add_argument("--id", action="store", help="server ID")
342    parser.add_argument("--ipv4", action="store_true", default=0,
343                        help="IPv4 flag")
344
345    return parser.parse_args()
346
347
348def setup_logging(options):
349    """
350    Set up logging from the command line options
351    """
352    root_logger = logging.getLogger()
353    add_stdout = False
354
355    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
356
357    # Write out to a logfile
358    if options.logfile:
359        handler = ClosingFileHandler(options.logfile)
360        handler.setFormatter(formatter)
361        handler.setLevel(logging.DEBUG)
362        root_logger.addHandler(handler)
363    else:
364        # The logfile wasn't specified. Add a stdout logger.
365        add_stdout = True
366
367    if options.verbose:
368        # Add a stdout logger as well in verbose mode
369        root_logger.setLevel(logging.DEBUG)
370        add_stdout = True
371    else:
372        root_logger.setLevel(logging.INFO)
373
374    if add_stdout:
375        stdout_handler = logging.StreamHandler(sys.stdout)
376        stdout_handler.setFormatter(formatter)
377        stdout_handler.setLevel(logging.DEBUG)
378        root_logger.addHandler(stdout_handler)
379
380
381if __name__ == '__main__':
382    # Get the options from the user.
383    options = get_options()
384
385    # Setup logging using the user options
386    setup_logging(options)
387
388    # Run main script.
389    try:
390        rc = smbserver(options)
391    except Exception as e:
392        log.exception(e)
393        rc = ScriptRC.EXCEPTION
394
395    if options.pidfile and os.path.isfile(options.pidfile):
396        os.unlink(options.pidfile)
397
398    log.info("[SMB] Returning %d", rc)
399    sys.exit(rc)
400