1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2012-2013 The Python Software Foundation.
4# See LICENSE.txt and CONTRIBUTORS.txt.
5#
6import codecs
7import os
8import logging
9import logging.handlers
10import shutil
11import socket
12try:
13    import ssl
14except ImportError:
15    ssl = None
16import sys
17import tempfile
18try:
19    import threading
20except ImportError:
21    import dummy_threading as threading
22import weakref
23
24from compat import (unittest, HTTPServer as BaseHTTPServer,
25                    SimpleHTTPRequestHandler, urlparse)
26
27from distlib import logger
28
29HERE = os.path.dirname(__file__)
30
31class _TestHandler(logging.handlers.BufferingHandler, object):
32    # stolen and adapted from test.support
33
34    def __init__(self):
35        super(_TestHandler, self).__init__(0)
36        self.setLevel(logging.DEBUG)
37
38    def shouldFlush(self):
39        return False
40
41    def emit(self, record):
42        self.buffer.append(record)
43
44class LoggingCatcher(object):
45    """TestCase-compatible mixin to receive logging calls.
46
47    Upon setUp, instances of this classes get a BufferingHandler that's
48    configured to record all messages logged to the 'distutils2' logger.
49
50    Use get_logs to retrieve messages and self.loghandler.flush to discard
51    them.  get_logs automatically flushes the logs, unless you pass
52    *flush=False*, for example to make multiple calls to the method with
53    different level arguments.  If your test calls some code that generates
54    logging message and then you don't call get_logs, you will need to flush
55    manually before testing other code in the same test_* method, otherwise
56    get_logs in the next lines will see messages from the previous lines.
57    See example in test_command_check.
58    """
59
60    def setUp(self):
61        super(LoggingCatcher, self).setUp()
62        self.loghandler = handler = _TestHandler()
63        self._old_level = logger.level
64        logger.addHandler(handler)
65        logger.setLevel(logging.DEBUG)  # we want all messages
66
67    def tearDown(self):
68        handler = self.loghandler
69        # All this is necessary to properly shut down the logging system and
70        # avoid a regrtest complaint.  Thanks to Vinay Sajip for the help.
71        handler.close()
72        logger.removeHandler(handler)
73        for ref in weakref.getweakrefs(handler):
74            logging._removeHandlerRef(ref)
75        del self.loghandler
76        logger.setLevel(self._old_level)
77        super(LoggingCatcher, self).tearDown()
78
79    def get_logs(self, level=logging.WARNING, flush=True):
80        """Return all log messages with given level.
81
82        *level* defaults to logging.WARNING.
83
84        For log calls with arguments (i.e.  logger.info('bla bla %r', arg)),
85        the messages will be formatted before being returned (e.g. "bla bla
86        'thing'").
87
88        Returns a list.  Automatically flushes the loghandler after being
89        called, unless *flush* is False (this is useful to get e.g. all
90        warnings then all info messages).
91        """
92        messages = [log.getMessage() for log in self.loghandler.buffer
93                    if log.levelno == level]
94        if flush:
95            self.loghandler.flush()
96        return messages
97
98
99class TempdirManager(object):
100    """TestCase-compatible mixin to create temporary directories and files.
101
102    Directories and files created in a test_* method will be removed after it
103    has run.
104    """
105
106    def setUp(self):
107        super(TempdirManager, self).setUp()
108        self._olddir = os.getcwd()
109        self._basetempdir = tempfile.mkdtemp()
110        self._files = []
111
112    def tearDown(self):
113        for handle, name in self._files:
114            if handle is not None:
115                handle.close()
116            os.remove(name)
117
118        os.chdir(self._olddir)
119        shutil.rmtree(self._basetempdir)
120        super(TempdirManager, self).tearDown()
121
122    def temp_filename(self):
123        """Create a read-write temporary file name and return it."""
124        fd, fn = tempfile.mkstemp(dir=self._basetempdir)
125        os.close(fd)
126        self._files.append((None, fn))
127        return fn
128
129    def mktempfile(self):
130        """Create a read-write temporary file and return it."""
131        fd, fn = tempfile.mkstemp(dir=self._basetempdir)
132        os.close(fd)
133        fp = open(fn, 'w+')
134        self._files.append((fp, fn))
135        return fp
136
137    def mkdtemp(self):
138        """Create a temporary directory and return its path."""
139        d = tempfile.mkdtemp(dir=self._basetempdir)
140        return d
141
142    def write_file(self, path, content='xxx', encoding=None):
143        """Write a file at the given path.
144
145        path can be a string, a tuple or a list; if it's a tuple or list,
146        os.path.join will be used to produce a path.
147        """
148        if isinstance(path, (list, tuple)):
149            path = os.path.join(*path)
150        f = codecs.open(path, 'w', encoding=encoding)
151        try:
152            f.write(content)
153        finally:
154            f.close()
155
156    def assertIsFile(self, *args):
157        path = os.path.join(*args)
158        dirname = os.path.dirname(path)
159        file = os.path.basename(path)
160        if os.path.isdir(dirname):
161            files = os.listdir(dirname)
162            msg = "%s not found in %s: %s" % (file, dirname, files)
163            assert os.path.isfile(path), msg
164        else:
165            raise AssertionError(
166                    '%s not found. %s does not exist' % (file, dirname))
167
168    def assertIsNotFile(self, *args):
169        path = os.path.join(*args)
170        self.assertFalse(os.path.isfile(path), "%r exists" % path)
171
172
173class EnvironRestorer(object):
174    """TestCase-compatible mixin to restore or delete environment variables.
175
176    The variables to restore (or delete if they were not originally present)
177    must be explicitly listed in self.restore_environ.  It's better to be
178    aware of what we're modifying instead of saving and restoring the whole
179    environment.
180    """
181
182    def setUp(self):
183        super(EnvironRestorer, self).setUp()
184        self._saved = []
185        self._added = []
186        for key in self.restore_environ:
187            if key in os.environ:
188                self._saved.append((key, os.environ[key]))
189            else:
190                self._added.append(key)
191
192    def tearDown(self):
193        for key, value in self._saved:
194            os.environ[key] = value
195        for key in self._added:
196            os.environ.pop(key, None)
197        super(EnvironRestorer, self).tearDown()
198
199class HTTPRequestHandler(SimpleHTTPRequestHandler):
200
201    server_version = "TestHTTPS/1.0"
202    # Avoid hanging when a request gets interrupted by the client
203    timeout = 5
204
205    def translate_path(self, path):
206        return os.path.join(HERE, 'testsrc', 'README.txt')
207
208    def log_message(self, format, *args):
209        pass
210
211class HTTPSServer(BaseHTTPServer):
212    # Adapted from the one in Python's test suite.
213    def __init__(self, server_address, handler_class, certfile):
214        BaseHTTPServer.__init__(self, server_address, handler_class)
215        self.certfile = certfile
216
217    def get_request(self):
218        try:
219            sock, addr = self.socket.accept()
220            if hasattr(ssl, 'SSLContext'):
221                context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
222                context.load_cert_chain(self.certfile)
223                sock = context.wrap_socket(sock, server_side=True)
224            else:
225                sock = ssl.wrap_socket(sock, server_side=True,
226                                       certfile=self.certfile,
227                                       keyfile=self.certfile,
228                                       ssl_version=ssl.PROTOCOL_SSLv23)
229        except socket.error as e:
230            # socket errors are silenced by the caller, print them here
231            sys.stderr.write("Got an error:\n%s\n" % e)
232            raise
233        return sock, addr
234
235class HTTPSServerThread(threading.Thread):
236
237    def __init__(self, certfile):
238        self.flag = None
239        self.server = HTTPSServer(('localhost', 0),
240                                  HTTPRequestHandler, certfile)
241        self.port = self.server.server_port
242        threading.Thread.__init__(self)
243        self.daemon = True
244
245    def start(self, flag=None):
246        self.flag = flag
247        threading.Thread.start(self)
248
249    def run(self):
250        if self.flag:
251            self.flag.set()
252        try:
253            self.server.serve_forever(0.05)
254        finally:
255            self.server.server_close()
256
257    def stop(self):
258        self.server.shutdown()
259
260try:
261    import zlib
262except ImportError:
263    zlib = None
264
265requires_zlib = unittest.skipUnless(zlib, 'requires zlib')
266
267_can_symlink = None
268def can_symlink():
269    global _can_symlink
270    if _can_symlink is not None:
271        return _can_symlink
272    fd, TESTFN = tempfile.mkstemp()
273    os.close(fd)
274    os.remove(TESTFN)
275    symlink_path = TESTFN + "can_symlink"
276    try:
277        os.symlink(TESTFN, symlink_path)
278        can = True
279    except (OSError, NotImplementedError, AttributeError):
280        can = False
281    else:
282        os.remove(symlink_path)
283    _can_symlink = can
284    return can
285
286def skip_unless_symlink(test):
287    """Skip decorator for tests that require functional symlink"""
288    ok = can_symlink()
289    msg = "Requires functional symlink implementation"
290    return test if ok else unittest.skip(msg)(test)
291
292def fake_dec(*args, **kw):
293    """Fake decorator"""
294    def _wrap(func):
295        def __wrap(*args, **kw):
296            return func(*args, **kw)
297        return __wrap
298    return _wrap
299
300def in_github_workflow():
301    return 'GITHUB_WORKFLOW' in os.environ
302
303SEP = '-' * 80
304
305class DistlibTestCase(unittest.TestCase):
306    def setUp(self):
307        logger.debug(SEP)
308        logger.debug(self.id().rsplit('.', 1)[-1])
309        logger.debug(SEP)
310