1# -*- coding: utf-8 -*-
2from __future__ import print_function
3import re
4import sys
5import os
6import codecs
7import doctest
8from nose.util import tolist, anyp
9from nose.plugins.base import Plugin
10from nose.suite import ContextList
11from nose.plugins.doctests import Doctest, log, DocFileCase
12
13ALLOW_UNICODE = doctest.register_optionflag('ALLOW_UNICODE')
14
15
16class _UnicodeOutputChecker(doctest.OutputChecker):
17    _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
18
19    def _remove_u_prefixes(self, txt):
20        return re.sub(self._literal_re, r'\1\2', txt)
21
22    def check_output(self, want, got, optionflags):
23        res = doctest.OutputChecker.check_output(self, want, got, optionflags)
24        if res:
25            return True
26        if not (optionflags & ALLOW_UNICODE):
27            return False
28
29        # ALLOW_UNICODE is active and want != got
30        cleaned_want = self._remove_u_prefixes(want)
31        cleaned_got = self._remove_u_prefixes(got)
32        res = doctest.OutputChecker.check_output(
33            self, cleaned_want, cleaned_got, optionflags
34        )
35        return res
36
37
38_checker = _UnicodeOutputChecker()
39
40
41class DoctestPluginHelper(object):
42    """
43    This mixin adds print_function future import to all test cases.
44
45    It also adds support for:
46        '#doctest +ALLOW_UNICODE' option that
47        makes DocTestCase think u'foo' == 'foo'.
48
49        '#doctest doctestencoding=utf-8' option that
50        changes the encoding of doctest files
51    """
52
53    OPTION_BY_NAME = ('doctestencoding',)
54
55    def loadTestsFromFileUnicode(self, filename):
56        if self.extension and anyp(filename.endswith, self.extension):
57            name = os.path.basename(filename)
58            dh = codecs.open(filename, 'r', self.options.get('doctestencoding'))
59            try:
60                doc = dh.read()
61            finally:
62                dh.close()
63
64            fixture_context = None
65            globs = {'__file__': filename}
66            if self.fixtures:
67                base, ext = os.path.splitext(name)
68                dirname = os.path.dirname(filename)
69                sys.path.append(dirname)
70                fixt_mod = base + self.fixtures
71                try:
72                    fixture_context = __import__(fixt_mod, globals(), locals(), ["nop"])
73                except ImportError as e:
74                    log.debug("Could not import %s: %s (%s)", fixt_mod, e, sys.path)
75                log.debug("Fixture module %s resolved to %s", fixt_mod, fixture_context)
76                if hasattr(fixture_context, 'globs'):
77                    globs = fixture_context.globs(globs)
78            parser = doctest.DocTestParser()
79            test = parser.get_doctest(
80                doc, globs=globs, name=name, filename=filename, lineno=0
81            )
82            if test.examples:
83                case = DocFileCase(
84                    test,
85                    optionflags=self.optionflags,
86                    setUp=getattr(fixture_context, 'setup_test', None),
87                    tearDown=getattr(fixture_context, 'teardown_test', None),
88                    result_var=self.doctest_result_var,
89                )
90                if fixture_context:
91                    yield ContextList((case,), context=fixture_context)
92                else:
93                    yield case
94            else:
95                yield False  # no tests to load
96
97    def loadTestsFromFile(self, filename):
98
99        cases = self.loadTestsFromFileUnicode(filename)
100
101        for case in cases:
102            if isinstance(case, ContextList):
103                yield ContextList([self._patchTestCase(c) for c in case], case.context)
104            else:
105                yield self._patchTestCase(case)
106
107    def loadTestsFromModule(self, module):
108        """Load doctests from the module.
109        """
110        for suite in super(DoctestPluginHelper, self).loadTestsFromModule(module):
111            cases = [self._patchTestCase(case) for case in suite._get_tests()]
112            yield self.suiteClass(cases, context=module, can_split=False)
113
114    def _patchTestCase(self, case):
115        if case:
116            case._dt_test.globs['print_function'] = print_function
117            case._dt_checker = _checker
118        return case
119
120    def configure(self, options, config):
121        # it is overriden in order to fix doctest options discovery
122
123        Plugin.configure(self, options, config)
124        self.doctest_result_var = options.doctest_result_var
125        self.doctest_tests = options.doctest_tests
126        self.extension = tolist(options.doctestExtension)
127        self.fixtures = options.doctestFixtures
128        self.finder = doctest.DocTestFinder()
129
130        # super(DoctestPluginHelper, self).configure(options, config)
131        self.optionflags = 0
132        self.options = {}
133
134        if options.doctestOptions:
135            stroptions = ",".join(options.doctestOptions).split(',')
136            for stroption in stroptions:
137                try:
138                    if stroption.startswith('+'):
139                        self.optionflags |= doctest.OPTIONFLAGS_BY_NAME[stroption[1:]]
140                        continue
141                    elif stroption.startswith('-'):
142                        self.optionflags &= ~doctest.OPTIONFLAGS_BY_NAME[stroption[1:]]
143                        continue
144                    try:
145                        key, value = stroption.split('=')
146                    except ValueError:
147                        pass
148                    else:
149                        if not key in self.OPTION_BY_NAME:
150                            raise ValueError()
151                        self.options[key] = value
152                        continue
153                except (AttributeError, ValueError, KeyError):
154                    raise ValueError("Unknown doctest option {}".format(stroption))
155                else:
156                    raise ValueError(
157                        "Doctest option is not a flag or a key/value pair: {} ".format(
158                            stroption
159                        )
160                    )
161
162
163class DoctestFix(DoctestPluginHelper, Doctest):
164    pass
165