1# Copyright (C) 2007 Giampaolo Rodola' <g.rodola@gmail.com>.
2# Use of this source code is governed by MIT license that can be
3# found in the LICENSE file.
4
5from __future__ import print_function
6import contextlib
7import errno
8import functools
9import logging
10import multiprocessing
11import os
12import shutil
13import socket
14import sys
15import tempfile
16import threading
17import time
18try:
19    from unittest import mock  # py3
20except ImportError:
21    import mock  # NOQA - requires "pip install mock"
22
23from pyftpdlib._compat import getcwdu
24from pyftpdlib._compat import u
25from pyftpdlib.authorizers import DummyAuthorizer
26from pyftpdlib.handlers import _import_sendfile
27from pyftpdlib.handlers import FTPHandler
28from pyftpdlib.ioloop import IOLoop
29from pyftpdlib.servers import FTPServer
30
31import psutil
32
33if sys.version_info < (2, 7):
34    import unittest2 as unittest  # pip install unittest2
35else:
36    import unittest
37
38if not hasattr(unittest.TestCase, "assertRaisesRegex"):
39    unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
40
41sendfile = _import_sendfile()
42
43
44# Attempt to use IP rather than hostname (test suite will run a lot faster)
45try:
46    HOST = socket.gethostbyname('localhost')
47except socket.error:
48    HOST = 'localhost'
49USER = 'user'
50PASSWD = '12345'
51HOME = getcwdu()
52TESTFN = 'tmp-pyftpdlib'
53TESTFN_UNICODE = TESTFN + '-unicode-' + '\xe2\x98\x83'
54TESTFN_UNICODE_2 = TESTFN_UNICODE + '-2'
55TIMEOUT = 2
56BUFSIZE = 1024
57INTERRUPTED_TRANSF_SIZE = 32768
58NO_RETRIES = 5
59OSX = sys.platform.startswith("darwin")
60POSIX = os.name == 'posix'
61WINDOWS = os.name == 'nt'
62TRAVIS = bool(os.environ.get('TRAVIS'))
63VERBOSITY = 1 if os.getenv('SILENT') else 2
64
65
66class TestCase(unittest.TestCase):
67
68    def __str__(self):
69        return "%s.%s.%s" % (
70            self.__class__.__module__, self.__class__.__name__,
71            self._testMethodName)
72
73
74# Hack that overrides default unittest.TestCase in order to print
75# a full path representation of the single unit tests being run.
76unittest.TestCase = TestCase
77
78
79def close_client(session):
80    """Closes a ftplib.FTP client session."""
81    try:
82        if session.sock is not None:
83            try:
84                resp = session.quit()
85            except Exception:
86                pass
87            else:
88                # ...just to make sure the server isn't replying to some
89                # pending command.
90                assert resp.startswith('221'), resp
91    finally:
92        session.close()
93
94
95def try_address(host, port=0, family=socket.AF_INET):
96    """Try to bind a socket on the given host:port and return True
97    if that has been possible."""
98    try:
99        with contextlib.closing(socket.socket(family)) as sock:
100            sock.bind((host, port))
101    except (socket.error, socket.gaierror):
102        return False
103    else:
104        return True
105
106
107SUPPORTS_IPV4 = try_address('127.0.0.1')
108SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
109SUPPORTS_SENDFILE = hasattr(os, 'sendfile') or sendfile is not None
110
111
112def safe_remove(*files):
113    "Convenience function for removing temporary test files"
114    for file in files:
115        try:
116            os.remove(file)
117        except OSError as err:
118            if os.name == 'nt':
119                return
120            if err.errno != errno.ENOENT:
121                raise
122
123
124def safe_rmdir(dir):
125    "Convenience function for removing temporary test directories"
126    try:
127        os.rmdir(dir)
128    except OSError as err:
129        if os.name == 'nt':
130            return
131        if err.errno != errno.ENOENT:
132            raise
133
134
135def safe_mkdir(dir):
136    "Convenience function for creating a directory"
137    try:
138        os.mkdir(dir)
139    except OSError as err:
140        if err.errno != errno.EEXIST:
141            raise
142
143
144def touch(name):
145    """Create a file and return its name."""
146    with open(name, 'w') as f:
147        return f.name
148
149
150def remove_test_files():
151    """Remove files and directores created during tests."""
152    for name in os.listdir(u('.')):
153        if name.startswith(tempfile.template):
154            if os.path.isdir(name):
155                shutil.rmtree(name)
156            else:
157                safe_remove(name)
158
159
160def configure_logging():
161    """Set pyftpdlib logger to "WARNING" level."""
162    channel = logging.StreamHandler()
163    logger = logging.getLogger('pyftpdlib')
164    logger.setLevel(logging.WARNING)
165    logger.addHandler(channel)
166
167
168def disable_log_warning(fun):
169    """Temporarily set FTP server's logging level to ERROR."""
170    @functools.wraps(fun)
171    def wrapper(self, *args, **kwargs):
172        logger = logging.getLogger('pyftpdlib')
173        level = logger.getEffectiveLevel()
174        logger.setLevel(logging.ERROR)
175        try:
176            return fun(self, *args, **kwargs)
177        finally:
178            logger.setLevel(level)
179    return wrapper
180
181
182def cleanup():
183    """Cleanup function executed on interpreter exit."""
184    remove_test_files()
185    map = IOLoop.instance().socket_map
186    for x in list(map.values()):
187        try:
188            sys.stderr.write("garbage: %s\n" % repr(x))
189            x.close()
190        except Exception:
191            pass
192    map.clear()
193
194
195def retry_on_failure(ntimes=None):
196    """Decorator to retry a test in case of failure."""
197    def decorator(fun):
198        @functools.wraps(fun)
199        def wrapper(*args, **kwargs):
200            for x in range(ntimes or NO_RETRIES):
201                try:
202                    return fun(*args, **kwargs)
203                except AssertionError as _:
204                    err = _
205            raise err
206        return wrapper
207    return decorator
208
209
210def call_until(fun, expr, timeout=TIMEOUT):
211    """Keep calling function for timeout secs and exit if eval()
212    expression is True.
213    """
214    stop_at = time.time() + timeout
215    while time.time() < stop_at:
216        ret = fun()
217        if eval(expr):
218            return ret
219        time.sleep(0.001)
220    raise RuntimeError('timed out (ret=%r)' % ret)
221
222
223def get_server_handler():
224    """Return the first FTPHandler instance running in the IOLoop."""
225    ioloop = IOLoop.instance()
226    for fd in ioloop.socket_map:
227        instance = ioloop.socket_map[fd]
228        if isinstance(instance, FTPHandler):
229            return instance
230    raise RuntimeError("can't find any FTPHandler instance")
231
232
233# commented out as per bug http://bugs.python.org/issue10354
234# tempfile.template = 'tmp-pyftpdlib'
235
236def setup_server(handler, server_class, addr=None):
237    addr = (HOST, 0) if addr is None else addr
238    authorizer = DummyAuthorizer()
239    # full perms
240    authorizer.add_user(USER, PASSWD, HOME, perm='elradfmwMT')
241    authorizer.add_anonymous(HOME)
242    handler.authorizer = authorizer
243    handler.auth_failed_timeout = 0.001
244    # lower buffer sizes = more "loops" while transfering data
245    # = less false positives
246    handler.dtp_handler.ac_in_buffer_size = 4096
247    handler.dtp_handler.ac_out_buffer_size = 4096
248    server = server_class(addr, handler)
249    return server
250
251
252def assert_free_resources():
253    ts = threading.enumerate()
254    assert len(ts) == 1, ts
255    p = psutil.Process()
256    children = p.children()
257    if children:
258        for p in children:
259            p.kill()
260            p.wait(1)
261        assert not children, children
262    cons = [x for x in p.connections('tcp')
263            if x.status != psutil.CONN_CLOSE_WAIT]
264    assert not cons, cons
265
266
267def reset_server_opts():
268    # Since all pyftpdlib configurable "options" are class attributes
269    # we reset them at module.class level.
270    import pyftpdlib.handlers
271    import pyftpdlib.servers
272    from pyftpdlib.handlers import _import_sendfile
273
274    # Control handlers.
275    tls_handler = getattr(pyftpdlib.handlers, "TLS_FTPHandler",
276                          pyftpdlib.handlers.FTPHandler)
277    for klass in (pyftpdlib.handlers.FTPHandler, tls_handler):
278        klass.auth_failed_timeout = 0.001
279        klass.authorizer = DummyAuthorizer()
280        klass.banner = "pyftpdlib ready."
281        klass.masquerade_address = None
282        klass.masquerade_address_map = {}
283        klass.max_login_attempts = 3
284        klass.passive_ports = None
285        klass.permit_foreign_addresses = False
286        klass.permit_privileged_ports = False
287        klass.tcp_no_delay = hasattr(socket, 'TCP_NODELAY')
288        klass.timeout = 300
289        klass.unicode_errors = "replace"
290        klass.use_gmt_times = True
291        klass.use_sendfile = _import_sendfile() is not None
292        klass.ac_in_buffer_size = 4096
293        klass.ac_out_buffer_size = 4096
294        if klass.__name__ == 'TLS_FTPHandler':
295            klass.tls_control_required = False
296            klass.tls_data_required = False
297
298    # Data handlers.
299    tls_handler = getattr(pyftpdlib.handlers, "TLS_DTPHandler",
300                          pyftpdlib.handlers.DTPHandler)
301    for klass in (pyftpdlib.handlers.DTPHandler, tls_handler):
302        klass.timeout = 300
303        klass.ac_in_buffer_size = 4096
304        klass.ac_out_buffer_size = 4096
305    pyftpdlib.handlers.ThrottledDTPHandler.read_limit = 0
306    pyftpdlib.handlers.ThrottledDTPHandler.write_limit = 0
307    pyftpdlib.handlers.ThrottledDTPHandler.auto_sized_buffers = True
308
309    # Acceptors.
310    ls = [pyftpdlib.servers.FTPServer,
311          pyftpdlib.servers.ThreadedFTPServer]
312    if os.name == 'posix':
313        ls.append(pyftpdlib.servers.MultiprocessFTPServer)
314    for klass in ls:
315        klass.max_cons = 0
316        klass.max_cons_per_ip = 0
317
318
319class ThreadedTestFTPd(threading.Thread):
320    """A threaded FTP server used for running tests.
321    This is basically a modified version of the FTPServer class which
322    wraps the polling loop into a thread.
323    The instance returned can be start()ed and stop()ped.
324    """
325    handler = FTPHandler
326    server_class = FTPServer
327    poll_interval = 0.001 if TRAVIS else 0.000001
328    # Makes the thread stop on interpreter exit.
329    daemon = True
330
331    def __init__(self, addr=None):
332        super(ThreadedTestFTPd, self).__init__(name='test-ftpd')
333        self.server = setup_server(self.handler, self.server_class, addr=addr)
334        self.host, self.port = self.server.socket.getsockname()[:2]
335
336        self.lock = threading.Lock()
337        self._stop_flag = False
338        self._event_stop = threading.Event()
339
340    def run(self):
341        try:
342            while not self._stop_flag:
343                with self.lock:
344                    self.server.serve_forever(timeout=self.poll_interval,
345                                              blocking=False)
346        finally:
347            self._event_stop.set()
348
349    def stop(self):
350        self._stop_flag = True  # signal the main loop to exit
351        self._event_stop.wait()
352        self.server.close_all()
353        self.join()
354        reset_server_opts()
355        assert_free_resources()
356
357
358class MProcessTestFTPd(multiprocessing.Process):
359    """Same as above but using a sub process instead."""
360    handler = FTPHandler
361    server_class = FTPServer
362
363    def __init__(self, addr=None):
364        super(MProcessTestFTPd, self).__init__(name='test-ftpd')
365        self.server = setup_server(self.handler, self.server_class, addr=addr)
366        self.host, self.port = self.server.socket.getsockname()[:2]
367        self._started = False
368
369    def run(self):
370        assert not self._started
371        self._started = True
372        self.server.serve_forever()
373
374    def stop(self):
375        self.server.close_all()
376        self.terminate()
377        self.join()
378        reset_server_opts()
379        assert_free_resources()
380