1# This module contains some code copied from unittest2/loader.py and other 2# code developed in reference to that module and others within unittest2. 3 4# unittest2 is Copyright (c) 2001-2010 Python Software Foundation; All 5# Rights Reserved. See: http://docs.python.org/license.html 6 7import logging 8import os 9import types 10import re 11import sys 12import traceback 13import platform 14import six 15import inspect 16from inspect import isgeneratorfunction 17 18 19__unittest = True 20IDENT_RE = re.compile(r'^[_a-zA-Z]\w*$', re.UNICODE) 21VALID_MODULE_RE = re.compile(r'[_a-zA-Z]\w*\.py$', re.UNICODE) 22 23 24def ln(label, char='-', width=70): 25 """Draw a divider, with ``label`` in the middle. 26 27 >>> ln('hello there') 28 '---------------------------- hello there -----------------------------' 29 30 ``width`` and divider ``char`` may be specified. Defaults are ``70`` and ``'-'``, 31 respectively. 32 33 """ 34 label_len = len(label) + 2 35 chunk = (width - label_len) // 2 36 out = '%s %s %s' % (char * chunk, label, char * chunk) 37 pad = width - len(out) 38 if pad > 0: 39 out = out + (char * pad) 40 return out 41 42 43def valid_module_name(path): 44 """Is ``path`` a valid module name?""" 45 return VALID_MODULE_RE.search(path) 46 47 48def name_from_path(path): 49 """Translate ``path`` into module name 50 51 Returns a two-element tuple: 52 53 1. a dotted module name that can be used in an import statement 54 (e.g., ``pkg1.test.test_things``) 55 56 2. a full path to filesystem directory, which must be on ``sys.path`` 57 for the import to succeed. 58 59 """ 60 # back up to find module root 61 parts = [] 62 path = os.path.normpath(path) 63 base = os.path.splitext(path)[0] 64 candidate, top = os.path.split(base) 65 parts.append(top) 66 while candidate: 67 if ispackage(candidate): 68 candidate, top = os.path.split(candidate) 69 parts.append(top) 70 else: 71 break 72 return '.'.join(reversed(parts)), candidate 73 74 75def module_from_name(name): 76 """Import module from ``name``""" 77 __import__(name) 78 return sys.modules[name] 79 80 81def test_from_name(name, module): 82 """Import test from ``name``""" 83 pos = name.find(':') 84 index = None 85 if pos != -1: 86 real_name, digits = name[:pos], name[pos + 1:] 87 try: 88 index = int(digits) 89 except ValueError: 90 pass 91 else: 92 name = real_name 93 parent, obj = object_from_name(name, module) 94 return parent, obj, name, index 95 96 97def object_from_name(name, module=None): 98 """ 99 Given a dotted name, return the corresponding object. 100 101 Getting the object can fail for two reason: 102 103 - the object is a module that cannot be imported. 104 - the object is a class or a function that does not exists. 105 106 Since we cannot distinguish between these two cases, we assume we are in 107 the first one. We expect the stacktrace is explicit enough for the user to 108 understand the error. 109 """ 110 import_error = None 111 parts = name.split('.') 112 if module is None: 113 (module, import_error) = try_import_module_from_name(parts[:]) 114 parts = parts[1:] 115 parent = None 116 obj = module 117 for part in parts: 118 try: 119 parent, obj = obj, getattr(obj, part) 120 except AttributeError as e: 121 if is_package_or_module(obj) and import_error: 122 # Re-raise the import error which got us here, since 123 # it probably better describes the issue. 124 _raise_custom_attribute_error(obj, part, e, import_error) 125 else: 126 raise 127 128 return parent, obj 129 130 131def _raise_custom_attribute_error(obj, attr, attr_error_exc, prev_exc): 132 133 if sys.version_info >= (3, 0): 134 six.raise_from(attr_error_exc, prev_exc[1]) 135 136 # for python 2, do exception chaining manually 137 raise AttributeError( 138 "'%s' has not attribute '%s'\n\nMaybe caused by\n\n%s" % ( 139 obj, attr, '\n'.join(traceback.format_exception(*prev_exc)))) 140 141 142def is_package_or_module(obj): 143 if hasattr(obj, '__path__') or isinstance(obj, types.ModuleType): 144 return True 145 return False 146 147 148def try_import_module_from_name(splitted_name): 149 """ 150 Try to find the longest importable from the ``splitted_name``, and return 151 the corresponding module, as well as the potential ``ImportError`` 152 exception that occurs when trying to import a longer name. 153 154 For instance, if ``splitted_name`` is ['a', 'b', 'c'] but only ``a.b`` is 155 importable, this function: 156 157 1. tries to import ``a.b.c`` and fails 158 2. tries to import ``a.b`` and succeeds 159 3. return ``a.b`` and the exception that occured at step 1. 160 """ 161 module = None 162 import_error = None 163 while splitted_name: 164 try: 165 module = __import__('.'.join(splitted_name)) 166 break 167 except: 168 import_error = sys.exc_info() 169 del splitted_name[-1] 170 if not splitted_name: 171 six.reraise(*sys.exc_info()) 172 return (module, import_error) 173 174 175def name_from_args(name, index, args): 176 """Create test name from test args""" 177 summary = ', '.join(repr(arg) for arg in args) 178 return '%s:%s\n%s' % (name, index + 1, summary[:79]) 179 180 181def test_name(test, qualname=True): 182 # XXX does not work for test funcs; test.id() lacks module 183 if hasattr(test, '_funcName'): 184 tid = test._funcName 185 elif hasattr(test, '_testFunc'): 186 tid = "%s.%s" % (test._testFunc.__module__, test._testFunc.__name__) 187 else: 188 if sys.version_info >= (3, 5) and not qualname: 189 test_module = test.__class__.__module__ 190 test_class = test.__class__.__name__ 191 test_method = test._testMethodName 192 tid = "%s.%s.%s" % (test_module, test_class, test_method) 193 else: 194 tid = test.id() 195 if '\n' in tid: 196 tid = tid.split('\n')[0] 197 # subtest support 198 if ' ' in tid: 199 tid = tid.split(' ')[0] 200 return tid 201 202 203def ispackage(path): 204 """Is this path a package directory?""" 205 if os.path.isdir(path): 206 # at least the end of the path must be a legal python identifier 207 # and __init__.py[co] must exist 208 end = os.path.basename(path) 209 if IDENT_RE.match(end): 210 for init in ('__init__.py', '__init__.pyc', '__init__.pyo'): 211 if os.path.isfile(os.path.join(path, init)): 212 return True 213 if sys.platform.startswith('java') and \ 214 os.path.isfile(os.path.join(path, '__init__$py.class')): 215 return True 216 return False 217 218 219def ensure_importable(dirname): 220 """Ensure a directory is on ``sys.path``.""" 221 if dirname not in sys.path: 222 sys.path.insert(0, dirname) 223 224 225def isgenerator(obj): 226 """Is this object a generator?""" 227 return (isgeneratorfunction(obj) 228 or getattr(obj, 'testGenerator', None) is not None) 229 230 231def has_module_fixtures(test): 232 """Does this test live in a module with module fixtures?""" 233 modname = test.__class__.__module__ 234 try: 235 mod = sys.modules[modname] 236 except KeyError: 237 return 238 return hasattr(mod, 'setUpModule') or hasattr(mod, 'tearDownModule') 239 240 241def has_class_fixtures(test): 242 # hasattr would be the obvious thing to use here. Unfortunately, all tests 243 # inherit from unittest2.case.TestCase, and that *always* has setUpClass and 244 # tearDownClass methods. Thus, exclude the unitest and unittest2 base 245 # classes from the lookup. 246 def is_not_base_class(c): 247 return ( 248 "unittest.case" not in c.__module__ and 249 "unittest2.case" not in c.__module__) 250 has_class_setups = any( 251 'setUpClass' in c.__dict__ for c in test.__class__.__mro__ if is_not_base_class(c)) 252 has_class_teardowns = any( 253 'tearDownClass' in c.__dict__ for c in test.__class__.__mro__ if is_not_base_class(c)) 254 return has_class_setups or has_class_teardowns 255 256 257def safe_decode(string): 258 """Safely decode a byte string into unicode""" 259 if string is None: 260 return string 261 try: 262 return string.decode() 263 except AttributeError: 264 return string 265 except UnicodeDecodeError: 266 pass 267 try: 268 return string.decode('utf-8') 269 except UnicodeDecodeError: 270 return six.u('<unable to decode>') 271 272 273def safe_encode(string, encoding='utf-8'): 274 if string is None: 275 return string 276 if encoding is None: 277 encoding = 'utf-8' 278 try: 279 return string.encode(encoding) 280 except AttributeError: 281 return string 282 except UnicodeDecodeError: 283 # already encoded 284 return string 285 except UnicodeEncodeError: 286 return six.u('<unable to encode>') 287 288 289def exc_info_to_string(err, test): 290 """Format exception info for output""" 291 formatTraceback = getattr(test, 'formatTraceback', None) 292 if formatTraceback is not None: 293 return test.formatTraceback(err) 294 else: 295 return format_traceback(test, err) 296 297 298def format_traceback(test, err): 299 """Converts a :func:`sys.exc_info` -style tuple of values into a string.""" 300 exctype, value, tb = err 301 if not hasattr(tb, 'tb_next'): 302 msgLines = tb 303 else: 304 # Skip test runner traceback levels 305 while tb and _is_relevant_tb_level(tb): 306 tb = tb.tb_next 307 failure = getattr(test, 'failureException', AssertionError) 308 if exctype is failure: 309 # Skip assert*() traceback levels 310 length = _count_relevant_tb_levels(tb) 311 msgLines = traceback.format_exception(exctype, value, tb, length) 312 else: 313 msgLines = traceback.format_exception(exctype, value, tb) 314 315 return ''.join(msgLines) 316 317 318def transplant_class(cls, module): 319 """Make ``cls`` appear to reside in ``module``. 320 321 :param cls: A class 322 :param module: A module name 323 :returns: A subclass of ``cls`` that appears to have been defined in ``module``. 324 325 The returned class's ``__name__`` will be equal to 326 ``cls.__name__``, and its ``__module__`` equal to ``module``. 327 328 """ 329 class C(cls): 330 pass 331 C.__module__ = module 332 C.__name__ = cls.__name__ 333 return C 334 335 336def parse_log_level(lvl): 337 """Return numeric log level given a string""" 338 try: 339 return int(lvl) 340 except ValueError: 341 pass 342 return getattr(logging, lvl.upper(), logging.WARN) 343 344 345def _is_relevant_tb_level(tb): 346 return '__unittest' in tb.tb_frame.f_globals 347 348 349def _count_relevant_tb_levels(tb): 350 length = 0 351 while tb and not _is_relevant_tb_level(tb): 352 length += 1 353 tb = tb.tb_next 354 return length 355 356 357class _WritelnDecorator(object): 358 359 """Used to decorate file-like objects with a handy :func:`writeln` method""" 360 361 def __init__(self, stream): 362 self.stream = stream 363 364 def __getattr__(self, attr): 365 if attr in ('stream', '__getstate__'): 366 raise AttributeError(attr) 367 return getattr(self.stream, attr) 368 369 def write(self, arg): 370 if sys.version_info[0] == 2: 371 arg = safe_encode(arg, getattr(self.stream, 'encoding', 'utf-8')) 372 self.stream.write(arg) 373 374 def writeln(self, arg=None): 375 if arg: 376 self.write(arg) 377 self.write('\n') # text-mode streams translate to \r\n if needed 378 379 380def ancestry(layer): 381 layers = [[layer]] 382 bases = [base for base in bases_and_mixins(layer) 383 if base is not object] 384 while bases: 385 layers.append(bases) 386 newbases = [] 387 for b in bases: 388 for bb in bases_and_mixins(b): 389 if bb is not object: 390 newbases.append(bb) 391 bases = newbases 392 layers.reverse() 393 return layers 394 395 396def bases_and_mixins(layer): 397 return (layer.__bases__ + getattr(layer, 'mixins', ())) 398 399 400def num_expected_args(func): 401 """Return the number of arguments that :func: expects""" 402 if six.PY2: 403 return len(inspect.getargspec(func)[0]) 404 else: 405 return len(inspect.getfullargspec(func)[0]) 406 407 408def call_with_args_if_expected(func, *args): 409 """Take :func: and call it with supplied :args:, in case that signature expects any. 410 Otherwise call the function without any arguments. 411 """ 412 if num_expected_args(func) > 0: 413 func(*args) 414 else: 415 func() 416