1from sqlalchemy import and_
2from sqlalchemy import Column
3from sqlalchemy import exc
4from sqlalchemy import ForeignKey
5from sqlalchemy import func
6from sqlalchemy import Integer
7from sqlalchemy import join
8from sqlalchemy import select
9from sqlalchemy import testing
10from sqlalchemy.orm import aliased
11from sqlalchemy.orm import joinedload
12from sqlalchemy.orm import noload
13from sqlalchemy.orm import relationship
14from sqlalchemy.orm import selectinload
15from sqlalchemy.orm import Session
16from sqlalchemy.testing import eq_
17from sqlalchemy.testing import fixtures
18from sqlalchemy.testing.assertions import expect_raises_message
19from sqlalchemy.testing.assertsql import CompiledSQL
20from sqlalchemy.testing.fixtures import ComparableEntity
21from sqlalchemy.testing.fixtures import fixture_session
22
23
24class PartitionByFixture(fixtures.DeclarativeMappedTest):
25    @classmethod
26    def setup_classes(cls):
27        Base = cls.DeclarativeBasic
28
29        class A(Base):
30            __tablename__ = "a"
31
32            id = Column(Integer, primary_key=True)
33
34        class B(Base):
35            __tablename__ = "b"
36            id = Column(Integer, primary_key=True)
37            a_id = Column(ForeignKey("a.id"))
38            cs = relationship("C")
39
40        class C(Base):
41            __tablename__ = "c"
42            id = Column(Integer, primary_key=True)
43            b_id = Column(ForeignKey("b.id"))
44
45        partition = select(
46            B,
47            func.row_number()
48            .over(order_by=B.id, partition_by=B.a_id)
49            .label("index"),
50        ).alias()
51
52        cls.partitioned_b = partitioned_b = aliased(B, alias=partition)
53
54        A.partitioned_bs = relationship(
55            partitioned_b,
56            primaryjoin=and_(
57                partitioned_b.a_id == A.id, partition.c.index < 10
58            ),
59        )
60
61    @classmethod
62    def insert_data(cls, connection):
63        A, B, C = cls.classes("A", "B", "C")
64
65        s = Session(connection)
66        s.add_all([A(id=i) for i in range(1, 4)])
67        s.flush()
68        s.add_all(
69            [
70                B(a_id=i, cs=[C(), C()])
71                for i in range(1, 4)
72                for j in range(1, 21)
73            ]
74        )
75        s.commit()
76
77
78class AliasedClassRelationshipTest(
79    PartitionByFixture, testing.AssertsCompiledSQL
80):
81    # TODO: maybe make this more  backend agnostic
82    __requires__ = ("window_functions",)
83    __dialect__ = "default"
84
85    def test_lazyload(self):
86        A, B, C = self.classes("A", "B", "C")
87
88        s = Session(testing.db)
89
90        def go():
91            for a1 in s.query(A):  # 1 query
92                eq_(len(a1.partitioned_bs), 9)  # 3 queries
93                for b in a1.partitioned_bs:
94                    eq_(len(b.cs), 2)  # 9 * 3 = 27 queries
95
96        self.assert_sql_count(testing.db, go, 31)
97
98    def test_join_one(self):
99        A, B, C = self.classes("A", "B", "C")
100
101        s = Session(testing.db)
102
103        q = s.query(A).join(A.partitioned_bs)
104        self.assert_compile(
105            q,
106            "SELECT a.id AS a_id FROM a JOIN "
107            "(SELECT b.id AS id, b.a_id AS a_id, row_number() "
108            "OVER (PARTITION BY b.a_id ORDER BY b.id) "
109            "AS index FROM b) AS anon_1 "
110            "ON anon_1.a_id = a.id AND anon_1.index < :index_1",
111        )
112
113    def test_join_two(self):
114        A, B, C = self.classes("A", "B", "C")
115
116        s = Session(testing.db)
117
118        q = s.query(A, A.partitioned_bs.entity).join(A.partitioned_bs)
119        self.assert_compile(
120            q,
121            "SELECT a.id AS a_id, anon_1.id AS anon_1_id, "
122            "anon_1.a_id AS anon_1_a_id "
123            "FROM a JOIN "
124            "(SELECT b.id AS id, b.a_id AS a_id, row_number() "
125            "OVER (PARTITION BY b.a_id ORDER BY b.id) "
126            "AS index FROM b) AS anon_1 "
127            "ON anon_1.a_id = a.id AND anon_1.index < :index_1",
128        )
129
130    def test_selectinload_w_noload_after(self):
131        A, B, C = self.classes("A", "B", "C")
132
133        s = Session(testing.db)
134
135        def go():
136            for a1 in s.query(A).options(
137                noload("*"), selectinload(A.partitioned_bs)
138            ):
139                for b in a1.partitioned_bs:
140                    eq_(b.cs, [])
141
142        self.assert_sql_count(testing.db, go, 2)
143
144    @testing.combinations("ac_attribute", "ac_attr_w_of_type")
145    def test_selectinload_w_joinedload_after(self, calling_style):
146        """test has been enhanced to also test #7224"""
147
148        A, B, C = self.classes("A", "B", "C")
149
150        s = Session(testing.db)
151
152        partitioned_b = self.partitioned_b
153
154        if calling_style == "ac_attribute":
155            opt = selectinload(A.partitioned_bs).joinedload(partitioned_b.cs)
156        elif calling_style == "ac_attr_w_of_type":
157            # this would have been a workaround for people who encountered
158            # #7224. The exception that was raised for "ac_attribute" actually
159            # suggested to use of_type() so we can assume this pattern is
160            # probably being used
161            opt = selectinload(
162                A.partitioned_bs.of_type(partitioned_b)
163            ).joinedload(partitioned_b.cs)
164        else:
165            assert False
166
167        def go():
168            for a1 in s.query(A).options(opt):
169                for b in a1.partitioned_bs:
170                    eq_(len(b.cs), 2)
171
172        self.assert_sql_count(testing.db, go, 2)
173
174    @testing.combinations(True, False)
175    def test_selectinload_w_joinedload_after_base_target_fails(
176        self, use_of_type
177    ):
178        A, B, C = self.classes("A", "B", "C")
179
180        s = Session(testing.db)
181        partitioned_b = self.partitioned_b
182
183        if use_of_type:
184            opt = selectinload(
185                A.partitioned_bs.of_type(partitioned_b)
186            ).joinedload(B.cs)
187        else:
188            opt = selectinload(A.partitioned_bs).joinedload(B.cs)
189
190        q = s.query(A).options(opt)
191
192        with expect_raises_message(
193            exc.ArgumentError,
194            r'Attribute "B.cs" does not link from element "aliased\(B\)"',
195        ):
196            q._compile_context()
197
198
199class AltSelectableTest(
200    fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL
201):
202    __dialect__ = "default"
203
204    @classmethod
205    def setup_classes(cls):
206        Base = cls.DeclarativeBasic
207
208        class A(ComparableEntity, Base):
209            __tablename__ = "a"
210
211            id = Column(Integer, primary_key=True)
212            b_id = Column(ForeignKey("b.id"))
213
214        class B(ComparableEntity, Base):
215            __tablename__ = "b"
216
217            id = Column(Integer, primary_key=True)
218
219        class C(ComparableEntity, Base):
220            __tablename__ = "c"
221
222            id = Column(Integer, primary_key=True)
223            a_id = Column(ForeignKey("a.id"))
224
225        class D(ComparableEntity, Base):
226            __tablename__ = "d"
227
228            id = Column(Integer, primary_key=True)
229            c_id = Column(ForeignKey("c.id"))
230            b_id = Column(ForeignKey("b.id"))
231
232        # 1. set up the join() as a variable, so we can refer
233        # to it in the mapping multiple times.
234        j = join(B, D, D.b_id == B.id).join(C, C.id == D.c_id)
235
236        # 2. Create an AliasedClass to B
237        B_viacd = aliased(B, j, flat=True)
238
239        A.b = relationship(B_viacd, primaryjoin=A.b_id == j.c.b_id)
240
241    @classmethod
242    def insert_data(cls, connection):
243        A, B, C, D = cls.classes("A", "B", "C", "D")
244        sess = Session(connection)
245
246        for obj in [
247            B(id=1),
248            A(id=1, b_id=1),
249            C(id=1, a_id=1),
250            D(id=1, c_id=1, b_id=1),
251        ]:
252            sess.add(obj)
253            sess.flush()
254        sess.commit()
255
256    def test_lazyload(self):
257        A, B = self.classes("A", "B")
258
259        sess = fixture_session()
260        a1 = sess.query(A).first()
261
262        with self.sql_execution_asserter() as asserter:
263            # note this is many-to-one.  use_get is unconditionally turned
264            # off for relationship to aliased class for now.
265            eq_(a1.b, B(id=1))
266
267        asserter.assert_(
268            CompiledSQL(
269                "SELECT b.id AS b_id FROM b JOIN d ON d.b_id = b.id "
270                "JOIN c ON c.id = d.c_id WHERE :param_1 = b.id",
271                [{"param_1": 1}],
272            )
273        )
274
275    def test_joinedload(self):
276        A, B = self.classes("A", "B")
277
278        sess = fixture_session()
279
280        with self.sql_execution_asserter() as asserter:
281            # note this is many-to-one.  use_get is unconditionally turned
282            # off for relationship to aliased class for now.
283            a1 = sess.query(A).options(joinedload(A.b)).first()
284            eq_(a1.b, B(id=1))
285
286        asserter.assert_(
287            CompiledSQL(
288                "SELECT a.id AS a_id, a.b_id AS a_b_id, b_1.id AS b_1_id "
289                "FROM a LEFT OUTER JOIN (b AS b_1 "
290                "JOIN d AS d_1 ON d_1.b_id = b_1.id "
291                "JOIN c AS c_1 ON c_1.id = d_1.c_id) ON a.b_id = b_1.id "
292                "LIMIT :param_1",
293                [{"param_1": 1}],
294            )
295        )
296
297    def test_selectinload(self):
298        A, B = self.classes("A", "B")
299
300        sess = fixture_session()
301
302        with self.sql_execution_asserter() as asserter:
303            # note this is many-to-one.  use_get is unconditionally turned
304            # off for relationship to aliased class for now.
305            a1 = sess.query(A).options(selectinload(A.b)).first()
306            eq_(a1.b, B(id=1))
307
308        asserter.assert_(
309            CompiledSQL(
310                "SELECT a.id AS a_id, a.b_id AS a_b_id "
311                "FROM a LIMIT :param_1",
312                [{"param_1": 1}],
313            ),
314            CompiledSQL(
315                "SELECT a_1.id AS a_1_id, b.id AS b_id FROM a AS a_1 "
316                "JOIN (b JOIN d ON d.b_id = b.id JOIN c ON c.id = d.c_id) "
317                "ON a_1.b_id = b.id WHERE a_1.id "
318                "IN (__[POSTCOMPILE_primary_keys])",
319                [{"primary_keys": [1]}],
320            ),
321        )
322
323    def test_join(self):
324        A, B = self.classes("A", "B")
325
326        sess = fixture_session()
327
328        self.assert_compile(
329            sess.query(A).join(A.b),
330            "SELECT a.id AS a_id, a.b_id AS a_b_id "
331            "FROM a JOIN (b JOIN d ON d.b_id = b.id "
332            "JOIN c ON c.id = d.c_id) ON a.b_id = b.id",
333        )
334