1import sys
2import os
3import unittest
4from unittest import TestResult, TestLoader
5import time
6
7
8class TagTestLoader(TestLoader):
9    """A TestLoader which handles additional __tags__ attributes for
10    test functions.
11    """
12    def __init__(self, excludetags, randomizer=None):
13        TestLoader.__init__(self)
14        self.excludetags = excludetags
15        self.randomizer = randomizer
16
17    def getTestCaseNames(self, testCaseClass):
18        """Gets only the tests, which are not within the tag exclusion.
19        The method overrides the original TestLoader.getTestCaseNames()
20        method, so we need to keep them in sync on updates.
21        """
22        def isTestMethod(attrname, testCaseClass=testCaseClass,
23                         prefix=self.testMethodPrefix):
24            if not attrname.startswith(prefix):
25                return False
26            if not hasattr(getattr(testCaseClass, attrname), "__call__"):
27                return False
28            if hasattr(getattr(testCaseClass, attrname), "__tags__"):
29                # Tagged test method
30                tags = getattr(getattr(testCaseClass, attrname), "__tags__")
31                for t in tags:
32                    if t in self.excludetags:
33                        return False
34            return True
35
36        if hasattr(testCaseClass, "__tags__"):
37            tags = getattr(testCaseClass, "__tags__")
38            for t in tags:
39                if t in self.excludetags:
40                    return []
41
42        testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
43        cmpkey = getattr(unittest, "_CmpToKey", None) or \
44            getattr(unittest, "CmpToKey", None)
45
46        if self.randomizer:
47            self.randomizer.shuffle(testFnNames)
48        elif self.sortTestMethodsUsing:
49            if cmpkey:
50                testFnNames.sort(key=cmpkey(self.sortTestMethodsUsing))
51            else:
52                testFnNames.sort()
53        return testFnNames
54
55
56class SimpleTestResult(TestResult):
57    """A simple TestResult class with output capabilities.
58    """
59    def __init__(self, stream=sys.stderr, verbose=False, countcall=None):
60        TestResult.__init__(self)
61        self.stream = stream
62        self.duration = 0
63        self.verbose = verbose
64        self.countcall = countcall
65
66    def addSkip(self, test, reason):
67        TestResult.addSkip(self, test, reason)
68        if self.verbose:
69            self.stream.write("SKIPPED: %s [%s]%s" % (test, reason,
70                                                      os.linesep))
71            self.stream.flush()
72        self.countcall()
73
74    def addSuccess(self, test):
75        TestResult.addSuccess(self, test)
76        if self.verbose:
77            self.stream.write("OK:      %s%s" % (test, os.linesep))
78            self.stream.flush()
79        self.countcall()
80
81    def addError(self, test, err):
82        TestResult.addError(self, test, err)
83        if self.verbose:
84            self.stream.write("ERROR:   %s%s" % (test, os.linesep))
85            self.stream.flush()
86        self.countcall()
87
88    def addFailure(self, test, err):
89        TestResult.addFailure(self, test, err)
90        if self.verbose:
91            self.stream.write("FAILED:  %s%s" % (test, os.linesep))
92            self.stream.flush()
93        self.countcall()
94
95
96class SimpleTestRunner(object):
97
98    def __init__(self, stream=sys.stderr, verbose=False):
99        self.stream = stream
100        self.verbose = verbose
101
102    def run(self, test, countcall):
103        result = SimpleTestResult(self.stream, self.verbose, countcall)
104        starttime = time.time()
105        test(result)
106        endtime = time.time()
107        result.duration = endtime - starttime
108        return result
109