1"""Evaluating SQL expressions on ORM objects""" 2 3from sqlalchemy import and_ 4from sqlalchemy import bindparam 5from sqlalchemy import ForeignKey 6from sqlalchemy import inspect 7from sqlalchemy import Integer 8from sqlalchemy import not_ 9from sqlalchemy import or_ 10from sqlalchemy import String 11from sqlalchemy import tuple_ 12from sqlalchemy.orm import evaluator 13from sqlalchemy.orm import exc as orm_exc 14from sqlalchemy.orm import relationship 15from sqlalchemy.testing import assert_raises 16from sqlalchemy.testing import assert_raises_message 17from sqlalchemy.testing import expect_warnings 18from sqlalchemy.testing import fixtures 19from sqlalchemy.testing import is_ 20from sqlalchemy.testing.fixtures import fixture_session 21from sqlalchemy.testing.schema import Column 22from sqlalchemy.testing.schema import Table 23 24compiler = evaluator.EvaluatorCompiler() 25 26 27def eval_eq(clause, testcases=None): 28 evaluator = compiler.process(clause) 29 30 def testeval(obj=None, expected_result=None): 31 assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % ( 32 evaluator(obj), 33 expected_result, 34 clause, 35 obj, 36 ) 37 38 if testcases: 39 for an_obj, result in testcases: 40 testeval(an_obj, result) 41 return testeval 42 43 44class EvaluateTest(fixtures.MappedTest): 45 @classmethod 46 def define_tables(cls, metadata): 47 Table( 48 "users", 49 metadata, 50 Column("id", Integer, primary_key=True), 51 Column("name", String(64)), 52 Column("othername", String(64)), 53 ) 54 55 @classmethod 56 def setup_classes(cls): 57 class User(cls.Basic): 58 pass 59 60 @classmethod 61 def setup_mappers(cls): 62 users, User = cls.tables.users, cls.classes.User 63 64 cls.mapper_registry.map_imperatively(User, users) 65 66 def test_compare_to_value(self): 67 User = self.classes.User 68 69 eval_eq( 70 User.name == "foo", 71 testcases=[ 72 (User(name="foo"), True), 73 (User(name="bar"), False), 74 (User(name=None), None), 75 ], 76 ) 77 78 eval_eq( 79 User.id < 5, 80 testcases=[ 81 (User(id=3), True), 82 (User(id=5), False), 83 (User(id=None), None), 84 ], 85 ) 86 87 def test_compare_to_callable_bind(self): 88 User = self.classes.User 89 90 eval_eq( 91 User.name == bindparam("x", callable_=lambda: "foo"), 92 testcases=[ 93 (User(name="foo"), True), 94 (User(name="bar"), False), 95 (User(name=None), None), 96 ], 97 ) 98 99 def test_compare_to_none(self): 100 User = self.classes.User 101 102 eval_eq( 103 User.name == None, # noqa 104 testcases=[ 105 (User(name="foo"), False), 106 (User(name=None), True), 107 (None, None), 108 ], 109 ) 110 111 def test_warn_on_unannotated_matched_column(self): 112 User = self.classes.User 113 114 compiler = evaluator.EvaluatorCompiler(User) 115 116 with expect_warnings( 117 r"Evaluating non-mapped column expression 'othername' " 118 "onto ORM instances; this is a deprecated use case." 119 ): 120 meth = compiler.process(User.name == Column("othername", String)) 121 122 u1 = User(id=5) 123 meth(u1) 124 125 def test_raise_on_unannotated_unmatched_column(self): 126 User = self.classes.User 127 128 compiler = evaluator.EvaluatorCompiler(User) 129 130 assert_raises_message( 131 evaluator.UnevaluatableError, 132 "Cannot evaluate column: foo", 133 compiler.process, 134 User.id == Column("foo", Integer), 135 ) 136 137 # if we let the above method through as we did 138 # prior to [ticket:3366], we would get 139 # AttributeError: 'User' object has no attribute 'foo' 140 # u1 = User(id=5) 141 # meth(u1) 142 143 def test_true_false(self): 144 User = self.classes.User 145 146 eval_eq( 147 User.name == False, # noqa 148 testcases=[ 149 (User(name="foo"), False), 150 (User(name=True), False), 151 (User(name=False), True), 152 (None, None), 153 ], 154 ) 155 156 eval_eq( 157 User.name == True, # noqa 158 testcases=[ 159 (User(name="foo"), False), 160 (User(name=True), True), 161 (User(name=False), False), 162 (None, None), 163 ], 164 ) 165 166 def test_boolean_ops(self): 167 User = self.classes.User 168 169 eval_eq( 170 and_(User.name == "foo", User.id == 1), 171 testcases=[ 172 (User(id=1, name="foo"), True), 173 (User(id=2, name="foo"), False), 174 (User(id=1, name="bar"), False), 175 (User(id=2, name="bar"), False), 176 (User(id=1, name=None), None), 177 (None, None), 178 ], 179 ) 180 181 eval_eq( 182 or_(User.name == "foo", User.id == 1), 183 testcases=[ 184 (User(id=1, name="foo"), True), 185 (User(id=2, name="foo"), True), 186 (User(id=1, name="bar"), True), 187 (User(id=2, name="bar"), False), 188 (User(id=1, name=None), True), 189 (User(id=2, name=None), None), 190 (None, None), 191 ], 192 ) 193 194 eval_eq( 195 not_(User.id == 1), 196 testcases=[ 197 (User(id=1), False), 198 (User(id=2), True), 199 (User(id=None), None), 200 ], 201 ) 202 203 def test_in(self): 204 User = self.classes.User 205 206 eval_eq( 207 User.name.in_(["foo", "bar"]), 208 testcases=[ 209 (User(id=1, name="foo"), True), 210 (User(id=2, name="bat"), False), 211 (User(id=1, name="bar"), True), 212 (User(id=1, name=None), None), 213 (None, None), 214 ], 215 ) 216 217 eval_eq( 218 User.name.not_in(["foo", "bar"]), 219 testcases=[ 220 (User(id=1, name="foo"), False), 221 (User(id=2, name="bat"), True), 222 (User(id=1, name="bar"), False), 223 (User(id=1, name=None), None), 224 (None, None), 225 ], 226 ) 227 228 def test_in_tuples(self): 229 User = self.classes.User 230 231 eval_eq( 232 tuple_(User.id, User.name).in_([(1, "foo"), (2, "bar")]), 233 testcases=[ 234 (User(id=1, name="foo"), True), 235 (User(id=2, name="bat"), False), 236 (User(id=1, name="bar"), False), 237 (User(id=2, name="bar"), True), 238 (User(id=1, name=None), None), 239 (None, None), 240 ], 241 ) 242 243 eval_eq( 244 tuple_(User.id, User.name).not_in([(1, "foo"), (2, "bar")]), 245 testcases=[ 246 (User(id=1, name="foo"), False), 247 (User(id=2, name="bat"), True), 248 (User(id=1, name="bar"), True), 249 (User(id=2, name="bar"), False), 250 (User(id=1, name=None), None), 251 (None, None), 252 ], 253 ) 254 255 def test_null_propagation(self): 256 User = self.classes.User 257 258 eval_eq( 259 (User.name == "foo") == (User.id == 1), 260 testcases=[ 261 (User(id=1, name="foo"), True), 262 (User(id=2, name="foo"), False), 263 (User(id=1, name="bar"), False), 264 (User(id=2, name="bar"), True), 265 (User(id=None, name="foo"), None), 266 (User(id=None, name=None), None), 267 (None, None), 268 ], 269 ) 270 271 272class M2OEvaluateTest(fixtures.DeclarativeMappedTest): 273 @classmethod 274 def setup_classes(cls): 275 Base = cls.DeclarativeBasic 276 277 class Parent(Base): 278 __tablename__ = "parent" 279 id = Column(Integer, primary_key=True) 280 281 class Child(Base): 282 __tablename__ = "child" 283 _id_parent = Column( 284 "id_parent", Integer, ForeignKey(Parent.id), primary_key=True 285 ) 286 name = Column(String(50), primary_key=True) 287 parent = relationship(Parent) 288 289 def test_delete_not_expired(self): 290 Parent, Child = self.classes("Parent", "Child") 291 292 session = fixture_session(expire_on_commit=False) 293 294 p = Parent(id=1) 295 session.add(p) 296 session.commit() 297 298 c = Child(name="foo", parent=p) 299 session.add(c) 300 session.commit() 301 302 session.query(Child).filter(Child.parent == p).delete("evaluate") 303 304 is_(inspect(c).deleted, True) 305 306 def test_delete_expired(self): 307 Parent, Child = self.classes("Parent", "Child") 308 309 session = fixture_session() 310 311 p = Parent(id=1) 312 session.add(p) 313 session.commit() 314 315 c = Child(name="foo", parent=p) 316 session.add(c) 317 session.commit() 318 319 session.query(Child).filter(Child.parent == p).delete("evaluate") 320 321 # because it's expired 322 is_(inspect(c).deleted, False) 323 324 # but it's gone 325 assert_raises(orm_exc.ObjectDeletedError, lambda: c.name) 326