1# plugin/plugin_base.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Testing extensions.
9
10this module is designed to work as a testing-framework-agnostic library,
11created so that multiple test frameworks can be supported at once
12(mostly so that we can migrate to new ones). The current target
13is pytest.
14
15"""
16
17from __future__ import absolute_import
18
19import abc
20import logging
21import re
22import sys
23
24# flag which indicates we are in the SQLAlchemy testing suite,
25# and not that of Alembic or a third party dialect.
26bootstrapped_as_sqlalchemy = False
27
28log = logging.getLogger("sqlalchemy.testing.plugin_base")
29
30
31py3k = sys.version_info >= (3, 0)
32
33if py3k:
34    import configparser
35
36    ABC = abc.ABC
37else:
38    import ConfigParser as configparser
39    import collections as collections_abc  # noqa
40
41    class ABC(object):
42        __metaclass__ = abc.ABCMeta
43
44
45# late imports
46fixtures = None
47engines = None
48exclusions = None
49warnings = None
50profiling = None
51provision = None
52assertions = None
53requirements = None
54config = None
55testing = None
56util = None
57file_config = None
58
59logging = None
60include_tags = set()
61exclude_tags = set()
62options = None
63
64
65def setup_options(make_option):
66    make_option(
67        "--log-info",
68        action="callback",
69        type=str,
70        callback=_log,
71        help="turn on info logging for <LOG> (multiple OK)",
72    )
73    make_option(
74        "--log-debug",
75        action="callback",
76        type=str,
77        callback=_log,
78        help="turn on debug logging for <LOG> (multiple OK)",
79    )
80    make_option(
81        "--db",
82        action="append",
83        type=str,
84        dest="db",
85        help="Use prefab database uri. Multiple OK, "
86        "first one is run by default.",
87    )
88    make_option(
89        "--dbs",
90        action="callback",
91        zeroarg_callback=_list_dbs,
92        help="List available prefab dbs",
93    )
94    make_option(
95        "--dburi",
96        action="append",
97        type=str,
98        dest="dburi",
99        help="Database uri.  Multiple OK, " "first one is run by default.",
100    )
101    make_option(
102        "--dbdriver",
103        action="append",
104        type="string",
105        dest="dbdriver",
106        help="Additional database drivers to include in tests.  "
107        "These are linked to the existing database URLs by the "
108        "provisioning system.",
109    )
110    make_option(
111        "--dropfirst",
112        action="store_true",
113        dest="dropfirst",
114        help="Drop all tables in the target database first",
115    )
116    make_option(
117        "--disable-asyncio",
118        action="store_true",
119        help="disable test / fixtures / provisoning running in asyncio",
120    )
121    make_option(
122        "--backend-only",
123        action="store_true",
124        dest="backend_only",
125        help="Run only tests marked with __backend__ or __sparse_backend__",
126    )
127    make_option(
128        "--nomemory",
129        action="store_true",
130        dest="nomemory",
131        help="Don't run memory profiling tests",
132    )
133    make_option(
134        "--notimingintensive",
135        action="store_true",
136        dest="notimingintensive",
137        help="Don't run timing intensive tests",
138    )
139    make_option(
140        "--profile-sort",
141        type=str,
142        default="cumulative",
143        dest="profilesort",
144        help="Type of sort for profiling standard output",
145    )
146    make_option(
147        "--profile-dump",
148        type=str,
149        dest="profiledump",
150        help="Filename where a single profile run will be dumped",
151    )
152    make_option(
153        "--postgresql-templatedb",
154        type=str,
155        help="name of template database to use for PostgreSQL "
156        "CREATE DATABASE (defaults to current database)",
157    )
158    make_option(
159        "--low-connections",
160        action="store_true",
161        dest="low_connections",
162        help="Use a low number of distinct connections - "
163        "i.e. for Oracle TNS",
164    )
165    make_option(
166        "--write-idents",
167        type=str,
168        dest="write_idents",
169        help="write out generated follower idents to <file>, "
170        "when -n<num> is used",
171    )
172    make_option(
173        "--reversetop",
174        action="store_true",
175        dest="reversetop",
176        default=False,
177        help="Use a random-ordering set implementation in the ORM "
178        "(helps reveal dependency issues)",
179    )
180    make_option(
181        "--requirements",
182        action="callback",
183        type=str,
184        callback=_requirements_opt,
185        help="requirements class for testing, overrides setup.cfg",
186    )
187    make_option(
188        "--with-cdecimal",
189        action="store_true",
190        dest="cdecimal",
191        default=False,
192        help="Monkeypatch the cdecimal library into Python 'decimal' "
193        "for all tests",
194    )
195    make_option(
196        "--include-tag",
197        action="callback",
198        callback=_include_tag,
199        type=str,
200        help="Include tests with tag <tag>",
201    )
202    make_option(
203        "--exclude-tag",
204        action="callback",
205        callback=_exclude_tag,
206        type=str,
207        help="Exclude tests with tag <tag>",
208    )
209    make_option(
210        "--write-profiles",
211        action="store_true",
212        dest="write_profiles",
213        default=False,
214        help="Write/update failing profiling data.",
215    )
216    make_option(
217        "--force-write-profiles",
218        action="store_true",
219        dest="force_write_profiles",
220        default=False,
221        help="Unconditionally write/update profiling data.",
222    )
223    make_option(
224        "--dump-pyannotate",
225        type=str,
226        dest="dump_pyannotate",
227        help="Run pyannotate and dump json info to given file",
228    )
229    make_option(
230        "--mypy-extra-test-path",
231        type=str,
232        action="append",
233        default=[],
234        dest="mypy_extra_test_paths",
235        help="Additional test directories to add to the mypy tests. "
236        "This is used only when running mypy tests. Multiple OK",
237    )
238
239
240def configure_follower(follower_ident):
241    """Configure required state for a follower.
242
243    This invokes in the parent process and typically includes
244    database creation.
245
246    """
247    from sqlalchemy.testing import provision
248
249    provision.FOLLOWER_IDENT = follower_ident
250
251
252def memoize_important_follower_config(dict_):
253    """Store important configuration we will need to send to a follower.
254
255    This invokes in the parent process after normal config is set up.
256
257    This is necessary as pytest seems to not be using forking, so we
258    start with nothing in memory, *but* it isn't running our argparse
259    callables, so we have to just copy all of that over.
260
261    """
262    dict_["memoized_config"] = {
263        "include_tags": include_tags,
264        "exclude_tags": exclude_tags,
265    }
266
267
268def restore_important_follower_config(dict_):
269    """Restore important configuration needed by a follower.
270
271    This invokes in the follower process.
272
273    """
274    global include_tags, exclude_tags
275    include_tags.update(dict_["memoized_config"]["include_tags"])
276    exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
277
278
279def read_config():
280    global file_config
281    file_config = configparser.ConfigParser()
282    file_config.read(["setup.cfg", "test.cfg"])
283
284
285def pre_begin(opt):
286    """things to set up early, before coverage might be setup."""
287    global options
288    options = opt
289    for fn in pre_configure:
290        fn(options, file_config)
291
292
293def set_coverage_flag(value):
294    options.has_coverage = value
295
296
297def post_begin():
298    """things to set up later, once we know coverage is running."""
299    # Lazy setup of other options (post coverage)
300    for fn in post_configure:
301        fn(options, file_config)
302
303    # late imports, has to happen after config.
304    global util, fixtures, engines, exclusions, assertions, provision
305    global warnings, profiling, config, testing
306    from sqlalchemy import testing  # noqa
307    from sqlalchemy.testing import fixtures, engines, exclusions  # noqa
308    from sqlalchemy.testing import assertions, warnings, profiling  # noqa
309    from sqlalchemy.testing import config, provision  # noqa
310    from sqlalchemy import util  # noqa
311
312    warnings.setup_filters()
313
314
315def _log(opt_str, value, parser):
316    global logging
317    if not logging:
318        import logging
319
320        logging.basicConfig()
321
322    if opt_str.endswith("-info"):
323        logging.getLogger(value).setLevel(logging.INFO)
324    elif opt_str.endswith("-debug"):
325        logging.getLogger(value).setLevel(logging.DEBUG)
326
327
328def _list_dbs(*args):
329    print("Available --db options (use --dburi to override)")
330    for macro in sorted(file_config.options("db")):
331        print("%20s\t%s" % (macro, file_config.get("db", macro)))
332    sys.exit(0)
333
334
335def _requirements_opt(opt_str, value, parser):
336    _setup_requirements(value)
337
338
339def _exclude_tag(opt_str, value, parser):
340    exclude_tags.add(value.replace("-", "_"))
341
342
343def _include_tag(opt_str, value, parser):
344    include_tags.add(value.replace("-", "_"))
345
346
347pre_configure = []
348post_configure = []
349
350
351def pre(fn):
352    pre_configure.append(fn)
353    return fn
354
355
356def post(fn):
357    post_configure.append(fn)
358    return fn
359
360
361@pre
362def _setup_options(opt, file_config):
363    global options
364    options = opt
365
366
367@pre
368def _set_nomemory(opt, file_config):
369    if opt.nomemory:
370        exclude_tags.add("memory_intensive")
371
372
373@pre
374def _set_notimingintensive(opt, file_config):
375    if opt.notimingintensive:
376        exclude_tags.add("timing_intensive")
377
378
379@pre
380def _monkeypatch_cdecimal(options, file_config):
381    if options.cdecimal:
382        import cdecimal
383
384        sys.modules["decimal"] = cdecimal
385
386
387@post
388def _init_symbols(options, file_config):
389    from sqlalchemy.testing import config
390
391    config._fixture_functions = _fixture_fn_class()
392
393
394@post
395def _set_disable_asyncio(opt, file_config):
396    if opt.disable_asyncio or not py3k:
397        from sqlalchemy.testing import asyncio
398
399        asyncio.ENABLE_ASYNCIO = False
400
401
402@post
403def _engine_uri(options, file_config):
404
405    from sqlalchemy import testing
406    from sqlalchemy.testing import config
407    from sqlalchemy.testing import provision
408
409    if options.dburi:
410        db_urls = list(options.dburi)
411    else:
412        db_urls = []
413
414    extra_drivers = options.dbdriver or []
415
416    if options.db:
417        for db_token in options.db:
418            for db in re.split(r"[,\s]+", db_token):
419                if db not in file_config.options("db"):
420                    raise RuntimeError(
421                        "Unknown URI specifier '%s'.  "
422                        "Specify --dbs for known uris." % db
423                    )
424                else:
425                    db_urls.append(file_config.get("db", db))
426
427    if not db_urls:
428        db_urls.append(file_config.get("db", "default"))
429
430    config._current = None
431
432    expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
433
434    for db_url in expanded_urls:
435        log.info("Adding database URL: %s", db_url)
436
437        if options.write_idents and provision.FOLLOWER_IDENT:
438            with open(options.write_idents, "a") as file_:
439                file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
440
441        cfg = provision.setup_config(
442            db_url, options, file_config, provision.FOLLOWER_IDENT
443        )
444        if not config._current:
445            cfg.set_as_current(cfg, testing)
446
447
448@post
449def _requirements(options, file_config):
450
451    requirement_cls = file_config.get("sqla_testing", "requirement_cls")
452    _setup_requirements(requirement_cls)
453
454
455def _setup_requirements(argument):
456    from sqlalchemy.testing import config
457    from sqlalchemy import testing
458
459    if config.requirements is not None:
460        return
461
462    modname, clsname = argument.split(":")
463
464    # importlib.import_module() only introduced in 2.7, a little
465    # late
466    mod = __import__(modname)
467    for component in modname.split(".")[1:]:
468        mod = getattr(mod, component)
469    req_cls = getattr(mod, clsname)
470
471    config.requirements = testing.requires = req_cls()
472
473    config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
474
475
476@post
477def _prep_testing_database(options, file_config):
478    from sqlalchemy.testing import config
479
480    if options.dropfirst:
481        from sqlalchemy.testing import provision
482
483        for cfg in config.Config.all_configs():
484            provision.drop_all_schema_objects(cfg, cfg.db)
485
486
487@post
488def _reverse_topological(options, file_config):
489    if options.reversetop:
490        from sqlalchemy.orm.util import randomize_unitofwork
491
492        randomize_unitofwork()
493
494
495@post
496def _post_setup_options(opt, file_config):
497    from sqlalchemy.testing import config
498
499    config.options = options
500    config.file_config = file_config
501
502
503@post
504def _setup_profiling(options, file_config):
505    from sqlalchemy.testing import profiling
506
507    profiling._profile_stats = profiling.ProfileStatsFile(
508        file_config.get("sqla_testing", "profile_file"),
509        sort=options.profilesort,
510        dump=options.profiledump,
511    )
512
513
514def want_class(name, cls):
515    if not issubclass(cls, fixtures.TestBase):
516        return False
517    elif name.startswith("_"):
518        return False
519    elif (
520        config.options.backend_only
521        and not getattr(cls, "__backend__", False)
522        and not getattr(cls, "__sparse_backend__", False)
523        and not getattr(cls, "__only_on__", False)
524    ):
525        return False
526    else:
527        return True
528
529
530def want_method(cls, fn):
531    if not fn.__name__.startswith("test_"):
532        return False
533    elif fn.__module__ is None:
534        return False
535    elif include_tags:
536        return (
537            hasattr(cls, "__tags__")
538            and exclusions.tags(cls.__tags__).include_test(
539                include_tags, exclude_tags
540            )
541        ) or (
542            hasattr(fn, "_sa_exclusion_extend")
543            and fn._sa_exclusion_extend.include_test(
544                include_tags, exclude_tags
545            )
546        )
547    elif exclude_tags and hasattr(cls, "__tags__"):
548        return exclusions.tags(cls.__tags__).include_test(
549            include_tags, exclude_tags
550        )
551    elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
552        return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
553    else:
554        return True
555
556
557def generate_sub_tests(cls, module):
558    if getattr(cls, "__backend__", False) or getattr(
559        cls, "__sparse_backend__", False
560    ):
561        sparse = getattr(cls, "__sparse_backend__", False)
562        for cfg in _possible_configs_for_cls(cls, sparse=sparse):
563            orig_name = cls.__name__
564
565            # we can have special chars in these names except for the
566            # pytest junit plugin, which is tripped up by the brackets
567            # and periods, so sanitize
568
569            alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
570            alpha_name = re.sub(r"_+$", "", alpha_name)
571            name = "%s_%s" % (cls.__name__, alpha_name)
572            subcls = type(
573                name,
574                (cls,),
575                {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
576            )
577            setattr(module, name, subcls)
578            yield subcls
579    else:
580        yield cls
581
582
583def start_test_class_outside_fixtures(cls):
584    _do_skips(cls)
585    _setup_engine(cls)
586
587
588def stop_test_class(cls):
589    # close sessions, immediate connections, etc.
590    fixtures.stop_test_class_inside_fixtures(cls)
591
592    # close outstanding connection pool connections, dispose of
593    # additional engines
594    engines.testing_reaper.stop_test_class_inside_fixtures()
595
596
597def stop_test_class_outside_fixtures(cls):
598    engines.testing_reaper.stop_test_class_outside_fixtures()
599    provision.stop_test_class_outside_fixtures(config, config.db, cls)
600    try:
601        if not options.low_connections:
602            assertions.global_cleanup_assertions()
603    finally:
604        _restore_engine()
605
606
607def _restore_engine():
608    if config._current:
609        config._current.reset(testing)
610
611
612def final_process_cleanup():
613    engines.testing_reaper.final_cleanup()
614    assertions.global_cleanup_assertions()
615    _restore_engine()
616
617
618def _setup_engine(cls):
619    if getattr(cls, "__engine_options__", None):
620        opts = dict(cls.__engine_options__)
621        opts["scope"] = "class"
622        eng = engines.testing_engine(options=opts)
623        config._current.push_engine(eng, testing)
624
625
626def before_test(test, test_module_name, test_class, test_name):
627
628    # format looks like:
629    # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
630
631    name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
632
633    id_ = "%s.%s.%s" % (test_module_name, name, test_name)
634
635    profiling._start_current_test(id_)
636
637
638def after_test(test):
639    fixtures.after_test()
640    engines.testing_reaper.after_test()
641
642
643def after_test_fixtures(test):
644    engines.testing_reaper.after_test_outside_fixtures(test)
645
646
647def _possible_configs_for_cls(cls, reasons=None, sparse=False):
648    all_configs = set(config.Config.all_configs())
649
650    if cls.__unsupported_on__:
651        spec = exclusions.db_spec(*cls.__unsupported_on__)
652        for config_obj in list(all_configs):
653            if spec(config_obj):
654                all_configs.remove(config_obj)
655
656    if getattr(cls, "__only_on__", None):
657        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
658        for config_obj in list(all_configs):
659            if not spec(config_obj):
660                all_configs.remove(config_obj)
661
662    if getattr(cls, "__only_on_config__", None):
663        all_configs.intersection_update([cls.__only_on_config__])
664
665    if hasattr(cls, "__requires__"):
666        requirements = config.requirements
667        for config_obj in list(all_configs):
668            for requirement in cls.__requires__:
669                check = getattr(requirements, requirement)
670
671                skip_reasons = check.matching_config_reasons(config_obj)
672                if skip_reasons:
673                    all_configs.remove(config_obj)
674                    if reasons is not None:
675                        reasons.extend(skip_reasons)
676                    break
677
678    if hasattr(cls, "__prefer_requires__"):
679        non_preferred = set()
680        requirements = config.requirements
681        for config_obj in list(all_configs):
682            for requirement in cls.__prefer_requires__:
683                check = getattr(requirements, requirement)
684
685                if not check.enabled_for_config(config_obj):
686                    non_preferred.add(config_obj)
687        if all_configs.difference(non_preferred):
688            all_configs.difference_update(non_preferred)
689
690    if sparse:
691        # pick only one config from each base dialect
692        # sorted so we get the same backend each time selecting the highest
693        # server version info.
694        per_dialect = {}
695        for cfg in reversed(
696            sorted(
697                all_configs,
698                key=lambda cfg: (
699                    cfg.db.name,
700                    cfg.db.driver,
701                    cfg.db.dialect.server_version_info,
702                ),
703            )
704        ):
705            db = cfg.db.name
706            if db not in per_dialect:
707                per_dialect[db] = cfg
708        return per_dialect.values()
709
710    return all_configs
711
712
713def _do_skips(cls):
714    reasons = []
715    all_configs = _possible_configs_for_cls(cls, reasons)
716
717    if getattr(cls, "__skip_if__", False):
718        for c in getattr(cls, "__skip_if__"):
719            if c():
720                config.skip_test(
721                    "'%s' skipped by %s" % (cls.__name__, c.__name__)
722                )
723
724    if not all_configs:
725        msg = "'%s' unsupported on any DB implementation %s%s" % (
726            cls.__name__,
727            ", ".join(
728                "'%s(%s)+%s'"
729                % (
730                    config_obj.db.name,
731                    ".".join(
732                        str(dig)
733                        for dig in exclusions._server_version(config_obj.db)
734                    ),
735                    config_obj.db.driver,
736                )
737                for config_obj in config.Config.all_configs()
738            ),
739            ", ".join(reasons),
740        )
741        config.skip_test(msg)
742    elif hasattr(cls, "__prefer_backends__"):
743        non_preferred = set()
744        spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
745        for config_obj in all_configs:
746            if not spec(config_obj):
747                non_preferred.add(config_obj)
748        if all_configs.difference(non_preferred):
749            all_configs.difference_update(non_preferred)
750
751    if config._current not in all_configs:
752        _setup_config(all_configs.pop(), cls)
753
754
755def _setup_config(config_obj, ctx):
756    config._current.push(config_obj, testing)
757
758
759class FixtureFunctions(ABC):
760    @abc.abstractmethod
761    def skip_test_exception(self, *arg, **kw):
762        raise NotImplementedError()
763
764    @abc.abstractmethod
765    def combinations(self, *args, **kw):
766        raise NotImplementedError()
767
768    @abc.abstractmethod
769    def param_ident(self, *args, **kw):
770        raise NotImplementedError()
771
772    @abc.abstractmethod
773    def fixture(self, *arg, **kw):
774        raise NotImplementedError()
775
776    def get_current_test_name(self):
777        raise NotImplementedError()
778
779    @abc.abstractmethod
780    def mark_base_test_class(self):
781        raise NotImplementedError()
782
783
784_fixture_fn_class = None
785
786
787def set_fixture_functions(fixture_fn_class):
788    global _fixture_fn_class
789    _fixture_fn_class = fixture_fn_class
790