1import sqlalchemy as sa 2from sqlalchemy import ForeignKey 3from sqlalchemy import Integer 4from sqlalchemy import String 5from sqlalchemy import testing 6from sqlalchemy.orm import query 7from sqlalchemy.orm import relationship 8from sqlalchemy.orm import scoped_session 9from sqlalchemy.orm import Session 10from sqlalchemy.orm import sessionmaker 11from sqlalchemy.testing import assert_raises_message 12from sqlalchemy.testing import eq_ 13from sqlalchemy.testing import fixtures 14from sqlalchemy.testing import is_ 15from sqlalchemy.testing import mock 16from sqlalchemy.testing.mock import Mock 17from sqlalchemy.testing.schema import Column 18from sqlalchemy.testing.schema import Table 19 20 21class ScopedSessionTest(fixtures.MappedTest): 22 @classmethod 23 def define_tables(cls, metadata): 24 Table( 25 "table1", 26 metadata, 27 Column( 28 "id", Integer, primary_key=True, test_needs_autoincrement=True 29 ), 30 Column("data", String(30)), 31 ) 32 Table( 33 "table2", 34 metadata, 35 Column( 36 "id", Integer, primary_key=True, test_needs_autoincrement=True 37 ), 38 Column("someid", None, ForeignKey("table1.id")), 39 ) 40 41 def test_basic(self): 42 table2, table1 = self.tables.table2, self.tables.table1 43 44 Session = scoped_session(sa.orm.sessionmaker(testing.db)) 45 46 class CustomQuery(query.Query): 47 pass 48 49 class SomeObject(fixtures.ComparableEntity): 50 query = Session.query_property() 51 52 class SomeOtherObject(fixtures.ComparableEntity): 53 query = Session.query_property() 54 custom_query = Session.query_property(query_cls=CustomQuery) 55 56 self.mapper_registry.map_imperatively( 57 SomeObject, 58 table1, 59 properties={"options": relationship(SomeOtherObject)}, 60 ) 61 self.mapper_registry.map_imperatively(SomeOtherObject, table2) 62 63 s = SomeObject(id=1, data="hello") 64 sso = SomeOtherObject() 65 s.options.append(sso) 66 Session.add(s) 67 Session.commit() 68 Session.refresh(sso) 69 Session.remove() 70 71 eq_( 72 SomeObject( 73 id=1, data="hello", options=[SomeOtherObject(someid=1)] 74 ), 75 Session.query(SomeObject).one(), 76 ) 77 eq_( 78 SomeObject( 79 id=1, data="hello", options=[SomeOtherObject(someid=1)] 80 ), 81 SomeObject.query.one(), 82 ) 83 eq_( 84 SomeOtherObject(someid=1), 85 SomeOtherObject.query.filter( 86 SomeOtherObject.someid == sso.someid 87 ).one(), 88 ) 89 assert isinstance(SomeOtherObject.query, query.Query) 90 assert not isinstance(SomeOtherObject.query, CustomQuery) 91 assert isinstance(SomeOtherObject.custom_query, query.Query) 92 93 def test_config_errors(self): 94 Session = scoped_session(sa.orm.sessionmaker()) 95 96 s = Session() # noqa 97 assert_raises_message( 98 sa.exc.InvalidRequestError, 99 "Scoped session is already present", 100 Session, 101 bind=testing.db, 102 ) 103 104 assert_raises_message( 105 sa.exc.SAWarning, 106 "At least one scoped session is already present. ", 107 Session.configure, 108 bind=testing.db, 109 ) 110 111 def test_call_with_kwargs(self): 112 mock_scope_func = Mock() 113 SessionMaker = sa.orm.sessionmaker() 114 Session = scoped_session(sa.orm.sessionmaker(), mock_scope_func) 115 116 s0 = SessionMaker() 117 assert s0.autoflush == True 118 119 mock_scope_func.return_value = 0 120 s1 = Session() 121 assert s1.autoflush == True 122 123 assert_raises_message( 124 sa.exc.InvalidRequestError, 125 "Scoped session is already present", 126 Session, 127 autoflush=False, 128 ) 129 130 mock_scope_func.return_value = 1 131 s2 = Session(autoflush=False) 132 assert s2.autoflush == False 133 134 def test_methods_etc(self): 135 mock_session = Mock() 136 mock_session.bind = "the bind" 137 138 sess = scoped_session(lambda: mock_session) 139 140 sess.add("add") 141 sess.delete("delete") 142 143 sess.get("Cls", 5) 144 145 eq_(sess.bind, "the bind") 146 147 eq_( 148 mock_session.mock_calls, 149 [ 150 mock.call.add("add", _warn=True), 151 mock.call.delete("delete"), 152 mock.call.get( 153 "Cls", 154 5, 155 options=None, 156 populate_existing=False, 157 with_for_update=None, 158 identity_token=None, 159 ), 160 ], 161 ) 162 163 with mock.patch( 164 "sqlalchemy.orm.session.object_session" 165 ) as mock_object_session: 166 sess.object_session("foo") 167 168 eq_(mock_object_session.mock_calls, [mock.call("foo")]) 169 170 @testing.combinations( 171 ("style1", testing.requires.python3), 172 ("style2", testing.requires.python3), 173 "style3", 174 "style4", 175 ) 176 def test_get_bind_custom_session_subclass(self, style): 177 """test #6285""" 178 179 class MySession(Session): 180 if style == "style1": 181 182 def get_bind(self, mapper=None, **kwargs): 183 return super().get_bind(mapper=mapper, **kwargs) 184 185 elif style == "style2": 186 # this was the workaround for #6285, ensure it continues 187 # working as well 188 def get_bind(self, mapper=None, *args, **kwargs): 189 return super().get_bind(mapper, *args, **kwargs) 190 191 elif style == "style3": 192 # py2k style 193 def get_bind(self, mapper=None, *args, **kwargs): 194 return super(MySession, self).get_bind( 195 mapper, *args, **kwargs 196 ) 197 198 elif style == "style4": 199 # py2k style 200 def get_bind(self, mapper=None, **kwargs): 201 return super(MySession, self).get_bind( 202 mapper=mapper, **kwargs 203 ) 204 205 s1 = MySession(testing.db) 206 is_(s1.get_bind(), testing.db) 207 208 ss = scoped_session(sessionmaker(testing.db, class_=MySession)) 209 210 is_(ss.get_bind(), testing.db) 211 212 def test_attributes(self): 213 expected = [ 214 name 215 for cls in Session.mro() 216 for name in vars(cls) 217 if not name.startswith("_") 218 ] 219 220 ignore_list = { 221 "connection_callable", 222 "transaction", 223 "in_transaction", 224 "in_nested_transaction", 225 "get_transaction", 226 "get_nested_transaction", 227 "prepare", 228 "invalidate", 229 "bind_mapper", 230 "bind_table", 231 "enable_relationship_loading", 232 "dispatch", 233 } 234 235 SM = scoped_session(sa.orm.sessionmaker(testing.db)) 236 237 missing = [ 238 name 239 for name in expected 240 if not hasattr(SM, name) and name not in ignore_list 241 ] 242 eq_(missing, []) 243