1import sqlalchemy as sa
2from sqlalchemy import ForeignKey
3from sqlalchemy import Integer
4from sqlalchemy import String
5from sqlalchemy import testing
6from sqlalchemy.orm import query
7from sqlalchemy.orm import relationship
8from sqlalchemy.orm import scoped_session
9from sqlalchemy.orm import Session
10from sqlalchemy.orm import sessionmaker
11from sqlalchemy.testing import assert_raises_message
12from sqlalchemy.testing import eq_
13from sqlalchemy.testing import fixtures
14from sqlalchemy.testing import is_
15from sqlalchemy.testing import mock
16from sqlalchemy.testing.mock import Mock
17from sqlalchemy.testing.schema import Column
18from sqlalchemy.testing.schema import Table
19
20
21class ScopedSessionTest(fixtures.MappedTest):
22    @classmethod
23    def define_tables(cls, metadata):
24        Table(
25            "table1",
26            metadata,
27            Column(
28                "id", Integer, primary_key=True, test_needs_autoincrement=True
29            ),
30            Column("data", String(30)),
31        )
32        Table(
33            "table2",
34            metadata,
35            Column(
36                "id", Integer, primary_key=True, test_needs_autoincrement=True
37            ),
38            Column("someid", None, ForeignKey("table1.id")),
39        )
40
41    def test_basic(self):
42        table2, table1 = self.tables.table2, self.tables.table1
43
44        Session = scoped_session(sa.orm.sessionmaker(testing.db))
45
46        class CustomQuery(query.Query):
47            pass
48
49        class SomeObject(fixtures.ComparableEntity):
50            query = Session.query_property()
51
52        class SomeOtherObject(fixtures.ComparableEntity):
53            query = Session.query_property()
54            custom_query = Session.query_property(query_cls=CustomQuery)
55
56        self.mapper_registry.map_imperatively(
57            SomeObject,
58            table1,
59            properties={"options": relationship(SomeOtherObject)},
60        )
61        self.mapper_registry.map_imperatively(SomeOtherObject, table2)
62
63        s = SomeObject(id=1, data="hello")
64        sso = SomeOtherObject()
65        s.options.append(sso)
66        Session.add(s)
67        Session.commit()
68        Session.refresh(sso)
69        Session.remove()
70
71        eq_(
72            SomeObject(
73                id=1, data="hello", options=[SomeOtherObject(someid=1)]
74            ),
75            Session.query(SomeObject).one(),
76        )
77        eq_(
78            SomeObject(
79                id=1, data="hello", options=[SomeOtherObject(someid=1)]
80            ),
81            SomeObject.query.one(),
82        )
83        eq_(
84            SomeOtherObject(someid=1),
85            SomeOtherObject.query.filter(
86                SomeOtherObject.someid == sso.someid
87            ).one(),
88        )
89        assert isinstance(SomeOtherObject.query, query.Query)
90        assert not isinstance(SomeOtherObject.query, CustomQuery)
91        assert isinstance(SomeOtherObject.custom_query, query.Query)
92
93    def test_config_errors(self):
94        Session = scoped_session(sa.orm.sessionmaker())
95
96        s = Session()  # noqa
97        assert_raises_message(
98            sa.exc.InvalidRequestError,
99            "Scoped session is already present",
100            Session,
101            bind=testing.db,
102        )
103
104        assert_raises_message(
105            sa.exc.SAWarning,
106            "At least one scoped session is already present. ",
107            Session.configure,
108            bind=testing.db,
109        )
110
111    def test_call_with_kwargs(self):
112        mock_scope_func = Mock()
113        SessionMaker = sa.orm.sessionmaker()
114        Session = scoped_session(sa.orm.sessionmaker(), mock_scope_func)
115
116        s0 = SessionMaker()
117        assert s0.autoflush == True
118
119        mock_scope_func.return_value = 0
120        s1 = Session()
121        assert s1.autoflush == True
122
123        assert_raises_message(
124            sa.exc.InvalidRequestError,
125            "Scoped session is already present",
126            Session,
127            autoflush=False,
128        )
129
130        mock_scope_func.return_value = 1
131        s2 = Session(autoflush=False)
132        assert s2.autoflush == False
133
134    def test_methods_etc(self):
135        mock_session = Mock()
136        mock_session.bind = "the bind"
137
138        sess = scoped_session(lambda: mock_session)
139
140        sess.add("add")
141        sess.delete("delete")
142
143        sess.get("Cls", 5)
144
145        eq_(sess.bind, "the bind")
146
147        eq_(
148            mock_session.mock_calls,
149            [
150                mock.call.add("add", _warn=True),
151                mock.call.delete("delete"),
152                mock.call.get(
153                    "Cls",
154                    5,
155                    options=None,
156                    populate_existing=False,
157                    with_for_update=None,
158                    identity_token=None,
159                ),
160            ],
161        )
162
163        with mock.patch(
164            "sqlalchemy.orm.session.object_session"
165        ) as mock_object_session:
166            sess.object_session("foo")
167
168        eq_(mock_object_session.mock_calls, [mock.call("foo")])
169
170    @testing.combinations(
171        ("style1", testing.requires.python3),
172        ("style2", testing.requires.python3),
173        "style3",
174        "style4",
175    )
176    def test_get_bind_custom_session_subclass(self, style):
177        """test #6285"""
178
179        class MySession(Session):
180            if style == "style1":
181
182                def get_bind(self, mapper=None, **kwargs):
183                    return super().get_bind(mapper=mapper, **kwargs)
184
185            elif style == "style2":
186                # this was the workaround for #6285, ensure it continues
187                # working as well
188                def get_bind(self, mapper=None, *args, **kwargs):
189                    return super().get_bind(mapper, *args, **kwargs)
190
191            elif style == "style3":
192                # py2k style
193                def get_bind(self, mapper=None, *args, **kwargs):
194                    return super(MySession, self).get_bind(
195                        mapper, *args, **kwargs
196                    )
197
198            elif style == "style4":
199                # py2k style
200                def get_bind(self, mapper=None, **kwargs):
201                    return super(MySession, self).get_bind(
202                        mapper=mapper, **kwargs
203                    )
204
205        s1 = MySession(testing.db)
206        is_(s1.get_bind(), testing.db)
207
208        ss = scoped_session(sessionmaker(testing.db, class_=MySession))
209
210        is_(ss.get_bind(), testing.db)
211
212    def test_attributes(self):
213        expected = [
214            name
215            for cls in Session.mro()
216            for name in vars(cls)
217            if not name.startswith("_")
218        ]
219
220        ignore_list = {
221            "connection_callable",
222            "transaction",
223            "in_transaction",
224            "in_nested_transaction",
225            "get_transaction",
226            "get_nested_transaction",
227            "prepare",
228            "invalidate",
229            "bind_mapper",
230            "bind_table",
231            "enable_relationship_loading",
232            "dispatch",
233        }
234
235        SM = scoped_session(sa.orm.sessionmaker(testing.db))
236
237        missing = [
238            name
239            for name in expected
240            if not hasattr(SM, name) and name not in ignore_list
241        ]
242        eq_(missing, [])
243