1import os
2import traceback, sys
3from unittest import TestResult
4import datetime
5
6from tcmessages import TeamcityServiceMessages
7
8PYTHON_VERSION_MAJOR = sys.version_info[0]
9
10
11def strclass(cls):
12  if not cls.__name__:
13    return cls.__module__
14  return "%s.%s" % (cls.__module__, cls.__name__)
15
16
17def smart_str(s):
18  encoding = 'utf-8'
19  errors = 'strict'
20  if PYTHON_VERSION_MAJOR < 3:
21    is_string = isinstance(s, basestring)
22  else:
23    is_string = isinstance(s, str)
24  if not is_string:
25    try:
26      return str(s)
27    except UnicodeEncodeError:
28      if isinstance(s, Exception):
29        # An Exception subclass containing non-ASCII data that doesn't
30        # know how to print itself properly. We shouldn't raise a
31        # further exception.
32        return ' '.join([smart_str(arg) for arg in s])
33      return unicode(s).encode(encoding, errors)
34  elif isinstance(s, unicode):
35    return s.encode(encoding, errors)
36  else:
37    return s
38
39
40class TeamcityTestResult(TestResult):
41  """
42  Set ``_jb_do_not_call_enter_matrix`` to prevent it from runnig "enter matrix"
43  """
44
45  def __init__(self, stream=sys.stdout, *args, **kwargs):
46    TestResult.__init__(self)
47    for arg, value in kwargs.items():
48      setattr(self, arg, value)
49    self.output = stream
50    self.messages = TeamcityServiceMessages(self.output, prepend_linebreak=True)
51    if not "_jb_do_not_call_enter_matrix" in os.environ:
52      self.messages.testMatrixEntered()
53    self.current_failed = False
54    self.current_suite = None
55    self.subtest_suite = None
56
57  def find_first(self, val):
58    quot = val[0]
59    count = 1
60    quote_ind = val[count:].find(quot)
61    while quote_ind != -1 and val[count + quote_ind - 1] == "\\":
62      count = count + quote_ind + 1
63      quote_ind = val[count:].find(quot)
64
65    return val[0:quote_ind + count + 1]
66
67  def find_second(self, val):
68    val_index = val.find("!=")
69    if val_index != -1:
70      count = 1
71      val = val[val_index + 2:].strip()
72      quot = val[0]
73      quote_ind = val[count:].find(quot)
74      while quote_ind != -1 and val[count + quote_ind - 1] == "\\":
75        count = count + quote_ind + 1
76        quote_ind = val[count:].find(quot)
77      return val[0:quote_ind + count + 1]
78
79    else:
80      quot = val[-1]
81      quote_ind = val[:len(val) - 1].rfind(quot)
82      while quote_ind != -1 and val[quote_ind - 1] == "\\":
83        quote_ind = val[:quote_ind - 1].rfind(quot)
84      return val[quote_ind:]
85
86  def formatErr(self, err):
87    exctype, value, tb = err
88    return ''.join(traceback.format_exception(exctype, value, tb))
89
90  def getTestName(self, test, is_subtest=False):
91    if is_subtest:
92      test_name = self.getTestName(test.test_case)
93      return "{} {}".format(test_name, test._subDescription())
94    if hasattr(test, '_testMethodName'):
95      if test._testMethodName == "runTest":
96        return str(test)
97      return test._testMethodName
98    else:
99      test_name = str(test)
100      whitespace_index = test_name.index(" ")
101      if whitespace_index != -1:
102        test_name = test_name[:whitespace_index]
103      return test_name
104
105  def getTestId(self, test):
106    return test.id
107
108  def addSuccess(self, test):
109    TestResult.addSuccess(self, test)
110
111  def addError(self, test, err):
112    location = self.init_suite(test)
113    self.current_failed = True
114    TestResult.addError(self, test, err)
115
116    err = self._exc_info_to_string(err, test)
117
118    self.messages.testStarted(self.getTestName(test), location=location)
119    self.messages.testError(self.getTestName(test),
120                            message='Error', details=err, duration=self.__getDuration(test))
121
122  def find_error_value(self, err):
123    error_value = traceback.extract_tb(err)
124    error_value = error_value[-1][-1]
125    return error_value.split('assert')[-1].strip()
126
127  def addFailure(self, test, err):
128    location = self.init_suite(test)
129    self.current_failed = True
130    TestResult.addFailure(self, test, err)
131
132    error_value = smart_str(err[1])
133    if not len(error_value):
134      # means it's test function and we have to extract value from traceback
135      error_value = self.find_error_value(err[2])
136
137    self_find_first = self.find_first(error_value)
138    self_find_second = self.find_second(error_value)
139    quotes = ["'", '"']
140    if (self_find_first[0] == self_find_first[-1] and self_find_first[0] in quotes and
141            self_find_second[0] == self_find_second[-1] and self_find_second[0] in quotes):
142      # let's unescape strings to show sexy multiline diff in PyCharm.
143      # By default all caret return chars are escaped by testing framework
144      first = self._unescape(self_find_first)
145      second = self._unescape(self_find_second)
146    else:
147      first = second = ""
148    err = self._exc_info_to_string(err, test)
149
150    self.messages.testStarted(self.getTestName(test), location=location)
151    duration = self.__getDuration(test)
152    self.messages.testFailed(self.getTestName(test),
153                             message='Failure', details=err, expected=first, actual=second, duration=duration)
154
155  def addSkip(self, test, reason):
156    self.init_suite(test)
157    self.current_failed = True
158    self.messages.testIgnored(self.getTestName(test), message=reason)
159
160  def _getSuite(self, test):
161    try:
162      suite = strclass(test.suite)
163      suite_location = test.suite.location
164      location = test.suite.abs_location
165      if hasattr(test, "lineno"):
166        location = location + ":" + str(test.lineno)
167      else:
168        location = location + ":" + str(test.test.lineno)
169    except AttributeError:
170      import inspect
171
172      try:
173        source_file = inspect.getsourcefile(test.__class__)
174        if source_file:
175          source_dir_splitted = source_file.split("/")[:-1]
176          source_dir = "/".join(source_dir_splitted) + "/"
177        else:
178          source_dir = ""
179      except TypeError:
180        source_dir = ""
181
182      suite = strclass(test.__class__)
183      suite_location = "python_uttestid://" + source_dir + suite
184      location = "python_uttestid://" + source_dir + str(test.id())
185
186    return (suite, location, suite_location)
187
188  def startTest(self, test):
189    self.current_failed = False
190    setattr(test, "startTime", datetime.datetime.now())
191
192  def init_suite(self, test):
193    suite, location, suite_location = self._getSuite(test)
194    if suite != self.current_suite:
195      if self.current_suite:
196        self.messages.testSuiteFinished(self.current_suite)
197      self.current_suite = suite
198      self.messages.testSuiteStarted(self.current_suite, location=suite_location)
199    return location
200
201  def stopTest(self, test):
202    duration = self.__getDuration(test)
203    if not self.subtest_suite:
204      if not self.current_failed:
205        location = self.init_suite(test)
206        self.messages.testStarted(self.getTestName(test), location=location)
207        self.messages.testFinished(self.getTestName(test), duration=int(duration))
208    else:
209      self.messages.testSuiteFinished(self.subtest_suite)
210      self.subtest_suite = None
211
212  def __getDuration(self, test):
213    start = getattr(test, "startTime", datetime.datetime.now())
214    assert isinstance(start, datetime.datetime), \
215      "You testcase has property named 'startTime' (value {0}). Please, rename it".format(start)
216    d = datetime.datetime.now() - start
217    duration = d.microseconds / 1000 + d.seconds * 1000 + d.days * 86400000
218    return duration
219
220  def addSubTest(self, test, subtest, err):
221    location = self.init_suite(test)
222    suite_name = self.getTestName(test)  # + " (subTests)"
223    if not self.subtest_suite:
224      self.subtest_suite = suite_name
225      self.messages.testSuiteStarted(self.subtest_suite, location=location)
226    else:
227      if suite_name != self.subtest_suite:
228        self.messages.testSuiteFinished(self.subtest_suite)
229        self.subtest_suite = suite_name
230        self.messages.testSuiteStarted(self.subtest_suite, location=location)
231
232    name = self.getTestName(subtest, True)
233    if err is not None:
234      error = self._exc_info_to_string(err, test)
235      self.messages.testStarted(name)
236      self.messages.testFailed(name, message='Failure', details=error, duration=None)
237    else:
238      self.messages.testStarted(name)
239      self.messages.testFinished(name)
240
241
242  def endLastSuite(self):
243    if self.current_suite:
244      self.messages.testSuiteFinished(self.current_suite)
245      self.current_suite = None
246
247  def _unescape(self, text):
248    # do not use text.decode('string_escape'), it leads to problems with different string encodings given
249    return text.replace("\\n", "\n")
250
251
252class TeamcityTestRunner(object):
253  def __init__(self, stream=sys.stdout):
254    self.stream = stream
255
256  def _makeResult(self, **kwargs):
257    return TeamcityTestResult(self.stream, **kwargs)
258
259  def run(self, test, **kwargs):
260    result = self._makeResult(**kwargs)
261    result.messages.testCount(test.countTestCases())
262    test(result)
263    result.endLastSuite()
264    return result
265