1# testing/exclusions.py
2# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8
9import operator
10from ..util import decorator
11from . import config
12from .. import util
13import inspect
14import contextlib
15from sqlalchemy.util.compat import inspect_getargspec
16
17
18def skip_if(predicate, reason=None):
19    rule = compound()
20    pred = _as_predicate(predicate, reason)
21    rule.skips.add(pred)
22    return rule
23
24
25def fails_if(predicate, reason=None):
26    rule = compound()
27    pred = _as_predicate(predicate, reason)
28    rule.fails.add(pred)
29    return rule
30
31
32class compound(object):
33    def __init__(self):
34        self.fails = set()
35        self.skips = set()
36        self.tags = set()
37
38    def __add__(self, other):
39        return self.add(other)
40
41    def add(self, *others):
42        copy = compound()
43        copy.fails.update(self.fails)
44        copy.skips.update(self.skips)
45        copy.tags.update(self.tags)
46        for other in others:
47            copy.fails.update(other.fails)
48            copy.skips.update(other.skips)
49            copy.tags.update(other.tags)
50        return copy
51
52    def not_(self):
53        copy = compound()
54        copy.fails.update(NotPredicate(fail) for fail in self.fails)
55        copy.skips.update(NotPredicate(skip) for skip in self.skips)
56        copy.tags.update(self.tags)
57        return copy
58
59    @property
60    def enabled(self):
61        return self.enabled_for_config(config._current)
62
63    def enabled_for_config(self, config):
64        for predicate in self.skips.union(self.fails):
65            if predicate(config):
66                return False
67        else:
68            return True
69
70    def matching_config_reasons(self, config):
71        return [
72            predicate._as_string(config) for predicate
73            in self.skips.union(self.fails)
74            if predicate(config)
75        ]
76
77    def include_test(self, include_tags, exclude_tags):
78        return bool(
79            not self.tags.intersection(exclude_tags) and
80            (not include_tags or self.tags.intersection(include_tags))
81        )
82
83    def _extend(self, other):
84        self.skips.update(other.skips)
85        self.fails.update(other.fails)
86        self.tags.update(other.tags)
87
88    def __call__(self, fn):
89        if hasattr(fn, '_sa_exclusion_extend'):
90            fn._sa_exclusion_extend._extend(self)
91            return fn
92
93        @decorator
94        def decorate(fn, *args, **kw):
95            return self._do(config._current, fn, *args, **kw)
96        decorated = decorate(fn)
97        decorated._sa_exclusion_extend = self
98        return decorated
99
100    @contextlib.contextmanager
101    def fail_if(self):
102        all_fails = compound()
103        all_fails.fails.update(self.skips.union(self.fails))
104
105        try:
106            yield
107        except Exception as ex:
108            all_fails._expect_failure(config._current, ex)
109        else:
110            all_fails._expect_success(config._current)
111
112    def _do(self, config, fn, *args, **kw):
113        for skip in self.skips:
114            if skip(config):
115                msg = "'%s' : %s" % (
116                    fn.__name__,
117                    skip._as_string(config)
118                )
119                config.skip_test(msg)
120
121        try:
122            return_value = fn(*args, **kw)
123        except Exception as ex:
124            self._expect_failure(config, ex, name=fn.__name__)
125        else:
126            self._expect_success(config, name=fn.__name__)
127            return return_value
128
129    def _expect_failure(self, config, ex, name='block'):
130        for fail in self.fails:
131            if fail(config):
132                print(("%s failed as expected (%s): %s " % (
133                    name, fail._as_string(config), str(ex))))
134                break
135        else:
136            util.raise_from_cause(ex)
137
138    def _expect_success(self, config, name='block'):
139        if not self.fails:
140            return
141        for fail in self.fails:
142            if not fail(config):
143                break
144        else:
145            raise AssertionError(
146                "Unexpected success for '%s' (%s)" %
147                (
148                    name,
149                    " and ".join(
150                        fail._as_string(config)
151                        for fail in self.fails
152                    )
153                )
154            )
155
156
157def requires_tag(tagname):
158    return tags([tagname])
159
160
161def tags(tagnames):
162    comp = compound()
163    comp.tags.update(tagnames)
164    return comp
165
166
167def only_if(predicate, reason=None):
168    predicate = _as_predicate(predicate)
169    return skip_if(NotPredicate(predicate), reason)
170
171
172def succeeds_if(predicate, reason=None):
173    predicate = _as_predicate(predicate)
174    return fails_if(NotPredicate(predicate), reason)
175
176
177class Predicate(object):
178    @classmethod
179    def as_predicate(cls, predicate, description=None):
180        if isinstance(predicate, compound):
181            return cls.as_predicate(predicate.enabled_for_config, description)
182        elif isinstance(predicate, Predicate):
183            if description and predicate.description is None:
184                predicate.description = description
185            return predicate
186        elif isinstance(predicate, (list, set)):
187            return OrPredicate(
188                [cls.as_predicate(pred) for pred in predicate],
189                description)
190        elif isinstance(predicate, tuple):
191            return SpecPredicate(*predicate)
192        elif isinstance(predicate, util.string_types):
193            tokens = predicate.split(" ", 2)
194            op = spec = None
195            db = tokens.pop(0)
196            if tokens:
197                op = tokens.pop(0)
198            if tokens:
199                spec = tuple(int(d) for d in tokens.pop(0).split("."))
200            return SpecPredicate(db, op, spec, description=description)
201        elif util.callable(predicate):
202            return LambdaPredicate(predicate, description)
203        else:
204            assert False, "unknown predicate type: %s" % predicate
205
206    def _format_description(self, config, negate=False):
207        bool_ = self(config)
208        if negate:
209            bool_ = not negate
210        return self.description % {
211            "driver": config.db.url.get_driver_name(),
212            "database": config.db.url.get_backend_name(),
213            "doesnt_support": "doesn't support" if bool_ else "does support",
214            "does_support": "does support" if bool_ else "doesn't support"
215        }
216
217    def _as_string(self, config=None, negate=False):
218        raise NotImplementedError()
219
220
221class BooleanPredicate(Predicate):
222    def __init__(self, value, description=None):
223        self.value = value
224        self.description = description or "boolean %s" % value
225
226    def __call__(self, config):
227        return self.value
228
229    def _as_string(self, config, negate=False):
230        return self._format_description(config, negate=negate)
231
232
233class SpecPredicate(Predicate):
234    def __init__(self, db, op=None, spec=None, description=None):
235        self.db = db
236        self.op = op
237        self.spec = spec
238        self.description = description
239
240    _ops = {
241        '<': operator.lt,
242        '>': operator.gt,
243        '==': operator.eq,
244        '!=': operator.ne,
245        '<=': operator.le,
246        '>=': operator.ge,
247        'in': operator.contains,
248        'between': lambda val, pair: val >= pair[0] and val <= pair[1],
249    }
250
251    def __call__(self, config):
252        engine = config.db
253
254        if "+" in self.db:
255            dialect, driver = self.db.split('+')
256        else:
257            dialect, driver = self.db, None
258
259        if dialect and engine.name != dialect:
260            return False
261        if driver is not None and engine.driver != driver:
262            return False
263
264        if self.op is not None:
265            assert driver is None, "DBAPI version specs not supported yet"
266
267            version = _server_version(engine)
268            oper = hasattr(self.op, '__call__') and self.op \
269                or self._ops[self.op]
270            return oper(version, self.spec)
271        else:
272            return True
273
274    def _as_string(self, config, negate=False):
275        if self.description is not None:
276            return self._format_description(config)
277        elif self.op is None:
278            if negate:
279                return "not %s" % self.db
280            else:
281                return "%s" % self.db
282        else:
283            if negate:
284                return "not %s %s %s" % (
285                    self.db,
286                    self.op,
287                    self.spec
288                )
289            else:
290                return "%s %s %s" % (
291                    self.db,
292                    self.op,
293                    self.spec
294                )
295
296
297class LambdaPredicate(Predicate):
298    def __init__(self, lambda_, description=None, args=None, kw=None):
299        spec = inspect_getargspec(lambda_)
300        if not spec[0]:
301            self.lambda_ = lambda db: lambda_()
302        else:
303            self.lambda_ = lambda_
304        self.args = args or ()
305        self.kw = kw or {}
306        if description:
307            self.description = description
308        elif lambda_.__doc__:
309            self.description = lambda_.__doc__
310        else:
311            self.description = "custom function"
312
313    def __call__(self, config):
314        return self.lambda_(config)
315
316    def _as_string(self, config, negate=False):
317        return self._format_description(config)
318
319
320class NotPredicate(Predicate):
321    def __init__(self, predicate, description=None):
322        self.predicate = predicate
323        self.description = description
324
325    def __call__(self, config):
326        return not self.predicate(config)
327
328    def _as_string(self, config, negate=False):
329        if self.description:
330            return self._format_description(config, not negate)
331        else:
332            return self.predicate._as_string(config, not negate)
333
334
335class OrPredicate(Predicate):
336    def __init__(self, predicates, description=None):
337        self.predicates = predicates
338        self.description = description
339
340    def __call__(self, config):
341        for pred in self.predicates:
342            if pred(config):
343                return True
344        return False
345
346    def _eval_str(self, config, negate=False):
347        if negate:
348            conjunction = " and "
349        else:
350            conjunction = " or "
351        return conjunction.join(p._as_string(config, negate=negate)
352                                for p in self.predicates)
353
354    def _negation_str(self, config):
355        if self.description is not None:
356            return "Not " + self._format_description(config)
357        else:
358            return self._eval_str(config, negate=True)
359
360    def _as_string(self, config, negate=False):
361        if negate:
362            return self._negation_str(config)
363        else:
364            if self.description is not None:
365                return self._format_description(config)
366            else:
367                return self._eval_str(config)
368
369
370_as_predicate = Predicate.as_predicate
371
372
373def _is_excluded(db, op, spec):
374    return SpecPredicate(db, op, spec)(config._current)
375
376
377def _server_version(engine):
378    """Return a server_version_info tuple."""
379
380    # force metadata to be retrieved
381    conn = engine.connect()
382    version = getattr(engine.dialect, 'server_version_info', ())
383    conn.close()
384    return version
385
386
387def db_spec(*dbs):
388    return OrPredicate(
389        [Predicate.as_predicate(db) for db in dbs]
390    )
391
392
393def open():
394    return skip_if(BooleanPredicate(False, "mark as execute"))
395
396
397def closed():
398    return skip_if(BooleanPredicate(True, "marked as skip"))
399
400
401def fails(reason=None):
402    return fails_if(BooleanPredicate(True, reason or "expected to fail"))
403
404
405@decorator
406def future(fn, *arg):
407    return fails_if(LambdaPredicate(fn), "Future feature")
408
409
410def fails_on(db, reason=None):
411    return fails_if(SpecPredicate(db), reason)
412
413
414def fails_on_everything_except(*dbs):
415    return succeeds_if(
416        OrPredicate([
417            SpecPredicate(db) for db in dbs
418        ])
419    )
420
421
422def skip(db, reason=None):
423    return skip_if(SpecPredicate(db), reason)
424
425
426def only_on(dbs, reason=None):
427    return only_if(
428        OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)])
429    )
430
431
432def exclude(db, op, spec, reason=None):
433    return skip_if(SpecPredicate(db, op, spec), reason)
434
435
436def against(config, *queries):
437    assert queries, "no queries sent!"
438    return OrPredicate([
439        Predicate.as_predicate(query)
440        for query in queries
441    ])(config)
442