1from __future__ import print_function
2from sqlalchemy.testing import fixtures
3from sqlalchemy.testing import config
4import doctest
5import logging
6import sys
7import re
8import os
9
10
11class DocTest(fixtures.TestBase):
12    def _setup_logger(self):
13        rootlogger = logging.getLogger('sqlalchemy.engine.base.Engine')
14
15        class MyStream(object):
16            def write(self, string):
17                sys.stdout.write(string)
18                sys.stdout.flush()
19
20            def flush(self):
21                pass
22
23        self._handler = handler = logging.StreamHandler(MyStream())
24        handler.setFormatter(logging.Formatter('%(message)s'))
25        rootlogger.addHandler(handler)
26
27    def _teardown_logger(self):
28        rootlogger = logging.getLogger('sqlalchemy.engine.base.Engine')
29        rootlogger.removeHandler(self._handler)
30
31    def _setup_create_table_patcher(self):
32        from sqlalchemy.sql import ddl
33        self.orig_sort = ddl.sort_tables_and_constraints
34
35        def our_sort(tables, **kw):
36            return self.orig_sort(
37                sorted(tables, key=lambda t: t.key), **kw
38            )
39        ddl.sort_tables_and_constraints = our_sort
40
41    def _teardown_create_table_patcher(self):
42        from sqlalchemy.sql import ddl
43        ddl.sort_tables_and_constraints = self.orig_sort
44
45    def setup(self):
46        self._setup_logger()
47        self._setup_create_table_patcher()
48
49    def teardown(self):
50        self._teardown_create_table_patcher()
51        self._teardown_logger()
52
53    def _run_doctest_for_content(self, name, content):
54        optionflags = (
55            doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE |
56            doctest.IGNORE_EXCEPTION_DETAIL |
57            _get_allow_unicode_flag()
58        )
59        runner = doctest.DocTestRunner(
60            verbose=None, optionflags=optionflags,
61            checker=_get_unicode_checker())
62        globs = {
63            'print_function': print_function}
64        parser = doctest.DocTestParser()
65        test = parser.get_doctest(content, globs, name, name, 0)
66        runner.run(test)
67        runner.summarize()
68        assert not runner.failures
69
70    def _run_doctest(self, fname):
71        here = os.path.dirname(__file__)
72        sqla_base = os.path.normpath(os.path.join(here, "..", ".."))
73        path = os.path.join(sqla_base, "doc/build", fname)
74        if not os.path.exists(path):
75            config.skip_test("Can't find documentation file %r" % path)
76        with open(path) as file_:
77            content = file_.read()
78            content = re.sub(r'{(?:stop|sql|opensql)}', '', content)
79            self._run_doctest_for_content(fname, content)
80
81    def test_orm(self):
82        self._run_doctest("orm/tutorial.rst")
83
84    def test_core(self):
85        self._run_doctest("core/tutorial.rst")
86
87
88# unicode checker courtesy py.test
89
90
91def _get_unicode_checker():
92    """
93    Returns a doctest.OutputChecker subclass that takes in account the
94    ALLOW_UNICODE option to ignore u'' prefixes in strings. Useful
95    when the same doctest should run in Python 2 and Python 3.
96
97    An inner class is used to avoid importing "doctest" at the module
98    level.
99    """
100    if hasattr(_get_unicode_checker, 'UnicodeOutputChecker'):
101        return _get_unicode_checker.UnicodeOutputChecker()
102
103    import doctest
104    import re
105
106    class UnicodeOutputChecker(doctest.OutputChecker):
107        """
108        Copied from doctest_nose_plugin.py from the nltk project:
109            https://github.com/nltk/nltk
110        """
111
112        _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
113
114        def check_output(self, want, got, optionflags):
115            res = doctest.OutputChecker.check_output(self, want, got,
116                                                     optionflags)
117            if res:
118                return True
119
120            if not (optionflags & _get_allow_unicode_flag()):
121                return False
122
123            else:  # pragma: no cover
124                # the code below will end up executed only in Python 2 in
125                # our tests, and our coverage check runs in Python 3 only
126                def remove_u_prefixes(txt):
127                    return re.sub(self._literal_re, r'\1\2', txt)
128
129                want = remove_u_prefixes(want)
130                got = remove_u_prefixes(got)
131                res = doctest.OutputChecker.check_output(self, want, got,
132                                                         optionflags)
133                return res
134
135    _get_unicode_checker.UnicodeOutputChecker = UnicodeOutputChecker
136    return _get_unicode_checker.UnicodeOutputChecker()
137
138
139def _get_allow_unicode_flag():
140    """
141    Registers and returns the ALLOW_UNICODE flag.
142    """
143    import doctest
144    return doctest.register_optionflag('ALLOW_UNICODE')
145