1# This module contains some code copied from unittest2/loader.py and other
2# code developed in reference to that module and others within unittest2.
3
4# unittest2 is Copyright (c) 2001-2010 Python Software Foundation; All
5# Rights Reserved. See: http://docs.python.org/license.html
6
7import logging
8import os
9import types
10import re
11import sys
12import traceback
13import platform
14import six
15import inspect
16from inspect import isgeneratorfunction
17
18
19__unittest = True
20IDENT_RE = re.compile(r'^[_a-zA-Z]\w*$', re.UNICODE)
21VALID_MODULE_RE = re.compile(r'[_a-zA-Z]\w*\.py$', re.UNICODE)
22
23
24def ln(label, char='-', width=70):
25    """Draw a divider, with ``label`` in the middle.
26
27    >>> ln('hello there')
28    '---------------------------- hello there -----------------------------'
29
30    ``width`` and divider ``char`` may be specified. Defaults are ``70`` and ``'-'``,
31    respectively.
32
33    """
34    label_len = len(label) + 2
35    chunk = (width - label_len) // 2
36    out = '%s %s %s' % (char * chunk, label, char * chunk)
37    pad = width - len(out)
38    if pad > 0:
39        out = out + (char * pad)
40    return out
41
42
43def valid_module_name(path):
44    """Is ``path`` a valid module name?"""
45    return VALID_MODULE_RE.search(path)
46
47
48def name_from_path(path):
49    """Translate ``path`` into module name
50
51    Returns a two-element tuple:
52
53    1. a dotted module name that can be used in an import statement
54       (e.g., ``pkg1.test.test_things``)
55
56    2. a full path to filesystem directory, which must be on ``sys.path``
57       for the import to succeed.
58
59    """
60    # back up to find module root
61    parts = []
62    path = os.path.normpath(path)
63    base = os.path.splitext(path)[0]
64    candidate, top = os.path.split(base)
65    parts.append(top)
66    while candidate:
67        if ispackage(candidate):
68            candidate, top = os.path.split(candidate)
69            parts.append(top)
70        else:
71            break
72    return '.'.join(reversed(parts)), candidate
73
74
75def module_from_name(name):
76    """Import module from ``name``"""
77    __import__(name)
78    return sys.modules[name]
79
80
81def test_from_name(name, module):
82    """Import test from ``name``"""
83    pos = name.find(':')
84    index = None
85    if pos != -1:
86        real_name, digits = name[:pos], name[pos + 1:]
87        try:
88            index = int(digits)
89        except ValueError:
90            pass
91        else:
92            name = real_name
93    parent, obj = object_from_name(name, module)
94    return parent, obj, name, index
95
96
97def object_from_name(name, module=None):
98    """
99    Given a dotted name, return the corresponding object.
100
101    Getting the object can fail for two reason:
102
103        - the object is a module that cannot be imported.
104        - the object is a class or a function that does not exists.
105
106    Since we cannot distinguish between these two cases, we assume we are in
107    the first one. We expect the stacktrace is explicit enough for the user to
108    understand the error.
109    """
110    import_error = None
111    parts = name.split('.')
112    if module is None:
113        (module, import_error) = try_import_module_from_name(parts[:])
114        parts = parts[1:]
115    parent = None
116    obj = module
117    for part in parts:
118        try:
119            parent, obj = obj, getattr(obj, part)
120        except AttributeError as e:
121            if is_package_or_module(obj) and import_error:
122                # Re-raise the import error which got us here, since
123                # it probably better describes the issue.
124                _raise_custom_attribute_error(obj, part, e, import_error)
125            else:
126                raise
127
128    return parent, obj
129
130
131def _raise_custom_attribute_error(obj, attr, attr_error_exc, prev_exc):
132
133    if sys.version_info >= (3, 0):
134        six.raise_from(attr_error_exc, prev_exc[1])
135
136    # for python 2, do exception chaining manually
137    raise AttributeError(
138        "'%s' has not attribute '%s'\n\nMaybe caused by\n\n%s" % (
139            obj, attr, '\n'.join(traceback.format_exception(*prev_exc))))
140
141
142def is_package_or_module(obj):
143    if hasattr(obj, '__path__') or isinstance(obj, types.ModuleType):
144        return True
145    return False
146
147
148def try_import_module_from_name(splitted_name):
149    """
150    Try to find the longest importable from the ``splitted_name``, and return
151    the corresponding module, as well as the potential ``ImportError``
152    exception that occurs when trying to import a longer name.
153
154    For instance, if ``splitted_name`` is ['a', 'b', 'c'] but only ``a.b`` is
155    importable, this function:
156
157        1. tries to import ``a.b.c`` and fails
158        2. tries to import ``a.b`` and succeeds
159        3. return ``a.b`` and the exception that occured at step 1.
160    """
161    module = None
162    import_error = None
163    while splitted_name:
164        try:
165            module = __import__('.'.join(splitted_name))
166            break
167        except:
168            import_error = sys.exc_info()
169            del splitted_name[-1]
170            if not splitted_name:
171                six.reraise(*sys.exc_info())
172    return (module, import_error)
173
174
175def name_from_args(name, index, args):
176    """Create test name from test args"""
177    summary = ', '.join(repr(arg) for arg in args)
178    return '%s:%s\n%s' % (name, index + 1, summary[:79])
179
180
181def test_name(test, qualname=True):
182    # XXX does not work for test funcs; test.id() lacks module
183    if hasattr(test, '_funcName'):
184        tid = test._funcName
185    elif hasattr(test, '_testFunc'):
186        tid = "%s.%s" % (test._testFunc.__module__, test._testFunc.__name__)
187    else:
188        if sys.version_info >= (3, 5) and not qualname:
189            test_module = test.__class__.__module__
190            test_class = test.__class__.__name__
191            test_method = test._testMethodName
192            tid = "%s.%s.%s" % (test_module, test_class, test_method)
193        else:
194            tid = test.id()
195    if '\n' in tid:
196        tid = tid.split('\n')[0]
197    # subtest support
198    if ' ' in tid:
199        tid = tid.split(' ')[0]
200    return tid
201
202
203def ispackage(path):
204    """Is this path a package directory?"""
205    if os.path.isdir(path):
206        # at least the end of the path must be a legal python identifier
207        # and __init__.py[co] must exist
208        end = os.path.basename(path)
209        if IDENT_RE.match(end):
210            for init in ('__init__.py', '__init__.pyc', '__init__.pyo'):
211                if os.path.isfile(os.path.join(path, init)):
212                    return True
213            if sys.platform.startswith('java') and \
214                    os.path.isfile(os.path.join(path, '__init__$py.class')):
215                return True
216    return False
217
218
219def ensure_importable(dirname):
220    """Ensure a directory is on ``sys.path``."""
221    if dirname not in sys.path:
222        sys.path.insert(0, dirname)
223
224
225def isgenerator(obj):
226    """Is this object a generator?"""
227    return (isgeneratorfunction(obj)
228            or getattr(obj, 'testGenerator', None) is not None)
229
230
231def has_module_fixtures(test):
232    """Does this test live in a module with module fixtures?"""
233    modname = test.__class__.__module__
234    try:
235        mod = sys.modules[modname]
236    except KeyError:
237        return
238    return hasattr(mod, 'setUpModule') or hasattr(mod, 'tearDownModule')
239
240
241def has_class_fixtures(test):
242    # hasattr would be the obvious thing to use here. Unfortunately, all tests
243    # inherit from unittest2.case.TestCase, and that *always* has setUpClass and
244    # tearDownClass methods. Thus, exclude the unitest and unittest2 base
245    # classes from the lookup.
246    def is_not_base_class(c):
247        return (
248            "unittest.case" not in c.__module__ and
249            "unittest2.case" not in c.__module__)
250    has_class_setups = any(
251        'setUpClass' in c.__dict__ for c in test.__class__.__mro__ if is_not_base_class(c))
252    has_class_teardowns = any(
253        'tearDownClass' in c.__dict__ for c in test.__class__.__mro__ if is_not_base_class(c))
254    return has_class_setups or has_class_teardowns
255
256
257def safe_decode(string):
258    """Safely decode a byte string into unicode"""
259    if string is None:
260        return string
261    try:
262        return string.decode()
263    except AttributeError:
264        return string
265    except UnicodeDecodeError:
266        pass
267    try:
268        return string.decode('utf-8')
269    except UnicodeDecodeError:
270        return six.u('<unable to decode>')
271
272
273def safe_encode(string, encoding='utf-8'):
274    if string is None:
275        return string
276    if encoding is None:
277        encoding = 'utf-8'
278    try:
279        return string.encode(encoding)
280    except AttributeError:
281        return string
282    except UnicodeDecodeError:
283        # already encoded
284        return string
285    except UnicodeEncodeError:
286        return six.u('<unable to encode>')
287
288
289def exc_info_to_string(err, test):
290    """Format exception info for output"""
291    formatTraceback = getattr(test, 'formatTraceback', None)
292    if formatTraceback is not None:
293        return test.formatTraceback(err)
294    else:
295        return format_traceback(test, err)
296
297
298def format_traceback(test, err):
299    """Converts a :func:`sys.exc_info` -style tuple of values into a string."""
300    exctype, value, tb = err
301    if not hasattr(tb, 'tb_next'):
302        msgLines = tb
303    else:
304        # Skip test runner traceback levels
305        while tb and _is_relevant_tb_level(tb):
306            tb = tb.tb_next
307        failure = getattr(test, 'failureException', AssertionError)
308        if exctype is failure:
309            # Skip assert*() traceback levels
310            length = _count_relevant_tb_levels(tb)
311            msgLines = traceback.format_exception(exctype, value, tb, length)
312        else:
313            msgLines = traceback.format_exception(exctype, value, tb)
314
315    return ''.join(msgLines)
316
317
318def transplant_class(cls, module):
319    """Make ``cls`` appear to reside in ``module``.
320
321    :param cls: A class
322    :param module: A module name
323    :returns: A subclass of ``cls`` that appears to have been defined in ``module``.
324
325    The returned class's ``__name__`` will be equal to
326    ``cls.__name__``, and its ``__module__`` equal to ``module``.
327
328    """
329    class C(cls):
330        pass
331    C.__module__ = module
332    C.__name__ = cls.__name__
333    return C
334
335
336def parse_log_level(lvl):
337    """Return numeric log level given a string"""
338    try:
339        return int(lvl)
340    except ValueError:
341        pass
342    return getattr(logging, lvl.upper(), logging.WARN)
343
344
345def _is_relevant_tb_level(tb):
346    return '__unittest' in tb.tb_frame.f_globals
347
348
349def _count_relevant_tb_levels(tb):
350    length = 0
351    while tb and not _is_relevant_tb_level(tb):
352        length += 1
353        tb = tb.tb_next
354    return length
355
356
357class _WritelnDecorator(object):
358
359    """Used to decorate file-like objects with a handy :func:`writeln` method"""
360
361    def __init__(self, stream):
362        self.stream = stream
363
364    def __getattr__(self, attr):
365        if attr in ('stream', '__getstate__'):
366            raise AttributeError(attr)
367        return getattr(self.stream, attr)
368
369    def write(self, arg):
370        if sys.version_info[0] == 2:
371            arg = safe_encode(arg, getattr(self.stream, 'encoding', 'utf-8'))
372        self.stream.write(arg)
373
374    def writeln(self, arg=None):
375        if arg:
376            self.write(arg)
377        self.write('\n')  # text-mode streams translate to \r\n if needed
378
379
380def ancestry(layer):
381    layers = [[layer]]
382    bases = [base for base in bases_and_mixins(layer)
383             if base is not object]
384    while bases:
385        layers.append(bases)
386        newbases = []
387        for b in bases:
388            for bb in bases_and_mixins(b):
389                if bb is not object:
390                    newbases.append(bb)
391        bases = newbases
392    layers.reverse()
393    return layers
394
395
396def bases_and_mixins(layer):
397    return (layer.__bases__ + getattr(layer, 'mixins', ()))
398
399
400def num_expected_args(func):
401    """Return the number of arguments that :func: expects"""
402    if six.PY2:
403        return len(inspect.getargspec(func)[0])
404    else:
405        return len(inspect.getfullargspec(func)[0])
406
407
408def call_with_args_if_expected(func, *args):
409    """Take :func: and call it with supplied :args:, in case that signature expects any.
410    Otherwise call the function without any arguments.
411    """
412    if num_expected_args(func) > 0:
413        func(*args)
414    else:
415        func()
416