1"""Evaluating SQL expressions on ORM objects""" 2 3from sqlalchemy import and_ 4from sqlalchemy import bindparam 5from sqlalchemy import ForeignKey 6from sqlalchemy import inspect 7from sqlalchemy import Integer 8from sqlalchemy import not_ 9from sqlalchemy import or_ 10from sqlalchemy import String 11from sqlalchemy.orm import evaluator 12from sqlalchemy.orm import exc as orm_exc 13from sqlalchemy.orm import mapper 14from sqlalchemy.orm import relationship 15from sqlalchemy.orm import Session 16from sqlalchemy.testing import assert_raises 17from sqlalchemy.testing import assert_raises_message 18from sqlalchemy.testing import expect_warnings 19from sqlalchemy.testing import fixtures 20from sqlalchemy.testing import is_ 21from sqlalchemy.testing.schema import Column 22from sqlalchemy.testing.schema import Table 23 24 25compiler = evaluator.EvaluatorCompiler() 26 27 28def eval_eq(clause, testcases=None): 29 evaluator = compiler.process(clause) 30 31 def testeval(obj=None, expected_result=None): 32 assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % ( 33 evaluator(obj), 34 expected_result, 35 clause, 36 obj, 37 ) 38 39 if testcases: 40 for an_obj, result in testcases: 41 testeval(an_obj, result) 42 return testeval 43 44 45class EvaluateTest(fixtures.MappedTest): 46 @classmethod 47 def define_tables(cls, metadata): 48 Table( 49 "users", 50 metadata, 51 Column("id", Integer, primary_key=True), 52 Column("name", String(64)), 53 Column("othername", String(64)), 54 ) 55 56 @classmethod 57 def setup_classes(cls): 58 class User(cls.Basic): 59 pass 60 61 @classmethod 62 def setup_mappers(cls): 63 users, User = cls.tables.users, cls.classes.User 64 65 mapper(User, users) 66 67 def test_compare_to_value(self): 68 User = self.classes.User 69 70 eval_eq( 71 User.name == "foo", 72 testcases=[ 73 (User(name="foo"), True), 74 (User(name="bar"), False), 75 (User(name=None), None), 76 ], 77 ) 78 79 eval_eq( 80 User.id < 5, 81 testcases=[ 82 (User(id=3), True), 83 (User(id=5), False), 84 (User(id=None), None), 85 ], 86 ) 87 88 def test_compare_to_callable_bind(self): 89 User = self.classes.User 90 91 eval_eq( 92 User.name == bindparam("x", callable_=lambda: "foo"), 93 testcases=[ 94 (User(name="foo"), True), 95 (User(name="bar"), False), 96 (User(name=None), None), 97 ], 98 ) 99 100 def test_compare_to_none(self): 101 User = self.classes.User 102 103 eval_eq( 104 User.name == None, # noqa 105 testcases=[(User(name="foo"), False), (User(name=None), True)], 106 ) 107 108 def test_warn_on_unannotated_matched_column(self): 109 User = self.classes.User 110 111 compiler = evaluator.EvaluatorCompiler(User) 112 113 with expect_warnings( 114 r"Evaluating non-mapped column expression 'othername' " 115 "onto ORM instances; this is a deprecated use case." 116 ): 117 meth = compiler.process(User.name == Column("othername", String)) 118 119 u1 = User(id=5) 120 meth(u1) 121 122 def test_raise_on_unannotated_unmatched_column(self): 123 User = self.classes.User 124 125 compiler = evaluator.EvaluatorCompiler(User) 126 127 assert_raises_message( 128 evaluator.UnevaluatableError, 129 "Cannot evaluate column: foo", 130 compiler.process, 131 User.id == Column("foo", Integer), 132 ) 133 134 # if we let the above method through as we did 135 # prior to [ticket:3366], we would get 136 # AttributeError: 'User' object has no attribute 'foo' 137 # u1 = User(id=5) 138 # meth(u1) 139 140 def test_true_false(self): 141 User = self.classes.User 142 143 eval_eq( 144 User.name == False, # noqa 145 testcases=[ 146 (User(name="foo"), False), 147 (User(name=True), False), 148 (User(name=False), True), 149 ], 150 ) 151 152 eval_eq( 153 User.name == True, # noqa 154 testcases=[ 155 (User(name="foo"), False), 156 (User(name=True), True), 157 (User(name=False), False), 158 ], 159 ) 160 161 def test_boolean_ops(self): 162 User = self.classes.User 163 164 eval_eq( 165 and_(User.name == "foo", User.id == 1), 166 testcases=[ 167 (User(id=1, name="foo"), True), 168 (User(id=2, name="foo"), False), 169 (User(id=1, name="bar"), False), 170 (User(id=2, name="bar"), False), 171 (User(id=1, name=None), None), 172 ], 173 ) 174 175 eval_eq( 176 or_(User.name == "foo", User.id == 1), 177 testcases=[ 178 (User(id=1, name="foo"), True), 179 (User(id=2, name="foo"), True), 180 (User(id=1, name="bar"), True), 181 (User(id=2, name="bar"), False), 182 (User(id=1, name=None), True), 183 (User(id=2, name=None), None), 184 ], 185 ) 186 187 eval_eq( 188 not_(User.id == 1), 189 testcases=[ 190 (User(id=1), False), 191 (User(id=2), True), 192 (User(id=None), None), 193 ], 194 ) 195 196 def test_null_propagation(self): 197 User = self.classes.User 198 199 eval_eq( 200 (User.name == "foo") == (User.id == 1), 201 testcases=[ 202 (User(id=1, name="foo"), True), 203 (User(id=2, name="foo"), False), 204 (User(id=1, name="bar"), False), 205 (User(id=2, name="bar"), True), 206 (User(id=None, name="foo"), None), 207 (User(id=None, name=None), None), 208 ], 209 ) 210 211 212class M2OEvaluateTest(fixtures.DeclarativeMappedTest): 213 @classmethod 214 def setup_classes(cls): 215 Base = cls.DeclarativeBasic 216 217 class Parent(Base): 218 __tablename__ = "parent" 219 id = Column(Integer, primary_key=True) 220 221 class Child(Base): 222 __tablename__ = "child" 223 _id_parent = Column( 224 "id_parent", Integer, ForeignKey(Parent.id), primary_key=True 225 ) 226 name = Column(String(50), primary_key=True) 227 parent = relationship(Parent) 228 229 def test_delete_not_expired(self): 230 Parent, Child = self.classes("Parent", "Child") 231 232 session = Session(expire_on_commit=False) 233 234 p = Parent(id=1) 235 session.add(p) 236 session.commit() 237 238 c = Child(name="foo", parent=p) 239 session.add(c) 240 session.commit() 241 242 session.query(Child).filter(Child.parent == p).delete("evaluate") 243 244 is_(inspect(c).deleted, True) 245 246 def test_delete_expired(self): 247 Parent, Child = self.classes("Parent", "Child") 248 249 session = Session() 250 251 p = Parent(id=1) 252 session.add(p) 253 session.commit() 254 255 c = Child(name="foo", parent=p) 256 session.add(c) 257 session.commit() 258 259 session.query(Child).filter(Child.parent == p).delete("evaluate") 260 261 # because it's expired 262 is_(inspect(c).deleted, False) 263 264 # but it's gone 265 assert_raises(orm_exc.ObjectDeletedError, lambda: c.name) 266