1"""Evaluating SQL expressions on ORM objects"""
2
3from sqlalchemy import and_
4from sqlalchemy import bindparam
5from sqlalchemy import ForeignKey
6from sqlalchemy import inspect
7from sqlalchemy import Integer
8from sqlalchemy import not_
9from sqlalchemy import or_
10from sqlalchemy import String
11from sqlalchemy.orm import evaluator
12from sqlalchemy.orm import exc as orm_exc
13from sqlalchemy.orm import mapper
14from sqlalchemy.orm import relationship
15from sqlalchemy.orm import Session
16from sqlalchemy.testing import assert_raises
17from sqlalchemy.testing import assert_raises_message
18from sqlalchemy.testing import expect_warnings
19from sqlalchemy.testing import fixtures
20from sqlalchemy.testing import is_
21from sqlalchemy.testing.schema import Column
22from sqlalchemy.testing.schema import Table
23
24
25compiler = evaluator.EvaluatorCompiler()
26
27
28def eval_eq(clause, testcases=None):
29    evaluator = compiler.process(clause)
30
31    def testeval(obj=None, expected_result=None):
32        assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % (
33            evaluator(obj),
34            expected_result,
35            clause,
36            obj,
37        )
38
39    if testcases:
40        for an_obj, result in testcases:
41            testeval(an_obj, result)
42    return testeval
43
44
45class EvaluateTest(fixtures.MappedTest):
46    @classmethod
47    def define_tables(cls, metadata):
48        Table(
49            "users",
50            metadata,
51            Column("id", Integer, primary_key=True),
52            Column("name", String(64)),
53            Column("othername", String(64)),
54        )
55
56    @classmethod
57    def setup_classes(cls):
58        class User(cls.Basic):
59            pass
60
61    @classmethod
62    def setup_mappers(cls):
63        users, User = cls.tables.users, cls.classes.User
64
65        mapper(User, users)
66
67    def test_compare_to_value(self):
68        User = self.classes.User
69
70        eval_eq(
71            User.name == "foo",
72            testcases=[
73                (User(name="foo"), True),
74                (User(name="bar"), False),
75                (User(name=None), None),
76            ],
77        )
78
79        eval_eq(
80            User.id < 5,
81            testcases=[
82                (User(id=3), True),
83                (User(id=5), False),
84                (User(id=None), None),
85            ],
86        )
87
88    def test_compare_to_callable_bind(self):
89        User = self.classes.User
90
91        eval_eq(
92            User.name == bindparam("x", callable_=lambda: "foo"),
93            testcases=[
94                (User(name="foo"), True),
95                (User(name="bar"), False),
96                (User(name=None), None),
97            ],
98        )
99
100    def test_compare_to_none(self):
101        User = self.classes.User
102
103        eval_eq(
104            User.name == None,  # noqa
105            testcases=[(User(name="foo"), False), (User(name=None), True)],
106        )
107
108    def test_warn_on_unannotated_matched_column(self):
109        User = self.classes.User
110
111        compiler = evaluator.EvaluatorCompiler(User)
112
113        with expect_warnings(
114            r"Evaluating non-mapped column expression 'othername' "
115            "onto ORM instances; this is a deprecated use case."
116        ):
117            meth = compiler.process(User.name == Column("othername", String))
118
119        u1 = User(id=5)
120        meth(u1)
121
122    def test_raise_on_unannotated_unmatched_column(self):
123        User = self.classes.User
124
125        compiler = evaluator.EvaluatorCompiler(User)
126
127        assert_raises_message(
128            evaluator.UnevaluatableError,
129            "Cannot evaluate column: foo",
130            compiler.process,
131            User.id == Column("foo", Integer),
132        )
133
134        # if we let the above method through as we did
135        # prior to [ticket:3366], we would get
136        # AttributeError: 'User' object has no attribute 'foo'
137        # u1 = User(id=5)
138        # meth(u1)
139
140    def test_true_false(self):
141        User = self.classes.User
142
143        eval_eq(
144            User.name == False,  # noqa
145            testcases=[
146                (User(name="foo"), False),
147                (User(name=True), False),
148                (User(name=False), True),
149            ],
150        )
151
152        eval_eq(
153            User.name == True,  # noqa
154            testcases=[
155                (User(name="foo"), False),
156                (User(name=True), True),
157                (User(name=False), False),
158            ],
159        )
160
161    def test_boolean_ops(self):
162        User = self.classes.User
163
164        eval_eq(
165            and_(User.name == "foo", User.id == 1),
166            testcases=[
167                (User(id=1, name="foo"), True),
168                (User(id=2, name="foo"), False),
169                (User(id=1, name="bar"), False),
170                (User(id=2, name="bar"), False),
171                (User(id=1, name=None), None),
172            ],
173        )
174
175        eval_eq(
176            or_(User.name == "foo", User.id == 1),
177            testcases=[
178                (User(id=1, name="foo"), True),
179                (User(id=2, name="foo"), True),
180                (User(id=1, name="bar"), True),
181                (User(id=2, name="bar"), False),
182                (User(id=1, name=None), True),
183                (User(id=2, name=None), None),
184            ],
185        )
186
187        eval_eq(
188            not_(User.id == 1),
189            testcases=[
190                (User(id=1), False),
191                (User(id=2), True),
192                (User(id=None), None),
193            ],
194        )
195
196    def test_null_propagation(self):
197        User = self.classes.User
198
199        eval_eq(
200            (User.name == "foo") == (User.id == 1),
201            testcases=[
202                (User(id=1, name="foo"), True),
203                (User(id=2, name="foo"), False),
204                (User(id=1, name="bar"), False),
205                (User(id=2, name="bar"), True),
206                (User(id=None, name="foo"), None),
207                (User(id=None, name=None), None),
208            ],
209        )
210
211
212class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
213    @classmethod
214    def setup_classes(cls):
215        Base = cls.DeclarativeBasic
216
217        class Parent(Base):
218            __tablename__ = "parent"
219            id = Column(Integer, primary_key=True)
220
221        class Child(Base):
222            __tablename__ = "child"
223            _id_parent = Column(
224                "id_parent", Integer, ForeignKey(Parent.id), primary_key=True
225            )
226            name = Column(String(50), primary_key=True)
227            parent = relationship(Parent)
228
229    def test_delete_not_expired(self):
230        Parent, Child = self.classes("Parent", "Child")
231
232        session = Session(expire_on_commit=False)
233
234        p = Parent(id=1)
235        session.add(p)
236        session.commit()
237
238        c = Child(name="foo", parent=p)
239        session.add(c)
240        session.commit()
241
242        session.query(Child).filter(Child.parent == p).delete("evaluate")
243
244        is_(inspect(c).deleted, True)
245
246    def test_delete_expired(self):
247        Parent, Child = self.classes("Parent", "Child")
248
249        session = Session()
250
251        p = Parent(id=1)
252        session.add(p)
253        session.commit()
254
255        c = Child(name="foo", parent=p)
256        session.add(c)
257        session.commit()
258
259        session.query(Child).filter(Child.parent == p).delete("evaluate")
260
261        # because it's expired
262        is_(inspect(c).deleted, False)
263
264        # but it's gone
265        assert_raises(orm_exc.ObjectDeletedError, lambda: c.name)
266