1# -*- coding: utf-8 -*- 2# 3# Copyright (C) 2012-2013 The Python Software Foundation. 4# See LICENSE.txt and CONTRIBUTORS.txt. 5# 6import codecs 7import os 8import logging 9import logging.handlers 10import shutil 11import socket 12try: 13 import ssl 14except ImportError: 15 ssl = None 16import sys 17import tempfile 18try: 19 import threading 20except ImportError: 21 import dummy_threading as threading 22import weakref 23 24from compat import (unittest, HTTPServer as BaseHTTPServer, 25 SimpleHTTPRequestHandler, urlparse) 26 27from distlib import logger 28 29HERE = os.path.dirname(__file__) 30 31class _TestHandler(logging.handlers.BufferingHandler, object): 32 # stolen and adapted from test.support 33 34 def __init__(self): 35 super(_TestHandler, self).__init__(0) 36 self.setLevel(logging.DEBUG) 37 38 def shouldFlush(self): 39 return False 40 41 def emit(self, record): 42 self.buffer.append(record) 43 44class LoggingCatcher(object): 45 """TestCase-compatible mixin to receive logging calls. 46 47 Upon setUp, instances of this classes get a BufferingHandler that's 48 configured to record all messages logged to the 'distutils2' logger. 49 50 Use get_logs to retrieve messages and self.loghandler.flush to discard 51 them. get_logs automatically flushes the logs, unless you pass 52 *flush=False*, for example to make multiple calls to the method with 53 different level arguments. If your test calls some code that generates 54 logging message and then you don't call get_logs, you will need to flush 55 manually before testing other code in the same test_* method, otherwise 56 get_logs in the next lines will see messages from the previous lines. 57 See example in test_command_check. 58 """ 59 60 def setUp(self): 61 super(LoggingCatcher, self).setUp() 62 self.loghandler = handler = _TestHandler() 63 self._old_level = logger.level 64 logger.addHandler(handler) 65 logger.setLevel(logging.DEBUG) # we want all messages 66 67 def tearDown(self): 68 handler = self.loghandler 69 # All this is necessary to properly shut down the logging system and 70 # avoid a regrtest complaint. Thanks to Vinay Sajip for the help. 71 handler.close() 72 logger.removeHandler(handler) 73 for ref in weakref.getweakrefs(handler): 74 logging._removeHandlerRef(ref) 75 del self.loghandler 76 logger.setLevel(self._old_level) 77 super(LoggingCatcher, self).tearDown() 78 79 def get_logs(self, level=logging.WARNING, flush=True): 80 """Return all log messages with given level. 81 82 *level* defaults to logging.WARNING. 83 84 For log calls with arguments (i.e. logger.info('bla bla %r', arg)), 85 the messages will be formatted before being returned (e.g. "bla bla 86 'thing'"). 87 88 Returns a list. Automatically flushes the loghandler after being 89 called, unless *flush* is False (this is useful to get e.g. all 90 warnings then all info messages). 91 """ 92 messages = [log.getMessage() for log in self.loghandler.buffer 93 if log.levelno == level] 94 if flush: 95 self.loghandler.flush() 96 return messages 97 98 99class TempdirManager(object): 100 """TestCase-compatible mixin to create temporary directories and files. 101 102 Directories and files created in a test_* method will be removed after it 103 has run. 104 """ 105 106 def setUp(self): 107 super(TempdirManager, self).setUp() 108 self._olddir = os.getcwd() 109 self._basetempdir = tempfile.mkdtemp() 110 self._files = [] 111 112 def tearDown(self): 113 for handle, name in self._files: 114 if handle is not None: 115 handle.close() 116 os.remove(name) 117 118 os.chdir(self._olddir) 119 shutil.rmtree(self._basetempdir) 120 super(TempdirManager, self).tearDown() 121 122 def temp_filename(self): 123 """Create a read-write temporary file name and return it.""" 124 fd, fn = tempfile.mkstemp(dir=self._basetempdir) 125 os.close(fd) 126 self._files.append((None, fn)) 127 return fn 128 129 def mktempfile(self): 130 """Create a read-write temporary file and return it.""" 131 fd, fn = tempfile.mkstemp(dir=self._basetempdir) 132 os.close(fd) 133 fp = open(fn, 'w+') 134 self._files.append((fp, fn)) 135 return fp 136 137 def mkdtemp(self): 138 """Create a temporary directory and return its path.""" 139 d = tempfile.mkdtemp(dir=self._basetempdir) 140 return d 141 142 def write_file(self, path, content='xxx', encoding=None): 143 """Write a file at the given path. 144 145 path can be a string, a tuple or a list; if it's a tuple or list, 146 os.path.join will be used to produce a path. 147 """ 148 if isinstance(path, (list, tuple)): 149 path = os.path.join(*path) 150 f = codecs.open(path, 'w', encoding=encoding) 151 try: 152 f.write(content) 153 finally: 154 f.close() 155 156 def assertIsFile(self, *args): 157 path = os.path.join(*args) 158 dirname = os.path.dirname(path) 159 file = os.path.basename(path) 160 if os.path.isdir(dirname): 161 files = os.listdir(dirname) 162 msg = "%s not found in %s: %s" % (file, dirname, files) 163 assert os.path.isfile(path), msg 164 else: 165 raise AssertionError( 166 '%s not found. %s does not exist' % (file, dirname)) 167 168 def assertIsNotFile(self, *args): 169 path = os.path.join(*args) 170 self.assertFalse(os.path.isfile(path), "%r exists" % path) 171 172 173class EnvironRestorer(object): 174 """TestCase-compatible mixin to restore or delete environment variables. 175 176 The variables to restore (or delete if they were not originally present) 177 must be explicitly listed in self.restore_environ. It's better to be 178 aware of what we're modifying instead of saving and restoring the whole 179 environment. 180 """ 181 182 def setUp(self): 183 super(EnvironRestorer, self).setUp() 184 self._saved = [] 185 self._added = [] 186 for key in self.restore_environ: 187 if key in os.environ: 188 self._saved.append((key, os.environ[key])) 189 else: 190 self._added.append(key) 191 192 def tearDown(self): 193 for key, value in self._saved: 194 os.environ[key] = value 195 for key in self._added: 196 os.environ.pop(key, None) 197 super(EnvironRestorer, self).tearDown() 198 199class HTTPRequestHandler(SimpleHTTPRequestHandler): 200 201 server_version = "TestHTTPS/1.0" 202 # Avoid hanging when a request gets interrupted by the client 203 timeout = 5 204 205 def translate_path(self, path): 206 return os.path.join(HERE, 'testsrc', 'README.txt') 207 208 def log_message(self, format, *args): 209 pass 210 211class HTTPSServer(BaseHTTPServer): 212 # Adapted from the one in Python's test suite. 213 def __init__(self, server_address, handler_class, certfile): 214 BaseHTTPServer.__init__(self, server_address, handler_class) 215 self.certfile = certfile 216 217 def get_request(self): 218 try: 219 sock, addr = self.socket.accept() 220 if hasattr(ssl, 'SSLContext'): 221 context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) 222 context.load_cert_chain(self.certfile) 223 sock = context.wrap_socket(sock, server_side=True) 224 else: 225 sock = ssl.wrap_socket(sock, server_side=True, 226 certfile=self.certfile, 227 keyfile=self.certfile, 228 ssl_version=ssl.PROTOCOL_SSLv23) 229 except socket.error as e: 230 # socket errors are silenced by the caller, print them here 231 sys.stderr.write("Got an error:\n%s\n" % e) 232 raise 233 return sock, addr 234 235class HTTPSServerThread(threading.Thread): 236 237 def __init__(self, certfile): 238 self.flag = None 239 self.server = HTTPSServer(('localhost', 0), 240 HTTPRequestHandler, certfile) 241 self.port = self.server.server_port 242 threading.Thread.__init__(self) 243 self.daemon = True 244 245 def start(self, flag=None): 246 self.flag = flag 247 threading.Thread.start(self) 248 249 def run(self): 250 if self.flag: 251 self.flag.set() 252 try: 253 self.server.serve_forever(0.05) 254 finally: 255 self.server.server_close() 256 257 def stop(self): 258 self.server.shutdown() 259 260try: 261 import zlib 262except ImportError: 263 zlib = None 264 265requires_zlib = unittest.skipUnless(zlib, 'requires zlib') 266 267_can_symlink = None 268def can_symlink(): 269 global _can_symlink 270 if _can_symlink is not None: 271 return _can_symlink 272 fd, TESTFN = tempfile.mkstemp() 273 os.close(fd) 274 os.remove(TESTFN) 275 symlink_path = TESTFN + "can_symlink" 276 try: 277 os.symlink(TESTFN, symlink_path) 278 can = True 279 except (OSError, NotImplementedError, AttributeError): 280 can = False 281 else: 282 os.remove(symlink_path) 283 _can_symlink = can 284 return can 285 286def skip_unless_symlink(test): 287 """Skip decorator for tests that require functional symlink""" 288 ok = can_symlink() 289 msg = "Requires functional symlink implementation" 290 return test if ok else unittest.skip(msg)(test) 291 292def fake_dec(*args, **kw): 293 """Fake decorator""" 294 def _wrap(func): 295 def __wrap(*args, **kw): 296 return func(*args, **kw) 297 return __wrap 298 return _wrap 299 300def in_github_workflow(): 301 return 'GITHUB_WORKFLOW' in os.environ 302 303SEP = '-' * 80 304 305class DistlibTestCase(unittest.TestCase): 306 def setUp(self): 307 logger.debug(SEP) 308 logger.debug(self.id().rsplit('.', 1)[-1]) 309 logger.debug(SEP) 310