1# -*- coding: utf-8 -*-
2
3###
4### oktest.py -- new style test utility
5###
6### $Release: 0.11.0 $
7### $Copyright: copyright(c) 2010-2011 kuwata-lab.com all rights reserved $
8### $License: MIT License $
9###
10
11__all__ = ('ok', 'NOT', 'NG', 'not_ok', 'run', 'spec', 'test', 'fail', 'skip', 'todo', 'subject', 'situation', 'main')
12__version__ = "$Release: 0.11.0 $".split()[1]
13
14import sys, os, re, types, traceback, time, linecache
15
16python2 = sys.version_info[0] == 2
17python3 = sys.version_info[0] == 3
18if python2:
19    from cStringIO import StringIO
20if python3:
21    xrange = range
22    from io import StringIO
23
24
25def _new_module(name, local_vars, util=None):
26    try:
27        mod = type(sys)(name)
28    except:
29        # The module creation above does not work for Jython 2.5.2
30        import imp
31        mod = imp.new_module(name)
32    sys.modules[name] = mod
33    mod.__dict__.update(local_vars)
34    if util and getattr(mod, '__all__', None):
35        for k in mod.__all__:
36            util.__dict__[k] = mod.__dict__[k]
37        util.__all__ += mod.__all__
38    return mod
39
40
41__unittest = True    # see unittest.TestResult._is_relevant_tb_level()
42
43
44config = _new_module('oktest.config', {
45    "debug": False,
46    #"color_enabled": _sys.platform.startswith(('darwin', 'linux', 'freebsd', 'netbsd'))  # not work on Python2.4
47    #"color_enabled": any(lambda x: _sys.platform.startswith(x), ('darwin', 'linux', 'freebsd', 'netbsd'))  # not work on Python2.4
48    "color_available": bool([ 1 for p in ('darwin', 'linux', 'freebsd', 'netbsd') if sys.platform.startswith(p) ]),
49    "color_enabled":  None,    # None means detect automatiallly
50    "TARGET_PATTERN": '.*(Test|TestCase|_TC)$',   # class name pattern of test case
51})
52
53
54## not used for compatibility with unittest
55#class TestFailed(AssertionError):
56#
57#    def __init__(self, mesg, file=None, line=None, diff=None):
58#        AssertionError.__init__(self, mesg)
59#        self.file = file
60#        self.line = line
61#        self.diff = diff
62#
63
64ASSERTION_ERROR = AssertionError
65
66
67def _diff_p(target, op, other):
68    if op != '==':             return False
69    if target == other:        return False
70    #if not util._is_string(target): return False
71    #if not util._is_string(other):  return False
72    if not DIFF:               return False
73    is_a = isinstance
74    if is_a(target, str) and is_a(other, str):
75        return True
76    if python2 and is_a(target, unicode) and is_a(other, unicode):
77        return True
78    return False
79
80
81def _truncated_repr(obj, max=80+15):
82    s = repr(obj)
83    if len(s) > max:
84        return s[:max - 15] + ' [truncated]...'
85    return s
86
87
88def _msg(target, op, other=None):
89    if   op.endswith('()'):   msg = '%r%s'     % (target, op)
90    elif op.startswith('.'):  msg = '%r%s(%r)' % (target, op, other)
91    else:                     msg = '%r %s %r' % (target, op, other)
92    msg += " : failed."
93    return msg
94
95
96def _msg2(target, op, other=None):
97    diff_str = _diff_p(target, op, other) and _diff(target, other) or ''
98    if diff_str:
99        #msg = "actual %s expected : failed.\n" % (op,)
100        msg = "%s == %s : failed." % (_truncated_repr(target), _truncated_repr(other))
101        return (msg, diff_str)
102    else:
103        return _msg(target, op, other)
104
105
106DIFF = True
107
108def _diff(target, other):
109    from difflib import unified_diff
110    if hasattr(DIFF, '__call__'):
111        expected = [ DIFF(line) + "\n" for line in other.splitlines(True) ]
112        actual   = [ DIFF(line) + "\n" for line in target.splitlines(True) ]
113    else:
114        if other.find("\n") == -1 and target.find("\n") == -1:
115            expected, actual = [other + "\n"], [target + "\n"]
116        else:
117            expected, actual = other.splitlines(True), target.splitlines(True)
118            if not expected: expected.append('')
119            if not actual:   actual.append('')
120            for lines in (expected, actual):
121                if not lines[-1].endswith("\n"):
122                    lines[-1] += "\n\\ No newline at end of string\n"
123    return ''.join(unified_diff(expected, actual, 'expected', 'actual', n=2))
124
125
126def assertion(func):
127    """decorator to declare assertion function.
128       ex.
129         @oktest.assertion
130         def startswith(self, arg):
131           boolean = self.target.startswith(arg)
132           if boolean != self.boolean:
133             self.failed("%r.startswith(%r) : failed." % (self.target, arg))
134         #
135         ok ("Sasaki").startswith("Sas")
136    """
137    def deco(self, *args):
138        self._tested = True
139        return func(self, *args)
140    deco.__name__ = func.__name__
141    deco.__doc__ = func.__doc__
142    setattr(AssertionObject, func.__name__, deco)
143    return deco
144
145
146#def deprecated(f):
147#    return f
148
149
150class AssertionObject(object):
151
152    def __init__(self, target, boolean=True):
153        self.target = target
154        self.boolean = boolean
155        self._tested = False
156        self._location = None
157
158    def __del__(self):
159        if self._tested is False:
160            msg = "%s() is called but not tested." % (self.boolean and 'ok' or 'not_ok')
161            if self._location:
162                msg += " (file '%s', line %s)" % self._location
163            #import warnings; warnings.warn(msg)
164            sys.stderr.write("*** warning: oktest: %s\n" % msg)
165
166    #def not_(self):
167    #    self.boolean = not self.boolean
168    #    return self
169
170    def failed(self, msg, depth=2, boolean=None):
171        file, line = util.get_location(depth + 1)
172        diff = None
173        if isinstance(msg, tuple):
174            msg, diff = msg
175        if boolean is None: boolean = self.boolean
176        if boolean is False:
177            msg = 'not ' + msg
178        raise self._assertion_error(msg, file, line, diff)
179
180    def _assertion_error(self, msg, file, line, diff):
181        #return TestFailed(msg, file=file, line=line, diff=diff)
182        ex = ASSERTION_ERROR(diff and msg + "\n" + diff or msg)
183        ex.file = file;  ex.line = line;  ex.diff = diff;  ex.errmsg = msg
184        ex._raised_by_oktest = True
185        return ex
186
187    @property
188    def should(self):           # UNDOCUMENTED
189        """(experimental) allows user to call True/False method as assertion.
190           ex.
191             ok ("SOS").should.startswith("S")   # same as ok ("SOS".startswith("S")) == True
192             ok ("123").should.isdigit()         # same as ok ("123".isdigit()) == True
193        """
194        return Should(self, self.boolean)
195
196    @property
197    def should_not(self):       # UNDOCUMENTED
198        """(experimental) allows user to call True/False method as assertion.
199           ex.
200             ok ("SOS").should_not.startswith("X")   # same as ok ("SOS".startswith("X")) == False
201             ok ("123").should_not.isalpha()         # same as ok ("123".isalpha()) == False
202        """
203        return Should(self, not self.boolean)
204
205
206def _f():
207
208    @assertion
209    def __eq__(self, other):
210        boolean = self.target == other
211        if boolean == self.boolean:  return self
212        #self.failed(_msg(self.target, '==', other))
213        self.failed(_msg2(self.target, '==', other))
214
215    @assertion
216    def __ne__(self, other):
217        boolean = self.target != other
218        if boolean == self.boolean:  return self
219        self.failed(_msg(self.target, '!=', other))
220
221    @assertion
222    def __gt__(self, other):
223        boolean = self.target > other
224        if boolean == self.boolean:  return self
225        self.failed(_msg(self.target, '>', other))
226
227    @assertion
228    def __ge__(self, other):
229        boolean = self.target >= other
230        if boolean == self.boolean:  return self
231        self.failed(_msg(self.target, '>=', other))
232
233    @assertion
234    def __lt__(self, other):
235        boolean = self.target < other
236        if boolean == self.boolean:  return self
237        self.failed(_msg(self.target, '<', other))
238
239    @assertion
240    def __le__(self, other):
241        boolean = self.target <= other
242        if boolean == self.boolean:  return self
243        self.failed(_msg(self.target, '<=', other))
244
245    @assertion
246    def in_delta(self, other, delta):
247        boolean = self.target > other - delta
248        if boolean != self.boolean:
249            self.failed(_msg(self.target, '>', other - delta))
250        boolean = self.target < other + delta
251        if boolean != self.boolean:
252            self.failed(_msg(self.target, '<', other + delta))
253        return self
254
255#    @assertion
256#    def __contains__(self, other):
257#        boolean = self.target in other
258#        if boolean == self.boolean:  return self
259#        self.failed(_msg(self.target, 'in', other))
260
261    @assertion
262    def in_(self, other):
263        boolean = self.target in other
264        if boolean == self.boolean:  return self
265        self.failed(_msg(self.target, 'in', other))
266
267    @assertion
268    def not_in(self, other):
269        boolean = self.target not in other
270        if boolean == self.boolean:  return self
271        self.failed(_msg(self.target, 'not in', other))
272
273    @assertion
274    def contains(self, other):
275        boolean = other in self.target
276        if boolean == self.boolean:  return self
277        self.failed(_msg(other, 'in', self.target))
278
279    @assertion
280    def not_contain(self, other):  # DEPRECATED
281        boolean = other in self.target
282        if boolean == self.boolean:  return self
283        self.failed(_msg(other, 'not in', self.target))
284
285    @assertion
286    def is_(self, other):
287        boolean = self.target is other
288        if boolean == self.boolean:  return self
289        self.failed(_msg(self.target, 'is', other))
290
291    @assertion
292    def is_not(self, other):
293        boolean = self.target is not other
294        if boolean == self.boolean:  return self
295        self.failed(_msg(self.target, 'is not', other))
296
297    @assertion
298    def is_a(self, other):
299        boolean = isinstance(self.target, other)
300        if boolean == self.boolean:  return self
301        self.failed("isinstance(%r, %s) : failed." % (self.target, other.__name__))
302
303    @assertion
304    def is_not_a(self, other):
305        boolean = not isinstance(self.target, other)
306        if boolean == self.boolean:  return self
307        self.failed("not isinstance(%r, %s) : failed." % (self.target, other.__name__))
308
309    @assertion
310    def has_attr(self, name):
311        boolean = hasattr(self.target, name)
312        if boolean == self.boolean:  return self
313        self.failed("hasattr(%r, %r) : failed." % (self.target, name))
314
315    @assertion
316    def attr(self, name, expected):
317        if not hasattr(self.target, name):
318            self.failed("hasattr(%r, %r) : failed." % (self.target, name), boolean=True)
319        boolean = getattr(self.target, name) == expected
320        if boolean == self.boolean:  return self
321        prefix = 'attr(%r): ' % name
322        msg = _msg2(getattr(self.target, name), "==", expected)
323        if isinstance(msg, tuple):
324            msg = (prefix + msg[0], msg[1])
325        else:
326            msg = prefix + msg
327        self.failed(msg)
328
329    @assertion
330    def matches(self, pattern, flags=0):
331        if isinstance(pattern, type(re.compile('x'))):
332            boolean = bool(pattern.search(self.target))
333            if boolean == self.boolean:  return self
334            self.failed("re.search(%r, %r) : failed." % (pattern.pattern, self.target))
335        else:
336            rexp = re.compile(pattern, flags)
337            boolean = bool(rexp.search(self.target))
338            if boolean == self.boolean:  return self
339            self.failed("re.search(%r, %r) : failed." % (pattern, self.target))
340
341    @assertion
342    def not_match(self, pattern, flag=0):
343        if isinstance(pattern, type(re.compile('x'))):
344            boolean = not pattern.search(self.target)
345            if boolean == self.boolean:  return self
346            self.failed("not re.search(%r, %r) : failed." % (pattern.pattern, self.target))
347        else:
348            rexp = re.compile(pattern, flag)
349            boolean = not rexp.search(self.target)
350            if boolean == self.boolean:  return self
351            self.failed("not re.search(%r, %r) : failed." % (pattern, self.target))
352
353    @assertion
354    def length(self, n):
355        boolean = len(self.target) == n
356        if boolean == self.boolean:  return self
357        self.failed("len(%r) == %r : failed." % (self.target, n))
358
359    @assertion
360    def is_file(self):
361        boolean = os.path.isfile(self.target)
362        if boolean == self.boolean:  return self
363        self.failed('os.path.isfile(%r) : failed.' % self.target)
364
365    @assertion
366    def not_file(self):
367        boolean = not os.path.isfile(self.target)
368        if boolean == self.boolean:  return self
369        self.failed('not os.path.isfile(%r) : failed.' % self.target)
370
371    @assertion
372    def is_dir(self):
373        boolean = os.path.isdir(self.target)
374        if boolean == self.boolean:  return self
375        self.failed('os.path.isdir(%r) : failed.' % self.target)
376
377    @assertion
378    def not_dir(self):
379        boolean = not os.path.isdir(self.target)
380        if boolean == self.boolean:  return self
381        self.failed('not os.path.isdir(%r) : failed.' % self.target)
382
383    @assertion
384    def exists(self):
385        boolean = os.path.exists(self.target)
386        if boolean == self.boolean:  return self
387        self.failed('os.path.exists(%r) : failed.' % self.target)
388
389    @assertion
390    def not_exist(self):
391        boolean = not os.path.exists(self.target)
392        if boolean == self.boolean:  return self
393        self.failed('not os.path.exists(%r) : failed.' % self.target)
394
395    @assertion
396    def raises(self, exception_class, errmsg=None):
397        return self._raise_or_not(exception_class, errmsg, self.boolean)
398
399    @assertion
400    def not_raise(self, exception_class=Exception):
401        return self._raise_or_not(exception_class, None, not self.boolean)
402
403    def _raise_or_not(self, exception_class, errmsg, flag_raise):
404        ex = None
405        try:
406            self.target()
407        except:
408            ex = sys.exc_info()[1]
409            if isinstance(ex, AssertionError) and not hasattr(ex, '_raised_by_oktest'):
410                raise
411            self.target.exception = ex
412            if flag_raise:
413                if not isinstance(ex, exception_class):
414                    self.failed('%s%r is kind of %s : failed.' % (ex.__class__.__name__, ex.args, exception_class.__name__), depth=3)
415                    #raise
416                if errmsg is None:
417                    pass
418                elif isinstance(errmsg, _rexp_type):
419                    if not errmsg.search(str(ex)):
420                        self.failed("error message %r is not matched to pattern." % str(ex), depth=3)   # don't use ex2msg(ex)!
421                else:
422                    if str(ex) != errmsg:   # don't use ex2msg(ex)!
423                        #self.failed("expected %r but got %r" % (errmsg, str(ex)))
424                        self.failed("%r == %r : failed." % (str(ex), errmsg), depth=3)   # don't use ex2msg(ex)!
425            else:
426                if isinstance(ex, exception_class):
427                    self.failed('%s should not be raised : failed, got %r.' % (exception_class.__name__, ex), depth=3)
428        else:
429            if flag_raise and ex is None:
430                self.failed('%s should be raised : failed.' % exception_class.__name__, depth=3)
431        return self
432
433    AssertionObject._raise_or_not = _raise_or_not
434    AssertionObject.hasattr = has_attr      # for backward compatibility
435    AssertionObject.is_not_file = not_file  # for backward compatibility
436    AssertionObject.is_not_dir  = not_dir   # for backward compatibility
437
438_f()
439del _f
440
441_rexp_type = type(re.compile('x'))
442
443ASSERTION_OBJECT = AssertionObject
444
445
446def ok(target):
447    obj = ASSERTION_OBJECT(target, True)
448    obj._location = util.get_location(1)
449    return obj
450
451def NG(target):
452    obj = ASSERTION_OBJECT(target, False)
453    obj._location = util.get_location(1)
454    return obj
455
456def not_ok(target):  # for backward compatibility
457    obj = ASSERTION_OBJECT(target, False)
458    obj._location = util.get_location(1)
459    return obj
460
461def NOT(target):     # experimental. prefer to NG()?
462    obj = ASSERTION_OBJECT(target, False)
463    obj._location = util.get_location(1)
464    return obj
465
466def fail(desc):
467    raise AssertionError(desc)
468
469
470class Should(object):
471
472    def __init__(self, assertion_object, boolean=None):
473        self.assertion_object = assertion_object
474        if boolean is None:
475            boolean = assertion_object.boolean
476        self.boolean = boolean
477
478    def __getattr__(self, key):
479        ass = self.assertion_object
480        tested = ass._tested
481        ass._tested = True
482        val = getattr(ass.target, key)
483        if not hasattr(val, '__call__'):
484            msg = "%s.%s: not a callable." % (type(ass.target).__name__, key)
485            raise ValueError(msg)   # or TypeError?
486        ass._tested = tested
487        def f(*args, **kwargs):
488            ass._tested = True
489            ret = val(*args, **kwargs)
490            if ret not in (True, False):
491                msg = "%r.%s(): expected to return True or False but it returned %r." \
492                      % (ass.target, val.__name__, ret)
493                raise ValueError(msg)
494            if ret != self.boolean:
495                buf = [ repr(arg) for arg in args ]
496                buf.extend([ "%s=%r" % (k, kwargs[k]) for k in kwargs ])
497                msg = "%r.%s(%s) : failed." % (ass.target, val.__name__, ", ".join(buf))
498                if self.boolean is False:
499                    msg = "not " + msg
500                ass.failed(msg)
501        return f
502
503
504class SkipTest(Exception):
505    pass
506
507try:
508    from unittest import SkipTest
509except ImportError:
510    if python2:
511        sys.exc_clear()
512
513
514class SkipObject(object):
515
516    def __call__(self, reason):
517        raise SkipTest(reason)
518
519    def when(self, condition, reason):
520        if condition:
521            def deco(func):
522                def fn(self):
523                    raise SkipTest(reason)
524                fn.__name__ = func.__name__
525                fn.__doc__  = func.__doc__
526                fn._firstlineno = util._func_firstlineno(func)
527                return fn
528        else:
529            def deco(func):
530                return func
531        return deco
532
533    #def unless(self, condition, reason):
534    #    if not condition:
535    #        raise SkipException(reason)
536
537skip = SkipObject()
538
539
540def todo(func):
541    def deco(*args, **kwargs):
542        exc_info = None
543        try:
544            func(*args, **kwargs)
545            raise _UnexpectedSuccess("test should be failed (because not implemented yet), but passed unexpectedly.")
546        except AssertionError:
547            raise _ExpectedFailure(sys.exc_info())
548    deco.__name__ = func.__name__
549    deco.__doc__  = func.__doc__
550    return deco
551
552class _ExpectedFailure(Exception):
553
554    def __init__(self, exc_info=None):
555        Exception.__init__(self, "expected failure")
556        if exc_info:
557            self.exc_info = exc_info
558
559class _UnexpectedSuccess(Exception):
560    pass
561
562try:
563    from unittest.case import _ExpectedFailure, _UnexpectedSuccess
564except ImportError:
565    if python2:
566        sys.exc_clear()
567
568
569
570ST_PASSED  = "passed"
571ST_FAILED  = "failed"
572ST_ERROR   = "error"
573ST_SKIPPED = "skipped"
574ST_TODO    = "todo"
575#ST_UNEXPECTED = "unexpected"
576
577
578class TestRunner(object):
579
580    _filter_test = _filter_key = _filter_val = None
581
582    def __init__(self, reporter=None, filter=None):
583        self._reporter = reporter
584        self.filter = filter
585        filter = filter and filter.copy() or {}
586        if filter:
587            self._filter_test = filter.pop('test', None)
588        if filter:
589            self._filter_key  = list(filter.keys())[0]
590            self._filter_val  = filter.pop(self._filter_key)
591
592    def __get_reporter(self):
593        if self._reporter is None:
594            self._reporter = REPORTER()
595        return self._reporter
596
597    def __set_reporter(self, reporter):
598        self._reporter = reporter
599
600    reporter = property(__get_reporter, __set_reporter)
601
602    def _test_name(self, name):
603        return re.sub(r'^test_?', '', name)
604
605    def get_testnames(self, klass):
606        #names = [ name for name in dir(klass) if name.startswith('test') ]
607        #names.sort()
608        #return names
609        testnames = [ k for k in dir(klass) if k.startswith('test') and hasattr(getattr(klass, k), '__class__') ]
610        ## filter by test name or user-defined options
611        pattern, key, val = self._filter_test, self._filter_key, self._filter_val
612        if pattern or key:
613            testnames = [ s for s in testnames
614                              if _filtered(klass, getattr(klass, s), s, pattern, key, val) ]
615        ## filter by $TEST environment variable
616        pattern = os.environ.get('TEST')
617        if pattern:
618            rexp  = re.compile(pattern)
619            testnames = [ s for s in testnames
620                              if rexp.search(self._test_name(s)) ]
621        ## sort by linenumber
622        def fn(testname, klass=klass):
623            func = getattr(klass, testname)
624            lineno = getattr(func, '_firstlineno', None) or util._func_firstlineno(func)
625            return (lineno, testname)
626        testnames.sort(key=fn)
627        return testnames
628
629    def _invoke(self, obj, method1, method2):
630        meth = getattr(obj, method1, None) or getattr(obj, method2, None)
631        if not meth: return None, None
632        try:
633            meth()
634            return meth, None
635        except KeyboardInterrupt:
636            raise
637        except Exception:
638            return meth.__name__, sys.exc_info()
639
640    def run_class(self, klass, testnames=None):
641        self._enter_testclass(klass)
642        try:
643            method_name, exc_info = self._invoke(klass, 'before_all', 'setUpClass')
644            if not exc_info:
645                try:
646                    self.run_testcases(klass, testnames)
647                finally:
648                    method_name, exc_info = self._invoke(klass, 'after_all', 'tearDownClass')
649        finally:
650            if not exc_info: method_name = None
651            self._exit_testclass(klass, method_name, exc_info)
652
653    def run_testcases(self, klass, testnames=None):
654        if testnames is None:
655            testnames = self.get_testnames(klass)
656        context_list = getattr(klass, '_context_list', None)
657        if context_list:
658            items = []
659            for tname in testnames:
660                meth = getattr(klass, tname)
661                if not hasattr(meth, '_test_context'):
662                    items.append((tname, meth))
663            items.extend(context_list)
664            TestContext._sort_items(items)
665            allowed = dict.fromkeys(testnames)
666            self._run_items(klass, items, allowed)
667        else:
668            for testname in testnames:
669                testcase = self._new_testcase(klass, testname)
670                self.run_testcase(testcase, testname)
671
672    def _run_items(self, klass, items, allowed):
673        for item in items:
674            if isinstance(item, tuple):
675                testname = item[0]
676                if testname in allowed:
677                    testcase = self._new_testcase(klass, testname)
678                    self.run_testcase(testcase, testname)
679            else:
680                assert isinstance(item, TestContext)
681                context = item
682                self._enter_testcontext(context)
683                try:
684                    self._run_items(klass, context.items, allowed)
685                finally:
686                    self._exit_testcontext(context)
687
688    def _new_testcase(self, klass, method_name):
689        try:
690            obj = klass()
691        except ValueError:     # unittest.TestCase raises ValueError
692            obj = klass(method_name)
693        meth = getattr(obj, method_name)
694        obj.__name__ = self._test_name(method_name)
695        obj._testMethodName = method_name    # unittest.TestCase compatible
696        obj._testMethodDoc  = meth.__doc__   # unittest.TestCase compatible
697        obj._run_by_oktest  = True
698        obj._oktest_specs   = []
699        return obj
700
701    def run_testcase(self, testcase, testname):
702        self._enter_testcase(testcase, testname)
703        try:
704            _, exc_info = self._invoke(testcase, 'before', 'setUp')
705            if exc_info:
706                status = ST_ERROR
707            else:
708                try:
709                    status = None
710                    try:
711                        status, exc_info = self._run_testcase(testcase, testname)
712                    except:
713                        status, exc_info = ST_ERROR, sys.exc_info()
714                finally:
715                    _, ret = self._invoke(testcase, 'after', 'tearDown')
716                    if ret:
717                        status, exc_info = ST_ERROR, ret
718                    #else:
719                        #assert status is not None
720        finally:
721            self._exit_testcase(testcase, testname, status, exc_info)
722
723    def _run_testcase(self, testcase, testname):
724        try:
725            meth = getattr(testcase, testname)
726            meth()
727        except KeyboardInterrupt:
728            raise
729        except AssertionError:
730            return ST_FAILED, sys.exc_info()
731        except SkipTest:
732            return ST_SKIPPED, sys.exc_info()
733        except _ExpectedFailure:   # when failed expectedly
734            return ST_TODO, ()
735        except _UnexpectedSuccess: # when passed unexpectedly
736            #return ST_UNEXPECTED, ()
737            ex = sys.exc_info()[1]
738            if not ex.args:
739                ex.args = ("test should be failed (because not implemented yet), but passed unexpectedly.",)
740            return ST_FAILED, sys.exc_info()
741        except Exception:
742            return ST_ERROR, sys.exc_info()
743        else:
744            specs = getattr(testcase, '_oktest_specs', None)
745            arr = specs and [ spec for spec in specs if spec._exception ]
746            if arr: return ST_FAILED, arr
747            return ST_PASSED, ()
748
749    def _enter_testclass(self, testclass):
750        self.reporter.enter_testclass(testclass)
751
752    def _exit_testclass(self, testclass, method_name, exc_info):
753        self.reporter.exit_testclass(testclass, method_name, exc_info)
754
755    def _enter_testcase(self, testcase, testname):
756        self.reporter.enter_testcase(testcase, testname)
757
758    def _exit_testcase(self, testcase, testname, status, exc_info):
759        self.reporter.exit_testcase(testcase, testname, status, exc_info)
760
761    def _enter_testcontext(self, context):
762        self.reporter.enter_testcontext(context)
763
764    def _exit_testcontext(self, context):
765        self.reporter.exit_testcontext(context)
766
767    def __enter__(self):
768        self.reporter.enter_all()
769        return self
770
771    def __exit__(self, *args):
772        self.reporter.exit_all()
773
774
775def _filtered(klass, meth, tname, pattern, key, val, _rexp=re.compile(r'^test(_|_\d\d\d(_|: ))?')):
776    from fnmatch import fnmatch
777    if pattern:
778        if not fnmatch(_rexp.sub('', tname), pattern):
779            return False   # skip testcase
780    if key:
781        if not meth: meth = getattr(klass, tname)
782        d = getattr(meth, '_options', None)
783        if not (d and isinstance(d, dict) and fnmatch(str(d.get(key)), val)):
784            return False   # skip testcase
785    return True   # invoke testcase
786
787
788TEST_RUNNER = TestRunner
789
790
791def run(*targets, **kwargs):
792    out    = kwargs.pop('out', None)
793    color  = kwargs.pop('color', None)
794    filter = kwargs.pop('filter', {})
795    style  = kwargs.pop('style', None)
796    klass  = kwargs.pop('reporter_class', None)
797    #
798    if not klass:
799        if style:
800            klass = BaseReporter.get_registered_class(style)
801            if not klass:
802                raise ValueError("%r: unknown report style." % style)
803        else:
804            klass = REPORTER
805    #
806    reporter = klass(out=out, color=color)
807    runner = TEST_RUNNER(reporter=reporter, filter=filter)
808    #
809    if len(targets) == 0:
810        targets = (config.TARGET_PATTERN, )
811    #
812    runner.__enter__()
813    try:
814        for klass in _target_classes(targets):
815            runner.run_class(klass)
816    finally:
817        runner.__exit__(sys.exc_info())
818    counts = runner.reporter.counts
819    get = counts.get
820    #return get(ST_FAILED, 0) + get(ST_ERROR, 0) + get(ST_UNEXPECTED, 0)
821    return get(ST_FAILED, 0) + get(ST_ERROR, 0)
822
823
824def _target_classes(targets):
825    target_classes = []
826    rexp_type = type(re.compile('x'))
827    vars = None
828    for arg in targets:
829        if util._is_class(arg):
830            klass = arg
831            target_classes.append(klass)
832        elif util._is_string(arg) or isinstance(arg, rexp_type):
833            rexp = util._is_string(arg) and re.compile(arg) or arg
834            if vars is None: vars = sys._getframe(2).f_locals
835            klasses = [ vars[k] for k in vars if rexp.search(k) and util._is_class(vars[k]) ]
836            if TESTCLASS_SORT_KEY:
837                klasses.sort(key=TESTCLASS_SORT_KEY)
838            target_classes.extend(klasses)
839        else:
840            raise ValueError("%r: not a class nor pattern string." % (arg, ))
841    return target_classes
842
843
844def _min_firstlineno_of_methods(klass):
845    func_types = (types.FunctionType, types.MethodType)
846    d = klass.__dict__
847    linenos = [ util._func_firstlineno(d[k]) for k in d
848                if k.startswith('test') and type(d[k]) in func_types ]
849    return linenos and min(linenos) or -1
850
851TESTCLASS_SORT_KEY = _min_firstlineno_of_methods
852
853
854##
855## Reporter
856##
857
858class Reporter(object):
859
860    def enter_all(self): pass
861    def exit_all(self):  pass
862    def enter_testclass(self, testclass): pass
863    def exit_testclass (self, testclass, method_name, exc_info): pass
864    def enter_testcase (self, testcase, testname): pass
865    def exit_testcase  (self, testcase, testname, status, exc_info): pass
866    def enter_testcontext (self, context): pass
867    def exit_testcontext  (self, context): pass
868
869
870class BaseReporter(Reporter):
871
872    INDICATOR = {
873        ST_PASSED:  "passed",          # or "ok" ?
874        ST_FAILED:  "Failed",
875        ST_ERROR:   "ERROR",
876        ST_SKIPPED: "skipped",
877        ST_TODO:    "TODO",
878        #ST_UNEXPECTED: "Unexpected",
879    }
880
881    separator =  "-" * 70
882
883    def __init__(self, out=None, color=None):
884        self._color = color
885        self.out = out
886        self.counts = {}
887        self._context_stack = []
888
889    def _set_color(self, color=None):
890        if color is not None:
891            self._color = color
892        elif config.color_enabled is not None:
893            self._color = config.color_enabled
894        elif not config.color_available:
895            self._color = False
896        else:
897            self._color = is_tty(self._out)
898
899    def __get_out(self):
900        if not self._out:
901            self.out = sys.stdout
902        return self._out
903
904    def __set_out(self, out):
905        self._out = out
906        if out is not None and self._color is None:
907            self._set_color(None)
908
909    out = property(__get_out, __set_out)
910
911    def clear_counts(self):
912        self.counts = {
913            ST_PASSED:     0,
914            ST_FAILED:     0,
915            ST_ERROR:      0,
916            ST_SKIPPED:    0,
917            ST_TODO:       0,
918            #ST_UNEXPECTED: 0,
919        }
920
921    _counts2str_table = [
922        (ST_PASSED,     "passed",     True),
923        (ST_FAILED,     "failed",     True),
924        (ST_ERROR,      "error",      True),
925        (ST_SKIPPED,    "skipped",    True),
926        (ST_TODO,       "todo",       True),
927        #(ST_UNEXPECTED, "unexpected", False),
928    ]
929
930    def counts2str(self):
931        buf = [None]; add = buf.append
932        total = 0
933        for word, status, required in self._counts2str_table:
934            n = self.counts.get(status, 0)
935            s = "%s:%s" % (word, n)
936            if n: s = self.colorize(s, status)
937            if required or n:
938                add(s)
939            total += n
940        buf[0] = "total:%s" % total
941        return ", ".join(buf)
942
943    def enter_all(self):
944        self.clear_counts()
945        self._start_time = time.time()
946
947    def exit_all(self):
948        dt = time.time() - self._start_time
949        min = int(int(dt) / 60)     # int / int is float on Python3
950        sec = dt - (min * 60)
951        elapsed = min and "%s:%06.3f" % (min, sec) or "%.3f" % sec
952        self.out.write("## %s  (%s sec)\n" % (self.counts2str(), elapsed))
953        self.out.flush()
954
955    def enter_testclass(self, testclass):
956        self._exceptions = []
957
958    def exit_testclass(self, testclass, method_name, exc_info):
959        for tupl in self._exceptions:
960            self.report_exceptions(*tupl)
961        if exc_info:
962            self.report_exception(testclass, method_name, ST_ERROR, exc_info, None)
963        if self._exceptions or exc_info:
964            self.write_separator()
965        self.out.flush()
966
967    def enter_testcase(self, testcase, testname):
968        pass
969
970    def exit_testcase(self, testcase, testname, status, exc_info):
971        self.counts[status] = self.counts.setdefault(status, 0) + 1
972        if exc_info and status != ST_SKIPPED:
973            context = self._context_stack and self._context_stack[-1] or None
974            self._exceptions.append((testcase, testname, status, exc_info, context))
975
976    def enter_testcontext(self, context):
977        self._context_stack.append(context)
978
979    def exit_testcontext(self, context):
980        popped = self._context_stack.pop()
981        assert popped is context
982
983    def indicator(self, status):
984        indicator = self.INDICATOR.get(status) or '???'
985        if self._color:
986            indicator = self.colorize(indicator, status)
987        return indicator
988
989    def get_testclass_name(self, testclass):
990        subject = testclass.__dict__.get('SUBJECT') or testclass
991        return getattr(subject, '__name__', None) or str(subject)
992
993    def get_testcase_desc(self, testcase, testname):
994        meth = getattr(testcase, testname)
995        return meth and meth.__doc__ and meth.__doc__ or testname
996
997    def report_exceptions(self, testcase, testname, status, exc_info, context):
998        if isinstance(exc_info, list):
999            specs = exc_info
1000            for spec in specs:
1001                self.report_spec_esception(testcase, testname, status, spec, context)
1002        else:
1003            self.report_exception(testcase, testname, status, exc_info, context)
1004
1005    def report_exception(self, testcase, testname, status, exc_info, context):
1006        self.report_exception_header(testcase, testname, status, exc_info, context)
1007        self.report_exception_body  (testcase, testname, status, exc_info, context)
1008        self.report_exception_footer(testcase, testname, status, exc_info, context)
1009
1010    def report_exception_header(self, testcase, testname, status, exc_info, context):
1011        if isinstance(testcase, type):
1012            klass, method = testcase, testname
1013            title = "%s > %s()" % (self.get_testclass_name(klass), method)
1014            desc   = None
1015        else:
1016            parent, child, desc = self._get_testcase_header_items(testcase, testname)
1017            items = [child]
1018            c = context
1019            while c:
1020                items.append(c.desc)
1021                c = c.parent
1022            items.append(parent)
1023            items.reverse()
1024            title = " > ".join(items)
1025        indicator = self.indicator(status)
1026        self.write_separator()
1027        self.out.write("[%s] %s\n" % (indicator, title))
1028        if desc: self.out.write(desc + "\n")
1029
1030    def _get_testcase_header_items(self, testcase, testname):
1031        parent = self.get_testclass_name(testcase.__class__)
1032        if re.match(r'^test_\d\d\d: ', testname):
1033            child = testname[5:]
1034            desc  = None
1035        else:
1036            child = testname + '()'
1037            desc  = getattr(testcase, testname).__doc__
1038        return parent, child, desc
1039
1040    def _filter(self, tb, filename, linenum, funcname):
1041        #return not filename.startswith(_oktest_filepath)
1042        return "__unittest" not in tb.tb_frame.f_globals
1043
1044    def report_exception_body(self, testcase, testname, status, exc_info, context):
1045        assert exc_info
1046        ex_class, ex, ex_traceback = exc_info
1047        filter = not config.debug and self._filter or None
1048        arr = format_traceback(ex, ex_traceback, filter=filter)
1049        for x in arr:
1050            self.out.write(x)
1051        errmsg = "%s: %s" % (ex_class.__name__, ex)
1052        tupl = errmsg.split("\n", 1)
1053        if len(tupl) == 1:
1054            first_line, rest = tupl[0], None
1055        else:
1056            first_line, rest = tupl
1057        self.out.write(self.colorize(first_line, status) + "\n")
1058        if rest:
1059            self.out.write(rest)
1060            if not rest.endswith("\n"): self.out.write("\n")
1061        self.out.flush()
1062
1063    def report_exception_footer(self, testcase, testname, status, exc_info, context):
1064        pass
1065
1066    def _print_temporary_str(self, string):
1067        if is_tty(self.out):
1068            #self.__string = string
1069            self.out.write(string)
1070            self.out.flush()
1071
1072    def _erase_temporary_str(self, _eraser="\b"*255):
1073        if is_tty(self.out):
1074            #n = len(self.__string) + 1    # why '+1' ?
1075            #self.out.write("\b" * n)      # not work with wide-chars
1076            #self.out.flush()
1077            #del self.__string
1078            self.out.write(_eraser)
1079            self.out.flush()
1080
1081    def report_spec_esception(self, testcase, testname, status, spec, context):
1082        ex = spec._exception
1083        exc_info = (ex.__class__, ex, spec._traceback)
1084        #self.report_exception_header(testcase, testname, status, exc_info, context)
1085        parent, child, desc = self._get_testcase_header_items(testcase, testname)
1086        indicator = self.indicator(status)
1087        self.write_separator()
1088        self.out.write("[%s] %s > %s > %s\n" % (indicator, parent, child, spec.desc))
1089        if desc: self.out.write(desc + "\n")
1090        #
1091        stacktrace = self._filter_stacktrace(spec._stacktrace, spec._traceback)
1092        self._print_stacktrace(stacktrace)
1093        #
1094        self.report_exception_body(testcase, testname, status, exc_info, context)
1095        self.report_exception_footer(testcase, testname, status, exc_info, context)
1096
1097    def _filter_stacktrace(self, stacktrace, traceback_):
1098        entries = traceback.extract_tb(traceback_)
1099        file, line, func, text = entries[0]
1100        i = len(stacktrace) - 1
1101        while i >= 0 and not (stacktrace[i][0] == file and stacktrace[i][2] == func):
1102            i -= 1
1103        bottom = i
1104        while i >= 0 and not _is_oktest_py(stacktrace[i][0]):
1105            i -= 1
1106        top = i + 1
1107        return stacktrace[top:bottom]
1108
1109    def _print_stacktrace(self, stacktrace):
1110        for file, line, func, text in stacktrace:
1111            self.out.write('  File "%s", line %s, in %s\n' % (file, line, func))
1112            self.out.write('    %s\n' % text)
1113
1114    def colorize(self, string, kind):
1115        if not self._color:
1116            return string
1117        if kind == ST_PASSED:  return util.Color.green(string, bold=True)
1118        if kind == ST_FAILED:  return util.Color.red(string, bold=True)
1119        if kind == ST_ERROR:   return util.Color.red(string, bold=True)
1120        if kind == ST_SKIPPED: return util.Color.yellow(string, bold=True)
1121        if kind == ST_TODO:    return util.Color.yellow(string, bold=True)
1122        #if kind == ST_UNEXPECTED: return util.Color.red(string, bold=True)
1123        if kind == "topic":    return util.Color.bold(string)
1124        if kind == "sep":      return util.Color.red(string)
1125        if kind == "context":  return util.Color.bold(string)
1126        return util.Color.yellow(string)
1127
1128    def write_separator(self):
1129        self.out.write(self.colorize(self.separator, "sep") + "\n")
1130
1131    def status_char(self, status):
1132        if not hasattr(self, '_status_chars'):
1133            self._status_chars = {
1134                ST_PASSED : ".",
1135                ST_FAILED : self.colorize("f", ST_FAILED ),
1136                ST_ERROR  : self.colorize("E", ST_ERROR  ),
1137                ST_SKIPPED: self.colorize("s", ST_SKIPPED),
1138                ST_TODO   : self.colorize("t", ST_TODO),
1139                #ST_UNEXPECTED: self.colorize("u", ST_UNEXPECTED),
1140                None      : self.colorize("?", None),
1141            }
1142        return self._status_chars.get(status) or self._status_chars.get(None)
1143
1144    _registered = {}
1145
1146    @classmethod
1147    def register_class(cls, name, klass):
1148        cls._registered[name] = klass
1149
1150    @classmethod
1151    def get_registered_class(cls, name):
1152        return cls._registered.get(name)
1153
1154
1155def _is_oktest_py(filepath, _dirpath=os.path.dirname(__file__)):
1156    #return re.search(r'oktest.py[co]?$', filepath)
1157    return filepath.startswith(_dirpath)
1158
1159
1160def is_tty(out):
1161    return hasattr(out, 'isatty') and out.isatty()
1162
1163
1164def traceback_formatter(file, line, func, linestr):
1165    text = linestr.strip()
1166    return func and '  File "%s", line %s, in %s\n    %s\n' % (file, line, func, text) \
1167                or  '  File "%s", line %s\n    %s\n'        % (file, line,       text)
1168
1169
1170def format_traceback(exception, traceback, filter=None, formatter=traceback_formatter):
1171    limit = getattr(sys, 'tracebacklimit', 200)
1172    if not formatter:
1173        formatter = lambda *args: args
1174    pos = -1
1175    if hasattr(exception, '_raised_by_oktest'):
1176        _file, _line = exception.file, exception.line
1177    else:
1178        _file, _line = False, -1
1179    tb = traceback
1180    arr = []; add = arr.append
1181    i = 0
1182    while tb and i < limit:
1183        linenum  = tb.tb_lineno
1184        filename = tb.tb_frame.f_code.co_filename
1185        funcname = tb.tb_frame.f_code.co_name
1186        if not filter or filter(tb, linenum, filename, funcname):
1187            linecache.checkcache(filename)
1188            linestr = linecache.getline(filename, linenum)
1189            add(formatter(filename, linenum, funcname, linestr))
1190            if linenum == _line and filename == _file:
1191                pos = i
1192            i += 1
1193        tb = tb.tb_next
1194    if pos >= 0:
1195        arr[pos+1:] = []
1196    return arr
1197
1198
1199class VerboseReporter(BaseReporter):
1200
1201    _super = BaseReporter
1202
1203    def __init__(self, *args, **kwargs):
1204        self._super.__init__(self, *args, **kwargs)
1205        self.depth = 1
1206
1207    def enter_testclass(self, testclass):
1208        self._super.enter_testclass(self, testclass)
1209        self.out.write("* %s\n" % self.colorize(self.get_testclass_name(testclass), "topic"))
1210        self.out.flush()
1211
1212    def enter_testcase(self, testcase, testname):
1213        desc = self.get_testcase_desc(testcase, testname)
1214        self._print_temporary_str("  " * self.depth + "- [      ] " + desc)
1215
1216    def exit_testcase(self, testcase, testname, status, exc_info):
1217        s = ""
1218        if status == ST_SKIPPED:
1219            ex = exc_info[1]
1220            #reason = getattr(ex, 'reason', '')
1221            reason = ex.args[0]
1222            s = " (reason: %s)" % (reason, )
1223            exc_info = ()
1224        self._super.exit_testcase(self, testcase, testname, status, exc_info)
1225        self._erase_temporary_str()
1226        indicator = self.indicator(status)
1227        desc = self.get_testcase_desc(testcase, testname)
1228        self.out.write("  " * self.depth + "- [%s] %s%s\n" % (indicator, desc, s))
1229        self.out.flush()
1230
1231    def enter_testcontext(self, context):
1232        self._super.enter_testcontext(self, context)
1233        s = context.desc
1234        if not (s.startswith("when ") or s == "else:"):
1235            s = self.colorize(s, "context")
1236        self.out.write("  " * self.depth + "+ %s\n" % s)
1237        self.depth += 1
1238
1239    def exit_testcontext(self, context):
1240        self._super.exit_testcontext(self, context)
1241        self.depth -= 1
1242
1243BaseReporter.register_class("verbose", VerboseReporter)
1244
1245
1246class SimpleReporter(BaseReporter):
1247
1248    _super = BaseReporter
1249
1250    def __init__(self, *args, **kwargs):
1251        self._super.__init__(self, *args, **kwargs)
1252
1253    def enter_testclass(self, testclass):
1254        self._super.enter_testclass(self, testclass)
1255        self.out.write("* %s: " % self.colorize(self.get_testclass_name(testclass), "topic"))
1256        self.out.flush()
1257
1258    def exit_testclass(self, *args):
1259        self.out.write("\n")
1260        self._super.exit_testclass(self, *args)
1261
1262    def exit_testcase(self, testcase, testname, status, exc_info):
1263        self._super.exit_testcase(self, testcase, testname, status, exc_info)
1264        self.out.write(self.status_char(status))
1265        self.out.flush()
1266
1267BaseReporter.register_class("simple", SimpleReporter)
1268
1269
1270class PlainReporter(BaseReporter):
1271
1272    _super = BaseReporter
1273
1274    def __init__(self, *args, **kwargs):
1275        self._super.__init__(self, *args, **kwargs)
1276
1277    def exit_testclass(self, testclass, method_name, exc_info):
1278        if self._exceptions or exc_info:
1279            self.out.write("\n")
1280        self._super.exit_testclass(self, testclass, method_name, exc_info)
1281
1282    def exit_testcase(self, testcase, testname, status, exc_info):
1283        self._super.exit_testcase(self, testcase, testname, status, exc_info)
1284        self.out.write(self.status_char(status))
1285        self.out.flush()
1286
1287    def exit_all(self):
1288        self.out.write("\n")
1289        self._super.exit_all(self)
1290
1291BaseReporter.register_class("plain", PlainReporter)
1292
1293
1294class UnittestStyleReporter(BaseReporter):
1295
1296    _super = BaseReporter
1297
1298    def __init__(self, *args, **kwargs):
1299        self._super.__init__(self, *args, **kwargs)
1300        self._color = False
1301        self.separator = "-" * 70
1302
1303    def enter_testclass(self, testclass):
1304        if getattr(self, '_exceptions', None) is None:
1305            self._exceptions = []
1306
1307    def exit_testclass(self, testclass, method_name, exc_info):
1308        if exc_info:
1309            self._exceptions.append((testclass, method_name, ST_ERROR, exc_info))
1310
1311    def enter_testcase(self, *args):
1312        self._super.enter_testcase(self, *args)
1313
1314    def exit_testcase(self, testcase, testname, status, exc_info):
1315        self._super.exit_testcase(self, testcase, testname, status, exc_info)
1316        self.out.write(self.status_char(status))
1317        self.out.flush()
1318
1319    def exit_all(self):
1320        self.out.write("\n")
1321        for tupl in self._exceptions:
1322            self.report_exceptions(*tupl)
1323        self._super.exit_all(self)
1324
1325    def report_exception_header(self, testcase, testname, status, exc_info, context):
1326        if isinstance(testcase, type):
1327            klass, method = testcase, testname
1328            parent = self.get_testclass_name(klass)
1329            child  = method
1330        else:
1331            parent = testcase.__class__.__name__
1332            child  = testname
1333        indicator = self.INDICATOR.get(status) or '???'
1334        self.out.write("=" * 70 + "\n")
1335        self.out.write("%s: %s#%s()\n" % (indicator, parent, child))
1336        self.out.write("-" * 70 + "\n")
1337
1338BaseReporter.register_class("unittest", SimpleReporter)
1339
1340
1341class OldStyleReporter(BaseReporter):
1342
1343    _super = BaseReporter
1344
1345    def enter_all(self):
1346        pass
1347
1348    def exit_all(self):
1349        pass
1350
1351    def enter_class(self, testcase, testname):
1352        pass
1353
1354    def exit_class(self, testcase, testname):
1355        pass
1356
1357    def enter_testcase(self, testcase, testname):
1358        self.out.write("* %s.%s ... " % (testcase.__class__.__name__, testname))
1359
1360    def exit_testcase(self, testcase, testname, status, exc_info):
1361        if status == ST_PASSED:
1362            self.out.write("[ok]\n")
1363        elif status == ST_FAILED:
1364            ex_class, ex, ex_traceback = exc_info
1365            flag = hasattr(ex, '_raised_by_oktest')
1366            self.out.write("[NG] %s\n" % (flag and ex.errmsg or util.ex2msg(ex)))
1367            def formatter(filepath, lineno, funcname, linestr):
1368                return "   %s:%s: %s\n" % (filepath, lineno, linestr.strip())
1369            arr = format_traceback(ex, ex_traceback, filter=self._filter, formatter=formatter)
1370            for x in arr:
1371                self.out.write(x)
1372            if flag and getattr(ex, 'diff', None):
1373                self.out.write(ex.diff)
1374        elif status == ST_ERROR:
1375            ex_class, ex, ex_traceback = exc_info
1376            self.out.write("[ERROR] %s: %s\n" % (ex_class.__name__, util.ex2msg(ex)))
1377            def formatter(filepath, lineno, funcname, linestr):
1378                return "  - %s:%s:  %s\n" % (filepath, lineno, linestr.strip())
1379            arr = format_traceback(ex, ex_traceback, filter=self._filter, formatter=formatter)
1380            for x in arr:
1381                self.out.write(x)
1382        elif status == ST_SKIPPED:
1383            self.out.write("[skipped]\n")
1384        elif status == ST_TODO:
1385            self.out.write("[TODO]\n")
1386        #elif status == ST_UNEXPECTED:
1387        #    self.out.write("[Unexpected]\n")
1388        else:
1389            assert False, "UNREACHABLE: status=%r" % (status,)
1390
1391BaseReporter.register_class("oldstyle", SimpleReporter)
1392
1393
1394REPORTER = VerboseReporter
1395#REPORTER = SimpleReporter
1396#REPORTER = PlainReporter
1397#REPORTER = OldStyleReporter
1398if os.environ.get('OKTEST_REPORTER'):
1399    REPORTER = globals().get(os.environ.get('OKTEST_REPORTER'))
1400    if not REPORTER:
1401        raise ValueError("%s: reporter class not found." % os.environ.get('OKTEST_REPORTER'))
1402
1403
1404##
1405## util
1406##
1407def _dummy():
1408
1409    __all__ = ('chdir', 'rm_rf')
1410
1411    if python2:
1412        def _is_string(val):
1413            return isinstance(val, (str, unicode))
1414        def _is_class(obj):
1415            return isinstance(obj, (types.TypeType, types.ClassType))
1416        def _is_unbound(method):
1417            return not method.im_self
1418        def _func_name(func):
1419            return func.func_name
1420        def _func_firstlineno(func):
1421            func = getattr(func, 'im_func', func)
1422            return func.func_code.co_firstlineno
1423    if python3:
1424        def _is_string(val):
1425            return isinstance(val, (str, bytes))
1426        def _is_class(obj):
1427            return isinstance(obj, (type, ))
1428        def _is_unbound(method):
1429            return not method.__self__
1430        def _func_name(func):
1431            return func.__name__
1432        def _func_firstlineno(func):
1433            return func.__code__.co_firstlineno
1434
1435    ##
1436    ## _Context
1437    ##
1438    class Context(object):
1439
1440        def __enter__(self):
1441            return self
1442
1443        def __exit__(self, *args):
1444            return None
1445
1446
1447    class RunnableContext(Context):
1448
1449        def run(self, func, *args, **kwargs):
1450            self.__enter__()
1451            try:
1452                return func(*args, **kwargs)
1453            finally:
1454                self.__exit__(*sys.exc_info())
1455
1456        def deco(self, func):
1457            def f(*args, **kwargs):
1458                return self.run(func, *args, **kwargs)
1459            return f
1460
1461        __call__ = run    # for backward compatibility
1462
1463
1464    class Chdir(RunnableContext):
1465
1466        def __init__(self, dirname):
1467            self.dirname = dirname
1468            self.path    = os.path.abspath(dirname)
1469            self.back_to = os.getcwd()
1470
1471        def __enter__(self, *args):
1472            os.chdir(self.path)
1473            return self
1474
1475        def __exit__(self, *args):
1476            os.chdir(self.back_to)
1477
1478
1479    class Using(Context):
1480        """ex.
1481             class MyTest(object):
1482                pass
1483             with oktest.util.Using(MyTest):
1484                def test_1(self):
1485                  ok (1+1) == 2
1486             if __name__ == '__main__':
1487                oktest.run(MyTest)
1488        """
1489        def __init__(self, klass):
1490            self.klass = klass
1491
1492        def __enter__(self):
1493            localvars = sys._getframe(1).f_locals
1494            self._start_names = localvars.keys()
1495            if python3: self._start_names = list(self._start_names)
1496            return self
1497
1498        def __exit__(self, *args):
1499            localvars  = sys._getframe(1).f_locals
1500            curr_names = localvars.keys()
1501            diff_names = list(set(curr_names) - set(self._start_names))
1502            for name in diff_names:
1503                setattr(self.klass, name, localvars[name])
1504
1505
1506    def chdir(path, func=None):
1507        cd = Chdir(path)
1508        return func is not None and cd.run(func) or cd
1509
1510    def using(klass):                       ## undocumented
1511        return Using(klass)
1512
1513
1514    def ex2msg(ex):
1515        #return ex.message   # deprecated since Python 2.6
1516        #return str(ex)      # may be empty
1517        #return ex.args[0]   # ex.args may be empty (ex. AssertionError)
1518        #return (ex.args or ['(no error message)'])[0]
1519        return str(ex) or '(no error message)'
1520
1521    def flatten(arr, type=(list, tuple)):   ## undocumented
1522        L = []
1523        for x in arr:
1524            if isinstance(x, type):
1525                L.extend(flatten(x))
1526            else:
1527                L.append(x)
1528        return L
1529
1530    def rm_rf(*fnames):
1531        for fname in flatten(fnames):
1532            if os.path.isfile(fname):
1533                os.unlink(fname)
1534            elif os.path.isdir(fname):
1535                from shutil import rmtree
1536                rmtree(fname)
1537
1538    def get_location(depth=0):
1539        frame = sys._getframe(depth+1)
1540        return (frame.f_code.co_filename, frame.f_lineno)
1541
1542    def read_binary_file(fname):
1543        f = open(fname, 'rb')
1544        try:
1545            b = f.read()
1546        finally:
1547            f.close()
1548        return b
1549
1550    if python2:
1551        _rexp = re.compile(r'(?:^#!.*?\r?\n)?#.*?coding:[ \t]*([-\w]+)')
1552        def read_text_file(fname,  _rexp=_rexp, _read_binary_file=read_binary_file):
1553            b = _read_binary_file(fname)
1554            m = _rexp.match(b)
1555            encoding = m and m.group(1) or 'utf-8'
1556            u = b.decode(encoding)
1557            assert isinstance(u, unicode)
1558            return u
1559    if python3:
1560        _rexp = re.compile(r'(?:^#!.*?\r?\n)?#.*?coding:[ \t]*([-\w]+)'.encode('us-ascii'))
1561        def read_text_file(fname,  _rexp=_rexp, _read_binary_file=read_binary_file):
1562            b = _read_binary_file(fname)
1563            m = _rexp.match(b)
1564            encoding = m and m.group(1).decode('us-ascii') or 'utf-8'
1565            u = b.decode(encoding)
1566            assert isinstance(u, str)
1567            return u
1568
1569    from types import MethodType as _MethodType
1570
1571    if python2:
1572        def func_argnames(func):
1573            if isinstance(func, _MethodType):
1574                codeobj = func.im_func.func_code
1575                index = 1
1576            else:
1577                codeobj = func.func_code
1578                index = 0
1579            return codeobj.co_varnames[index:codeobj.co_argcount]
1580        def func_defaults(func):
1581            if isinstance(func, _MethodType):
1582                return func.im_func.func_defaults
1583            else:
1584                return func.func_defaults
1585    if python3:
1586        def func_argnames(func):
1587            if isinstance(func, _MethodType):
1588                codeobj = func.__func__.__code__
1589                index = 1
1590            else:
1591                codeobj = func.__code__
1592                index = 0
1593            return codeobj.co_varnames[index:codeobj.co_argcount]
1594        def func_defaults(func):
1595            if isinstance(func, _MethodType):
1596                return func.__func__.__defaults__
1597            else:
1598                return func.__defaults__
1599
1600    ##
1601    ## color
1602    ##
1603    class Color(object):
1604
1605        @staticmethod
1606        def bold(s):
1607            return "\033[0;1m" + s + "\033[22m"
1608
1609        @staticmethod
1610        def black(s, bold=False):
1611            return "\033[%s;30m%s\033[0m" % (bold and 1 or 0, s)
1612
1613        @staticmethod
1614        def red(s, bold=False):
1615            return "\033[%s;31m%s\033[0m" % (bold and 1 or 0, s)
1616
1617        @staticmethod
1618        def green(s, bold=False):
1619            return "\033[%s;32m%s\033[0m" % (bold and 1 or 0, s)
1620
1621        @staticmethod
1622        def yellow(s, bold=False):
1623            return "\033[%s;33m%s\033[0m" % (bold and 1 or 0, s)
1624
1625        @staticmethod
1626        def blue(s, bold=False):
1627            return "\033[%s;34m%s\033[0m" % (bold and 1 or 0, s)
1628
1629        @staticmethod
1630        def magenta(s, bold=False):
1631            return "\033[%s;35m%s\033[0m" % (bold and 1 or 0, s)
1632
1633        @staticmethod
1634        def cyan(s, bold=False):
1635            return "\033[%s;36m%s\033[0m" % (bold and 1 or 0, s)
1636
1637        @staticmethod
1638        def white(s, bold=False):
1639            return "\033[%s;37m%s\033[0m" % (bold and 1 or 0, s)
1640
1641        @staticmethod
1642        def _colorize(s):
1643            s = re.sub(r'<b>(.*?)</b>', lambda m: Color.bold(m.group(1)), s)
1644            s = re.sub(r'<R>(.*?)</R>', lambda m: Color.red(m.group(1), bold=True), s)
1645            s = re.sub(r'<r>(.*?)</r>', lambda m: Color.red(m.group(1), bold=False), s)
1646            s = re.sub(r'<G>(.*?)</G>', lambda m: Color.green(m.group(1), bold=True), s)
1647            s = re.sub(r'<Y>(.*?)</Y>', lambda m: Color.yellow(m.group(1), bold=True), s)
1648            return s
1649
1650
1651    return locals()
1652
1653util = _new_module('oktest.util', _dummy())
1654del _dummy
1655
1656helper = util  ## 'help' is an alias of 'util' (for backward compatibility)
1657sys.modules['oktest.helper'] = sys.modules['oktest.util']
1658
1659
1660##
1661## spec()   # deprecated
1662##
1663class Spec(util.Context):   # deprecated
1664
1665    _exception  = None
1666    _traceback  = None
1667    _stacktrace = None
1668
1669    def __init__(self, desc):
1670        self.desc = desc
1671        self._testcase = None
1672
1673    def __enter__(self):
1674        self._testcase = tc = self._find_testcase_object()
1675        if getattr(tc, '_run_by_oktest', None):
1676            tc._oktest_specs.append(self)
1677        return self
1678
1679    def _find_testcase_object(self):
1680        max_depth = 10
1681        for i in xrange(2, max_depth):
1682            try:
1683                frame = sys._getframe(i)   # raises ValueError when too deep
1684            except ValueError:
1685                break
1686            method = frame.f_code.co_name
1687            if method.startswith("test"):
1688                arg_name = frame.f_code.co_varnames[0]
1689                testcase = frame.f_locals.get(arg_name, None)
1690                if hasattr(testcase, "_testMethodName") or hasattr(testcase, "_TestCase__testMethodName"):
1691                    return testcase
1692        return None
1693
1694    def __exit__(self, *args):
1695        ex = args[1]
1696        tc = self._testcase
1697        if ex and hasattr(ex, '_raised_by_oktest') and hasattr(tc, '_run_by_oktest'):
1698            self._exception  = ex
1699            self._traceback  = args[2]
1700            self._stacktrace = traceback.extract_stack()
1701            return True
1702
1703    def __iter__(self):
1704        self.__enter__()
1705        #try:
1706        #    yield self  # (Python2.4) SyntaxError: 'yield' not allowed in a 'try' block with a 'finally' clause
1707        #finally:
1708        #    self.__exit__(*sys.exc_info())
1709        ex = None
1710        try:
1711            yield self
1712        except:
1713            ex = None
1714        self.__exit__(*sys.exc_info())
1715        if ex:
1716            raise ex
1717
1718    def __call__(self, func):
1719        self.__enter__()
1720        try:
1721            func()
1722        finally:
1723            self.__exit__(*sys.exc_info())
1724
1725    def __bool__(self):       # for Pyton3
1726        filter = os.environ.get('SPEC')
1727        return not filter or (filter in self.desc)
1728
1729    __nonzero__ = __bool__    # for Python2
1730
1731
1732def spec(desc):   # deprecated
1733    #if not os.getenv('OKTEST_WARNING_DISABLED'):
1734    #    import warnings
1735    #    warnings.warn("oktest.spec() is deprecated.", DeprecationWarning, 2)
1736    return Spec(desc)
1737
1738
1739##
1740## @test() decorator
1741##
1742
1743def test(description_text=None, **options):
1744    frame = sys._getframe(1)
1745    localvars  = frame.f_locals
1746    globalvars = frame.f_globals
1747    n = localvars.get('__n', 0) + 1
1748    localvars['__n'] = n
1749    def deco(orig_func):
1750        argnames = util.func_argnames(orig_func)
1751        fixture_names = argnames[1:]   # except 'self'
1752        if fixture_names:
1753            def newfunc(self):
1754                self._options = options
1755                self._description = description_text
1756                return fixture_injector.invoke(self, orig_func, globalvars)
1757        else:
1758            def newfunc(self):
1759                self._options = options
1760                self._description = description_text
1761                return orig_func(self)
1762        orig_name = orig_func.__name__
1763        if orig_name.startswith('test'):
1764            newfunc.__name__ = orig_name
1765        else:
1766            newfunc.__name__ = "test_%03d: %s" % (n, description_text)
1767            localvars[newfunc.__name__] = newfunc
1768        newfunc.__doc__  = orig_func.__doc__ or description_text
1769        newfunc._options = options
1770        newfunc._firstlineno = getattr(orig_func, '_firstlineno', None) or util._func_firstlineno(orig_func)
1771        return newfunc
1772    return deco
1773
1774
1775##
1776## fixture manager and injector
1777##
1778
1779class FixtureManager(object):
1780
1781    def provide(self, name):
1782        raise NameError("Fixture provider for '%s' not found." % (name,))
1783
1784    def release(self, name, value):
1785        pass
1786
1787fixture_manager = FixtureManager()
1788
1789
1790class FixtureInjector(object):
1791
1792    def invoke(self, object, func, *opts):
1793        """invoke function with fixtures."""
1794        releasers = {"self": None}     # {"arg_name": releaser_func()}
1795        resolved  = {"self": object}   # {"arg_name": arg_value}
1796        in_progress = []
1797        ##
1798        arg_names = util.func_argnames(func)
1799        ## default arg values of test method are stored into 'resolved' dict
1800        ## in order for providers to access to them
1801        defaults = util.func_defaults(func)
1802        if defaults:
1803            idx = - len(defaults)
1804            for aname, default in zip(arg_names[idx:], defaults):
1805                resolved[aname] = default
1806            arg_names = arg_names[:idx]
1807        ##
1808        def _resolve(arg_name):
1809            aname = arg_name
1810            if aname not in resolved:
1811                pair = self.find(aname, object, *opts)
1812                if pair:
1813                    provider, releaser = pair
1814                    resolved[aname] = _call(provider, aname)
1815                    releasers[aname] = releaser
1816                else:
1817                    resolved[aname] = fixture_manager.provide(aname)
1818            return resolved[aname]
1819        def _call(provider, resolving_arg_name):
1820            arg_names = util.func_argnames(provider)
1821            if not arg_names:
1822                return provider()
1823            in_progress.append(resolving_arg_name)
1824            defaults = util.func_defaults(provider)
1825            if not defaults:
1826                arg_values = [ _get_value(aname) for aname in arg_names ]
1827            else:
1828                idx  = - len(defaults)
1829                arg_values = [ _get_value(aname) for aname in arg_names[:idx] ]
1830                for aname, default in zip(arg_names[idx:], defaults):
1831                    arg_values.append(resolved.get(aname, default))
1832            in_progress.remove(resolving_arg_name)
1833            return provider(*arg_values)
1834        def _get_value(arg_name):
1835            if arg_name in resolved:        return resolved[arg_name]
1836            if arg_name not in in_progress: return _resolve(arg_name)
1837            raise self._looped_dependency_error(arg_name, in_progress, object)
1838        ##
1839        arguments = [ _resolve(aname) for aname in arg_names ]
1840        assert not in_progress
1841        try:
1842            #return func(object, *arguments)
1843            return func(*arguments)
1844        finally:
1845            self._release_fixtures(resolved, releasers)
1846
1847    def _release_fixtures(self, resolved, releasers):
1848        for name in resolved:
1849            if name in releasers:
1850                releaser = releasers[name]
1851                if releaser:
1852                    names = util.func_argnames(releaser)
1853                    if names and names[0] == "self":
1854                        releaser(resolved["self"], resolved[name])
1855                    else:
1856                        releaser(resolved[name])
1857            else:
1858                fixture_manager.release(name, resolved[name])
1859
1860    def find(self, name, object, *opts):
1861        """return provide_xxx() and release_xxx() functions."""
1862        globalvars = opts[0]
1863        provider_name = 'provide_' + name
1864        releaser_name = 'release_' + name
1865        meth = getattr(object, provider_name, None)
1866        if meth:
1867            provider = meth
1868            if python2:
1869                if hasattr(meth, 'im_func'):  provider = meth.im_func
1870            elif python3:
1871                if hasattr(meth, '__func__'): provider = meth.__func__
1872            releaser = getattr(object, releaser_name, None)
1873            return (provider, releaser)
1874        elif provider_name in globalvars:
1875            provider = globalvars[provider_name]
1876            if not isinstance(provider, types.FunctionType):
1877                raise TypeError("%s: expected function but got %s." % (provider_name, type(provider)))
1878            releaser = globalvars.get(releaser_name)
1879            return (provider, releaser)
1880        #else:
1881        #    raise NameError("%s: no such fixture provider for '%s'." % (provider_name, name))
1882            return None
1883
1884    def _looped_dependency_error(self, aname, in_progress, object):
1885        names = in_progress + [aname]
1886        pos   = names.index(aname)
1887        loop  = '=>'.join(names[pos:])
1888        if pos > 0:
1889            loop = '->'.join(names[0:pos]) + '->' + loop
1890        classname = object.__class__.__name__
1891        testdesc  = object._description
1892        return LoopedDependencyError("fixture dependency is looped: %s (class: %s, test: '%s')" % (loop, classname, testdesc))
1893
1894
1895fixture_injector = FixtureInjector()
1896
1897
1898class LoopedDependencyError(ValueError):
1899    pass
1900
1901
1902##
1903## test context
1904##
1905def context():
1906
1907    __all__ = ('subject', 'situation', )
1908    global TestContext
1909
1910    class TestContext(object):
1911        """grouping test methods.
1912
1913        normally created with subject() or situation() helpers.
1914
1915        ex::
1916            class HelloClassTest(unittest.TestCase):
1917                SUBJECT = Hello
1918                with subject('#method1()'):
1919                    @test("spec1")
1920                    def _(self):
1921                        ...
1922                    @test("spec2")
1923                    def _(self):
1924                        ...
1925                with subject('#method2()'):
1926                    with situation('when condition:'):
1927                        @test("spec3")
1928                        def _(self):
1929                    with situation('else:')
1930                        @test("spec3")
1931                        def _(self):
1932                        ...
1933        """
1934
1935        def __init__(self, desc, _lineno=None):
1936            self.desc = desc
1937            self.items = []
1938            self.parent = None
1939            self._lineno = _lineno
1940
1941        def __repr__(self):
1942            return "<TestContext desc=%r items=[%s]>" % \
1943                       (self.desc, ','.join(repr(x) for x in self.items))
1944
1945        def __enter__(self):
1946            f_locals = sys._getframe(1).f_locals
1947            self._f_locals = f_locals
1948            self._varnames = set(f_locals.keys())
1949            stack = f_locals.setdefault('_context_stack', [])
1950            if not stack:
1951                f_locals.setdefault('_context_list', []).append(self)
1952            else:
1953                self.parent = stack[-1]
1954                self.parent.items.append(self)
1955            stack.append(self)
1956            return self
1957
1958        def __exit__(self, *args):
1959            f_locals = self._f_locals
1960            popped = f_locals['_context_stack'].pop()
1961            assert popped is self
1962            newvars = set(f_locals.keys()) - self._varnames
1963            for name in newvars:
1964                if name.startswith('test'):
1965                    func = f_locals[name]
1966                    if not hasattr(func, '_test_context'):
1967                        func._test_context = self.desc
1968                        self.items.append((name, func))
1969            self._sort_items(self.items)
1970            del self._f_locals
1971            del self._varnames
1972
1973        @staticmethod
1974        def _sort_items(items):
1975            def fn(item):
1976                if isinstance(item, tuple):
1977                    return getattr(item[1], '_firstlineno', None) or \
1978                           util._func_firstlineno(item[1])
1979                elif isinstance(item, TestContext):
1980                    return item._lineno or 0
1981                else:
1982                    assert False, "** item=%r" % (item, )
1983            items.sort(key=fn)
1984
1985        @staticmethod
1986        def _inspect_items(items):
1987            def _inspect(items, depth, add):
1988                for item in items:
1989                    if isinstance(item, tuple):
1990                        add("  " * depth + "- %s()\n" % item[0])
1991                    else:
1992                        add("  " * depth + "- Context: %r\n" % item.desc)
1993                        _inspect(item.items, depth+1, add)
1994            buf = []
1995            _inspect(items, 0, buf.append)
1996            return "".join(buf)
1997
1998
1999    def subject(desc):
2000        """helper to group test methods by subject"""
2001        lineno = sys._getframe(1).f_lineno
2002        return TestContext(desc, _lineno=lineno)
2003
2004    def situation(desc):
2005        """helper to group test methods by situation or condition"""
2006        lineno = sys._getframe(1).f_lineno
2007        return TestContext(desc, _lineno=lineno)
2008
2009
2010    return locals()
2011
2012context = _new_module("oktest.context", context())
2013context.TestContext = TestContext
2014subject   = context.subject
2015situation = context.situation
2016
2017
2018##
2019## dummy
2020##
2021def _dummy():
2022
2023    __all__ = ('dummy_file', 'dummy_dir', 'dummy_values', 'dummy_attrs', 'dummy_environ_vars', 'dummy_io')
2024
2025
2026    class DummyFile(util.RunnableContext):
2027
2028        def __init__(self, filename, content):
2029            self.filename = filename
2030            self.path     = os.path.abspath(filename)
2031            self.content  = content
2032
2033        def __enter__(self, *args):
2034            f = open(self.path, 'w')
2035            try:
2036                f.write(self.content)
2037            finally:
2038                f.close()
2039            return self
2040
2041        def __exit__(self, *args):
2042            os.unlink(self.path)
2043
2044
2045    class DummyDir(util.RunnableContext):
2046
2047        def __init__(self, dirname):
2048            self.dirname = dirname
2049            self.path    = os.path.abspath(dirname)
2050
2051        def __enter__(self, *args):
2052            os.mkdir(self.path)
2053            return self
2054
2055        def __exit__(self, *args):
2056            import shutil
2057            shutil.rmtree(self.path)
2058
2059
2060    class DummyValues(util.RunnableContext):
2061
2062        def __init__(self, dictionary, items_=None, **kwargs):
2063            self.dict = dictionary
2064            self.items = {}
2065            if isinstance(items_, dict):
2066                self.items.update(items_)
2067            if kwargs:
2068                self.items.update(kwargs)
2069
2070        def __enter__(self):
2071            self.original = d = {}
2072            for k in self.items:
2073                if k in self.dict:
2074                    d[k] = self.dict[k]
2075            self.dict.update(self.items)
2076            return self
2077
2078        def __exit__(self, *args):
2079            for k in self.items:
2080                if k in self.original:
2081                    self.dict[k] = self.original[k]
2082                else:
2083                    del self.dict[k]
2084            self.__dict__.clear()
2085
2086
2087    class DummyIO(util.RunnableContext):
2088
2089        def __init__(self, stdin_content=None):
2090            self.stdin_content = stdin_content
2091
2092        def __enter__(self):
2093            self.stdout, sys.stdout = sys.stdout, StringIO()
2094            self.stderr, sys.stderr = sys.stderr, StringIO()
2095            self.stdin,  sys.stdin  = sys.stdin,  StringIO(self.stdin_content or "")
2096            return self
2097
2098        def __exit__(self, *args):
2099            sout, serr = sys.stdout.getvalue(), sys.stderr.getvalue()
2100            sys.stdout, self.stdout = self.stdout, sys.stdout.getvalue()
2101            sys.stderr, self.stderr = self.stderr, sys.stderr.getvalue()
2102            sys.stdin,  self.stdin  = self.stdin,  self.stdin_content
2103
2104        def __call__(self, func, *args, **kwargs):
2105            self.returned = self.run(func, *args, **kwargs)
2106            return self
2107
2108        def __iter__(self):
2109            yield self.stdout
2110            yield self.stderr
2111
2112
2113    def dummy_file(filename, content):
2114        return DummyFile(filename, content)
2115
2116    def dummy_dir(dirname):
2117        return DummyDir(dirname)
2118
2119    def dummy_values(dictionary, items_=None, **kwargs):
2120        return DummyValues(dictionary, items_, **kwargs)
2121
2122    def dummy_attrs(object, items_=None, **kwargs):
2123        return DummyValues(object.__dict__, items_, **kwargs)
2124
2125    def dummy_environ_vars(**kwargs):
2126        return DummyValues(os.environ, **kwargs)
2127
2128    def dummy_io(stdin_content="", func=None, *args, **kwargs):
2129        obj = dummy.DummyIO(stdin_content)
2130        if func is None:
2131            return obj    # for with-statement
2132        obj.__enter__()
2133        try:
2134            func(*args, **kwargs)
2135        finally:
2136            obj.__exit__(*sys.exc_info())
2137        #return obj.stdout, obj.stderr
2138        return obj
2139
2140
2141    return locals()
2142
2143
2144dummy = _new_module('oktest.dummy', _dummy(), util)
2145del _dummy
2146
2147
2148
2149##
2150## Tracer
2151##
2152def _dummy():
2153
2154    __all__ = ('Tracer', )
2155
2156
2157    class Call(object):
2158
2159        __repr_style = None
2160
2161        def __init__(self, receiver=None, name=None, args=None, kwargs=None, ret=None):
2162            self.receiver = receiver
2163            self.name   = name     # method name
2164            self.args   = args
2165            self.kwargs = kwargs
2166            self.ret    = ret
2167
2168        def __repr__(self):
2169            #return '%s(args=%r, kwargs=%r, ret=%r)' % (self.name, self.args, self.kwargs, self.ret)
2170            if self.__repr_style == 'list':
2171                return repr(self.list())
2172            if self.__repr_style == 'tuple':
2173                return repr(self.tuple())
2174            buf = []; a = buf.append
2175            a("%s(" % self.name)
2176            for arg in self.args:
2177                a(repr(arg))
2178                a(", ")
2179            for k in self.kwargs:
2180                a("%s=%s" % (k, repr(self.kwargs[k])))
2181                a(", ")
2182            if buf[-1] == ", ":  buf.pop()
2183            a(") #=> %s" % repr(self.ret))
2184            return "".join(buf)
2185
2186        def __iter__(self):
2187            yield self.receiver
2188            yield self.name
2189            yield self.args
2190            yield self.kwargs
2191            yield self.ret
2192
2193        def list(self):
2194            return list(self)
2195
2196        def tuple(self):
2197            return tuple(self)
2198
2199        def __eq__(self, other):
2200            if isinstance(other, list):
2201                self.__repr_style = 'list'
2202                return list(self) == other
2203            elif isinstance(other, tuple):
2204                self.__repr_style = 'tuple'
2205                return tuple(self) == other
2206            elif isinstance(other, self.__class__):
2207                return self.name == other.name and self.args == other.args \
2208                    and self.kwargs == other.kwargs and self.ret == other.ret
2209            else:
2210                return False
2211
2212        def __ne__(self, other):
2213            return not self.__eq__(other)
2214
2215
2216    class FakeObject(object):
2217
2218        def __init__(self, **kwargs):
2219            self._calls = self.__calls = []
2220            for name in kwargs:
2221                setattr(self, name, self.__new_method(name, kwargs[name]))
2222
2223        def __new_method(self, name, val):
2224            fake_obj = self
2225            if isinstance(val, types.FunctionType):
2226                func = val
2227                def f(self, *args, **kwargs):
2228                    r = Call(fake_obj, name, args, kwargs, None)
2229                    fake_obj.__calls.append(r)
2230                    r.ret = func(self, *args, **kwargs)
2231                    return r.ret
2232            else:
2233                def f(self, *args, **kwargs):
2234                    r = Call(fake_obj, name, args, kwargs, val)
2235                    fake_obj.__calls.append(r)
2236                    return val
2237            f.func_name = f.__name__ = name
2238            if python2: return types.MethodType(f, self, self.__class__)
2239            if python3: return types.MethodType(f, self)
2240
2241
2242    class Tracer(object):
2243        """trace function or method call to record arguments and return value.
2244           see README.txt for details.
2245        """
2246
2247        def __init__(self):
2248            self.calls = []
2249
2250        def __getitem__(self, index):
2251            return self.calls[index]
2252
2253        def __len__(self):
2254            return len(self.calls)
2255
2256        def __iter__(self):
2257            return self.calls.__iter__()
2258
2259        def _copy_attrs(self, func, newfunc):
2260            for k in ('func_name', '__name__', '__doc__'):
2261                if hasattr(func, k):
2262                    setattr(newfunc, k, getattr(func, k))
2263
2264        def _wrap_func(self, func, block):
2265            tr = self
2266            def newfunc(*args, **kwargs):                # no 'self'
2267                call = Call(None, util._func_name(func), args, kwargs, None)
2268                tr.calls.append(call)
2269                if block:
2270                    ret = block(func, *args, **kwargs)
2271                else:
2272                    ret = func(*args, **kwargs)
2273                #newfunc._return = ret
2274                call.ret = ret
2275                return ret
2276            self._copy_attrs(func, newfunc)
2277            return newfunc
2278
2279        def _wrap_method(self, method_obj, block):
2280            func = method_obj
2281            tr = self
2282            def newfunc(self, *args, **kwargs):          # has 'self'
2283                call = Call(self, util._func_name(func), args, kwargs, None)
2284                tr.calls.append(call)
2285                if util._is_unbound(func): args = (self, ) + args   # call with 'self' if unbound method
2286                if block:
2287                    ret = block(func, *args, **kwargs)
2288                else:
2289                    ret = func(*args, **kwargs)
2290                call.ret = ret
2291                return ret
2292            self._copy_attrs(func, newfunc)
2293            if python2:  return types.MethodType(newfunc, func.im_self, func.im_class)
2294            if python3:  return types.MethodType(newfunc, func.__self__)
2295
2296        def trace_func(self, func):
2297            newfunc = self._wrap_func(func, None)
2298            return newfunc
2299
2300        def fake_func(self, func, block):
2301            newfunc = self._wrap_func(func, block)
2302            return newfunc
2303
2304        def trace_method(self, obj, *method_names):
2305            for method_name in method_names:
2306                method_obj = getattr(obj, method_name, None)
2307                if method_obj is None:
2308                    raise NameError("%s: method not found on %r." % (method_name, obj))
2309                setattr(obj, method_name, self._wrap_method(method_obj, None))
2310            return None
2311
2312        def fake_method(self, obj, **kwargs):
2313            def _new_block(ret_val):
2314                def _block(*args, **kwargs):
2315                    return ret_val
2316                return _block
2317            def _dummy_method(obj, name):
2318                fn = lambda *args, **kwargs: None
2319                fn.__name__ = name
2320                if python2: fn.func_name = name
2321                if python2: return types.MethodType(fn, obj, type(obj))
2322                if python3: return types.MethodType(fn, obj)
2323            for method_name in kwargs:
2324                method_obj = getattr(obj, method_name, None)
2325                if method_obj is None:
2326                    method_obj = _dummy_method(obj, method_name)
2327                block = kwargs[method_name]
2328                if not isinstance(block, types.FunctionType):
2329                    block = _new_block(block)
2330                setattr(obj, method_name, self._wrap_method(method_obj, block))
2331            return None
2332
2333        def trace(self, target, *args):
2334            if type(target) is types.FunctionType:       # function
2335                func = target
2336                return self.trace_func(func)
2337            else:
2338                obj = target
2339                return self.trace_method(obj, *args)
2340
2341        def fake(self, target, *args, **kwargs):
2342            if type(target) is types.FunctionType:       # function
2343                func = target
2344                block = args and args[0] or None
2345                return self.fake_func(func, block)
2346            else:
2347                obj = target
2348                return self.fake_method(obj, **kwargs)
2349
2350        def fake_obj(self, **kwargs):
2351            obj = FakeObject(**kwargs)
2352            obj._calls = obj._FakeObject__calls = self.calls
2353            return obj
2354
2355
2356    return locals()
2357
2358
2359tracer = _new_module('oktest.tracer', _dummy(), util)
2360del _dummy
2361
2362
2363
2364##
2365## mainapp
2366##
2367import unittest
2368
2369def load_module(mod_name, filepath, content=None):
2370    mod = type(os)(mod_name)
2371    mod.__dict__["__name__"] = mod_name
2372    mod.__dict__["__file__"] = filepath
2373    #mod.__dict__["__file__"] = os.path.abspath(filepath)
2374    if content is None:
2375        if python2:
2376            content = util.read_binary_file(filepath)
2377        if python3:
2378            content = util.read_text_file(filepath)
2379    if filepath:
2380        code = compile(content, filepath, "exec")
2381        exec(code, mod.__dict__, mod.__dict__)
2382    else:
2383        exec(content, mod.__dict__, mod.__dict__)
2384    return mod
2385
2386def rglob(dirpath, pattern, _entries=None):
2387    import fnmatch
2388    if _entries is None: _entries = []
2389    isdir, join = os.path.isdir, os.path.join
2390    add = _entries.append
2391    if isdir(dirpath):
2392        items = os.listdir(dirpath)
2393        for item in fnmatch.filter(items, pattern):
2394            path = join(dirpath, item)
2395            add(path)
2396        for item in items:
2397            path = join(dirpath, item)
2398            if isdir(path) and not item.startswith('.'):
2399                rglob(path, pattern, _entries)
2400    return _entries
2401
2402
2403def _dummy():
2404
2405    global optparse
2406    import optparse
2407
2408    class MainApp(object):
2409
2410        debug = False
2411
2412        def __init__(self, command=None):
2413            self.command = command
2414
2415        def _new_cmdopt_parser(self):
2416            #import cmdopt
2417            #parser = cmdopt.Parser()
2418            #parser.opt("-h").name("help")                         .desc("show help")
2419            #parser.opt("-v").name("version")                      .desc("version of oktest.py")
2420            ##parser.opt("-s").name("testdir").arg("DIR[,DIR2,..]") .desc("test directory (default 'test' or 'tests')")
2421            #parser.opt("-p").name("pattern").arg("PAT[,PAT2,..]") .desc("test script pattern (default '*_test.py,test_*.py')")
2422            #parser.opt("-x").name("exclude").arg("PAT[,PAT2,..]") .desc("exclue file pattern")
2423            #parser.opt("-D").name("debug")                        .desc("debug mode")
2424            #return parser
2425            parser = optparse.OptionParser(conflict_handler="resolve")
2426            parser.add_option("-h", "--help",       action="store_true",     help="show help")
2427            parser.add_option("-v", "--version",    action="store_true",     help="verion of oktest.py")
2428            parser.add_option("-s", dest="style",   metavar="STYLE",         help="reporting style (plain/simple/verbose, or p/s/v)")
2429            parser.add_option(      "--color",      metavar="true|false",    help="enable/disable output color")
2430            parser.add_option("-K", dest="encoding", metavar="ENCODING",     help="output encoding (utf-8 when system default is US-ASCII)")
2431            parser.add_option("-p", dest="pattern", metavar="PAT[,PAT2,..]", help="test script pattern (default '*_test.py,test_*.py')")
2432            #parser.add_option("-x", dest="exclude", metavar="PAT[,PAT2,..]", help="exclue file pattern")
2433            parser.add_option("-U", dest="unittest", action="store_true",    help="run testcases with unittest.main instead of oktest.run")
2434            parser.add_option("-D", dest="debug",   action="store_true",     help="debug mode")
2435            parser.add_option("-f", dest="filter",  metavar="FILTER",        help="filter (class=xxx/test=xxx/useroption=xxx)")
2436            return parser
2437
2438        def _load_modules(self, filepaths, pattern=None):
2439            from fnmatch import fnmatch
2440            modules = []
2441            for fpath in filepaths:
2442                mod_name = os.path.basename(fpath).replace('.py', '')
2443                if pattern and not fnmatch(mod_name, pattern):
2444                    continue
2445                mod = load_module(mod_name, fpath)
2446                modules.append(mod)
2447            self._trace("modules: ", modules)
2448            return modules
2449
2450        def _load_classes(self, modules, pattern=None):
2451            from fnmatch import fnmatch
2452            testclasses = []
2453            unittest_testclasses = []
2454            oktest_testclasses   = []
2455            for mod in modules:
2456                for k in dir(mod):
2457                    #if k.startswith('_'): continue
2458                    v = getattr(mod, k)
2459                    if not isinstance(v, type): continue
2460                    klass = v
2461                    if pattern and not fnmatch(klass.__name__, pattern):
2462                        continue
2463                    if issubclass(klass, unittest.TestCase):
2464                        testclasses.append(klass)
2465                        unittest_testclasses.append(klass)
2466                    elif re.search(config.TARGET_PATTERN, klass.__name__):
2467                        testclasses.append(klass)
2468                        oktest_testclasses.append(klass)
2469            return testclasses, unittest_testclasses, oktest_testclasses
2470
2471        def _run_unittest(self, klasses, pattern=None, filters=None):
2472            self._trace("test_pattern: %r" % (pattern,))
2473            self._trace("unittest_testclasses: ", klasses)
2474            loader = unittest.TestLoader()
2475            the_suite = unittest.TestSuite()
2476            rexp = re.compile(r'^test(_|_\d\d\d(_|: ))?')
2477            if filters:
2478                key = list(filters.keys())[0]
2479                val = filters[key]
2480            else:
2481                key = val = None
2482            for klass in klasses:
2483                if pattern or filters:
2484                    testnames = loader.getTestCaseNames(klass)
2485                    testcases = [ klass(tname) for tname in testnames
2486                                      if _filtered(klass, None, tname, pattern, key, val) ]
2487                    suite = loader.suiteClass(testcases)
2488                else:
2489                    suite = loader.loadTestsFromTestCase(klass)
2490                the_suite.addTest(suite)
2491            #runner = unittest.TextTestRunner()
2492            runner = unittest.TextTestRunner(stream=sys.stderr)
2493            result = runner.run(the_suite)
2494            n_errors = len(result.errors) + len(result.failures)
2495            return n_errors
2496
2497        def _run_oktest(self, klasses, pattern, kwargs):
2498            self._trace("test_pattern: %r" % (pattern,))
2499            self._trace("oktest_testclasses: ", klasses)
2500            if pattern:
2501                kwargs.setdefault('filter', {})['test'] = pattern
2502            import oktest; run = oktest.run    # don't remove!
2503            n_errors = run(*klasses, **kwargs)
2504            return n_errors
2505
2506        def _trace(self, msg, items=None):
2507            write = sys.stderr.write
2508            if items is None:
2509                write("** DEBUG: %s\n" % msg)
2510            else:
2511                write("** DEBUG: %s[\n" % msg)
2512                for item in items:
2513                    write("**   %r,\n" % (item,))
2514                write("** ]\n")
2515
2516        def _help_message(self, parser):
2517            buf = []; add = buf.append
2518            add("Usage: python -m oktest [options] file_or_directory...\n")
2519            #add(parser.help_message(20))
2520            add(re.sub(r'^.*\n.*\n[oO]ptions:\n', '', parser.format_help()))
2521            add("Example:\n")
2522            add("   ## run test scripts in plain format\n")
2523            add("   $ python -m oktest -sp tests/*_test.py\n")
2524            add("   ## run test scripts in 'tests' dir with pattern '*_test.py'\n")
2525            add("   $ python -m oktest -p '*_test.py' tests\n")
2526            add("   ## filter by class name\n")
2527            add("   $ python -m oktest -f class='ClassName*' tests\n")
2528            add("   ## filter by test method name\n")
2529            add("   $ python -m oktest -f '*method*' tests   # or -f test='*method*'\n")
2530            add("   ## filter by user-defined option added by @test decorator\n")
2531            add("   $ python -m oktest -f tag='*value*' tests\n")
2532            return "".join(buf)
2533
2534        def _version_info(self):
2535            buf = []; add = buf.append
2536            add("oktest: " + __version__)
2537            add("python: " + sys.version.split("\n")[0])
2538            add("")
2539            return "\n".join(buf)
2540
2541        def _get_files(self, args, pattern):
2542            filepaths = []
2543            for arg in args:
2544                if os.path.isfile(arg):
2545                    filepaths.append(arg)
2546                elif os.path.isdir(arg):
2547                    files = self._find_files_recursively(arg, pattern)
2548                    filepaths.extend(files)
2549                else:
2550                    raise ValueError("%s: file or directory expected." % (arg,))
2551            return filepaths
2552
2553        def _find_files_recursively(self, testdir, pattern):
2554            isdir = os.path.isdir
2555            assert isdir(testdir)
2556            filepaths = []
2557            for pat in pattern.split(","):
2558                files = rglob(testdir, pat)
2559                if files:
2560                    filepaths.extend(files)
2561                    self._trace("testdir: %r, pattern: %r, files: " % (testdir, pat), files)
2562            return filepaths
2563
2564        #def _exclude_files(self, filepaths, pattern):
2565        #    from fnmatch import fnmatch
2566        #    _trace = self._trace
2567        #    basename = os.path.basename
2568        #    original = filepaths[:]
2569        #    for pat in pattern.split(","):
2570        #        filepaths = [ fpath for fpath in filepaths
2571        #                          if not fnmatch(basename(fpath), pat) ]
2572        #    _trace("excluded: %r" % (list(set(original) - set(filepaths)), ))
2573        #    return filepaths
2574
2575        def _get_filters(self, opts_filter):
2576            filters = {}
2577            if opts_filter:
2578                pair = opts_filter.split('=', 2)
2579                if len(pair) != 2:
2580                    pair = ('test', pair[0])
2581                filters[pair[0]] = pair[1]
2582            return filters
2583
2584        def _handle_opt_report(self, opt_report, parser):
2585            key = None
2586            d = {"p": "plain", "s": "simple", "v": "verbose"}
2587            key = d.get(opt_report, opt_report)
2588            self._trace("reporter: %s" % key)
2589            if not BaseReporter.get_registered_class(key):
2590                #raise optparse.OptionError("%r: unknown report sytle (plain/simple/verbose, or p/s/v)" % opt_report)
2591                parser.error("%r: unknown report sytle (plain/simple/verbose, or p/s/v)" % opt_report)
2592            return key
2593
2594        def _handle_opt_color(self, opt_color, parser):
2595            import oktest.config
2596            if   opt_color in ('true', 'yes', 'on'):
2597                oktest.config.color_enabled = True
2598            elif opt_color in ('false', 'no', 'off'):
2599                oktest.config.color_enabled = False
2600            else:
2601                #raise optparse.OptionError("--color=%r: 'true' or 'false' expected" % opt_color)
2602                parser.error("--color=%r: 'true' or 'false' expected" % opt_color)
2603            return oktest.config.color_enabled
2604
2605        def _get_output_writer(self, encoding):
2606            self._trace('output encoding: ' + encoding)
2607            if python2:
2608                import codecs
2609                return codecs.getwriter(encoding)(sys.stdout)
2610            if python3:
2611                import io
2612                return io.TextIOWrapper(sys.stdout.buffer, encoding=encoding)
2613
2614        def run(self, args=None, **kwargs):
2615            if args is None: args = sys.argv[1:]
2616            parser = self._new_cmdopt_parser()
2617            #opts = parser.parse(args)
2618            opts, args = parser.parse_args(args)
2619            if opts.debug:
2620                self.debug = True
2621                _trace = self._trace
2622                import oktest.config
2623                oktest.config.debug = True
2624            else:
2625                _trace = self._trace = lambda msg, items=None: None
2626            _trace("python: " + sys.version.split()[0])
2627            _trace("oktest: " + __version__)
2628            _trace("opts: %r" % (opts,))
2629            _trace("args: %r" % (args,))
2630            if opts.help:
2631                print(self._help_message(parser))
2632                return
2633            if opts.version:
2634                print(self._version_info())
2635                return
2636            #
2637            if opts.style:
2638                kwargs['style'] = self._handle_opt_report(opts.style, parser)
2639            if opts.color:
2640                kwargs['color'] = self._handle_opt_color(opts.color, parser)
2641            if 'out' not in kwargs:
2642                if opts.encoding:
2643                    kwargs['out'] = self._get_output_writer(opts.encoding)
2644                elif not hasattr(sys.stdout, 'encoding') or sys.stdout.encoding == 'US-ASCII':
2645                    kwargs['out'] = self._get_output_writer('utf-8')
2646            #
2647            pattern = opts.pattern or '*_test.py,test_*.py'
2648            filepaths = self._get_files(args, pattern)
2649            #if opts.exclude:
2650            #    filepaths = self._exclude_files(filepaths, opts.exclude)
2651            filters = self._get_filters(opts.filter)
2652            fval = lambda key, filters=filters: filters.pop(key, None)
2653            modules = self._load_modules(filepaths, fval('module'))
2654            tupl = self._load_classes(modules, fval('class'))
2655            testclasses, unittest_testclasses, oktest_testclasses = tupl
2656            kwargs['filter'] = filters
2657            if opts.unittest:
2658                n_errors = 0
2659                if unittest_testclasses:
2660                    n_errors += self._run_unittest(unittest_testclasses, fval('test'), filters)
2661                if oktest_testclasses:
2662                    n_errors += self._run_oktest(oktest_testclasses, fval('test'), kwargs)
2663            else:
2664                n_errors = self._run_oktest(testclasses, fval('test'), kwargs)
2665            return n_errors
2666
2667        @classmethod
2668        def main(cls, sys_argv=None):
2669            #import cmdopt
2670            if sys_argv is None: sys_argv = sys.argv
2671            #app = cls(sys_argv[0])
2672            #try:
2673            #    app.run(sys_argv[1:])
2674            #    sys.exit(0)
2675            #except cmdopt.ParseError:
2676            #    ex = sys.exc_info()[1]
2677            #    sys.stderr.write("%s" % (ex, ))
2678            #    sys.exit(1)
2679            app = cls(sys_argv[0])
2680            n_errors = app.run(sys_argv[1:])
2681            sys.exit(n_errors)
2682
2683    return locals()
2684
2685
2686mainapp = _new_module('oktest.mainapp', _dummy(), util)
2687del _dummy
2688
2689
2690def main(*args):
2691    sys_argv = [__file__] + sys.argv + list(args)
2692    mainapp.MainApp.main(sys_argv)
2693
2694
2695if __name__ == '__main__':
2696    mainapp.MainApp.main()
2697