1# plugin/plugin_base.py
2# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Testing extensions.
9
10this module is designed to work as a testing-framework-agnostic library,
11so that we can continue to support nose and also begin adding new
12functionality via py.test.
13
14"""
15
16from __future__ import absolute_import
17
18import sys
19import re
20
21py3k = sys.version_info >= (3, 0)
22
23if py3k:
24    import configparser
25else:
26    import ConfigParser as configparser
27
28# late imports
29fixtures = None
30engines = None
31exclusions = None
32warnings = None
33profiling = None
34assertions = None
35requirements = None
36config = None
37testing = None
38util = None
39file_config = None
40
41
42logging = None
43include_tags = set()
44exclude_tags = set()
45options = None
46
47
48def setup_options(make_option):
49    make_option("--log-info", action="callback", type="string", callback=_log,
50                help="turn on info logging for <LOG> (multiple OK)")
51    make_option("--log-debug", action="callback",
52                type="string", callback=_log,
53                help="turn on debug logging for <LOG> (multiple OK)")
54    make_option("--db", action="append", type="string", dest="db",
55                help="Use prefab database uri. Multiple OK, "
56                "first one is run by default.")
57    make_option('--dbs', action='callback', callback=_list_dbs,
58                help="List available prefab dbs")
59    make_option("--dburi", action="append", type="string", dest="dburi",
60                help="Database uri.  Multiple OK, "
61                "first one is run by default.")
62    make_option("--dropfirst", action="store_true", dest="dropfirst",
63                help="Drop all tables in the target database first")
64    make_option("--backend-only", action="store_true", dest="backend_only",
65                help="Run only tests marked with __backend__")
66    make_option("--low-connections", action="store_true",
67                dest="low_connections",
68                help="Use a low number of distinct connections - "
69                "i.e. for Oracle TNS")
70    make_option("--write-idents", type="string", dest="write_idents",
71                help="write out generated follower idents to <file>, "
72                "when -n<num> is used")
73    make_option("--reversetop", action="store_true",
74                dest="reversetop", default=False,
75                help="Use a random-ordering set implementation in the ORM "
76                "(helps reveal dependency issues)")
77    make_option("--requirements", action="callback", type="string",
78                callback=_requirements_opt,
79                help="requirements class for testing, overrides setup.cfg")
80    make_option("--with-cdecimal", action="store_true",
81                dest="cdecimal", default=False,
82                help="Monkeypatch the cdecimal library into Python 'decimal' "
83                "for all tests")
84    make_option("--include-tag", action="callback", callback=_include_tag,
85                type="string",
86                help="Include tests with tag <tag>")
87    make_option("--exclude-tag", action="callback", callback=_exclude_tag,
88                type="string",
89                help="Exclude tests with tag <tag>")
90    make_option("--write-profiles", action="store_true",
91                dest="write_profiles", default=False,
92                help="Write/update failing profiling data.")
93    make_option("--force-write-profiles", action="store_true",
94                dest="force_write_profiles", default=False,
95                help="Unconditionally write/update profiling data.")
96
97
98def configure_follower(follower_ident):
99    """Configure required state for a follower.
100
101    This invokes in the parent process and typically includes
102    database creation.
103
104    """
105    from sqlalchemy.testing import provision
106    provision.FOLLOWER_IDENT = follower_ident
107
108
109def memoize_important_follower_config(dict_):
110    """Store important configuration we will need to send to a follower.
111
112    This invokes in the parent process after normal config is set up.
113
114    This is necessary as py.test seems to not be using forking, so we
115    start with nothing in memory, *but* it isn't running our argparse
116    callables, so we have to just copy all of that over.
117
118    """
119    dict_['memoized_config'] = {
120        'include_tags': include_tags,
121        'exclude_tags': exclude_tags
122    }
123
124
125def restore_important_follower_config(dict_):
126    """Restore important configuration needed by a follower.
127
128    This invokes in the follower process.
129
130    """
131    global include_tags, exclude_tags
132    include_tags.update(dict_['memoized_config']['include_tags'])
133    exclude_tags.update(dict_['memoized_config']['exclude_tags'])
134
135
136def read_config():
137    global file_config
138    file_config = configparser.ConfigParser()
139    file_config.read(['setup.cfg', 'test.cfg'])
140
141
142def pre_begin(opt):
143    """things to set up early, before coverage might be setup."""
144    global options
145    options = opt
146    for fn in pre_configure:
147        fn(options, file_config)
148
149
150def set_coverage_flag(value):
151    options.has_coverage = value
152
153_skip_test_exception = None
154
155
156def set_skip_test(exc):
157    global _skip_test_exception
158    _skip_test_exception = exc
159
160
161def post_begin():
162    """things to set up later, once we know coverage is running."""
163    # Lazy setup of other options (post coverage)
164    for fn in post_configure:
165        fn(options, file_config)
166
167    # late imports, has to happen after config as well
168    # as nose plugins like coverage
169    global util, fixtures, engines, exclusions, \
170        assertions, warnings, profiling,\
171        config, testing
172    from sqlalchemy import testing # noqa
173    from sqlalchemy.testing import fixtures, engines, exclusions  # noqa
174    from sqlalchemy.testing import assertions, warnings, profiling # noqa
175    from sqlalchemy.testing import config  # noqa
176    from sqlalchemy import util  # noqa
177    warnings.setup_filters()
178
179
180def _log(opt_str, value, parser):
181    global logging
182    if not logging:
183        import logging
184        logging.basicConfig()
185
186    if opt_str.endswith('-info'):
187        logging.getLogger(value).setLevel(logging.INFO)
188    elif opt_str.endswith('-debug'):
189        logging.getLogger(value).setLevel(logging.DEBUG)
190
191
192def _list_dbs(*args):
193    print("Available --db options (use --dburi to override)")
194    for macro in sorted(file_config.options('db')):
195        print("%20s\t%s" % (macro, file_config.get('db', macro)))
196    sys.exit(0)
197
198
199def _requirements_opt(opt_str, value, parser):
200    _setup_requirements(value)
201
202
203def _exclude_tag(opt_str, value, parser):
204    exclude_tags.add(value.replace('-', '_'))
205
206
207def _include_tag(opt_str, value, parser):
208    include_tags.add(value.replace('-', '_'))
209
210pre_configure = []
211post_configure = []
212
213
214def pre(fn):
215    pre_configure.append(fn)
216    return fn
217
218
219def post(fn):
220    post_configure.append(fn)
221    return fn
222
223
224@pre
225def _setup_options(opt, file_config):
226    global options
227    options = opt
228
229
230@pre
231def _monkeypatch_cdecimal(options, file_config):
232    if options.cdecimal:
233        import cdecimal
234        sys.modules['decimal'] = cdecimal
235
236
237@post
238def _init_skiptest(options, file_config):
239    from sqlalchemy.testing import config
240
241    config._skip_test_exception = _skip_test_exception
242
243
244@post
245def _engine_uri(options, file_config):
246    from sqlalchemy.testing import config
247    from sqlalchemy import testing
248    from sqlalchemy.testing import provision
249
250    if options.dburi:
251        db_urls = list(options.dburi)
252    else:
253        db_urls = []
254
255    if options.db:
256        for db_token in options.db:
257            for db in re.split(r'[,\s]+', db_token):
258                if db not in file_config.options('db'):
259                    raise RuntimeError(
260                        "Unknown URI specifier '%s'.  "
261                        "Specify --dbs for known uris."
262                        % db)
263                else:
264                    db_urls.append(file_config.get('db', db))
265
266    if not db_urls:
267        db_urls.append(file_config.get('db', 'default'))
268
269    for db_url in db_urls:
270        cfg = provision.setup_config(
271            db_url, options, file_config, provision.FOLLOWER_IDENT)
272
273        if not config._current:
274            cfg.set_as_current(cfg, testing)
275
276
277@post
278def _requirements(options, file_config):
279
280    requirement_cls = file_config.get('sqla_testing', "requirement_cls")
281    _setup_requirements(requirement_cls)
282
283
284def _setup_requirements(argument):
285    from sqlalchemy.testing import config
286    from sqlalchemy import testing
287
288    if config.requirements is not None:
289        return
290
291    modname, clsname = argument.split(":")
292
293    # importlib.import_module() only introduced in 2.7, a little
294    # late
295    mod = __import__(modname)
296    for component in modname.split(".")[1:]:
297        mod = getattr(mod, component)
298    req_cls = getattr(mod, clsname)
299
300    config.requirements = testing.requires = req_cls()
301
302
303@post
304def _prep_testing_database(options, file_config):
305    from sqlalchemy.testing import config, util
306    from sqlalchemy.testing.exclusions import against
307    from sqlalchemy import schema, inspect
308
309    if options.dropfirst:
310        for cfg in config.Config.all_configs():
311            e = cfg.db
312            inspector = inspect(e)
313            try:
314                view_names = inspector.get_view_names()
315            except NotImplementedError:
316                pass
317            else:
318                for vname in view_names:
319                    e.execute(schema._DropView(
320                        schema.Table(vname, schema.MetaData())
321                    ))
322
323            if config.requirements.schemas.enabled_for_config(cfg):
324                try:
325                    view_names = inspector.get_view_names(
326                        schema="test_schema")
327                except NotImplementedError:
328                    pass
329                else:
330                    for vname in view_names:
331                        e.execute(schema._DropView(
332                            schema.Table(vname, schema.MetaData(),
333                                         schema="test_schema")
334                        ))
335
336            util.drop_all_tables(e, inspector)
337
338            if config.requirements.schemas.enabled_for_config(cfg):
339                util.drop_all_tables(e, inspector, schema=cfg.test_schema)
340
341            if against(cfg, "postgresql"):
342                from sqlalchemy.dialects import postgresql
343                for enum in inspector.get_enums("*"):
344                    e.execute(postgresql.DropEnumType(
345                        postgresql.ENUM(
346                            name=enum['name'],
347                            schema=enum['schema'])))
348
349
350@post
351def _reverse_topological(options, file_config):
352    if options.reversetop:
353        from sqlalchemy.orm.util import randomize_unitofwork
354        randomize_unitofwork()
355
356
357@post
358def _post_setup_options(opt, file_config):
359    from sqlalchemy.testing import config
360    config.options = options
361    config.file_config = file_config
362
363
364@post
365def _setup_profiling(options, file_config):
366    from sqlalchemy.testing import profiling
367    profiling._profile_stats = profiling.ProfileStatsFile(
368        file_config.get('sqla_testing', 'profile_file'))
369
370
371def want_class(cls):
372    if not issubclass(cls, fixtures.TestBase):
373        return False
374    elif cls.__name__.startswith('_'):
375        return False
376    elif config.options.backend_only and not getattr(cls, '__backend__',
377                                                     False):
378        return False
379    else:
380        return True
381
382
383def want_method(cls, fn):
384    if not fn.__name__.startswith("test_"):
385        return False
386    elif fn.__module__ is None:
387        return False
388    elif include_tags:
389        return (
390            hasattr(cls, '__tags__') and
391            exclusions.tags(cls.__tags__).include_test(
392                include_tags, exclude_tags)
393        ) or (
394            hasattr(fn, '_sa_exclusion_extend') and
395            fn._sa_exclusion_extend.include_test(
396                include_tags, exclude_tags)
397        )
398    elif exclude_tags and hasattr(cls, '__tags__'):
399        return exclusions.tags(cls.__tags__).include_test(
400            include_tags, exclude_tags)
401    elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
402        return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
403    else:
404        return True
405
406
407def generate_sub_tests(cls, module):
408    if getattr(cls, '__backend__', False):
409        for cfg in _possible_configs_for_cls(cls):
410            name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver)
411            subcls = type(
412                name,
413                (cls, ),
414                {
415                    "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
416                }
417            )
418            setattr(module, name, subcls)
419            yield subcls
420    else:
421        yield cls
422
423
424def start_test_class(cls):
425    _do_skips(cls)
426    _setup_engine(cls)
427
428
429def stop_test_class(cls):
430    #from sqlalchemy import inspect
431    #assert not inspect(testing.db).get_table_names()
432    engines.testing_reaper._stop_test_ctx()
433    try:
434        if not options.low_connections:
435            assertions.global_cleanup_assertions()
436    finally:
437        _restore_engine()
438
439
440def _restore_engine():
441    config._current.reset(testing)
442
443
444def _setup_engine(cls):
445    if getattr(cls, '__engine_options__', None):
446        eng = engines.testing_engine(options=cls.__engine_options__)
447        config._current.push_engine(eng, testing)
448
449
450def before_test(test, test_module_name, test_class, test_name):
451
452    # like a nose id, e.g.:
453    # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
454    name = test_class.__name__
455
456    suffix = "_%s_%s" % (config.db.name, config.db.driver)
457    if name.endswith(suffix):
458        name = name[0:-(len(suffix))]
459
460    id_ = "%s.%s.%s" % (test_module_name, name, test_name)
461
462    profiling._current_test = id_
463
464
465def after_test(test):
466    engines.testing_reaper._after_test_ctx()
467
468
469def _possible_configs_for_cls(cls, reasons=None):
470    all_configs = set(config.Config.all_configs())
471
472    if cls.__unsupported_on__:
473        spec = exclusions.db_spec(*cls.__unsupported_on__)
474        for config_obj in list(all_configs):
475            if spec(config_obj):
476                all_configs.remove(config_obj)
477
478    if getattr(cls, '__only_on__', None):
479        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
480        for config_obj in list(all_configs):
481            if not spec(config_obj):
482                all_configs.remove(config_obj)
483
484    if hasattr(cls, '__requires__'):
485        requirements = config.requirements
486        for config_obj in list(all_configs):
487            for requirement in cls.__requires__:
488                check = getattr(requirements, requirement)
489
490                skip_reasons = check.matching_config_reasons(config_obj)
491                if skip_reasons:
492                    all_configs.remove(config_obj)
493                    if reasons is not None:
494                        reasons.extend(skip_reasons)
495                    break
496
497    if hasattr(cls, '__prefer_requires__'):
498        non_preferred = set()
499        requirements = config.requirements
500        for config_obj in list(all_configs):
501            for requirement in cls.__prefer_requires__:
502                check = getattr(requirements, requirement)
503
504                if not check.enabled_for_config(config_obj):
505                    non_preferred.add(config_obj)
506        if all_configs.difference(non_preferred):
507            all_configs.difference_update(non_preferred)
508
509    return all_configs
510
511
512def _do_skips(cls):
513    reasons = []
514    all_configs = _possible_configs_for_cls(cls, reasons)
515
516    if getattr(cls, '__skip_if__', False):
517        for c in getattr(cls, '__skip_if__'):
518            if c():
519                config.skip_test("'%s' skipped by %s" % (
520                    cls.__name__, c.__name__)
521                )
522
523    if not all_configs:
524        if getattr(cls, '__backend__', False):
525            msg = "'%s' unsupported for implementation '%s'" % (
526                cls.__name__, cls.__only_on__)
527        else:
528            msg = "'%s' unsupported on any DB implementation %s%s" % (
529                cls.__name__,
530                ", ".join(
531                    "'%s(%s)+%s'" % (
532                        config_obj.db.name,
533                        ".".join(
534                            str(dig) for dig in
535                            config_obj.db.dialect.server_version_info),
536                        config_obj.db.driver
537                    )
538                  for config_obj in config.Config.all_configs()
539                ),
540                ", ".join(reasons)
541            )
542        config.skip_test(msg)
543    elif hasattr(cls, '__prefer_backends__'):
544        non_preferred = set()
545        spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
546        for config_obj in all_configs:
547            if not spec(config_obj):
548                non_preferred.add(config_obj)
549        if all_configs.difference(non_preferred):
550            all_configs.difference_update(non_preferred)
551
552    if config._current not in all_configs:
553        _setup_config(all_configs.pop(), cls)
554
555
556def _setup_config(config_obj, ctx):
557    config._current.push(config_obj, testing)
558