1"""Evaluating SQL expressions on ORM objects"""
2
3from sqlalchemy import and_
4from sqlalchemy import bindparam
5from sqlalchemy import ForeignKey
6from sqlalchemy import inspect
7from sqlalchemy import Integer
8from sqlalchemy import not_
9from sqlalchemy import or_
10from sqlalchemy import String
11from sqlalchemy import tuple_
12from sqlalchemy.orm import evaluator
13from sqlalchemy.orm import exc as orm_exc
14from sqlalchemy.orm import relationship
15from sqlalchemy.testing import assert_raises
16from sqlalchemy.testing import assert_raises_message
17from sqlalchemy.testing import expect_warnings
18from sqlalchemy.testing import fixtures
19from sqlalchemy.testing import is_
20from sqlalchemy.testing.fixtures import fixture_session
21from sqlalchemy.testing.schema import Column
22from sqlalchemy.testing.schema import Table
23
24compiler = evaluator.EvaluatorCompiler()
25
26
27def eval_eq(clause, testcases=None):
28    evaluator = compiler.process(clause)
29
30    def testeval(obj=None, expected_result=None):
31        assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % (
32            evaluator(obj),
33            expected_result,
34            clause,
35            obj,
36        )
37
38    if testcases:
39        for an_obj, result in testcases:
40            testeval(an_obj, result)
41    return testeval
42
43
44class EvaluateTest(fixtures.MappedTest):
45    @classmethod
46    def define_tables(cls, metadata):
47        Table(
48            "users",
49            metadata,
50            Column("id", Integer, primary_key=True),
51            Column("name", String(64)),
52            Column("othername", String(64)),
53        )
54
55    @classmethod
56    def setup_classes(cls):
57        class User(cls.Basic):
58            pass
59
60    @classmethod
61    def setup_mappers(cls):
62        users, User = cls.tables.users, cls.classes.User
63
64        cls.mapper_registry.map_imperatively(User, users)
65
66    def test_compare_to_value(self):
67        User = self.classes.User
68
69        eval_eq(
70            User.name == "foo",
71            testcases=[
72                (User(name="foo"), True),
73                (User(name="bar"), False),
74                (User(name=None), None),
75            ],
76        )
77
78        eval_eq(
79            User.id < 5,
80            testcases=[
81                (User(id=3), True),
82                (User(id=5), False),
83                (User(id=None), None),
84            ],
85        )
86
87    def test_compare_to_callable_bind(self):
88        User = self.classes.User
89
90        eval_eq(
91            User.name == bindparam("x", callable_=lambda: "foo"),
92            testcases=[
93                (User(name="foo"), True),
94                (User(name="bar"), False),
95                (User(name=None), None),
96            ],
97        )
98
99    def test_compare_to_none(self):
100        User = self.classes.User
101
102        eval_eq(
103            User.name == None,  # noqa
104            testcases=[
105                (User(name="foo"), False),
106                (User(name=None), True),
107                (None, None),
108            ],
109        )
110
111    def test_warn_on_unannotated_matched_column(self):
112        User = self.classes.User
113
114        compiler = evaluator.EvaluatorCompiler(User)
115
116        with expect_warnings(
117            r"Evaluating non-mapped column expression 'othername' "
118            "onto ORM instances; this is a deprecated use case."
119        ):
120            meth = compiler.process(User.name == Column("othername", String))
121
122        u1 = User(id=5)
123        meth(u1)
124
125    def test_raise_on_unannotated_unmatched_column(self):
126        User = self.classes.User
127
128        compiler = evaluator.EvaluatorCompiler(User)
129
130        assert_raises_message(
131            evaluator.UnevaluatableError,
132            "Cannot evaluate column: foo",
133            compiler.process,
134            User.id == Column("foo", Integer),
135        )
136
137        # if we let the above method through as we did
138        # prior to [ticket:3366], we would get
139        # AttributeError: 'User' object has no attribute 'foo'
140        # u1 = User(id=5)
141        # meth(u1)
142
143    def test_true_false(self):
144        User = self.classes.User
145
146        eval_eq(
147            User.name == False,  # noqa
148            testcases=[
149                (User(name="foo"), False),
150                (User(name=True), False),
151                (User(name=False), True),
152                (None, None),
153            ],
154        )
155
156        eval_eq(
157            User.name == True,  # noqa
158            testcases=[
159                (User(name="foo"), False),
160                (User(name=True), True),
161                (User(name=False), False),
162                (None, None),
163            ],
164        )
165
166    def test_boolean_ops(self):
167        User = self.classes.User
168
169        eval_eq(
170            and_(User.name == "foo", User.id == 1),
171            testcases=[
172                (User(id=1, name="foo"), True),
173                (User(id=2, name="foo"), False),
174                (User(id=1, name="bar"), False),
175                (User(id=2, name="bar"), False),
176                (User(id=1, name=None), None),
177                (None, None),
178            ],
179        )
180
181        eval_eq(
182            or_(User.name == "foo", User.id == 1),
183            testcases=[
184                (User(id=1, name="foo"), True),
185                (User(id=2, name="foo"), True),
186                (User(id=1, name="bar"), True),
187                (User(id=2, name="bar"), False),
188                (User(id=1, name=None), True),
189                (User(id=2, name=None), None),
190                (None, None),
191            ],
192        )
193
194        eval_eq(
195            not_(User.id == 1),
196            testcases=[
197                (User(id=1), False),
198                (User(id=2), True),
199                (User(id=None), None),
200            ],
201        )
202
203    def test_in(self):
204        User = self.classes.User
205
206        eval_eq(
207            User.name.in_(["foo", "bar"]),
208            testcases=[
209                (User(id=1, name="foo"), True),
210                (User(id=2, name="bat"), False),
211                (User(id=1, name="bar"), True),
212                (User(id=1, name=None), None),
213                (None, None),
214            ],
215        )
216
217        eval_eq(
218            User.name.not_in(["foo", "bar"]),
219            testcases=[
220                (User(id=1, name="foo"), False),
221                (User(id=2, name="bat"), True),
222                (User(id=1, name="bar"), False),
223                (User(id=1, name=None), None),
224                (None, None),
225            ],
226        )
227
228    def test_in_tuples(self):
229        User = self.classes.User
230
231        eval_eq(
232            tuple_(User.id, User.name).in_([(1, "foo"), (2, "bar")]),
233            testcases=[
234                (User(id=1, name="foo"), True),
235                (User(id=2, name="bat"), False),
236                (User(id=1, name="bar"), False),
237                (User(id=2, name="bar"), True),
238                (User(id=1, name=None), None),
239                (None, None),
240            ],
241        )
242
243        eval_eq(
244            tuple_(User.id, User.name).not_in([(1, "foo"), (2, "bar")]),
245            testcases=[
246                (User(id=1, name="foo"), False),
247                (User(id=2, name="bat"), True),
248                (User(id=1, name="bar"), True),
249                (User(id=2, name="bar"), False),
250                (User(id=1, name=None), None),
251                (None, None),
252            ],
253        )
254
255    def test_null_propagation(self):
256        User = self.classes.User
257
258        eval_eq(
259            (User.name == "foo") == (User.id == 1),
260            testcases=[
261                (User(id=1, name="foo"), True),
262                (User(id=2, name="foo"), False),
263                (User(id=1, name="bar"), False),
264                (User(id=2, name="bar"), True),
265                (User(id=None, name="foo"), None),
266                (User(id=None, name=None), None),
267                (None, None),
268            ],
269        )
270
271
272class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
273    @classmethod
274    def setup_classes(cls):
275        Base = cls.DeclarativeBasic
276
277        class Parent(Base):
278            __tablename__ = "parent"
279            id = Column(Integer, primary_key=True)
280
281        class Child(Base):
282            __tablename__ = "child"
283            _id_parent = Column(
284                "id_parent", Integer, ForeignKey(Parent.id), primary_key=True
285            )
286            name = Column(String(50), primary_key=True)
287            parent = relationship(Parent)
288
289    def test_delete_not_expired(self):
290        Parent, Child = self.classes("Parent", "Child")
291
292        session = fixture_session(expire_on_commit=False)
293
294        p = Parent(id=1)
295        session.add(p)
296        session.commit()
297
298        c = Child(name="foo", parent=p)
299        session.add(c)
300        session.commit()
301
302        session.query(Child).filter(Child.parent == p).delete("evaluate")
303
304        is_(inspect(c).deleted, True)
305
306    def test_delete_expired(self):
307        Parent, Child = self.classes("Parent", "Child")
308
309        session = fixture_session()
310
311        p = Parent(id=1)
312        session.add(p)
313        session.commit()
314
315        c = Child(name="foo", parent=p)
316        session.add(c)
317        session.commit()
318
319        session.query(Child).filter(Child.parent == p).delete("evaluate")
320
321        # because it's expired
322        is_(inspect(c).deleted, False)
323
324        # but it's gone
325        assert_raises(orm_exc.ObjectDeletedError, lambda: c.name)
326