1import re
2
3from sqlalchemy import and_
4from sqlalchemy import bindparam
5from sqlalchemy import case
6from sqlalchemy import Column
7from sqlalchemy import exc
8from sqlalchemy import extract
9from sqlalchemy import ForeignKey
10from sqlalchemy import func
11from sqlalchemy import Integer
12from sqlalchemy import join
13from sqlalchemy import literal
14from sqlalchemy import literal_column
15from sqlalchemy import MetaData
16from sqlalchemy import null
17from sqlalchemy import select
18from sqlalchemy import String
19from sqlalchemy import Table
20from sqlalchemy import testing
21from sqlalchemy import text
22from sqlalchemy import true
23from sqlalchemy import tuple_
24from sqlalchemy import union
25from sqlalchemy.sql import ClauseElement
26from sqlalchemy.sql import column
27from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
28from sqlalchemy.sql import operators
29from sqlalchemy.sql import table
30from sqlalchemy.sql import util as sql_util
31from sqlalchemy.sql import visitors
32from sqlalchemy.sql.elements import _clone
33from sqlalchemy.sql.expression import _from_objects
34from sqlalchemy.sql.visitors import ClauseVisitor
35from sqlalchemy.sql.visitors import cloned_traverse
36from sqlalchemy.sql.visitors import CloningVisitor
37from sqlalchemy.sql.visitors import ReplacingCloningVisitor
38from sqlalchemy.testing import assert_raises
39from sqlalchemy.testing import assert_raises_message
40from sqlalchemy.testing import AssertsCompiledSQL
41from sqlalchemy.testing import AssertsExecutionResults
42from sqlalchemy.testing import eq_
43from sqlalchemy.testing import fixtures
44from sqlalchemy.testing import is_
45from sqlalchemy.testing import is_not
46from sqlalchemy.testing.schema import eq_clause_element
47from sqlalchemy.util import pickle
48
49A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
50
51
52class TraversalTest(
53    fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
54):
55
56    """test ClauseVisitor's traversal, particularly its
57    ability to copy and modify a ClauseElement in place."""
58
59    @classmethod
60    def setup_test_class(cls):
61        global A, B
62
63        # establish two fictitious ClauseElements.
64        # define deep equality semantics as well as deep
65        # identity semantics.
66        class A(ClauseElement):
67            __visit_name__ = "a"
68            _traverse_internals = []
69
70            def __init__(self, expr):
71                self.expr = expr
72
73            def is_other(self, other):
74                return other is self
75
76            __hash__ = ClauseElement.__hash__
77
78            def __eq__(self, other):
79                return other.expr == self.expr
80
81            def __ne__(self, other):
82                return other.expr != self.expr
83
84            def __str__(self):
85                return "A(%s)" % repr(self.expr)
86
87        class B(ClauseElement):
88            __visit_name__ = "b"
89
90            def __init__(self, *items):
91                self.items = items
92
93            def is_other(self, other):
94                if other is not self:
95                    return False
96                for i1, i2 in zip(self.items, other.items):
97                    if i1 is not i2:
98                        return False
99                return True
100
101            __hash__ = ClauseElement.__hash__
102
103            def __eq__(self, other):
104                for i1, i2 in zip(self.items, other.items):
105                    if i1 != i2:
106                        return False
107                return True
108
109            def __ne__(self, other):
110                for i1, i2 in zip(self.items, other.items):
111                    if i1 != i2:
112                        return True
113                return False
114
115            def _copy_internals(self, clone=_clone, **kw):
116                self.items = [clone(i, **kw) for i in self.items]
117
118            def get_children(self, **kwargs):
119                return self.items
120
121            def __str__(self):
122                return "B(%s)" % repr([str(i) for i in self.items])
123
124    def test_test_classes(self):
125        a1 = A("expr1")
126        struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
127        struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
128        struct3 = B(
129            a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3")
130        )
131
132        assert a1.is_other(a1)
133        assert struct.is_other(struct)
134        assert struct == struct2
135        assert struct != struct3
136        assert not struct.is_other(struct2)
137        assert not struct.is_other(struct3)
138
139    def test_clone(self):
140        struct = B(
141            A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")
142        )
143
144        class Vis(CloningVisitor):
145            def visit_a(self, a):
146                pass
147
148            def visit_b(self, b):
149                pass
150
151        vis = Vis()
152        s2 = vis.traverse(struct)
153        assert struct == s2
154        assert not struct.is_other(s2)
155
156    def test_no_clone(self):
157        struct = B(
158            A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")
159        )
160
161        class Vis(ClauseVisitor):
162            def visit_a(self, a):
163                pass
164
165            def visit_b(self, b):
166                pass
167
168        vis = Vis()
169        s2 = vis.traverse(struct)
170        assert struct == s2
171        assert struct.is_other(s2)
172
173    def test_clone_anon_label(self):
174        from sqlalchemy.sql.elements import Grouping
175
176        c1 = Grouping(literal_column("q"))
177        s1 = select(c1)
178
179        class Vis(CloningVisitor):
180            def visit_grouping(self, elem):
181                pass
182
183        vis = Vis()
184        s2 = vis.traverse(s1)
185        eq_(list(s2.selected_columns)[0]._anon_name_label, c1._anon_name_label)
186
187    @testing.combinations(
188        ("clone",), ("pickle",), ("conv_to_unique"), ("none"), argnames="meth"
189    )
190    @testing.combinations(
191        ("name with space",),
192        ("name with [brackets]",),
193        ("name with~~tildes~~",),
194        argnames="name",
195    )
196    def test_bindparam_key_proc_for_copies(self, meth, name):
197        r"""test :ticket:`6249`.
198
199        The key of the bindparam needs spaces and other characters
200        escaped out for the POSTCOMPILE regex to work correctly.
201
202
203        Currently, the bind key reg is::
204
205            re.sub(r"[%\(\) \$\[\]]", "_", name)
206
207        and the compiler postcompile reg is::
208
209            re.sub(r"\__[POSTCOMPILE_(\S+)\]", process_expanding, self.string)
210
211        Interestingly, brackets in the name seems to work out.
212
213        """
214        expr = column(name).in_([1, 2, 3])
215
216        if meth == "clone":
217            expr = visitors.cloned_traverse(expr, {}, {})
218        elif meth == "pickle":
219            expr = pickle.loads(pickle.dumps(expr))
220        elif meth == "conv_to_unique":
221            expr.right.unique = False
222            expr.right._convert_to_unique()
223
224        token = re.sub(r"[%\(\) \$\[\]]", "_", name)
225
226        self.assert_compile(
227            expr,
228            '"%(name)s" IN (:%(token)s_1_1, '
229            ":%(token)s_1_2, :%(token)s_1_3)" % {"name": name, "token": token},
230            render_postcompile=True,
231            dialect="default",
232        )
233
234    def test_expanding_in_bindparam_safe_to_clone(self):
235        expr = column("x").in_([1, 2, 3])
236
237        expr2 = expr._clone()
238
239        # shallow copy, bind is used twice
240        is_(expr.right, expr2.right)
241
242        stmt = and_(expr, expr2)
243        self.assert_compile(
244            stmt, "x IN (__[POSTCOMPILE_x_1]) AND x IN (__[POSTCOMPILE_x_1])"
245        )
246        self.assert_compile(
247            stmt, "x IN (1, 2, 3) AND x IN (1, 2, 3)", literal_binds=True
248        )
249
250    def test_traversal_size(self):
251        """Test :ticket:`6304`.
252
253        Testing that _iterate_from_elements returns only unique FROM
254        clauses; overall traversal should be short and all items unique.
255
256        """
257
258        t = table("t", *[column(x) for x in "pqrxyz"])
259
260        s1 = select(t.c.p, t.c.q, t.c.r, t.c.x, t.c.y, t.c.z).subquery()
261
262        s2 = (
263            select(s1.c.p, s1.c.q, s1.c.r, s1.c.x, s1.c.y, s1.c.z)
264            .select_from(s1)
265            .subquery()
266        )
267
268        s3 = (
269            select(s2.c.p, s2.c.q, s2.c.r, s2.c.x, s2.c.y, s2.c.z)
270            .select_from(s2)
271            .subquery()
272        )
273
274        tt = list(s3.element._iterate_from_elements())
275        eq_(tt, [s2])
276
277        total = list(visitors.iterate(s3))
278        # before the bug was fixed, this was 750
279        eq_(len(total), 25)
280
281        seen = set()
282        for elem in visitors.iterate(s3):
283            assert elem not in seen
284            seen.add(elem)
285
286        eq_(len(seen), 25)
287
288    def test_change_in_place(self):
289        struct = B(
290            A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")
291        )
292        struct2 = B(
293            A("expr1"),
294            A("expr2modified"),
295            B(A("expr1b"), A("expr2b")),
296            A("expr3"),
297        )
298        struct3 = B(
299            A("expr1"),
300            A("expr2"),
301            B(A("expr1b"), A("expr2bmodified")),
302            A("expr3"),
303        )
304
305        class Vis(CloningVisitor):
306            def visit_a(self, a):
307                if a.expr == "expr2":
308                    a.expr = "expr2modified"
309
310            def visit_b(self, b):
311                pass
312
313        vis = Vis()
314        s2 = vis.traverse(struct)
315        assert struct != s2
316        assert not struct.is_other(s2)
317        assert struct2 == s2
318
319        class Vis2(CloningVisitor):
320            def visit_a(self, a):
321                if a.expr == "expr2b":
322                    a.expr = "expr2bmodified"
323
324            def visit_b(self, b):
325                pass
326
327        vis2 = Vis2()
328        s3 = vis2.traverse(struct)
329        assert struct != s3
330        assert struct3 == s3
331
332    def test_visit_name(self):
333        # override fns in testlib/schema.py
334        from sqlalchemy import Column
335
336        class CustomObj(Column):
337            pass
338
339        assert CustomObj.__visit_name__ == Column.__visit_name__ == "column"
340
341        foo, bar = CustomObj("foo", String), CustomObj("bar", String)
342        bin_ = foo == bar
343        set(ClauseVisitor().iterate(bin_))
344        assert set(ClauseVisitor().iterate(bin_)) == set([foo, bar, bin_])
345
346
347class BinaryEndpointTraversalTest(fixtures.TestBase):
348
349    """test the special binary product visit"""
350
351    def _assert_traversal(self, expr, expected):
352        canary = []
353
354        def visit(binary, l, r):
355            canary.append((binary.operator, l, r))
356            print(binary.operator, l, r)
357
358        sql_util.visit_binary_product(visit, expr)
359        eq_(canary, expected)
360
361    def test_basic(self):
362        a, b = column("a"), column("b")
363        self._assert_traversal(a == b, [(operators.eq, a, b)])
364
365    def test_with_tuples(self):
366        a, b, c, d, b1, b1a, b1b, e, f = (
367            column("a"),
368            column("b"),
369            column("c"),
370            column("d"),
371            column("b1"),
372            column("b1a"),
373            column("b1b"),
374            column("e"),
375            column("f"),
376        )
377        expr = tuple_(a, b, b1 == tuple_(b1a, b1b == d), c) > tuple_(
378            func.go(e + f)
379        )
380        self._assert_traversal(
381            expr,
382            [
383                (operators.gt, a, e),
384                (operators.gt, a, f),
385                (operators.gt, b, e),
386                (operators.gt, b, f),
387                (operators.eq, b1, b1a),
388                (operators.eq, b1b, d),
389                (operators.gt, c, e),
390                (operators.gt, c, f),
391            ],
392        )
393
394    def test_composed(self):
395        a, b, e, f, q, j, r = (
396            column("a"),
397            column("b"),
398            column("e"),
399            column("f"),
400            column("q"),
401            column("j"),
402            column("r"),
403        )
404        expr = and_((a + b) == q + func.sum(e + f), and_(j == r, f == q))
405        self._assert_traversal(
406            expr,
407            [
408                (operators.eq, a, q),
409                (operators.eq, a, e),
410                (operators.eq, a, f),
411                (operators.eq, b, q),
412                (operators.eq, b, e),
413                (operators.eq, b, f),
414                (operators.eq, j, r),
415                (operators.eq, f, q),
416            ],
417        )
418
419    def test_subquery(self):
420        a, b, c = column("a"), column("b"), column("c")
421        subq = select(c).where(c == a).scalar_subquery()
422        expr = and_(a == b, b == subq)
423        self._assert_traversal(
424            expr, [(operators.eq, a, b), (operators.eq, b, subq)]
425        )
426
427
428class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
429
430    """test copy-in-place behavior of various ClauseElements."""
431
432    __dialect__ = "default"
433
434    @classmethod
435    def setup_test_class(cls):
436        global t1, t2, t3
437        t1 = table("table1", column("col1"), column("col2"), column("col3"))
438        t2 = table("table2", column("col1"), column("col2"), column("col3"))
439        t3 = Table(
440            "table3",
441            MetaData(),
442            Column("col1", Integer),
443            Column("col2", Integer),
444        )
445
446    def test_binary(self):
447        clause = t1.c.col2 == t2.c.col2
448        eq_(str(clause), str(CloningVisitor().traverse(clause)))
449
450    def test_binary_anon_label_quirk(self):
451        t = table("t1", column("col1"))
452
453        f = t.c.col1 * 5
454        self.assert_compile(
455            select(f), "SELECT t1.col1 * :col1_1 AS anon_1 FROM t1"
456        )
457
458        f._anon_name_label
459
460        a = t.alias()
461        f = sql_util.ClauseAdapter(a).traverse(f)
462
463        self.assert_compile(
464            select(f), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1"
465        )
466
467    @testing.combinations(
468        (lambda t1: t1.c.col1, "t1_1.col1"),
469        (lambda t1: t1.c.col1 == "foo", "t1_1.col1 = :col1_1"),
470        (
471            lambda t1: case((t1.c.col1 == "foo", "bar"), else_=t1.c.col1),
472            "CASE WHEN (t1_1.col1 = :col1_1) THEN :param_1 ELSE t1_1.col1 END",
473        ),
474        argnames="case, expected",
475    )
476    @testing.combinations(False, True, argnames="label_")
477    @testing.combinations(False, True, argnames="annotate")
478    def test_annotated_label_cases(self, case, expected, label_, annotate):
479        """test #6550"""
480
481        t1 = table("t1", column("col1"))
482        a1 = t1.alias()
483
484        expr = case(t1=t1)
485
486        if label_:
487            expr = expr.label(None)
488        if annotate:
489            expr = expr._annotate({"foo": "bar"})
490
491        adapted = sql_util.ClauseAdapter(a1).traverse(expr)
492
493        self.assert_compile(adapted, expected)
494
495    @testing.combinations((null(),), (true(),))
496    def test_dont_adapt_singleton_elements(self, elem):
497        """test :ticket:`6259`"""
498        t1 = table("t1", column("c1"))
499
500        stmt = select(t1.c.c1, elem)
501
502        wherecond = t1.c.c1.is_(elem)
503
504        subq = stmt.subquery()
505
506        adapted_wherecond = sql_util.ClauseAdapter(subq).traverse(wherecond)
507        stmt = select(subq).where(adapted_wherecond)
508
509        self.assert_compile(
510            stmt,
511            "SELECT anon_1.c1, anon_1.anon_2 FROM (SELECT t1.c1 AS c1, "
512            "%s AS anon_2 FROM t1) AS anon_1 WHERE anon_1.c1 IS %s"
513            % (str(elem), str(elem)),
514            dialect="default_enhanced",
515        )
516
517    def test_adapt_funcs_etc_on_identity_one(self):
518        """Adapting to a function etc. will adapt if its on identity"""
519        t1 = table("t1", column("c1"))
520
521        elem = func.foobar()
522
523        stmt = select(t1.c.c1, elem)
524
525        wherecond = t1.c.c1 == elem
526
527        subq = stmt.subquery()
528
529        adapted_wherecond = sql_util.ClauseAdapter(subq).traverse(wherecond)
530        stmt = select(subq).where(adapted_wherecond)
531
532        self.assert_compile(
533            stmt,
534            "SELECT anon_1.c1, anon_1.foobar_1 FROM (SELECT t1.c1 AS c1, "
535            "foobar() AS foobar_1 FROM t1) AS anon_1 "
536            "WHERE anon_1.c1 = anon_1.foobar_1",
537            dialect="default_enhanced",
538        )
539
540    def test_adapt_funcs_etc_on_identity_two(self):
541        """Adapting to a function etc. will not adapt if they are different"""
542        t1 = table("t1", column("c1"))
543
544        elem = func.foobar()
545        elem2 = func.foobar()
546
547        stmt = select(t1.c.c1, elem)
548
549        wherecond = t1.c.c1 == elem2
550
551        subq = stmt.subquery()
552
553        adapted_wherecond = sql_util.ClauseAdapter(subq).traverse(wherecond)
554        stmt = select(subq).where(adapted_wherecond)
555
556        self.assert_compile(
557            stmt,
558            "SELECT anon_1.c1, anon_1.foobar_1 FROM (SELECT t1.c1 AS c1, "
559            "foobar() AS foobar_1 FROM t1) AS anon_1 "
560            "WHERE anon_1.c1 = foobar()",
561            dialect="default_enhanced",
562        )
563
564    def test_join(self):
565        clause = t1.join(t2, t1.c.col2 == t2.c.col2)
566        c1 = str(clause)
567        assert str(clause) == str(CloningVisitor().traverse(clause))
568
569        class Vis(CloningVisitor):
570            def visit_binary(self, binary):
571                binary.right = t2.c.col3
572
573        clause2 = Vis().traverse(clause)
574        assert c1 == str(clause)
575        assert str(clause2) == str(t1.join(t2, t1.c.col2 == t2.c.col3))
576
577    def test_aliased_column_adapt(self):
578        t1.select()
579
580        aliased = t1.select().alias()
581        aliased2 = t1.alias()
582
583        adapter = sql_util.ColumnAdapter(aliased)
584
585        f = select(*[adapter.columns[c] for c in aliased2.c]).select_from(
586            aliased
587        )
588
589        s = select(aliased2).select_from(aliased)
590        eq_(str(s), str(f))
591
592        f = select(adapter.columns[func.count(aliased2.c.col1)]).select_from(
593            aliased
594        )
595        eq_(
596            str(select(func.count(aliased2.c.col1)).select_from(aliased)),
597            str(f),
598        )
599
600    def test_aliased_cloned_column_adapt_inner(self):
601        clause = select(t1.c.col1, func.foo(t1.c.col2).label("foo"))
602        c_sub = clause.subquery()
603        aliased1 = select(c_sub.c.col1, c_sub.c.foo).subquery()
604        aliased2 = clause
605        aliased2.selected_columns.col1, aliased2.selected_columns.foo
606        aliased3 = cloned_traverse(aliased2, {}, {})
607
608        # fixed by [ticket:2419].   the inside columns
609        # on aliased3 have _is_clone_of pointers to those of
610        # aliased2.  corresponding_column checks these
611        # now.
612        adapter = sql_util.ColumnAdapter(aliased1)
613        f1 = select(*[adapter.columns[c] for c in aliased2._raw_columns])
614        f2 = select(*[adapter.columns[c] for c in aliased3._raw_columns])
615        eq_(str(f1), str(f2))
616
617    def test_aliased_cloned_column_adapt_exported(self):
618        clause = select(t1.c.col1, func.foo(t1.c.col2).label("foo")).subquery()
619
620        aliased1 = select(clause.c.col1, clause.c.foo).subquery()
621        aliased2 = clause
622        aliased2.c.col1, aliased2.c.foo
623        aliased3 = cloned_traverse(aliased2, {}, {})
624
625        # also fixed by [ticket:2419].  When we look at the
626        # *outside* columns of aliased3, they previously did not
627        # have an _is_clone_of pointer.   But we now modified _make_proxy
628        # to assign this.
629        adapter = sql_util.ColumnAdapter(aliased1)
630        f1 = select(*[adapter.columns[c] for c in aliased2.c])
631        f2 = select(*[adapter.columns[c] for c in aliased3.c])
632        eq_(str(f1), str(f2))
633
634    def test_aliased_cloned_schema_column_adapt_exported(self):
635        clause = select(t3.c.col1, func.foo(t3.c.col2).label("foo")).subquery()
636
637        aliased1 = select(clause.c.col1, clause.c.foo).subquery()
638        aliased2 = clause
639        aliased2.c.col1, aliased2.c.foo
640        aliased3 = cloned_traverse(aliased2, {}, {})
641
642        # also fixed by [ticket:2419].  When we look at the
643        # *outside* columns of aliased3, they previously did not
644        # have an _is_clone_of pointer.   But we now modified _make_proxy
645        # to assign this.
646        adapter = sql_util.ColumnAdapter(aliased1)
647        f1 = select(*[adapter.columns[c] for c in aliased2.c])
648        f2 = select(*[adapter.columns[c] for c in aliased3.c])
649        eq_(str(f1), str(f2))
650
651    def test_labeled_expression_adapt(self):
652        lbl_x = (t3.c.col1 == 1).label("x")
653        t3_alias = t3.alias()
654
655        adapter = sql_util.ColumnAdapter(t3_alias)
656
657        lblx_adapted = adapter.traverse(lbl_x)
658        is_not(lblx_adapted._element, lbl_x._element)
659
660        lblx_adapted = adapter.traverse(lbl_x)
661        self.assert_compile(
662            select(lblx_adapted.self_group()),
663            "SELECT (table3_1.col1 = :col1_1) AS x FROM table3 AS table3_1",
664        )
665
666        self.assert_compile(
667            select(lblx_adapted.is_(True)),
668            "SELECT (table3_1.col1 = :col1_1) IS 1 AS anon_1 "
669            "FROM table3 AS table3_1",
670        )
671
672    def test_cte_w_union(self):
673        t = select(func.values(1).label("n")).cte("t", recursive=True)
674        t = t.union_all(select(t.c.n + 1).where(t.c.n < 100))
675        s = select(func.sum(t.c.n))
676
677        from sqlalchemy.sql.visitors import cloned_traverse
678
679        cloned = cloned_traverse(s, {}, {})
680
681        self.assert_compile(
682            cloned,
683            "WITH RECURSIVE t(n) AS "
684            "(SELECT values(:values_1) AS n "
685            "UNION ALL SELECT t.n + :n_1 AS anon_1 "
686            "FROM t "
687            "WHERE t.n < :n_2) "
688            "SELECT sum(t.n) AS sum_1 FROM t",
689        )
690
691    def test_aliased_cte_w_union(self):
692        t = (
693            select(func.values(1).label("n"))
694            .cte("t", recursive=True)
695            .alias("foo")
696        )
697        t = t.union_all(select(t.c.n + 1).where(t.c.n < 100))
698        s = select(func.sum(t.c.n))
699
700        from sqlalchemy.sql.visitors import cloned_traverse
701
702        cloned = cloned_traverse(s, {}, {})
703
704        self.assert_compile(
705            cloned,
706            "WITH RECURSIVE foo(n) AS (SELECT values(:values_1) AS n "
707            "UNION ALL SELECT foo.n + :n_1 AS anon_1 FROM foo "
708            "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo",
709        )
710
711    def test_text(self):
712        clause = text("select * from table where foo=:bar").bindparams(
713            bindparam("bar")
714        )
715        c1 = str(clause)
716
717        class Vis(CloningVisitor):
718            def visit_textclause(self, text):
719                text.text = text.text + " SOME MODIFIER=:lala"
720                text._bindparams["lala"] = bindparam("lala")
721
722        clause2 = Vis().traverse(clause)
723        assert c1 == str(clause)
724        assert str(clause2) == c1 + " SOME MODIFIER=:lala"
725        assert list(clause._bindparams.keys()) == ["bar"]
726        assert set(clause2._bindparams.keys()) == set(["bar", "lala"])
727
728    def test_select(self):
729        s2 = select(t1)
730        s2_assert = str(s2)
731        s3_assert = str(select(t1).where(t1.c.col2 == 7))
732
733        class Vis(CloningVisitor):
734            def visit_select(self, select):
735                select.where.non_generative(select, t1.c.col2 == 7)
736
737        s3 = Vis().traverse(s2)
738        assert str(s3) == s3_assert
739        assert str(s2) == s2_assert
740        print(str(s2))
741        print(str(s3))
742
743        class Vis(ClauseVisitor):
744            def visit_select(self, select):
745                select.where.non_generative(select, t1.c.col2 == 7)
746
747        Vis().traverse(s2)
748        assert str(s2) == s3_assert
749
750        s4_assert = str(select(t1).where(and_(t1.c.col2 == 7, t1.c.col3 == 9)))
751
752        class Vis(CloningVisitor):
753            def visit_select(self, select):
754                select.where.non_generative(select, t1.c.col3 == 9)
755
756        s4 = Vis().traverse(s3)
757        print(str(s3))
758        print(str(s4))
759        assert str(s4) == s4_assert
760        assert str(s3) == s3_assert
761
762        s5_assert = str(select(t1).where(and_(t1.c.col2 == 7, t1.c.col1 == 9)))
763
764        class Vis(CloningVisitor):
765            def visit_binary(self, binary):
766                if binary.left is t1.c.col3:
767                    binary.left = t1.c.col1
768                    binary.right = bindparam("col1", unique=True)
769
770        s5 = Vis().traverse(s4)
771        print(str(s4))
772        print(str(s5))
773        assert str(s5) == s5_assert
774        assert str(s4) == s4_assert
775
776    def test_union(self):
777        u = union(t1.select(), t2.select())
778        u2 = CloningVisitor().traverse(u)
779        eq_(str(u), str(u2))
780
781        eq_(
782            [str(c) for c in u2.selected_columns],
783            [str(c) for c in u.selected_columns],
784        )
785
786        u = union(t1.select(), t2.select())
787        cols = [str(c) for c in u.selected_columns]
788        u2 = CloningVisitor().traverse(u)
789        eq_(str(u), str(u2))
790        eq_([str(c) for c in u2.selected_columns], cols)
791
792        s1 = select(t1).where(t1.c.col1 == bindparam("id_param"))
793        s2 = select(t2)
794        u = union(s1, s2)
795
796        u2 = u.params(id_param=7)
797        u3 = u.params(id_param=10)
798
799        eq_(str(u), str(u2))
800        eq_(str(u2), str(u3))
801        eq_(u2.compile().params, {"id_param": 7})
802        eq_(u3.compile().params, {"id_param": 10})
803
804    def test_params_elements_in_setup_joins(self):
805        """test #7055"""
806
807        meta = MetaData()
808
809        X = Table("x", meta, Column("a", Integer), Column("b", Integer))
810        Y = Table("y", meta, Column("a", Integer), Column("b", Integer))
811        s1 = select(X.c.a).where(X.c.b == bindparam("xb")).alias("s1")
812        jj = (
813            select(Y)
814            .join(s1, Y.c.a == s1.c.a)
815            .where(Y.c.b == bindparam("yb"))
816            .alias("s2")
817        )
818
819        params = {"xb": 42, "yb": 33}
820        sel = select(Y).select_from(jj).params(params)
821
822        eq_(
823            [
824                eq_clause_element(bindparam("yb", value=33)),
825                eq_clause_element(bindparam("xb", value=42)),
826            ],
827            sel._generate_cache_key()[1],
828        )
829
830    def test_params_subqueries_in_joins_one(self):
831        """test #7055"""
832
833        meta = MetaData()
834
835        Pe = Table(
836            "pe",
837            meta,
838            Column("c", Integer),
839            Column("p", Integer),
840            Column("pid", Integer),
841        )
842        S = Table(
843            "s",
844            meta,
845            Column("c", Integer),
846            Column("p", Integer),
847            Column("sid", Integer),
848        )
849        Ps = Table("ps", meta, Column("c", Integer), Column("p", Integer))
850        params = {"pid": 42, "sid": 33}
851
852        pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s")
853        s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s")
854        jj = join(
855            Ps,
856            join(pe_s, s_s, and_(pe_s.c.c == s_s.c.c, pe_s.c.p == s_s.c.p)),
857            and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p),
858        ).params(params)
859
860        eq_(
861            [
862                eq_clause_element(bindparam("pid", value=42)),
863                eq_clause_element(bindparam("sid", value=33)),
864            ],
865            jj._generate_cache_key()[1],
866        )
867
868    def test_params_subqueries_in_joins_two(self):
869        """test #7055"""
870
871        meta = MetaData()
872
873        Pe = Table(
874            "pe",
875            meta,
876            Column("c", Integer),
877            Column("p", Integer),
878            Column("pid", Integer),
879        )
880        S = Table(
881            "s",
882            meta,
883            Column("c", Integer),
884            Column("p", Integer),
885            Column("sid", Integer),
886        )
887        Ps = Table("ps", meta, Column("c", Integer), Column("p", Integer))
888
889        params = {"pid": 42, "sid": 33}
890
891        pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s")
892        s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s")
893        jj = (
894            join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p))
895            .join(s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p))
896            .params(params)
897        )
898
899        eq_(
900            [
901                eq_clause_element(bindparam("pid", value=42)),
902                eq_clause_element(bindparam("sid", value=33)),
903            ],
904            jj._generate_cache_key()[1],
905        )
906
907    def test_in(self):
908        expr = t1.c.col1.in_(["foo", "bar"])
909        expr2 = CloningVisitor().traverse(expr)
910        assert str(expr) == str(expr2)
911
912    def test_over(self):
913        expr = func.row_number().over(order_by=t1.c.col1)
914        expr2 = CloningVisitor().traverse(expr)
915        assert str(expr) == str(expr2)
916
917        assert expr in visitors.iterate(expr, {})
918
919    def test_within_group(self):
920        expr = func.row_number().within_group(t1.c.col1)
921        expr2 = CloningVisitor().traverse(expr)
922        assert str(expr) == str(expr2)
923
924        assert expr in visitors.iterate(expr, {})
925
926    def test_funcfilter(self):
927        expr = func.count(1).filter(t1.c.col1 > 1)
928        expr2 = CloningVisitor().traverse(expr)
929        assert str(expr) == str(expr2)
930
931    def test_adapt_union(self):
932        u = union(
933            t1.select().where(t1.c.col1 == 4),
934            t1.select().where(t1.c.col1 == 5),
935        ).alias()
936
937        assert sql_util.ClauseAdapter(u).traverse(t1) is u
938
939    def test_bindparams(self):
940        """test that unique bindparams change their name upon clone()
941        to prevent conflicts"""
942
943        s = select(t1).where(t1.c.col1 == bindparam(None, unique=True)).alias()
944        s2 = CloningVisitor().traverse(s).alias()
945        s3 = select(s).where(s.c.col2 == s2.c.col2)
946
947        self.assert_compile(
948            s3,
949            "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM "
950            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
951            "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :param_1) "
952            "AS anon_1, "
953            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 "
954            "AS col3 FROM table1 WHERE table1.col1 = :param_2) AS anon_2 "
955            "WHERE anon_1.col2 = anon_2.col2",
956        )
957
958        s = select(t1).where(t1.c.col1 == 4).alias()
959        s2 = CloningVisitor().traverse(s).alias()
960        s3 = select(s).where(s.c.col2 == s2.c.col2)
961        self.assert_compile(
962            s3,
963            "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM "
964            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
965            "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) "
966            "AS anon_1, "
967            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 "
968            "AS col3 FROM table1 WHERE table1.col1 = :col1_2) AS anon_2 "
969            "WHERE anon_1.col2 = anon_2.col2",
970        )
971
972    def test_extract(self):
973        s = select(extract("foo", t1.c.col1).label("col1"))
974        self.assert_compile(
975            s, "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1"
976        )
977
978        s2 = CloningVisitor().traverse(s).alias()
979        s3 = select(s2.c.col1)
980        self.assert_compile(
981            s, "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1"
982        )
983        self.assert_compile(
984            s3,
985            "SELECT anon_1.col1 FROM (SELECT EXTRACT(foo FROM "
986            "table1.col1) AS col1 FROM table1) AS anon_1",
987        )
988
989    @testing.emits_warning(".*replaced by another column with the same key")
990    def test_alias(self):
991        subq = t2.select().alias("subq")
992        s = select(t1.c.col1, subq.c.col1).select_from(
993            t1, subq, t1.join(subq, t1.c.col1 == subq.c.col2)
994        )
995        orig = str(s)
996        s2 = CloningVisitor().traverse(s)
997        eq_(orig, str(s))
998        eq_(str(s), str(s2))
999
1000        s4 = CloningVisitor().traverse(s2)
1001        eq_(orig, str(s))
1002        eq_(str(s), str(s2))
1003        eq_(str(s), str(s4))
1004
1005        s3 = sql_util.ClauseAdapter(table("foo")).traverse(s)
1006        eq_(orig, str(s))
1007        eq_(str(s), str(s3))
1008
1009        s4 = sql_util.ClauseAdapter(table("foo")).traverse(s3)
1010        eq_(orig, str(s))
1011        eq_(str(s), str(s3))
1012        eq_(str(s), str(s4))
1013
1014        subq = subq.alias("subq")
1015        s = select(t1.c.col1, subq.c.col1).select_from(
1016            t1,
1017            subq,
1018            t1.join(subq, t1.c.col1 == subq.c.col2),
1019        )
1020        s5 = CloningVisitor().traverse(s)
1021        eq_(str(s), str(s5))
1022
1023    def test_correlated_select(self):
1024        s = (
1025            select(literal_column("*"))
1026            .where(t1.c.col1 == t2.c.col1)
1027            .select_from(t1, t2)
1028            .correlate(t2)
1029        )
1030
1031        class Vis(CloningVisitor):
1032            def visit_select(self, select):
1033                select.where.non_generative(select, t1.c.col2 == 7)
1034
1035        self.assert_compile(
1036            select(t2).where(t2.c.col1 == Vis().traverse(s).scalar_subquery()),
1037            "SELECT table2.col1, table2.col2, table2.col3 "
1038            "FROM table2 WHERE table2.col1 = "
1039            "(SELECT * FROM table1 WHERE table1.col1 = table2.col1 "
1040            "AND table1.col2 = :col2_1)",
1041        )
1042
1043    def test_this_thing(self):
1044        s = select(t1).where(t1.c.col1 == "foo").alias()
1045        s2 = select(s.c.col1)
1046
1047        self.assert_compile(
1048            s2,
1049            "SELECT anon_1.col1 FROM (SELECT "
1050            "table1.col1 AS col1, table1.col2 AS col2, "
1051            "table1.col3 AS col3 FROM table1 WHERE "
1052            "table1.col1 = :col1_1) AS anon_1",
1053        )
1054        t1a = t1.alias()
1055        s2 = sql_util.ClauseAdapter(t1a).traverse(s2)
1056        self.assert_compile(
1057            s2,
1058            "SELECT anon_1.col1 FROM (SELECT "
1059            "table1_1.col1 AS col1, table1_1.col2 AS "
1060            "col2, table1_1.col3 AS col3 FROM table1 "
1061            "AS table1_1 WHERE table1_1.col1 = "
1062            ":col1_1) AS anon_1",
1063        )
1064
1065    def test_this_thing_using_setup_joins_one(self):
1066        s = select(t1).join_from(t1, t2, t1.c.col1 == t2.c.col2).subquery()
1067        s2 = select(s.c.col1).join_from(t3, s, t3.c.col2 == s.c.col1)
1068
1069        self.assert_compile(
1070            s2,
1071            "SELECT anon_1.col1 FROM table3 JOIN (SELECT table1.col1 AS "
1072            "col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 "
1073            "JOIN table2 ON table1.col1 = table2.col2) AS anon_1 "
1074            "ON table3.col2 = anon_1.col1",
1075        )
1076        t1a = t1.alias()
1077        s2 = sql_util.ClauseAdapter(t1a).traverse(s2)
1078        self.assert_compile(
1079            s2,
1080            "SELECT anon_1.col1 FROM table3 JOIN (SELECT table1_1.col1 AS "
1081            "col1, table1_1.col2 AS col2, table1_1.col3 AS col3 "
1082            "FROM table1 AS table1_1 JOIN table2 ON table1_1.col1 = "
1083            "table2.col2) AS anon_1 ON table3.col2 = anon_1.col1",
1084        )
1085
1086    def test_this_thing_using_setup_joins_two(self):
1087        s = select(t1.c.col1).join(t2, t1.c.col1 == t2.c.col2).subquery()
1088        s2 = select(s.c.col1)
1089
1090        self.assert_compile(
1091            s2,
1092            "SELECT anon_1.col1 FROM (SELECT table1.col1 AS col1 "
1093            "FROM table1 JOIN table2 ON table1.col1 = table2.col2) AS anon_1",
1094        )
1095
1096        t1alias = t1.alias("t1alias")
1097        j = t1.join(t1alias, t1.c.col1 == t1alias.c.col2)
1098
1099        vis = sql_util.ClauseAdapter(j)
1100
1101        s2 = vis.traverse(s2)
1102        self.assert_compile(
1103            s2,
1104            "SELECT anon_1.col1 FROM (SELECT table1.col1 AS col1 "
1105            "FROM table1 JOIN table1 AS t1alias "
1106            "ON table1.col1 = t1alias.col2 "
1107            "JOIN table2 ON table1.col1 = table2.col2) AS anon_1",
1108        )
1109
1110    def test_this_thing_using_setup_joins_three(self):
1111
1112        j = t1.join(t2, t1.c.col1 == t2.c.col2)
1113
1114        s1 = select(j)
1115
1116        s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
1117
1118        self.assert_compile(
1119            s2,
1120            "SELECT table1.col1, table1.col2, table1.col3, "
1121            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1122            "table2.col3 AS col3_1 FROM table1 "
1123            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
1124            "ON table3.col1 = table1.col1",
1125        )
1126
1127        vis = sql_util.ClauseAdapter(j)
1128
1129        s3 = vis.traverse(s1)
1130
1131        s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
1132
1133        self.assert_compile(
1134            s4,
1135            "SELECT table1.col1, table1.col2, table1.col3, "
1136            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1137            "table2.col3 AS col3_1 FROM table1 "
1138            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
1139            "ON table3.col1 = table1.col1",
1140        )
1141
1142        s5 = vis.traverse(s3)
1143
1144        s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
1145
1146        self.assert_compile(
1147            s6,
1148            "SELECT table1.col1, table1.col2, table1.col3, "
1149            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1150            "table2.col3 AS col3_1 FROM table1 "
1151            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
1152            "ON table3.col1 = table1.col1",
1153        )
1154
1155    def test_this_thing_using_setup_joins_four(self):
1156
1157        j = t1.join(t2, t1.c.col1 == t2.c.col2)
1158
1159        s1 = select(j)
1160
1161        assert not s1._from_obj
1162
1163        s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
1164
1165        self.assert_compile(
1166            s2,
1167            "SELECT table1.col1, table1.col2, table1.col3, "
1168            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1169            "table2.col3 AS col3_1 FROM table1 "
1170            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
1171            "ON table3.col1 = table1.col1",
1172        )
1173
1174        s3 = visitors.replacement_traverse(s1, {}, lambda elem: None)
1175
1176        s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
1177
1178        self.assert_compile(
1179            s4,
1180            "SELECT table1.col1, table1.col2, table1.col3, "
1181            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1182            "table2.col3 AS col3_1 FROM table1 "
1183            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
1184            "ON table3.col1 = table1.col1",
1185        )
1186
1187        s5 = visitors.replacement_traverse(s3, {}, lambda elem: None)
1188
1189        s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
1190
1191        self.assert_compile(
1192            s6,
1193            "SELECT table1.col1, table1.col2, table1.col3, "
1194            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1195            "table2.col3 AS col3_1 FROM table1 "
1196            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
1197            "ON table3.col1 = table1.col1",
1198        )
1199
1200    def test_select_fromtwice_one(self):
1201        t1a = t1.alias()
1202
1203        s = (
1204            select(1)
1205            .where(t1.c.col1 == t1a.c.col1)
1206            .select_from(t1a)
1207            .correlate(t1a)
1208        )
1209        s = select(t1).where(t1.c.col1 == s.scalar_subquery())
1210        self.assert_compile(
1211            s,
1212            "SELECT table1.col1, table1.col2, table1.col3 FROM table1 "
1213            "WHERE table1.col1 = "
1214            "(SELECT 1 FROM table1, table1 AS table1_1 "
1215            "WHERE table1.col1 = table1_1.col1)",
1216        )
1217        s = CloningVisitor().traverse(s)
1218        self.assert_compile(
1219            s,
1220            "SELECT table1.col1, table1.col2, table1.col3 FROM table1 "
1221            "WHERE table1.col1 = "
1222            "(SELECT 1 FROM table1, table1 AS table1_1 "
1223            "WHERE table1.col1 = table1_1.col1)",
1224        )
1225
1226    def test_select_fromtwice_two(self):
1227        s = select(t1).where(t1.c.col1 == "foo").alias()
1228
1229        s2 = (
1230            select(1).where(t1.c.col1 == s.c.col1).select_from(s).correlate(t1)
1231        )
1232        s3 = select(t1).where(t1.c.col1 == s2.scalar_subquery())
1233        self.assert_compile(
1234            s3,
1235            "SELECT table1.col1, table1.col2, table1.col3 "
1236            "FROM table1 WHERE table1.col1 = "
1237            "(SELECT 1 FROM "
1238            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
1239            "table1.col3 AS col3 FROM table1 "
1240            "WHERE table1.col1 = :col1_1) "
1241            "AS anon_1 WHERE table1.col1 = anon_1.col1)",
1242        )
1243
1244        s4 = ReplacingCloningVisitor().traverse(s3)
1245        self.assert_compile(
1246            s4,
1247            "SELECT table1.col1, table1.col2, table1.col3 "
1248            "FROM table1 WHERE table1.col1 = "
1249            "(SELECT 1 FROM "
1250            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
1251            "table1.col3 AS col3 FROM table1 "
1252            "WHERE table1.col1 = :col1_1) "
1253            "AS anon_1 WHERE table1.col1 = anon_1.col1)",
1254        )
1255
1256    def test_select_setup_joins_adapt_element_one(self):
1257        s = select(t1).join(t2, t1.c.col1 == t2.c.col2)
1258
1259        t1a = t1.alias()
1260
1261        s2 = sql_util.ClauseAdapter(t1a).traverse(s)
1262
1263        self.assert_compile(
1264            s,
1265            "SELECT table1.col1, table1.col2, table1.col3 "
1266            "FROM table1 JOIN table2 ON table1.col1 = table2.col2",
1267        )
1268        self.assert_compile(
1269            s2,
1270            "SELECT table1_1.col1, table1_1.col2, table1_1.col3 "
1271            "FROM table1 AS table1_1 JOIN table2 "
1272            "ON table1_1.col1 = table2.col2",
1273        )
1274
1275    def test_select_setup_joins_adapt_element_two(self):
1276        s = select(literal_column("1")).join_from(
1277            t1, t2, t1.c.col1 == t2.c.col2
1278        )
1279
1280        t1a = t1.alias()
1281
1282        s2 = sql_util.ClauseAdapter(t1a).traverse(s)
1283
1284        self.assert_compile(
1285            s, "SELECT 1 FROM table1 JOIN table2 ON table1.col1 = table2.col2"
1286        )
1287        self.assert_compile(
1288            s2,
1289            "SELECT 1 FROM table1 AS table1_1 "
1290            "JOIN table2 ON table1_1.col1 = table2.col2",
1291        )
1292
1293    def test_select_setup_joins_adapt_element_three(self):
1294        s = select(literal_column("1")).join_from(
1295            t1, t2, t1.c.col1 == t2.c.col2
1296        )
1297
1298        t2a = t2.alias()
1299
1300        s2 = sql_util.ClauseAdapter(t2a).traverse(s)
1301
1302        self.assert_compile(
1303            s, "SELECT 1 FROM table1 JOIN table2 ON table1.col1 = table2.col2"
1304        )
1305        self.assert_compile(
1306            s2,
1307            "SELECT 1 FROM table1 "
1308            "JOIN table2 AS table2_1 ON table1.col1 = table2_1.col2",
1309        )
1310
1311    def test_select_setup_joins_straight_clone(self):
1312        s = select(t1).join(t2, t1.c.col1 == t2.c.col2)
1313
1314        s2 = CloningVisitor().traverse(s)
1315
1316        self.assert_compile(
1317            s,
1318            "SELECT table1.col1, table1.col2, table1.col3 "
1319            "FROM table1 JOIN table2 ON table1.col1 = table2.col2",
1320        )
1321        self.assert_compile(
1322            s2,
1323            "SELECT table1.col1, table1.col2, table1.col3 "
1324            "FROM table1 JOIN table2 ON table1.col1 = table2.col2",
1325        )
1326
1327
1328class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
1329    __dialect__ = "default"
1330
1331    @classmethod
1332    def setup_test_class(cls):
1333        global t1, t2
1334        t1 = table(
1335            "table1",
1336            column("col1"),
1337            column("col2"),
1338            column("col3"),
1339            column("col4"),
1340        )
1341        t2 = table("table2", column("col1"), column("col2"), column("col3"))
1342
1343    def test_traverse_memoizes_w_columns(self):
1344        t1a = t1.alias()
1345        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
1346
1347        expr = select(t1a.c.col1).label("x")
1348        expr_adapted = adapter.traverse(expr)
1349        is_not(expr, expr_adapted)
1350        is_(adapter.columns[expr], expr_adapted)
1351
1352    def test_traverse_memoizes_w_itself(self):
1353        t1a = t1.alias()
1354        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
1355
1356        expr = select(t1a.c.col1).label("x")
1357        expr_adapted = adapter.traverse(expr)
1358        is_not(expr, expr_adapted)
1359        is_(adapter.traverse(expr), expr_adapted)
1360
1361    def test_columns_memoizes_w_itself(self):
1362        t1a = t1.alias()
1363        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
1364
1365        expr = select(t1a.c.col1).label("x")
1366        expr_adapted = adapter.columns[expr]
1367        is_not(expr, expr_adapted)
1368        is_(adapter.columns[expr], expr_adapted)
1369
1370    def test_wrapping_fallthrough(self):
1371        t1a = t1.alias(name="t1a")
1372        t2a = t2.alias(name="t2a")
1373        a1 = sql_util.ColumnAdapter(t1a)
1374
1375        s1 = (
1376            select(t1a.c.col1, t2a.c.col1)
1377            .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
1378            .alias()
1379        )
1380        a2 = sql_util.ColumnAdapter(s1)
1381        a3 = a2.wrap(a1)
1382        a4 = a1.wrap(a2)
1383        a5 = a1.chain(a2)
1384
1385        # t1.c.col1 -> s1.c.t1a_col1
1386
1387        # adapted by a2
1388        is_(a3.columns[t1.c.col1], s1.c.t1a_col1)
1389        is_(a4.columns[t1.c.col1], s1.c.t1a_col1)
1390
1391        # chaining can't fall through because a1 grabs it
1392        # first
1393        is_(a5.columns[t1.c.col1], t1a.c.col1)
1394
1395        # t2.c.col1 -> s1.c.t2a_col1
1396
1397        # adapted by a2
1398        is_(a3.columns[t2.c.col1], s1.c.t2a_col1)
1399        is_(a4.columns[t2.c.col1], s1.c.t2a_col1)
1400        # chaining, t2 hits s1
1401        is_(a5.columns[t2.c.col1], s1.c.t2a_col1)
1402
1403        # t1.c.col2 -> t1a.c.col2
1404
1405        # fallthrough to a1
1406        is_(a3.columns[t1.c.col2], t1a.c.col2)
1407        is_(a4.columns[t1.c.col2], t1a.c.col2)
1408
1409        # chaining hits a1
1410        is_(a5.columns[t1.c.col2], t1a.c.col2)
1411
1412        # t2.c.col2 -> t2.c.col2
1413
1414        # fallthrough to no adaption
1415        is_(a3.columns[t2.c.col2], t2.c.col2)
1416        is_(a4.columns[t2.c.col2], t2.c.col2)
1417
1418    def test_wrapping_ordering(self):
1419        """illustrate an example where order of wrappers matters.
1420
1421        This test illustrates both the ordering being significant
1422        as well as a scenario where multiple translations are needed
1423        (e.g. wrapping vs. chaining).
1424
1425        """
1426
1427        stmt = (
1428            select(t1.c.col1, t2.c.col1)
1429            .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
1430            .subquery()
1431        )
1432
1433        sa = stmt.alias()
1434        stmt2 = select(t2, sa).subquery()
1435
1436        a1 = sql_util.ColumnAdapter(stmt)
1437        a2 = sql_util.ColumnAdapter(stmt2)
1438
1439        a2_to_a1 = a2.wrap(a1)
1440        a1_to_a2 = a1.wrap(a2)
1441
1442        # when stmt2 and stmt represent the same column
1443        # in different contexts, order of wrapping matters
1444
1445        # t2.c.col1 via a2 is stmt2.c.col1; then ignored by a1
1446        is_(a2_to_a1.columns[t2.c.col1], stmt2.c.col1)
1447        # t2.c.col1 via a1 is stmt.c.table2_col1; a2 then
1448        # sends this to stmt2.c.table2_col1
1449        is_(a1_to_a2.columns[t2.c.col1], stmt2.c.table2_col1)
1450
1451        # check that these aren't the same column
1452        is_not(stmt2.c.col1, stmt2.c.table2_col1)
1453
1454        # for mutually exclusive columns, order doesn't matter
1455        is_(a2_to_a1.columns[t1.c.col1], stmt2.c.table1_col1)
1456        is_(a1_to_a2.columns[t1.c.col1], stmt2.c.table1_col1)
1457        is_(a2_to_a1.columns[t2.c.col2], stmt2.c.col2)
1458
1459    def test_wrapping_multiple(self):
1460        """illustrate that wrapping runs both adapters"""
1461
1462        t1a = t1.alias(name="t1a")
1463        t2a = t2.alias(name="t2a")
1464        a1 = sql_util.ColumnAdapter(t1a)
1465        a2 = sql_util.ColumnAdapter(t2a)
1466        a3 = a2.wrap(a1)
1467
1468        stmt = select(t1.c.col1, t2.c.col2)
1469
1470        self.assert_compile(
1471            a3.traverse(stmt),
1472            "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a",
1473        )
1474
1475        # chaining does too because these adapters don't share any
1476        # columns
1477        a4 = a2.chain(a1)
1478        self.assert_compile(
1479            a4.traverse(stmt),
1480            "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a",
1481        )
1482
1483    def test_wrapping_inclusions(self):
1484        """test wrapping and inclusion rules together,
1485        taking into account multiple objects with equivalent hash identity."""
1486
1487        t1a = t1.alias(name="t1a")
1488        t2a = t2.alias(name="t2a")
1489        a1 = sql_util.ColumnAdapter(
1490            t1a, include_fn=lambda col: "a1" in col._annotations
1491        )
1492
1493        s1 = (
1494            select(t1a, t2a)
1495            .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
1496            .alias()
1497        )
1498        a2 = sql_util.ColumnAdapter(
1499            s1, include_fn=lambda col: "a2" in col._annotations
1500        )
1501        a3 = a2.wrap(a1)
1502
1503        c1a1 = t1.c.col1._annotate(dict(a1=True))
1504        c1a2 = t1.c.col1._annotate(dict(a2=True))
1505        c1aa = t1.c.col1._annotate(dict(a1=True, a2=True))
1506
1507        c2a1 = t2.c.col1._annotate(dict(a1=True))
1508        c2a2 = t2.c.col1._annotate(dict(a2=True))
1509        c2aa = t2.c.col1._annotate(dict(a1=True, a2=True))
1510
1511        is_(a3.columns[c1a1], t1a.c.col1)
1512        is_(a3.columns[c1a2], s1.c.t1a_col1)
1513        is_(a3.columns[c1aa], s1.c.t1a_col1)
1514
1515        # not covered by a1, accepted by a2
1516        is_(a3.columns[c2aa], s1.c.t2a_col1)
1517
1518        # not covered by a1, accepted by a2
1519        is_(a3.columns[c2a2], s1.c.t2a_col1)
1520        # not covered by a1, rejected by a2
1521        is_(a3.columns[c2a1], c2a1)
1522
1523
1524class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
1525    __dialect__ = "default"
1526
1527    @classmethod
1528    def setup_test_class(cls):
1529        global t1, t2
1530        t1 = table("table1", column("col1"), column("col2"), column("col3"))
1531        t2 = table("table2", column("col1"), column("col2"), column("col3"))
1532
1533    def test_correlation_on_clone(self):
1534        t1alias = t1.alias("t1alias")
1535        t2alias = t2.alias("t2alias")
1536        vis = sql_util.ClauseAdapter(t1alias)
1537
1538        s = (
1539            select(literal_column("*"))
1540            .select_from(t1alias, t2alias)
1541            .scalar_subquery()
1542        )
1543
1544        froms = list(s._iterate_from_elements())
1545        assert t2alias in froms
1546        assert t1alias in froms
1547
1548        self.assert_compile(
1549            select(literal_column("*")).where(t2alias.c.col1 == s),
1550            "SELECT * FROM table2 AS t2alias WHERE "
1551            "t2alias.col1 = (SELECT * FROM table1 AS "
1552            "t1alias)",
1553        )
1554        s = vis.traverse(s)
1555
1556        froms = list(s._iterate_from_elements())
1557        assert t2alias in froms  # present because it was not cloned
1558        assert t1alias in froms  # present because the adapter placed
1559        # it there and was also not cloned
1560
1561        # correlate list on "s" needs to take into account the full
1562        # _cloned_set for each element in _froms when correlating
1563
1564        self.assert_compile(
1565            select(literal_column("*")).where(t2alias.c.col1 == s),
1566            "SELECT * FROM table2 AS t2alias WHERE "
1567            "t2alias.col1 = (SELECT * FROM table1 AS "
1568            "t1alias)",
1569        )
1570        s = (
1571            select(literal_column("*"))
1572            .select_from(t1alias, t2alias)
1573            .correlate(t2alias)
1574            .scalar_subquery()
1575        )
1576        self.assert_compile(
1577            select(literal_column("*")).where(t2alias.c.col1 == s),
1578            "SELECT * FROM table2 AS t2alias WHERE "
1579            "t2alias.col1 = (SELECT * FROM table1 AS "
1580            "t1alias)",
1581        )
1582        s = vis.traverse(s)
1583        self.assert_compile(
1584            select(literal_column("*")).where(t2alias.c.col1 == s),
1585            "SELECT * FROM table2 AS t2alias WHERE "
1586            "t2alias.col1 = (SELECT * FROM table1 AS "
1587            "t1alias)",
1588        )
1589        s = CloningVisitor().traverse(s)
1590        self.assert_compile(
1591            select(literal_column("*")).where(t2alias.c.col1 == s),
1592            "SELECT * FROM table2 AS t2alias WHERE "
1593            "t2alias.col1 = (SELECT * FROM table1 AS "
1594            "t1alias)",
1595        )
1596
1597        s = (
1598            select(literal_column("*"))
1599            .where(t1.c.col1 == t2.c.col1)
1600            .scalar_subquery()
1601        )
1602        self.assert_compile(
1603            select(t1.c.col1, s),
1604            "SELECT table1.col1, (SELECT * FROM table2 "
1605            "WHERE table1.col1 = table2.col1) AS "
1606            "anon_1 FROM table1",
1607        )
1608        vis = sql_util.ClauseAdapter(t1alias)
1609        s = vis.traverse(s)
1610        self.assert_compile(
1611            select(t1alias.c.col1, s),
1612            "SELECT t1alias.col1, (SELECT * FROM "
1613            "table2 WHERE t1alias.col1 = table2.col1) "
1614            "AS anon_1 FROM table1 AS t1alias",
1615        )
1616        s = CloningVisitor().traverse(s)
1617        self.assert_compile(
1618            select(t1alias.c.col1, s),
1619            "SELECT t1alias.col1, (SELECT * FROM "
1620            "table2 WHERE t1alias.col1 = table2.col1) "
1621            "AS anon_1 FROM table1 AS t1alias",
1622        )
1623        s = (
1624            select(literal_column("*"))
1625            .where(t1.c.col1 == t2.c.col1)
1626            .correlate(t1)
1627            .scalar_subquery()
1628        )
1629        self.assert_compile(
1630            select(t1.c.col1, s),
1631            "SELECT table1.col1, (SELECT * FROM table2 "
1632            "WHERE table1.col1 = table2.col1) AS "
1633            "anon_1 FROM table1",
1634        )
1635        vis = sql_util.ClauseAdapter(t1alias)
1636        s = vis.traverse(s)
1637        self.assert_compile(
1638            select(t1alias.c.col1, s),
1639            "SELECT t1alias.col1, (SELECT * FROM "
1640            "table2 WHERE t1alias.col1 = table2.col1) "
1641            "AS anon_1 FROM table1 AS t1alias",
1642        )
1643        s = CloningVisitor().traverse(s)
1644        self.assert_compile(
1645            select(t1alias.c.col1, s),
1646            "SELECT t1alias.col1, (SELECT * FROM "
1647            "table2 WHERE t1alias.col1 = table2.col1) "
1648            "AS anon_1 FROM table1 AS t1alias",
1649        )
1650
1651    def test_adapt_select_w_unlabeled_fn(self):
1652
1653        expr = func.count(t1.c.col1)
1654        stmt = select(t1, expr)
1655
1656        self.assert_compile(
1657            stmt,
1658            "SELECT table1.col1, table1.col2, table1.col3, "
1659            "count(table1.col1) AS count_1 FROM table1",
1660        )
1661
1662        stmt2 = select(stmt.subquery())
1663
1664        self.assert_compile(
1665            stmt2,
1666            "SELECT anon_1.col1, anon_1.col2, anon_1.col3, anon_1.count_1 "
1667            "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
1668            "table1.col3 AS col3, count(table1.col1) AS count_1 "
1669            "FROM table1) AS anon_1",
1670        )
1671
1672        is_(
1673            stmt2.selected_columns[3],
1674            stmt2.selected_columns.corresponding_column(expr),
1675        )
1676
1677        is_(
1678            sql_util.ClauseAdapter(stmt2).replace(expr),
1679            stmt2.selected_columns[3],
1680        )
1681
1682        column_adapter = sql_util.ColumnAdapter(stmt2)
1683        is_(column_adapter.columns[expr], stmt2.selected_columns[3])
1684
1685    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1686    def test_correlate_except_on_clone(self, use_adapt_from):
1687        # test [ticket:4537]'s issue
1688
1689        t1alias = t1.alias("t1alias")
1690        j = t1.join(t1alias, t1.c.col1 == t1alias.c.col2)
1691
1692        if use_adapt_from:
1693            vis = sql_util.ClauseAdapter(j, adapt_from_selectables=[t1])
1694        else:
1695            vis = sql_util.ClauseAdapter(j)
1696
1697        # "control" subquery - uses correlate which has worked w/ adaption
1698        # for a long time
1699        control_s = (
1700            select(t2.c.col1)
1701            .where(t2.c.col1 == t1.c.col1)
1702            .correlate(t2)
1703            .scalar_subquery()
1704        )
1705
1706        # test subquery - given only t1 and t2 in the enclosing selectable,
1707        # will do the same thing as the "control" query since the correlation
1708        # works out the same
1709        s = (
1710            select(t2.c.col1)
1711            .where(t2.c.col1 == t1.c.col1)
1712            .correlate_except(t1)
1713            .scalar_subquery()
1714        )
1715
1716        # use both subqueries in statements
1717        control_stmt = select(control_s, t1.c.col1, t2.c.col1).select_from(
1718            t1.join(t2, t1.c.col1 == t2.c.col1)
1719        )
1720
1721        stmt = select(s, t1.c.col1, t2.c.col1).select_from(
1722            t1.join(t2, t1.c.col1 == t2.c.col1)
1723        )
1724        # they are the same
1725        self.assert_compile(
1726            control_stmt,
1727            "SELECT "
1728            "(SELECT table2.col1 FROM table1 "
1729            "WHERE table2.col1 = table1.col1) AS anon_1, "
1730            "table1.col1, table2.col1 AS col1_1 "
1731            "FROM table1 "
1732            "JOIN table2 ON table1.col1 = table2.col1",
1733        )
1734        self.assert_compile(
1735            stmt,
1736            "SELECT "
1737            "(SELECT table2.col1 FROM table1 "
1738            "WHERE table2.col1 = table1.col1) AS anon_1, "
1739            "table1.col1, table2.col1 AS col1_1 "
1740            "FROM table1 "
1741            "JOIN table2 ON table1.col1 = table2.col1",
1742        )
1743
1744        # now test against the adaption of "t1" into "t1 JOIN t1alias".
1745        # note in the control case, we aren't actually testing that
1746        # Select is processing the "correlate" list during the adaption
1747        # since we aren't adapting the "correlate"
1748        self.assert_compile(
1749            vis.traverse(control_stmt),
1750            "SELECT "
1751            "(SELECT table2.col1 FROM "
1752            "table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1753            "WHERE table2.col1 = table1.col1) AS anon_1, "
1754            "table1.col1, table2.col1 AS col1_1 "
1755            "FROM table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1756            "JOIN table2 ON table1.col1 = table2.col1",
1757        )
1758
1759        # but here, correlate_except() does have the thing we're adapting
1760        # so whatever is in there has to be expanded out to include
1761        # the adaptation target, in this case "t1 JOIN t1alias".
1762        self.assert_compile(
1763            vis.traverse(stmt),
1764            "SELECT "
1765            "(SELECT table2.col1 FROM "
1766            "table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1767            "WHERE table2.col1 = table1.col1) AS anon_1, "
1768            "table1.col1, table2.col1 AS col1_1 "
1769            "FROM table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1770            "JOIN table2 ON table1.col1 = table2.col1",
1771        )
1772
1773    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1774    def test_correlate_except_with_mixed_tables(self, use_adapt_from):
1775        # test [ticket:6060]'s issue
1776
1777        stmt = select(
1778            t1.c.col1,
1779            select(func.count(t2.c.col1))
1780            .where(t2.c.col1 == t1.c.col1)
1781            .correlate_except(t2)
1782            .scalar_subquery(),
1783        )
1784        self.assert_compile(
1785            stmt,
1786            "SELECT table1.col1, "
1787            "(SELECT count(table2.col1) AS count_1 FROM table2 "
1788            "WHERE table2.col1 = table1.col1) AS anon_1 "
1789            "FROM table1",
1790        )
1791
1792        subq = (
1793            select(t1)
1794            .join(t2, t1.c.col1 == t2.c.col1)
1795            .where(t2.c.col2 == "x")
1796            .subquery()
1797        )
1798
1799        if use_adapt_from:
1800            vis = sql_util.ClauseAdapter(subq, adapt_from_selectables=[t1])
1801        else:
1802            vis = sql_util.ClauseAdapter(subq)
1803
1804        if use_adapt_from:
1805            self.assert_compile(
1806                vis.traverse(stmt),
1807                "SELECT anon_1.col1, "
1808                "(SELECT count(table2.col1) AS count_1 FROM table2 WHERE "
1809                "table2.col1 = anon_1.col1) AS anon_2 "
1810                "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
1811                "table1.col3 AS col3 FROM table1 JOIN table2 ON table1.col1 = "
1812                "table2.col1 WHERE table2.col2 = :col2_1) AS anon_1",
1813            )
1814        else:
1815            # here's the buggy version.  table2 gets yanked out of the
1816            # correlated subquery also.  AliasedClass now uses
1817            # adapt_from_selectables in all cases
1818            self.assert_compile(
1819                vis.traverse(stmt),
1820                "SELECT anon_1.col1, "
1821                "(SELECT count(table2.col1) AS count_1 FROM table2, "
1822                "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
1823                "table1.col3 AS col3 FROM table1 JOIN table2 ON "
1824                "table1.col1 = table2.col1 WHERE table2.col2 = :col2_1) AS "
1825                "anon_1 WHERE table2.col1 = anon_1.col1) AS anon_2 "
1826                "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
1827                "table1.col3 AS col3 FROM table1 JOIN table2 "
1828                "ON table1.col1 = table2.col1 "
1829                "WHERE table2.col2 = :col2_1) AS anon_1",
1830            )
1831
1832    @testing.fails_on_everything_except()
1833    def test_joins_dont_adapt(self):
1834        # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't
1835        # make much sense. ClauseAdapter doesn't make any changes if
1836        # it's against a straight join.
1837
1838        users = table("users", column("id"))
1839        addresses = table("addresses", column("id"), column("user_id"))
1840
1841        ualias = users.alias()
1842
1843        s = (
1844            select(func.count(addresses.c.id))
1845            .where(users.c.id == addresses.c.user_id)
1846            .correlate(users)
1847        )
1848        s = sql_util.ClauseAdapter(ualias).traverse(s)
1849
1850        j1 = addresses.join(ualias, addresses.c.user_id == ualias.c.id)
1851
1852        self.assert_compile(
1853            sql_util.ClauseAdapter(j1).traverse(s),
1854            "SELECT count(addresses.id) AS count_1 "
1855            "FROM addresses WHERE users_1.id = "
1856            "addresses.user_id",
1857        )
1858
1859    def test_prev_entities_adapt(self):
1860        """test #6503"""
1861
1862        m = MetaData()
1863        users = Table("users", m, Column("id", Integer, primary_key=True))
1864        addresses = Table(
1865            "addresses",
1866            m,
1867            Column("id", Integer, primary_key=True),
1868            Column("user_id", ForeignKey("users.id")),
1869        )
1870
1871        ualias = users.alias()
1872
1873        s = select(users).join(addresses).with_only_columns(addresses.c.id)
1874        s = sql_util.ClauseAdapter(ualias).traverse(s)
1875
1876        self.assert_compile(
1877            s,
1878            "SELECT addresses.id FROM users AS users_1 "
1879            "JOIN addresses ON users_1.id = addresses.user_id",
1880        )
1881
1882    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1883    def test_table_to_alias_1(self, use_adapt_from):
1884        t1alias = t1.alias("t1alias")
1885
1886        if use_adapt_from:
1887            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
1888        else:
1889            vis = sql_util.ClauseAdapter(t1alias)
1890        ff = vis.traverse(func.count(t1.c.col1).label("foo"))
1891        assert list(_from_objects(ff)) == [t1alias]
1892
1893    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1894    def test_table_to_alias_2(self, use_adapt_from):
1895        t1alias = t1.alias("t1alias")
1896        if use_adapt_from:
1897            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
1898        else:
1899            vis = sql_util.ClauseAdapter(t1alias)
1900        self.assert_compile(
1901            vis.traverse(select(literal_column("*")).select_from(t1)),
1902            "SELECT * FROM table1 AS t1alias",
1903        )
1904
1905    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1906    def test_table_to_alias_3(self, use_adapt_from):
1907        t1alias = t1.alias("t1alias")
1908        if use_adapt_from:
1909            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
1910        else:
1911            vis = sql_util.ClauseAdapter(t1alias)
1912        self.assert_compile(
1913            vis.traverse(
1914                select(literal_column("*")).where(t1.c.col1 == t2.c.col2)
1915            ),
1916            "SELECT * FROM table1 AS t1alias, table2 "
1917            "WHERE t1alias.col1 = table2.col2",
1918        )
1919
1920    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1921    def test_table_to_alias_4(self, use_adapt_from):
1922        t1alias = t1.alias("t1alias")
1923        if use_adapt_from:
1924            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
1925        else:
1926            vis = sql_util.ClauseAdapter(t1alias)
1927        self.assert_compile(
1928            vis.traverse(
1929                select(literal_column("*"))
1930                .where(t1.c.col1 == t2.c.col2)
1931                .select_from(t1, t2)
1932            ),
1933            "SELECT * FROM table1 AS t1alias, table2 "
1934            "WHERE t1alias.col1 = table2.col2",
1935        )
1936
1937    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1938    def test_table_to_alias_5(self, use_adapt_from):
1939        t1alias = t1.alias("t1alias")
1940        if use_adapt_from:
1941            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
1942        else:
1943            vis = sql_util.ClauseAdapter(t1alias)
1944        self.assert_compile(
1945            select(t1alias, t2).where(
1946                t1alias.c.col1
1947                == vis.traverse(
1948                    select(literal_column("*"))
1949                    .where(t1.c.col1 == t2.c.col2)
1950                    .select_from(t1, t2)
1951                    .correlate(t1)
1952                    .scalar_subquery()
1953                )
1954            ),
1955            "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
1956            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1957            "table2.col3 AS col3_1 "
1958            "FROM table1 AS t1alias, table2 WHERE t1alias.col1 = "
1959            "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)",
1960        )
1961
1962    @testing.combinations((True,), (False,), argnames="use_adapt_from")
1963    def test_table_to_alias_6(self, use_adapt_from):
1964        t1alias = t1.alias("t1alias")
1965        if use_adapt_from:
1966            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
1967        else:
1968            vis = sql_util.ClauseAdapter(t1alias)
1969        self.assert_compile(
1970            select(t1alias, t2).where(
1971                t1alias.c.col1
1972                == vis.traverse(
1973                    select(literal_column("*"))
1974                    .where(t1.c.col1 == t2.c.col2)
1975                    .select_from(t1, t2)
1976                    .correlate(t2)
1977                    .scalar_subquery()
1978                )
1979            ),
1980            "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
1981            "table2.col1 AS col1_1, table2.col2 AS col2_1, "
1982            "table2.col3 AS col3_1 "
1983            "FROM table1 AS t1alias, table2 "
1984            "WHERE t1alias.col1 = "
1985            "(SELECT * FROM table1 AS t1alias "
1986            "WHERE t1alias.col1 = table2.col2)",
1987        )
1988
1989    def test_table_to_alias_7(self):
1990        t1alias = t1.alias("t1alias")
1991        vis = sql_util.ClauseAdapter(t1alias)
1992        self.assert_compile(
1993            vis.traverse(case((t1.c.col1 == 5, t1.c.col2), else_=t1.c.col1)),
1994            "CASE WHEN (t1alias.col1 = :col1_1) THEN "
1995            "t1alias.col2 ELSE t1alias.col1 END",
1996        )
1997
1998    def test_table_to_alias_8(self):
1999        t1alias = t1.alias("t1alias")
2000        vis = sql_util.ClauseAdapter(t1alias)
2001        self.assert_compile(
2002            vis.traverse(
2003                case((5, t1.c.col2), value=t1.c.col1, else_=t1.c.col1)
2004            ),
2005            "CASE t1alias.col1 WHEN :param_1 THEN "
2006            "t1alias.col2 ELSE t1alias.col1 END",
2007        )
2008
2009    def test_table_to_alias_9(self):
2010        s = select(literal_column("*")).select_from(t1).alias("foo")
2011        self.assert_compile(
2012            s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo"
2013        )
2014
2015    def test_table_to_alias_10(self):
2016        s = select(literal_column("*")).select_from(t1).alias("foo")
2017        t1alias = t1.alias("t1alias")
2018        vis = sql_util.ClauseAdapter(t1alias)
2019        self.assert_compile(
2020            vis.traverse(s.select()),
2021            "SELECT foo.* FROM (SELECT * FROM table1 " "AS t1alias) AS foo",
2022        )
2023
2024    def test_table_to_alias_11(self):
2025        s = select(literal_column("*")).select_from(t1).alias("foo")
2026        self.assert_compile(
2027            s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo"
2028        )
2029
2030    def test_table_to_alias_12(self):
2031        t1alias = t1.alias("t1alias")
2032        vis = sql_util.ClauseAdapter(t1alias)
2033        ff = vis.traverse(func.count(t1.c.col1).label("foo"))
2034        self.assert_compile(
2035            select(ff),
2036            "SELECT count(t1alias.col1) AS foo FROM " "table1 AS t1alias",
2037        )
2038        assert list(_from_objects(ff)) == [t1alias]
2039
2040    # def test_table_to_alias_2(self):
2041    # TODO: self.assert_compile(vis.traverse(select(func.count(t1.c
2042    # .col1).l abel('foo')), clone=True), "SELECT
2043    # count(t1alias.col1) AS foo FROM table1 AS t1alias")
2044
2045    def test_table_to_alias_13(self):
2046        t1alias = t1.alias("t1alias")
2047        vis = sql_util.ClauseAdapter(t1alias)
2048        t2alias = t2.alias("t2alias")
2049        vis.chain(sql_util.ClauseAdapter(t2alias))
2050        self.assert_compile(
2051            vis.traverse(
2052                select(literal_column("*")).where(t1.c.col1 == t2.c.col2)
2053            ),
2054            "SELECT * FROM table1 AS t1alias, table2 "
2055            "AS t2alias WHERE t1alias.col1 = "
2056            "t2alias.col2",
2057        )
2058
2059    def test_table_to_alias_14(self):
2060        t1alias = t1.alias("t1alias")
2061        vis = sql_util.ClauseAdapter(t1alias)
2062        t2alias = t2.alias("t2alias")
2063        vis.chain(sql_util.ClauseAdapter(t2alias))
2064        self.assert_compile(
2065            vis.traverse(
2066                select("*").where(t1.c.col1 == t2.c.col2).select_from(t1, t2)
2067            ),
2068            "SELECT * FROM table1 AS t1alias, table2 "
2069            "AS t2alias WHERE t1alias.col1 = "
2070            "t2alias.col2",
2071        )
2072
2073    def test_table_to_alias_15(self):
2074        t1alias = t1.alias("t1alias")
2075        vis = sql_util.ClauseAdapter(t1alias)
2076        t2alias = t2.alias("t2alias")
2077        vis.chain(sql_util.ClauseAdapter(t2alias))
2078        self.assert_compile(
2079            select(t1alias, t2alias).where(
2080                t1alias.c.col1
2081                == vis.traverse(
2082                    select("*")
2083                    .where(t1.c.col1 == t2.c.col2)
2084                    .select_from(t1, t2)
2085                    .correlate(t1)
2086                    .scalar_subquery()
2087                )
2088            ),
2089            "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
2090            "t2alias.col1 AS col1_1, t2alias.col2 AS col2_1, "
2091            "t2alias.col3 AS col3_1 "
2092            "FROM table1 AS t1alias, table2 AS t2alias "
2093            "WHERE t1alias.col1 = "
2094            "(SELECT * FROM table2 AS t2alias "
2095            "WHERE t1alias.col1 = t2alias.col2)",
2096        )
2097
2098    def test_table_to_alias_16(self):
2099        t1alias = t1.alias("t1alias")
2100        vis = sql_util.ClauseAdapter(t1alias)
2101        t2alias = t2.alias("t2alias")
2102        vis.chain(sql_util.ClauseAdapter(t2alias))
2103        self.assert_compile(
2104            t2alias.select().where(
2105                t2alias.c.col2
2106                == vis.traverse(
2107                    select("*")
2108                    .where(t1.c.col1 == t2.c.col2)
2109                    .select_from(t1, t2)
2110                    .correlate(t2)
2111                    .scalar_subquery()
2112                )
2113            ),
2114            "SELECT t2alias.col1, t2alias.col2, t2alias.col3 "
2115            "FROM table2 AS t2alias WHERE t2alias.col2 = "
2116            "(SELECT * FROM table1 AS t1alias WHERE "
2117            "t1alias.col1 = t2alias.col2)",
2118        )
2119
2120    def test_include_exclude(self):
2121        m = MetaData()
2122        a = Table(
2123            "a",
2124            m,
2125            Column("id", Integer, primary_key=True),
2126            Column(
2127                "xxx_id",
2128                Integer,
2129                ForeignKey("a.id", name="adf", use_alter=True),
2130            ),
2131        )
2132
2133        e = a.c.id == a.c.xxx_id
2134        assert str(e) == "a.id = a.xxx_id"
2135        b = a.alias()
2136
2137        e = sql_util.ClauseAdapter(
2138            b,
2139            include_fn=lambda x: x in set([a.c.id]),
2140            equivalents={a.c.id: set([a.c.id])},
2141        ).traverse(e)
2142
2143        assert str(e) == "a_1.id = a.xxx_id"
2144
2145    def test_recursive_equivalents(self):
2146        m = MetaData()
2147        a = Table("a", m, Column("x", Integer), Column("y", Integer))
2148        b = Table("b", m, Column("x", Integer), Column("y", Integer))
2149        c = Table("c", m, Column("x", Integer), Column("y", Integer))
2150
2151        # force a recursion overflow, by linking a.c.x<->c.c.x, and
2152        # asking for a nonexistent col.  corresponding_column should prevent
2153        # endless depth.
2154        adapt = sql_util.ClauseAdapter(
2155            b, equivalents={a.c.x: set([c.c.x]), c.c.x: set([a.c.x])}
2156        )
2157        assert adapt._corresponding_column(a.c.x, False) is None
2158
2159    def test_multilevel_equivalents(self):
2160        m = MetaData()
2161        a = Table("a", m, Column("x", Integer), Column("y", Integer))
2162        b = Table("b", m, Column("x", Integer), Column("y", Integer))
2163        c = Table("c", m, Column("x", Integer), Column("y", Integer))
2164
2165        alias = select(a).select_from(a.join(b, a.c.x == b.c.x)).alias()
2166
2167        # two levels of indirection from c.x->b.x->a.x, requires recursive
2168        # corresponding_column call
2169        adapt = sql_util.ClauseAdapter(
2170            alias, equivalents={b.c.x: set([a.c.x]), c.c.x: set([b.c.x])}
2171        )
2172        assert adapt._corresponding_column(a.c.x, False) is alias.c.x
2173        assert adapt._corresponding_column(c.c.x, False) is alias.c.x
2174
2175    def test_join_to_alias(self):
2176        metadata = MetaData()
2177        a = Table("a", metadata, Column("id", Integer, primary_key=True))
2178        b = Table(
2179            "b",
2180            metadata,
2181            Column("id", Integer, primary_key=True),
2182            Column("aid", Integer, ForeignKey("a.id")),
2183        )
2184        c = Table(
2185            "c",
2186            metadata,
2187            Column("id", Integer, primary_key=True),
2188            Column("bid", Integer, ForeignKey("b.id")),
2189        )
2190
2191        d = Table(
2192            "d",
2193            metadata,
2194            Column("id", Integer, primary_key=True),
2195            Column("aid", Integer, ForeignKey("a.id")),
2196        )
2197
2198        j1 = a.outerjoin(b)
2199        j2 = (
2200            select(j1)
2201            .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
2202            .subquery()
2203        )
2204
2205        j3 = c.join(j2, j2.c.b_id == c.c.bid)
2206
2207        j4 = j3.outerjoin(d)
2208        self.assert_compile(
2209            j4,
2210            "c JOIN (SELECT a.id AS a_id, b.id AS "
2211            "b_id, b.aid AS b_aid FROM a LEFT OUTER "
2212            "JOIN b ON a.id = b.aid) AS anon_1 ON anon_1.b_id = c.bid "
2213            "LEFT OUTER JOIN d ON anon_1.a_id = d.aid",
2214        )
2215        j5 = (
2216            j3.select()
2217            .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
2218            .subquery("foo")
2219        )
2220        j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0]
2221
2222        # this statement takes c join(a join b), wraps it inside an
2223        # aliased "select * from c join(a join b) AS foo". the outermost
2224        # right side "left outer join d" stays the same, except "d"
2225        # joins against foo.a_id instead of plain "a_id"
2226
2227        self.assert_compile(
2228            j6,
2229            "(SELECT c.id AS c_id, c.bid AS c_bid, "
2230            "anon_1.a_id AS anon_1_a_id, anon_1.b_id AS anon_1_b_id, "
2231            "anon_1.b_aid AS "
2232            "anon_1_b_aid FROM c JOIN (SELECT a.id AS a_id, "
2233            "b.id AS b_id, b.aid AS b_aid FROM a LEFT "
2234            "OUTER JOIN b ON a.id = b.aid) AS anon_1 ON anon_1.b_id = "
2235            "c.bid) AS foo LEFT OUTER JOIN d ON "
2236            "foo.anon_1_a_id = d.aid",
2237        )
2238
2239    def test_derived_from(self):
2240        assert select(t1).is_derived_from(t1)
2241        assert not select(t2).is_derived_from(t1)
2242        assert not t1.is_derived_from(select(t1))
2243        assert t1.alias().is_derived_from(t1)
2244
2245        s1 = select(t1, t2).alias("foo")
2246        s2 = select(s1).limit(5).offset(10).alias()
2247        assert s2.is_derived_from(s1)
2248        s2 = s2._clone()
2249        assert s2.is_derived_from(s1)
2250
2251    def test_aliasedselect_to_aliasedselect_straight(self):
2252
2253        # original issue from ticket #904
2254
2255        s1 = select(t1).alias("foo")
2256        s2 = select(s1).limit(5).offset(10).alias()
2257        self.assert_compile(
2258            sql_util.ClauseAdapter(s2).traverse(s1),
2259            "SELECT foo.col1, foo.col2, foo.col3 FROM "
2260            "(SELECT table1.col1 AS col1, table1.col2 "
2261            "AS col2, table1.col3 AS col3 FROM table1) "
2262            "AS foo LIMIT :param_1 OFFSET :param_2",
2263            {"param_1": 5, "param_2": 10},
2264        )
2265
2266    def test_aliasedselect_to_aliasedselect_join(self):
2267        s1 = select(t1).alias("foo")
2268        s2 = select(s1).limit(5).offset(10).alias()
2269        j = s1.outerjoin(t2, s1.c.col1 == t2.c.col1)
2270        self.assert_compile(
2271            sql_util.ClauseAdapter(s2).traverse(j).select(),
2272            "SELECT anon_1.col1, anon_1.col2, "
2273            "anon_1.col3, table2.col1 AS col1_1, table2.col2 AS col2_1, "
2274            "table2.col3 AS col3_1 FROM (SELECT foo.col1 AS "
2275            "col1, foo.col2 AS col2, foo.col3 AS col3 "
2276            "FROM (SELECT table1.col1 AS col1, "
2277            "table1.col2 AS col2, table1.col3 AS col3 "
2278            "FROM table1) AS foo LIMIT :param_1 OFFSET "
2279            ":param_2) AS anon_1 LEFT OUTER JOIN "
2280            "table2 ON anon_1.col1 = table2.col1",
2281            {"param_1": 5, "param_2": 10},
2282        )
2283
2284    @testing.combinations((True,), (False,), argnames="use_adapt_from")
2285    def test_aliasedselect_to_aliasedselect_join_nested_table(
2286        self, use_adapt_from
2287    ):
2288        """test the logic in clauseadapter regarding not traversing aliases.
2289
2290        adapt_from_selectables case added to test #6762, which is a regression
2291        from #6060
2292
2293        """
2294        s1 = select(t1).alias("foo")
2295        s2 = select(s1).limit(5).offset(10).alias()
2296        talias = t1.alias("bar")
2297
2298        # here is the problem.   s2 is derived from s1 which is derived
2299        # from t1
2300        assert s2.is_derived_from(t1)
2301
2302        # however, s2 is not derived from talias, which *is* derived from t1
2303        assert not s2.is_derived_from(talias)
2304
2305        # therefore, talias gets its table replaced, except for a rule
2306        # we added to ClauseAdapter to stop traversal if the selectable is
2307        # not derived from an alias of a table.  This rule was previously
2308        # in Alias._copy_internals().
2309
2310        j = s1.outerjoin(talias, s1.c.col1 == talias.c.col1)
2311
2312        if use_adapt_from:
2313            vis = sql_util.ClauseAdapter(s2, adapt_from_selectables=[s1])
2314        else:
2315            vis = sql_util.ClauseAdapter(s2)
2316        self.assert_compile(
2317            vis.traverse(j).select(),
2318            "SELECT anon_1.col1, anon_1.col2, "
2319            "anon_1.col3, bar.col1 AS col1_1, bar.col2 AS col2_1, "
2320            "bar.col3 AS col3_1 "
2321            "FROM (SELECT foo.col1 AS col1, foo.col2 "
2322            "AS col2, foo.col3 AS col3 FROM (SELECT "
2323            "table1.col1 AS col1, table1.col2 AS col2, "
2324            "table1.col3 AS col3 FROM table1) AS foo "
2325            "LIMIT :param_1 OFFSET :param_2) AS anon_1 "
2326            "LEFT OUTER JOIN table1 AS bar ON "
2327            "anon_1.col1 = bar.col1",
2328            {"param_1": 5, "param_2": 10},
2329        )
2330
2331    def test_functions(self):
2332        self.assert_compile(
2333            sql_util.ClauseAdapter(t1.alias()).traverse(func.count(t1.c.col1)),
2334            "count(table1_1.col1)",
2335        )
2336        s = select(func.count(t1.c.col1))
2337        self.assert_compile(
2338            sql_util.ClauseAdapter(t1.alias()).traverse(s),
2339            "SELECT count(table1_1.col1) AS count_1 "
2340            "FROM table1 AS table1_1",
2341        )
2342
2343    def test_table_valued_column(self):
2344        """test #6775"""
2345        stmt = select(func.some_json_func(t1.table_valued()))
2346
2347        self.assert_compile(
2348            stmt,
2349            "SELECT some_json_func(table1) AS some_json_func_1 FROM table1",
2350        )
2351
2352        self.assert_compile(
2353            sql_util.ClauseAdapter(t1.alias()).traverse(stmt),
2354            "SELECT some_json_func(table1_1) AS some_json_func_1 "
2355            "FROM table1 AS table1_1",
2356        )
2357
2358    def test_recursive(self):
2359        metadata = MetaData()
2360        a = Table("a", metadata, Column("id", Integer, primary_key=True))
2361        b = Table(
2362            "b",
2363            metadata,
2364            Column("id", Integer, primary_key=True),
2365            Column("aid", Integer, ForeignKey("a.id")),
2366        )
2367        c = Table(
2368            "c",
2369            metadata,
2370            Column("id", Integer, primary_key=True),
2371            Column("bid", Integer, ForeignKey("b.id")),
2372        )
2373
2374        d = Table(
2375            "d",
2376            metadata,
2377            Column("id", Integer, primary_key=True),
2378            Column("aid", Integer, ForeignKey("a.id")),
2379        )
2380
2381        u = union(
2382            a.join(b).select().set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
2383            a.join(d).select().set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
2384        ).alias()
2385
2386        self.assert_compile(
2387            sql_util.ClauseAdapter(u).traverse(
2388                select(c.c.bid).where(c.c.bid == u.c.b_aid)
2389            ),
2390            "SELECT c.bid "
2391            "FROM c, (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid "
2392            "FROM a JOIN b ON a.id = b.aid UNION SELECT a.id AS a_id, d.id "
2393            "AS d_id, d.aid AS d_aid "
2394            "FROM a JOIN d ON a.id = d.aid) AS anon_1 "
2395            "WHERE c.bid = anon_1.b_aid",
2396        )
2397
2398    def test_label_anonymize_one(self):
2399        t1a = t1.alias()
2400        adapter = sql_util.ClauseAdapter(t1a, anonymize_labels=True)
2401
2402        expr = select(t1.c.col2).where(t1.c.col3 == 5).label("expr")
2403        expr_adapted = adapter.traverse(expr)
2404
2405        stmt = select(expr, expr_adapted).order_by(expr, expr_adapted)
2406        self.assert_compile(
2407            stmt,
2408            "SELECT "
2409            "(SELECT table1.col2 FROM table1 WHERE table1.col3 = :col3_1) "
2410            "AS expr, "
2411            "(SELECT table1_1.col2 FROM table1 AS table1_1 "
2412            "WHERE table1_1.col3 = :col3_2) AS anon_1 "
2413            "ORDER BY expr, anon_1",
2414        )
2415
2416    def test_label_anonymize_two(self):
2417        t1a = t1.alias()
2418        adapter = sql_util.ClauseAdapter(t1a, anonymize_labels=True)
2419
2420        expr = select(t1.c.col2).where(t1.c.col3 == 5).label(None)
2421        expr_adapted = adapter.traverse(expr)
2422
2423        stmt = select(expr, expr_adapted).order_by(expr, expr_adapted)
2424        self.assert_compile(
2425            stmt,
2426            "SELECT "
2427            "(SELECT table1.col2 FROM table1 WHERE table1.col3 = :col3_1) "
2428            "AS anon_1, "
2429            "(SELECT table1_1.col2 FROM table1 AS table1_1 "
2430            "WHERE table1_1.col3 = :col3_2) AS anon_2 "
2431            "ORDER BY anon_1, anon_2",
2432        )
2433
2434    def test_label_anonymize_three(self):
2435        t1a = t1.alias()
2436        adapter = sql_util.ColumnAdapter(
2437            t1a, anonymize_labels=True, allow_label_resolve=False
2438        )
2439
2440        expr = select(t1.c.col2).where(t1.c.col3 == 5).label(None)
2441        l1 = expr
2442        is_(l1._order_by_label_element, l1)
2443        eq_(l1._allow_label_resolve, True)
2444
2445        expr_adapted = adapter.traverse(expr)
2446        l2 = expr_adapted
2447        is_(l2._order_by_label_element, l2)
2448        eq_(l2._allow_label_resolve, False)
2449
2450        l3 = adapter.traverse(expr)
2451        is_(l3._order_by_label_element, l3)
2452        eq_(l3._allow_label_resolve, False)
2453
2454
2455class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
2456    __dialect__ = "default"
2457
2458    @classmethod
2459    def setup_test_class(cls):
2460        global table1, table2, table3, table4
2461
2462        def _table(name):
2463            return table(name, column("col1"), column("col2"), column("col3"))
2464
2465        table1, table2, table3, table4 = [
2466            _table(name) for name in ("table1", "table2", "table3", "table4")
2467        ]
2468
2469    def test_splice(self):
2470        t1, t2, t3, t4 = table1, table2, table1.alias(), table2.alias()
2471        j = (
2472            t1.join(t2, t1.c.col1 == t2.c.col1)
2473            .join(t3, t2.c.col1 == t3.c.col1)
2474            .join(t4, t4.c.col1 == t1.c.col1)
2475        )
2476        s = select(t1).where(t1.c.col2 < 5).alias()
2477        self.assert_compile(
2478            sql_util.splice_joins(s, j),
2479            "(SELECT table1.col1 AS col1, table1.col2 "
2480            "AS col2, table1.col3 AS col3 FROM table1 "
2481            "WHERE table1.col2 < :col2_1) AS anon_1 "
2482            "JOIN table2 ON anon_1.col1 = table2.col1 "
2483            "JOIN table1 AS table1_1 ON table2.col1 = "
2484            "table1_1.col1 JOIN table2 AS table2_1 ON "
2485            "table2_1.col1 = anon_1.col1",
2486        )
2487
2488    def test_stop_on(self):
2489        t1, t2, t3 = table1, table2, table3
2490        j1 = t1.join(t2, t1.c.col1 == t2.c.col1)
2491        j2 = j1.join(t3, t2.c.col1 == t3.c.col1)
2492        s = select(t1).select_from(j1).alias()
2493        self.assert_compile(
2494            sql_util.splice_joins(s, j2),
2495            "(SELECT table1.col1 AS col1, table1.col2 "
2496            "AS col2, table1.col3 AS col3 FROM table1 "
2497            "JOIN table2 ON table1.col1 = table2.col1) "
2498            "AS anon_1 JOIN table2 ON anon_1.col1 = "
2499            "table2.col1 JOIN table3 ON table2.col1 = "
2500            "table3.col1",
2501        )
2502        self.assert_compile(
2503            sql_util.splice_joins(s, j2, j1),
2504            "(SELECT table1.col1 AS col1, table1.col2 "
2505            "AS col2, table1.col3 AS col3 FROM table1 "
2506            "JOIN table2 ON table1.col1 = table2.col1) "
2507            "AS anon_1 JOIN table3 ON table2.col1 = "
2508            "table3.col1",
2509        )
2510
2511    def test_splice_2(self):
2512        t2a = table2.alias()
2513        t3a = table3.alias()
2514        j1 = table1.join(t2a, table1.c.col1 == t2a.c.col1).join(
2515            t3a, t2a.c.col2 == t3a.c.col2
2516        )
2517        t2b = table4.alias()
2518        j2 = table1.join(t2b, table1.c.col3 == t2b.c.col3)
2519        self.assert_compile(
2520            sql_util.splice_joins(table1, j1),
2521            "table1 JOIN table2 AS table2_1 ON "
2522            "table1.col1 = table2_1.col1 JOIN table3 "
2523            "AS table3_1 ON table2_1.col2 = "
2524            "table3_1.col2",
2525        )
2526        self.assert_compile(
2527            sql_util.splice_joins(table1, j2),
2528            "table1 JOIN table4 AS table4_1 ON " "table1.col3 = table4_1.col3",
2529        )
2530        self.assert_compile(
2531            sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2),
2532            "table1 JOIN table2 AS table2_1 ON "
2533            "table1.col1 = table2_1.col1 JOIN table3 "
2534            "AS table3_1 ON table2_1.col2 = "
2535            "table3_1.col2 JOIN table4 AS table4_1 ON "
2536            "table1.col3 = table4_1.col3",
2537        )
2538
2539
2540class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
2541
2542    """tests the generative capability of Select"""
2543
2544    __dialect__ = "default"
2545
2546    @classmethod
2547    def setup_test_class(cls):
2548        global t1, t2
2549        t1 = table("table1", column("col1"), column("col2"), column("col3"))
2550        t2 = table("table2", column("col1"), column("col2"), column("col3"))
2551
2552    def test_columns(self):
2553        s = t1.select()
2554        self.assert_compile(
2555            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
2556        )
2557        select_copy = s.add_columns(column("yyy"))
2558        self.assert_compile(
2559            select_copy,
2560            "SELECT table1.col1, table1.col2, " "table1.col3, yyy FROM table1",
2561        )
2562        is_not(s.selected_columns, select_copy.selected_columns)
2563        is_not(s._raw_columns, select_copy._raw_columns)
2564        self.assert_compile(
2565            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
2566        )
2567
2568    def test_froms(self):
2569        s = t1.select()
2570        self.assert_compile(
2571            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
2572        )
2573        select_copy = s.select_from(t2)
2574        self.assert_compile(
2575            select_copy,
2576            "SELECT table1.col1, table1.col2, "
2577            "table1.col3 FROM table1, table2",
2578        )
2579
2580        self.assert_compile(
2581            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
2582        )
2583
2584    def test_prefixes(self):
2585        s = t1.select()
2586        self.assert_compile(
2587            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
2588        )
2589        select_copy = s.prefix_with("FOOBER")
2590        self.assert_compile(
2591            select_copy,
2592            "SELECT FOOBER table1.col1, table1.col2, "
2593            "table1.col3 FROM table1",
2594        )
2595        self.assert_compile(
2596            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
2597        )
2598
2599    def test_execution_options(self):
2600        s = select().execution_options(foo="bar")
2601        s2 = s.execution_options(bar="baz")
2602        s3 = s.execution_options(foo="not bar")
2603        # The original select should not be modified.
2604        eq_(s.get_execution_options(), dict(foo="bar"))
2605        # s2 should have its execution_options based on s, though.
2606        eq_(s2.get_execution_options(), dict(foo="bar", bar="baz"))
2607        eq_(s3.get_execution_options(), dict(foo="not bar"))
2608
2609    def test_invalid_options(self):
2610        assert_raises(
2611            exc.ArgumentError, select().execution_options, compiled_cache={}
2612        )
2613
2614        assert_raises(
2615            exc.ArgumentError,
2616            select().execution_options,
2617            isolation_level="READ_COMMITTED",
2618        )
2619
2620    # this feature not available yet
2621    def _NOTYET_test_execution_options_in_kwargs(self):
2622        s = select(execution_options=dict(foo="bar"))
2623        s2 = s.execution_options(bar="baz")
2624        # The original select should not be modified.
2625        assert s._execution_options == dict(foo="bar")
2626        # s2 should have its execution_options based on s, though.
2627        assert s2._execution_options == dict(foo="bar", bar="baz")
2628
2629    # this feature not available yet
2630    def _NOTYET_test_execution_options_in_text(self):
2631        s = text("select 42", execution_options=dict(foo="bar"))
2632        assert s._execution_options == dict(foo="bar")
2633
2634
2635class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
2636
2637    """Tests the generative capability of Insert, Update"""
2638
2639    __dialect__ = "default"
2640
2641    # fixme: consolidate converage from elsewhere here and expand
2642
2643    @classmethod
2644    def setup_test_class(cls):
2645        global t1, t2
2646        t1 = table("table1", column("col1"), column("col2"), column("col3"))
2647        t2 = table("table2", column("col1"), column("col2"), column("col3"))
2648
2649    def test_prefixes(self):
2650        i = t1.insert()
2651        self.assert_compile(
2652            i,
2653            "INSERT INTO table1 (col1, col2, col3) "
2654            "VALUES (:col1, :col2, :col3)",
2655        )
2656
2657        gen = i.prefix_with("foober")
2658        self.assert_compile(
2659            gen,
2660            "INSERT foober INTO table1 (col1, col2, col3) "
2661            "VALUES (:col1, :col2, :col3)",
2662        )
2663
2664        self.assert_compile(
2665            i,
2666            "INSERT INTO table1 (col1, col2, col3) "
2667            "VALUES (:col1, :col2, :col3)",
2668        )
2669
2670        i2 = t1.insert().prefix_with("squiznart")
2671        self.assert_compile(
2672            i2,
2673            "INSERT squiznart INTO table1 (col1, col2, col3) "
2674            "VALUES (:col1, :col2, :col3)",
2675        )
2676
2677        gen2 = i2.prefix_with("quux")
2678        self.assert_compile(
2679            gen2,
2680            "INSERT squiznart quux INTO "
2681            "table1 (col1, col2, col3) "
2682            "VALUES (:col1, :col2, :col3)",
2683        )
2684
2685    def test_add_kwarg(self):
2686        i = t1.insert()
2687        compile_state = i._compile_state_factory(i, None)
2688        eq_(compile_state._dict_parameters, None)
2689        i = i.values(col1=5)
2690        compile_state = i._compile_state_factory(i, None)
2691        self._compare_param_dict(compile_state._dict_parameters, {"col1": 5})
2692        i = i.values(col2=7)
2693        compile_state = i._compile_state_factory(i, None)
2694        self._compare_param_dict(
2695            compile_state._dict_parameters, {"col1": 5, "col2": 7}
2696        )
2697
2698    def test_via_tuple_single(self):
2699        i = t1.insert()
2700
2701        compile_state = i._compile_state_factory(i, None)
2702        eq_(compile_state._dict_parameters, None)
2703
2704        i = i.values((5, 6, 7))
2705        compile_state = i._compile_state_factory(i, None)
2706
2707        self._compare_param_dict(
2708            compile_state._dict_parameters,
2709            {"col1": 5, "col2": 6, "col3": 7},
2710        )
2711
2712    def test_kw_and_dict_simultaneously_single(self):
2713        i = t1.insert()
2714        assert_raises_message(
2715            exc.ArgumentError,
2716            r"Can't pass positional and kwargs to values\(\) simultaneously",
2717            i.values,
2718            {"col1": 5},
2719            col2=7,
2720        )
2721
2722    def test_via_tuple_multi(self):
2723        i = t1.insert()
2724        compile_state = i._compile_state_factory(i, None)
2725        eq_(compile_state._dict_parameters, None)
2726
2727        i = i.values([(5, 6, 7), (8, 9, 10)])
2728        compile_state = i._compile_state_factory(i, None)
2729        eq_(
2730            compile_state._dict_parameters,
2731            {"col1": 5, "col2": 6, "col3": 7},
2732        )
2733        eq_(compile_state._has_multi_parameters, True)
2734        eq_(
2735            compile_state._multi_parameters,
2736            [
2737                {"col1": 5, "col2": 6, "col3": 7},
2738                {"col1": 8, "col2": 9, "col3": 10},
2739            ],
2740        )
2741
2742    def test_inline_values_single(self):
2743        i = t1.insert().values({"col1": 5})
2744
2745        compile_state = i._compile_state_factory(i, None)
2746
2747        self._compare_param_dict(compile_state._dict_parameters, {"col1": 5})
2748        is_(compile_state._has_multi_parameters, False)
2749
2750    def test_inline_values_multi(self):
2751        i = t1.insert().values([{"col1": 5}, {"col1": 6}])
2752
2753        compile_state = i._compile_state_factory(i, None)
2754
2755        # multiparams are not converted to bound parameters
2756        eq_(compile_state._dict_parameters, {"col1": 5})
2757
2758        # multiparams are not converted to bound parameters
2759        eq_(compile_state._multi_parameters, [{"col1": 5}, {"col1": 6}])
2760        is_(compile_state._has_multi_parameters, True)
2761
2762    def _compare_param_dict(self, a, b):
2763        if list(a) != list(b):
2764            return False
2765
2766        from sqlalchemy.types import NullType
2767
2768        for a_k, a_i in a.items():
2769            b_i = b[a_k]
2770
2771            # compare BindParameter on the left to
2772            # literal value on the right
2773            assert a_i.compare(literal(b_i, type_=NullType()))
2774
2775    def test_add_dictionary(self):
2776        i = t1.insert()
2777
2778        compile_state = i._compile_state_factory(i, None)
2779
2780        eq_(compile_state._dict_parameters, None)
2781        i = i.values({"col1": 5})
2782
2783        compile_state = i._compile_state_factory(i, None)
2784
2785        self._compare_param_dict(compile_state._dict_parameters, {"col1": 5})
2786        is_(compile_state._has_multi_parameters, False)
2787
2788        i = i.values({"col1": 6})
2789        # note replaces
2790        compile_state = i._compile_state_factory(i, None)
2791
2792        self._compare_param_dict(compile_state._dict_parameters, {"col1": 6})
2793        is_(compile_state._has_multi_parameters, False)
2794
2795        i = i.values({"col2": 7})
2796        compile_state = i._compile_state_factory(i, None)
2797        self._compare_param_dict(
2798            compile_state._dict_parameters, {"col1": 6, "col2": 7}
2799        )
2800        is_(compile_state._has_multi_parameters, False)
2801
2802    def test_add_kwarg_disallowed_multi(self):
2803        i = t1.insert()
2804        i = i.values([{"col1": 5}, {"col1": 7}])
2805        i = i.values(col2=7)
2806        assert_raises_message(
2807            exc.InvalidRequestError,
2808            "Can't mix single and multiple VALUES formats",
2809            i.compile,
2810        )
2811
2812    def test_cant_mix_single_multi_formats_dict_to_list(self):
2813        i = t1.insert().values(col1=5)
2814        i = i.values([{"col1": 6}])
2815        assert_raises_message(
2816            exc.InvalidRequestError,
2817            "Can't mix single and multiple VALUES "
2818            "formats in one INSERT statement",
2819            i.compile,
2820        )
2821
2822    def test_cant_mix_single_multi_formats_list_to_dict(self):
2823        i = t1.insert().values([{"col1": 6}])
2824        i = i.values({"col1": 5})
2825        assert_raises_message(
2826            exc.InvalidRequestError,
2827            "Can't mix single and multiple VALUES "
2828            "formats in one INSERT statement",
2829            i.compile,
2830        )
2831
2832    def test_erroneous_multi_args_dicts(self):
2833        i = t1.insert()
2834        assert_raises_message(
2835            exc.ArgumentError,
2836            "Only a single dictionary/tuple or list of "
2837            "dictionaries/tuples is accepted positionally.",
2838            i.values,
2839            {"col1": 5},
2840            {"col1": 7},
2841        )
2842
2843    def test_erroneous_multi_args_tuples(self):
2844        i = t1.insert()
2845        assert_raises_message(
2846            exc.ArgumentError,
2847            "Only a single dictionary/tuple or list of "
2848            "dictionaries/tuples is accepted positionally.",
2849            i.values,
2850            (5, 6, 7),
2851            (8, 9, 10),
2852        )
2853
2854    def test_erroneous_multi_args_plus_kw(self):
2855        i = t1.insert()
2856        assert_raises_message(
2857            exc.ArgumentError,
2858            r"Can't pass positional and kwargs to values\(\) simultaneously",
2859            i.values,
2860            [{"col1": 5}],
2861            col2=7,
2862        )
2863
2864    def test_update_no_support_multi_values(self):
2865        u = t1.update()
2866        u = u.values([{"col1": 5}, {"col1": 7}])
2867        assert_raises_message(
2868            exc.InvalidRequestError,
2869            "UPDATE construct does not support multiple parameter sets.",
2870            u.compile,
2871        )
2872
2873    def test_update_no_support_multi_constructor(self):
2874        stmt = t1.update().values([{"col1": 5}, {"col1": 7}])
2875
2876        assert_raises_message(
2877            exc.InvalidRequestError,
2878            "UPDATE construct does not support multiple parameter sets.",
2879            stmt.compile,
2880        )
2881