1from __future__ import print_function
2
3import doctest
4import logging
5import os
6import re
7import sys
8
9from sqlalchemy import testing
10from sqlalchemy.testing import config
11from sqlalchemy.testing import fixtures
12
13
14class DocTest(fixtures.TestBase):
15    __requires__ = ("python3",)
16
17    def _setup_logger(self):
18        rootlogger = logging.getLogger("sqlalchemy.engine.Engine")
19
20        class MyStream(object):
21            def write(self, string):
22                sys.stdout.write(string)
23                sys.stdout.flush()
24
25            def flush(self):
26                pass
27
28        self._handler = handler = logging.StreamHandler(MyStream())
29        handler.setFormatter(logging.Formatter("%(message)s"))
30        rootlogger.addHandler(handler)
31
32    def _teardown_logger(self):
33        rootlogger = logging.getLogger("sqlalchemy.engine.Engine")
34        rootlogger.removeHandler(self._handler)
35
36    def _setup_create_table_patcher(self):
37        from sqlalchemy.sql import ddl
38
39        self.orig_sort = ddl.sort_tables_and_constraints
40
41        def our_sort(tables, **kw):
42            return self.orig_sort(sorted(tables, key=lambda t: t.key), **kw)
43
44        ddl.sort_tables_and_constraints = our_sort
45
46    def _teardown_create_table_patcher(self):
47        from sqlalchemy.sql import ddl
48
49        ddl.sort_tables_and_constraints = self.orig_sort
50
51    def setup_test(self):
52        self._setup_logger()
53        self._setup_create_table_patcher()
54
55    def teardown_test(self):
56        self._teardown_create_table_patcher()
57        self._teardown_logger()
58
59    def _run_doctest(self, *fnames):
60        here = os.path.dirname(__file__)
61        sqla_base = os.path.normpath(os.path.join(here, "..", ".."))
62
63        optionflags = (
64            doctest.ELLIPSIS
65            | doctest.NORMALIZE_WHITESPACE
66            | doctest.IGNORE_EXCEPTION_DETAIL
67            | _get_allow_unicode_flag()
68        )
69        runner = doctest.DocTestRunner(
70            verbose=None,
71            optionflags=optionflags,
72            checker=_get_unicode_checker(),
73        )
74        parser = doctest.DocTestParser()
75        globs = {"print_function": print_function}
76
77        for fname in fnames:
78            path = os.path.join(sqla_base, "doc/build", fname)
79            if not os.path.exists(path):
80                config.skip_test("Can't find documentation file %r" % path)
81            with open(path, encoding="utf-8") as file_:
82                content = file_.read()
83                content = re.sub(r"{(?:stop|sql|opensql)}", "", content)
84
85                test = parser.get_doctest(content, globs, fname, fname, 0)
86                runner.run(test, clear_globs=False)
87                runner.summarize()
88                globs.update(test.globs)
89                assert not runner.failures
90
91    def test_20_style(self):
92        self._run_doctest(
93            "tutorial/index.rst",
94            "tutorial/engine.rst",
95            "tutorial/dbapi_transactions.rst",
96            "tutorial/metadata.rst",
97            "tutorial/data.rst",
98            "tutorial/data_insert.rst",
99            "tutorial/data_select.rst",
100            "tutorial/data_update.rst",
101            "tutorial/orm_data_manipulation.rst",
102            "tutorial/orm_related_objects.rst",
103        )
104
105    def test_orm(self):
106        self._run_doctest("orm/tutorial.rst")
107
108    @testing.emits_warning()
109    def test_core(self):
110        self._run_doctest("core/tutorial.rst")
111
112    def test_core_operators(self):
113        self._run_doctest("core/operators.rst")
114
115    def test_orm_queryguide(self):
116        self._run_doctest("orm/queryguide.rst")
117
118
119# unicode checker courtesy pytest
120
121
122def _get_unicode_checker():
123    """
124    Returns a doctest.OutputChecker subclass that takes in account the
125    ALLOW_UNICODE option to ignore u'' prefixes in strings. Useful
126    when the same doctest should run in Python 2 and Python 3.
127
128    An inner class is used to avoid importing "doctest" at the module
129    level.
130    """
131    if hasattr(_get_unicode_checker, "UnicodeOutputChecker"):
132        return _get_unicode_checker.UnicodeOutputChecker()
133
134    import doctest
135    import re
136
137    class UnicodeOutputChecker(doctest.OutputChecker):
138        """
139        Copied from doctest_nose_plugin.py from the nltk project:
140            https://github.com/nltk/nltk
141        """
142
143        _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
144
145        def check_output(self, want, got, optionflags):
146            res = doctest.OutputChecker.check_output(
147                self, want, got, optionflags
148            )
149            if res:
150                return True
151
152            if not (optionflags & _get_allow_unicode_flag()):
153                return False
154
155            else:  # pragma: no cover
156                # the code below will end up executed only in Python 2 in
157                # our tests, and our coverage check runs in Python 3 only
158                def remove_u_prefixes(txt):
159                    return re.sub(self._literal_re, r"\1\2", txt)
160
161                want = remove_u_prefixes(want)
162                got = remove_u_prefixes(got)
163                res = doctest.OutputChecker.check_output(
164                    self, want, got, optionflags
165                )
166                return res
167
168    _get_unicode_checker.UnicodeOutputChecker = UnicodeOutputChecker
169    return _get_unicode_checker.UnicodeOutputChecker()
170
171
172def _get_allow_unicode_flag():
173    """
174    Registers and returns the ALLOW_UNICODE flag.
175    """
176    import doctest
177
178    return doctest.register_optionflag("ALLOW_UNICODE")
179
180
181# increase number to force pipeline run. 1
182