1# orm/evaluator.py
2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8import operator
9
10from .. import inspect
11from .. import util
12from ..sql import operators
13
14
15class UnevaluatableError(Exception):
16    pass
17
18
19_straight_ops = set(
20    getattr(operators, op)
21    for op in (
22        "add",
23        "mul",
24        "sub",
25        "div",
26        "mod",
27        "truediv",
28        "lt",
29        "le",
30        "ne",
31        "gt",
32        "ge",
33        "eq",
34    )
35)
36
37
38_notimplemented_ops = set(
39    getattr(operators, op)
40    for op in (
41        "like_op",
42        "notlike_op",
43        "ilike_op",
44        "notilike_op",
45        "between_op",
46        "in_op",
47        "notin_op",
48        "endswith_op",
49        "concat_op",
50    )
51)
52
53
54class EvaluatorCompiler(object):
55    def __init__(self, target_cls=None):
56        self.target_cls = target_cls
57
58    def process(self, clause):
59        meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
60        if not meth:
61            raise UnevaluatableError(
62                "Cannot evaluate %s" % type(clause).__name__
63            )
64        return meth(clause)
65
66    def visit_grouping(self, clause):
67        return self.process(clause.element)
68
69    def visit_null(self, clause):
70        return lambda obj: None
71
72    def visit_false(self, clause):
73        return lambda obj: False
74
75    def visit_true(self, clause):
76        return lambda obj: True
77
78    def visit_column(self, clause):
79        if "parentmapper" in clause._annotations:
80            parentmapper = clause._annotations["parentmapper"]
81            if self.target_cls and not issubclass(
82                self.target_cls, parentmapper.class_
83            ):
84                raise UnevaluatableError(
85                    "Can't evaluate criteria against alternate class %s"
86                    % parentmapper.class_
87                )
88            key = parentmapper._columntoproperty[clause].key
89        else:
90            key = clause.key
91            if (
92                self.target_cls
93                and key in inspect(self.target_cls).column_attrs
94            ):
95                util.warn(
96                    "Evaluating non-mapped column expression '%s' onto "
97                    "ORM instances; this is a deprecated use case.  Please "
98                    "make use of the actual mapped columns in ORM-evaluated "
99                    "UPDATE / DELETE expressions." % clause
100                )
101            else:
102                raise UnevaluatableError("Cannot evaluate column: %s" % clause)
103
104        get_corresponding_attr = operator.attrgetter(key)
105        return lambda obj: get_corresponding_attr(obj)
106
107    def visit_clauselist(self, clause):
108        evaluators = list(map(self.process, clause.clauses))
109        if clause.operator is operators.or_:
110
111            def evaluate(obj):
112                has_null = False
113                for sub_evaluate in evaluators:
114                    value = sub_evaluate(obj)
115                    if value:
116                        return True
117                    has_null = has_null or value is None
118                if has_null:
119                    return None
120                return False
121
122        elif clause.operator is operators.and_:
123
124            def evaluate(obj):
125                for sub_evaluate in evaluators:
126                    value = sub_evaluate(obj)
127                    if not value:
128                        if value is None:
129                            return None
130                        return False
131                return True
132
133        else:
134            raise UnevaluatableError(
135                "Cannot evaluate clauselist with operator %s" % clause.operator
136            )
137
138        return evaluate
139
140    def visit_binary(self, clause):
141        eval_left, eval_right = list(
142            map(self.process, [clause.left, clause.right])
143        )
144        operator = clause.operator
145        if operator is operators.is_:
146
147            def evaluate(obj):
148                return eval_left(obj) == eval_right(obj)
149
150        elif operator is operators.isnot:
151
152            def evaluate(obj):
153                return eval_left(obj) != eval_right(obj)
154
155        elif operator in _straight_ops:
156
157            def evaluate(obj):
158                left_val = eval_left(obj)
159                right_val = eval_right(obj)
160                if left_val is None or right_val is None:
161                    return None
162                return operator(eval_left(obj), eval_right(obj))
163
164        else:
165            raise UnevaluatableError(
166                "Cannot evaluate %s with operator %s"
167                % (type(clause).__name__, clause.operator)
168            )
169        return evaluate
170
171    def visit_unary(self, clause):
172        eval_inner = self.process(clause.element)
173        if clause.operator is operators.inv:
174
175            def evaluate(obj):
176                value = eval_inner(obj)
177                if value is None:
178                    return None
179                return not value
180
181            return evaluate
182        raise UnevaluatableError(
183            "Cannot evaluate %s with operator %s"
184            % (type(clause).__name__, clause.operator)
185        )
186
187    def visit_bindparam(self, clause):
188        if clause.callable:
189            val = clause.callable()
190        else:
191            val = clause.value
192        return lambda obj: val
193