1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) 2017 - 2021, Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22"""Server for testing SMB"""
23
24from __future__ import absolute_import, division, print_function
25# NOTE: the impacket configuration is not unicode_literals compatible!
26
27import argparse
28import logging
29import os
30import sys
31import tempfile
32
33# Import our curl test data helper
34from util import ClosingFileHandler, TestData
35
36if sys.version_info.major >= 3:
37    import configparser
38else:
39    import ConfigParser as configparser
40
41# impacket needs to be installed in the Python environment
42try:
43    import impacket
44except ImportError:
45    sys.stderr.write('Python package impacket needs to be installed!\n')
46    sys.stderr.write('Use pip or your package manager to install it.\n')
47    sys.exit(1)
48from impacket import smb as imp_smb
49from impacket import smbserver as imp_smbserver
50from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
51                                STATUS_SUCCESS)
52
53log = logging.getLogger(__name__)
54SERVER_MAGIC = "SERVER_MAGIC"
55TESTS_MAGIC = "TESTS_MAGIC"
56VERIFIED_REQ = "verifiedserver"
57VERIFIED_RSP = "WE ROOLZ: {pid}\n"
58
59
60def smbserver(options):
61    """Start up a TCP SMB server that serves forever
62
63    """
64    if options.pidfile:
65        pid = os.getpid()
66        # see tests/server/util.c function write_pidfile
67        if os.name == "nt":
68            pid += 65536
69        with open(options.pidfile, "w") as f:
70            f.write(str(pid))
71
72    # Here we write a mini config for the server
73    smb_config = configparser.ConfigParser()
74    smb_config.add_section("global")
75    smb_config.set("global", "server_name", "SERVICE")
76    smb_config.set("global", "server_os", "UNIX")
77    smb_config.set("global", "server_domain", "WORKGROUP")
78    smb_config.set("global", "log_file", "")
79    smb_config.set("global", "credentials_file", "")
80
81    # We need a share which allows us to test that the server is running
82    smb_config.add_section("SERVER")
83    smb_config.set("SERVER", "comment", "server function")
84    smb_config.set("SERVER", "read only", "yes")
85    smb_config.set("SERVER", "share type", "0")
86    smb_config.set("SERVER", "path", SERVER_MAGIC)
87
88    # Have a share for tests.  These files will be autogenerated from the
89    # test input.
90    smb_config.add_section("TESTS")
91    smb_config.set("TESTS", "comment", "tests")
92    smb_config.set("TESTS", "read only", "yes")
93    smb_config.set("TESTS", "share type", "0")
94    smb_config.set("TESTS", "path", TESTS_MAGIC)
95
96    if not options.srcdir or not os.path.isdir(options.srcdir):
97        raise ScriptException("--srcdir is mandatory")
98
99    test_data_dir = os.path.join(options.srcdir, "data")
100
101    smb_server = TestSmbServer((options.host, options.port),
102                               config_parser=smb_config,
103                               test_data_directory=test_data_dir)
104    log.info("[SMB] setting up SMB server on port %s", options.port)
105    smb_server.processConfigFile()
106    smb_server.serve_forever()
107    return 0
108
109
110class TestSmbServer(imp_smbserver.SMBSERVER):
111    """
112    Test server for SMB which subclasses the impacket SMBSERVER and provides
113    test functionality.
114    """
115
116    def __init__(self,
117                 address,
118                 config_parser=None,
119                 test_data_directory=None):
120        imp_smbserver.SMBSERVER.__init__(self,
121                                         address,
122                                         config_parser=config_parser)
123
124        # Set up a test data object so we can get test data later.
125        self.ctd = TestData(test_data_directory)
126
127        # Override smbComNtCreateAndX so we can pretend to have files which
128        # don't exist.
129        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
130                            self.create_and_x)
131
132    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
133        """
134        Our version of smbComNtCreateAndX looks for special test files and
135        fools the rest of the framework into opening them as if they were
136        normal files.
137        """
138        conn_data = smb_server.getConnectionData(conn_id)
139
140        # Wrap processing in a try block which allows us to throw SmbException
141        # to control the flow.
142        try:
143            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
144                smb_command["Parameters"])
145
146            path = self.get_share_path(conn_data,
147                                       ncax_parms["RootFid"],
148                                       recv_packet["Tid"])
149            log.info("[SMB] Requested share path: %s", path)
150
151            disposition = ncax_parms["Disposition"]
152            log.debug("[SMB] Requested disposition: %s", disposition)
153
154            # Currently we only support reading files.
155            if disposition != imp_smb.FILE_OPEN:
156                raise SmbException(STATUS_ACCESS_DENIED,
157                                   "Only support reading files")
158
159            # Check to see if the path we were given is actually a
160            # magic path which needs generating on the fly.
161            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
162                # Pass the command onto the original handler.
163                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
164                                                                    smb_server,
165                                                                    smb_command,
166                                                                    recv_packet)
167
168            flags2 = recv_packet["Flags2"]
169            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
170                                                     data=smb_command[
171                                                         "Data"])
172            requested_file = imp_smbserver.decodeSMBString(
173                flags2,
174                ncax_data["FileName"])
175            log.debug("[SMB] User requested file '%s'", requested_file)
176
177            if path == SERVER_MAGIC:
178                fid, full_path = self.get_server_path(requested_file)
179            else:
180                assert (path == TESTS_MAGIC)
181                fid, full_path = self.get_test_path(requested_file)
182
183            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
184            resp_data = ""
185
186            # Simple way to generate a fid
187            if len(conn_data["OpenedFiles"]) == 0:
188                fakefid = 1
189            else:
190                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
191            resp_parms["Fid"] = fakefid
192            resp_parms["CreateAction"] = disposition
193
194            if os.path.isdir(path):
195                resp_parms[
196                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
197                resp_parms["IsDirectory"] = 1
198            else:
199                resp_parms["IsDirectory"] = 0
200                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
201
202            # Get this file's information
203            resp_info, error_code = imp_smbserver.queryPathInformation(
204                "", full_path, level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
205
206            if error_code != STATUS_SUCCESS:
207                raise SmbException(error_code, "Failed to query path info")
208
209            resp_parms["CreateTime"] = resp_info["CreationTime"]
210            resp_parms["LastAccessTime"] = resp_info[
211                "LastAccessTime"]
212            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
213            resp_parms["LastChangeTime"] = resp_info[
214                "LastChangeTime"]
215            resp_parms["FileAttributes"] = resp_info[
216                "ExtFileAttributes"]
217            resp_parms["AllocationSize"] = resp_info[
218                "AllocationSize"]
219            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
220
221            # Let's store the fid for the connection
222            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
223            conn_data["OpenedFiles"][fakefid] = {}
224            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
225            conn_data["OpenedFiles"][fakefid]["FileName"] = path
226            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
227
228        except SmbException as s:
229            log.debug("[SMB] SmbException hit: %s", s)
230            error_code = s.error_code
231            resp_parms = ""
232            resp_data = ""
233
234        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
235        resp_cmd["Parameters"] = resp_parms
236        resp_cmd["Data"] = resp_data
237        smb_server.setConnectionData(conn_id, conn_data)
238
239        return [resp_cmd], None, error_code
240
241    def get_share_path(self, conn_data, root_fid, tid):
242        conn_shares = conn_data["ConnectedShares"]
243
244        if tid in conn_shares:
245            if root_fid > 0:
246                # If we have a rootFid, the path is relative to that fid
247                path = conn_data["OpenedFiles"][root_fid]["FileName"]
248                log.debug("RootFid present %s!" % path)
249            else:
250                if "path" in conn_shares[tid]:
251                    path = conn_shares[tid]["path"]
252                else:
253                    raise SmbException(STATUS_ACCESS_DENIED,
254                                       "Connection share had no path")
255        else:
256            raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
257                               "TID was invalid")
258
259        return path
260
261    def get_server_path(self, requested_filename):
262        log.debug("[SMB] Get server path '%s'", requested_filename)
263
264        if requested_filename not in [VERIFIED_REQ]:
265            raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
266
267        fid, filename = tempfile.mkstemp()
268        log.debug("[SMB] Created %s (%d) for storing '%s'",
269                  filename, fid, requested_filename)
270
271        contents = ""
272
273        if requested_filename == VERIFIED_REQ:
274            log.debug("[SMB] Verifying server is alive")
275            pid = os.getpid()
276            # see tests/server/util.c function write_pidfile
277            if os.name == "nt":
278                pid += 65536
279            contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
280
281        self.write_to_fid(fid, contents)
282        return fid, filename
283
284    def write_to_fid(self, fid, contents):
285        # Write the contents to file descriptor
286        os.write(fid, contents)
287        os.fsync(fid)
288
289        # Rewind the file to the beginning so a read gets us the contents
290        os.lseek(fid, 0, os.SEEK_SET)
291
292    def get_test_path(self, requested_filename):
293        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
294
295        fid, filename = tempfile.mkstemp()
296        log.debug("[SMB] Created %s (%d) for storing test '%s'",
297                  filename, fid, requested_filename)
298
299        try:
300            contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
301            self.write_to_fid(fid, contents)
302            return fid, filename
303
304        except Exception:
305            log.exception("Failed to make test file")
306            raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
307
308
309class SmbException(Exception):
310    def __init__(self, error_code, error_message):
311        super(SmbException, self).__init__(error_message)
312        self.error_code = error_code
313
314
315class ScriptRC(object):
316    """Enum for script return codes"""
317    SUCCESS = 0
318    FAILURE = 1
319    EXCEPTION = 2
320
321
322class ScriptException(Exception):
323    pass
324
325
326def get_options():
327    parser = argparse.ArgumentParser()
328
329    parser.add_argument("--port", action="store", default=9017,
330                      type=int, help="port to listen on")
331    parser.add_argument("--host", action="store", default="127.0.0.1",
332                      help="host to listen on")
333    parser.add_argument("--verbose", action="store", type=int, default=0,
334                        help="verbose output")
335    parser.add_argument("--pidfile", action="store",
336                        help="file name for the PID")
337    parser.add_argument("--logfile", action="store",
338                        help="file name for the log")
339    parser.add_argument("--srcdir", action="store", help="test directory")
340    parser.add_argument("--id", action="store", help="server ID")
341    parser.add_argument("--ipv4", action="store_true", default=0,
342                        help="IPv4 flag")
343
344    return parser.parse_args()
345
346
347def setup_logging(options):
348    """
349    Set up logging from the command line options
350    """
351    root_logger = logging.getLogger()
352    add_stdout = False
353
354    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
355
356    # Write out to a logfile
357    if options.logfile:
358        handler = ClosingFileHandler(options.logfile)
359        handler.setFormatter(formatter)
360        handler.setLevel(logging.DEBUG)
361        root_logger.addHandler(handler)
362    else:
363        # The logfile wasn't specified. Add a stdout logger.
364        add_stdout = True
365
366    if options.verbose:
367        # Add a stdout logger as well in verbose mode
368        root_logger.setLevel(logging.DEBUG)
369        add_stdout = True
370    else:
371        root_logger.setLevel(logging.INFO)
372
373    if add_stdout:
374        stdout_handler = logging.StreamHandler(sys.stdout)
375        stdout_handler.setFormatter(formatter)
376        stdout_handler.setLevel(logging.DEBUG)
377        root_logger.addHandler(stdout_handler)
378
379
380if __name__ == '__main__':
381    # Get the options from the user.
382    options = get_options()
383
384    # Setup logging using the user options
385    setup_logging(options)
386
387    # Run main script.
388    try:
389        rc = smbserver(options)
390    except Exception as e:
391        log.exception(e)
392        rc = ScriptRC.EXCEPTION
393
394    if options.pidfile and os.path.isfile(options.pidfile):
395        os.unlink(options.pidfile)
396
397    log.info("[SMB] Returning %d", rc)
398    sys.exit(rc)
399