1from sqlalchemy import and_
2from sqlalchemy import bindparam
3from sqlalchemy import case
4from sqlalchemy import Column
5from sqlalchemy import exc
6from sqlalchemy import extract
7from sqlalchemy import ForeignKey
8from sqlalchemy import func
9from sqlalchemy import Integer
10from sqlalchemy import literal_column
11from sqlalchemy import MetaData
12from sqlalchemy import select
13from sqlalchemy import String
14from sqlalchemy import Table
15from sqlalchemy import testing
16from sqlalchemy import text
17from sqlalchemy import tuple_
18from sqlalchemy import union
19from sqlalchemy.sql import ClauseElement
20from sqlalchemy.sql import column
21from sqlalchemy.sql import operators
22from sqlalchemy.sql import table
23from sqlalchemy.sql import util as sql_util
24from sqlalchemy.sql import visitors
25from sqlalchemy.sql.expression import _clone
26from sqlalchemy.sql.expression import _from_objects
27from sqlalchemy.sql.visitors import ClauseVisitor
28from sqlalchemy.sql.visitors import cloned_traverse
29from sqlalchemy.sql.visitors import CloningVisitor
30from sqlalchemy.sql.visitors import ReplacingCloningVisitor
31from sqlalchemy.testing import assert_raises
32from sqlalchemy.testing import assert_raises_message
33from sqlalchemy.testing import AssertsCompiledSQL
34from sqlalchemy.testing import AssertsExecutionResults
35from sqlalchemy.testing import eq_
36from sqlalchemy.testing import fixtures
37from sqlalchemy.testing import is_
38from sqlalchemy.testing import is_not
39
40
41A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
42
43
44class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
45
46    """test ClauseVisitor's traversal, particularly its
47    ability to copy and modify a ClauseElement in place."""
48
49    @classmethod
50    def setup_class(cls):
51        global A, B
52
53        # establish two fictitious ClauseElements.
54        # define deep equality semantics as well as deep
55        # identity semantics.
56        class A(ClauseElement):
57            __visit_name__ = "a"
58
59            def __init__(self, expr):
60                self.expr = expr
61
62            def is_other(self, other):
63                return other is self
64
65            __hash__ = ClauseElement.__hash__
66
67            def __eq__(self, other):
68                return other.expr == self.expr
69
70            def __ne__(self, other):
71                return other.expr != self.expr
72
73            def __str__(self):
74                return "A(%s)" % repr(self.expr)
75
76        class B(ClauseElement):
77            __visit_name__ = "b"
78
79            def __init__(self, *items):
80                self.items = items
81
82            def is_other(self, other):
83                if other is not self:
84                    return False
85                for i1, i2 in zip(self.items, other.items):
86                    if i1 is not i2:
87                        return False
88                return True
89
90            __hash__ = ClauseElement.__hash__
91
92            def __eq__(self, other):
93                for i1, i2 in zip(self.items, other.items):
94                    if i1 != i2:
95                        return False
96                return True
97
98            def __ne__(self, other):
99                for i1, i2 in zip(self.items, other.items):
100                    if i1 != i2:
101                        return True
102                return False
103
104            def _copy_internals(self, clone=_clone, **kw):
105                self.items = [clone(i, **kw) for i in self.items]
106
107            def get_children(self, **kwargs):
108                return self.items
109
110            def __str__(self):
111                return "B(%s)" % repr([str(i) for i in self.items])
112
113    def test_test_classes(self):
114        a1 = A("expr1")
115        struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
116        struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
117        struct3 = B(
118            a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3")
119        )
120
121        assert a1.is_other(a1)
122        assert struct.is_other(struct)
123        assert struct == struct2
124        assert struct != struct3
125        assert not struct.is_other(struct2)
126        assert not struct.is_other(struct3)
127
128    def test_clone(self):
129        struct = B(
130            A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")
131        )
132
133        class Vis(CloningVisitor):
134            def visit_a(self, a):
135                pass
136
137            def visit_b(self, b):
138                pass
139
140        vis = Vis()
141        s2 = vis.traverse(struct)
142        assert struct == s2
143        assert not struct.is_other(s2)
144
145    def test_no_clone(self):
146        struct = B(
147            A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")
148        )
149
150        class Vis(ClauseVisitor):
151            def visit_a(self, a):
152                pass
153
154            def visit_b(self, b):
155                pass
156
157        vis = Vis()
158        s2 = vis.traverse(struct)
159        assert struct == s2
160        assert struct.is_other(s2)
161
162    def test_clone_anon_label(self):
163        from sqlalchemy.sql.elements import Grouping
164
165        c1 = Grouping(literal_column("q"))
166        s1 = select([c1])
167
168        class Vis(CloningVisitor):
169            def visit_grouping(self, elem):
170                pass
171
172        vis = Vis()
173        s2 = vis.traverse(s1)
174        eq_(list(s2.inner_columns)[0].anon_label, c1.anon_label)
175
176    def test_change_in_place(self):
177        struct = B(
178            A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")
179        )
180        struct2 = B(
181            A("expr1"),
182            A("expr2modified"),
183            B(A("expr1b"), A("expr2b")),
184            A("expr3"),
185        )
186        struct3 = B(
187            A("expr1"),
188            A("expr2"),
189            B(A("expr1b"), A("expr2bmodified")),
190            A("expr3"),
191        )
192
193        class Vis(CloningVisitor):
194            def visit_a(self, a):
195                if a.expr == "expr2":
196                    a.expr = "expr2modified"
197
198            def visit_b(self, b):
199                pass
200
201        vis = Vis()
202        s2 = vis.traverse(struct)
203        assert struct != s2
204        assert not struct.is_other(s2)
205        assert struct2 == s2
206
207        class Vis2(CloningVisitor):
208            def visit_a(self, a):
209                if a.expr == "expr2b":
210                    a.expr = "expr2bmodified"
211
212            def visit_b(self, b):
213                pass
214
215        vis2 = Vis2()
216        s3 = vis2.traverse(struct)
217        assert struct != s3
218        assert struct3 == s3
219
220    def test_visit_name(self):
221        # override fns in testlib/schema.py
222        from sqlalchemy import Column
223
224        class CustomObj(Column):
225            pass
226
227        assert CustomObj.__visit_name__ == Column.__visit_name__ == "column"
228
229        foo, bar = CustomObj("foo", String), CustomObj("bar", String)
230        bin_ = foo == bar
231        set(ClauseVisitor().iterate(bin_))
232        assert set(ClauseVisitor().iterate(bin_)) == set([foo, bar, bin_])
233
234
235class BinaryEndpointTraversalTest(fixtures.TestBase):
236
237    """test the special binary product visit"""
238
239    def _assert_traversal(self, expr, expected):
240        canary = []
241
242        def visit(binary, l, r):
243            canary.append((binary.operator, l, r))
244            print(binary.operator, l, r)
245
246        sql_util.visit_binary_product(visit, expr)
247        eq_(canary, expected)
248
249    def test_basic(self):
250        a, b = column("a"), column("b")
251        self._assert_traversal(a == b, [(operators.eq, a, b)])
252
253    def test_with_tuples(self):
254        a, b, c, d, b1, b1a, b1b, e, f = (
255            column("a"),
256            column("b"),
257            column("c"),
258            column("d"),
259            column("b1"),
260            column("b1a"),
261            column("b1b"),
262            column("e"),
263            column("f"),
264        )
265        expr = tuple_(a, b, b1 == tuple_(b1a, b1b == d), c) > tuple_(
266            func.go(e + f)
267        )
268        self._assert_traversal(
269            expr,
270            [
271                (operators.gt, a, e),
272                (operators.gt, a, f),
273                (operators.gt, b, e),
274                (operators.gt, b, f),
275                (operators.eq, b1, b1a),
276                (operators.eq, b1b, d),
277                (operators.gt, c, e),
278                (operators.gt, c, f),
279            ],
280        )
281
282    def test_composed(self):
283        a, b, e, f, q, j, r = (
284            column("a"),
285            column("b"),
286            column("e"),
287            column("f"),
288            column("q"),
289            column("j"),
290            column("r"),
291        )
292        expr = and_((a + b) == q + func.sum(e + f), and_(j == r, f == q))
293        self._assert_traversal(
294            expr,
295            [
296                (operators.eq, a, q),
297                (operators.eq, a, e),
298                (operators.eq, a, f),
299                (operators.eq, b, q),
300                (operators.eq, b, e),
301                (operators.eq, b, f),
302                (operators.eq, j, r),
303                (operators.eq, f, q),
304            ],
305        )
306
307    def test_subquery(self):
308        a, b, c = column("a"), column("b"), column("c")
309        subq = select([c]).where(c == a).as_scalar()
310        expr = and_(a == b, b == subq)
311        self._assert_traversal(
312            expr, [(operators.eq, a, b), (operators.eq, b, subq)]
313        )
314
315
316class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
317
318    """test copy-in-place behavior of various ClauseElements."""
319
320    __dialect__ = "default"
321
322    @classmethod
323    def setup_class(cls):
324        global t1, t2, t3
325        t1 = table("table1", column("col1"), column("col2"), column("col3"))
326        t2 = table("table2", column("col1"), column("col2"), column("col3"))
327        t3 = Table(
328            "table3",
329            MetaData(),
330            Column("col1", Integer),
331            Column("col2", Integer),
332        )
333
334    def test_binary(self):
335        clause = t1.c.col2 == t2.c.col2
336        eq_(str(clause), str(CloningVisitor().traverse(clause)))
337
338    def test_binary_anon_label_quirk(self):
339        t = table("t1", column("col1"))
340
341        f = t.c.col1 * 5
342        self.assert_compile(
343            select([f]), "SELECT t1.col1 * :col1_1 AS anon_1 FROM t1"
344        )
345
346        f.anon_label
347
348        a = t.alias()
349        f = sql_util.ClauseAdapter(a).traverse(f)
350
351        self.assert_compile(
352            select([f]), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1"
353        )
354
355    def test_join(self):
356        clause = t1.join(t2, t1.c.col2 == t2.c.col2)
357        c1 = str(clause)
358        assert str(clause) == str(CloningVisitor().traverse(clause))
359
360        class Vis(CloningVisitor):
361            def visit_binary(self, binary):
362                binary.right = t2.c.col3
363
364        clause2 = Vis().traverse(clause)
365        assert c1 == str(clause)
366        assert str(clause2) == str(t1.join(t2, t1.c.col2 == t2.c.col3))
367
368    def test_aliased_column_adapt(self):
369        t1.select()
370
371        aliased = t1.select().alias()
372        aliased2 = t1.alias()
373
374        adapter = sql_util.ColumnAdapter(aliased)
375
376        f = select([adapter.columns[c] for c in aliased2.c]).select_from(
377            aliased
378        )
379
380        s = select([aliased2]).select_from(aliased)
381        eq_(str(s), str(f))
382
383        f = select([adapter.columns[func.count(aliased2.c.col1)]]).select_from(
384            aliased
385        )
386        eq_(
387            str(select([func.count(aliased2.c.col1)]).select_from(aliased)),
388            str(f),
389        )
390
391    def test_aliased_cloned_column_adapt_inner(self):
392        clause = select([t1.c.col1, func.foo(t1.c.col2).label("foo")])
393
394        aliased1 = select([clause.c.col1, clause.c.foo])
395        aliased2 = clause
396        aliased2.c.col1, aliased2.c.foo
397        aliased3 = cloned_traverse(aliased2, {}, {})
398
399        # fixed by [ticket:2419].   the inside columns
400        # on aliased3 have _is_clone_of pointers to those of
401        # aliased2.  corresponding_column checks these
402        # now.
403        adapter = sql_util.ColumnAdapter(aliased1)
404        f1 = select([adapter.columns[c] for c in aliased2._raw_columns])
405        f2 = select([adapter.columns[c] for c in aliased3._raw_columns])
406        eq_(str(f1), str(f2))
407
408    def test_aliased_cloned_column_adapt_exported(self):
409        clause = select([t1.c.col1, func.foo(t1.c.col2).label("foo")])
410
411        aliased1 = select([clause.c.col1, clause.c.foo])
412        aliased2 = clause
413        aliased2.c.col1, aliased2.c.foo
414        aliased3 = cloned_traverse(aliased2, {}, {})
415
416        # also fixed by [ticket:2419].  When we look at the
417        # *outside* columns of aliased3, they previously did not
418        # have an _is_clone_of pointer.   But we now modified _make_proxy
419        # to assign this.
420        adapter = sql_util.ColumnAdapter(aliased1)
421        f1 = select([adapter.columns[c] for c in aliased2.c])
422        f2 = select([adapter.columns[c] for c in aliased3.c])
423        eq_(str(f1), str(f2))
424
425    def test_aliased_cloned_schema_column_adapt_exported(self):
426        clause = select([t3.c.col1, func.foo(t3.c.col2).label("foo")])
427
428        aliased1 = select([clause.c.col1, clause.c.foo])
429        aliased2 = clause
430        aliased2.c.col1, aliased2.c.foo
431        aliased3 = cloned_traverse(aliased2, {}, {})
432
433        # also fixed by [ticket:2419].  When we look at the
434        # *outside* columns of aliased3, they previously did not
435        # have an _is_clone_of pointer.   But we now modified _make_proxy
436        # to assign this.
437        adapter = sql_util.ColumnAdapter(aliased1)
438        f1 = select([adapter.columns[c] for c in aliased2.c])
439        f2 = select([adapter.columns[c] for c in aliased3.c])
440        eq_(str(f1), str(f2))
441
442    def test_labeled_expression_adapt(self):
443        lbl_x = (t3.c.col1 == 1).label("x")
444        t3_alias = t3.alias()
445
446        adapter = sql_util.ColumnAdapter(t3_alias)
447
448        lblx_adapted = adapter.traverse(lbl_x)
449        is_not(lblx_adapted._element, lbl_x._element)
450
451        lblx_adapted = adapter.traverse(lbl_x)
452        self.assert_compile(
453            select([lblx_adapted.self_group()]),
454            "SELECT (table3_1.col1 = :col1_1) AS x FROM table3 AS table3_1",
455        )
456
457        self.assert_compile(
458            select([lblx_adapted.is_(True)]),
459            "SELECT (table3_1.col1 = :col1_1) IS 1 AS anon_1 "
460            "FROM table3 AS table3_1",
461        )
462
463    def test_cte_w_union(self):
464        t = select([func.values(1).label("n")]).cte("t", recursive=True)
465        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
466        s = select([func.sum(t.c.n)])
467
468        from sqlalchemy.sql.visitors import cloned_traverse
469
470        cloned = cloned_traverse(s, {}, {})
471
472        self.assert_compile(
473            cloned,
474            "WITH RECURSIVE t(n) AS "
475            "(SELECT values(:values_1) AS n "
476            "UNION ALL SELECT t.n + :n_1 AS anon_1 "
477            "FROM t "
478            "WHERE t.n < :n_2) "
479            "SELECT sum(t.n) AS sum_1 FROM t",
480        )
481
482    def test_aliased_cte_w_union(self):
483        t = (
484            select([func.values(1).label("n")])
485            .cte("t", recursive=True)
486            .alias("foo")
487        )
488        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
489        s = select([func.sum(t.c.n)])
490
491        from sqlalchemy.sql.visitors import cloned_traverse
492
493        cloned = cloned_traverse(s, {}, {})
494
495        self.assert_compile(
496            cloned,
497            "WITH RECURSIVE foo(n) AS (SELECT values(:values_1) AS n "
498            "UNION ALL SELECT foo.n + :n_1 AS anon_1 FROM foo "
499            "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo",
500        )
501
502    def test_text(self):
503        clause = text("select * from table where foo=:bar").bindparams(
504            bindparam("bar")
505        )
506        c1 = str(clause)
507
508        class Vis(CloningVisitor):
509            def visit_textclause(self, text):
510                text.text = text.text + " SOME MODIFIER=:lala"
511                text._bindparams["lala"] = bindparam("lala")
512
513        clause2 = Vis().traverse(clause)
514        assert c1 == str(clause)
515        assert str(clause2) == c1 + " SOME MODIFIER=:lala"
516        assert list(clause._bindparams.keys()) == ["bar"]
517        assert set(clause2._bindparams.keys()) == set(["bar", "lala"])
518
519    def test_select(self):
520        s2 = select([t1])
521        s2_assert = str(s2)
522        s3_assert = str(select([t1], t1.c.col2 == 7))
523
524        class Vis(CloningVisitor):
525            def visit_select(self, select):
526                select.append_whereclause(t1.c.col2 == 7)
527
528        s3 = Vis().traverse(s2)
529        assert str(s3) == s3_assert
530        assert str(s2) == s2_assert
531        print(str(s2))
532        print(str(s3))
533
534        class Vis(ClauseVisitor):
535            def visit_select(self, select):
536                select.append_whereclause(t1.c.col2 == 7)
537
538        Vis().traverse(s2)
539        assert str(s2) == s3_assert
540
541        s4_assert = str(select([t1], and_(t1.c.col2 == 7, t1.c.col3 == 9)))
542
543        class Vis(CloningVisitor):
544            def visit_select(self, select):
545                select.append_whereclause(t1.c.col3 == 9)
546
547        s4 = Vis().traverse(s3)
548        print(str(s3))
549        print(str(s4))
550        assert str(s4) == s4_assert
551        assert str(s3) == s3_assert
552
553        s5_assert = str(select([t1], and_(t1.c.col2 == 7, t1.c.col1 == 9)))
554
555        class Vis(CloningVisitor):
556            def visit_binary(self, binary):
557                if binary.left is t1.c.col3:
558                    binary.left = t1.c.col1
559                    binary.right = bindparam("col1", unique=True)
560
561        s5 = Vis().traverse(s4)
562        print(str(s4))
563        print(str(s5))
564        assert str(s5) == s5_assert
565        assert str(s4) == s4_assert
566
567    def test_union(self):
568        u = union(t1.select(), t2.select())
569        u2 = CloningVisitor().traverse(u)
570        assert str(u) == str(u2)
571        assert [str(c) for c in u2.c] == [str(c) for c in u.c]
572
573        u = union(t1.select(), t2.select())
574        cols = [str(c) for c in u.c]
575        u2 = CloningVisitor().traverse(u)
576        assert str(u) == str(u2)
577        assert [str(c) for c in u2.c] == cols
578
579        s1 = select([t1], t1.c.col1 == bindparam("id_param"))
580        s2 = select([t2])
581        u = union(s1, s2)
582
583        u2 = u.params(id_param=7)
584        u3 = u.params(id_param=10)
585        assert str(u) == str(u2) == str(u3)
586        assert u2.compile().params == {"id_param": 7}
587        assert u3.compile().params == {"id_param": 10}
588
589    def test_in(self):
590        expr = t1.c.col1.in_(["foo", "bar"])
591        expr2 = CloningVisitor().traverse(expr)
592        assert str(expr) == str(expr2)
593
594    def test_over(self):
595        expr = func.row_number().over(order_by=t1.c.col1)
596        expr2 = CloningVisitor().traverse(expr)
597        assert str(expr) == str(expr2)
598
599        assert expr in visitors.iterate(expr, {})
600
601    def test_within_group(self):
602        expr = func.row_number().within_group(t1.c.col1)
603        expr2 = CloningVisitor().traverse(expr)
604        assert str(expr) == str(expr2)
605
606        assert expr in visitors.iterate(expr, {})
607
608    def test_funcfilter(self):
609        expr = func.count(1).filter(t1.c.col1 > 1)
610        expr2 = CloningVisitor().traverse(expr)
611        assert str(expr) == str(expr2)
612
613    def test_adapt_union(self):
614        u = union(
615            t1.select().where(t1.c.col1 == 4),
616            t1.select().where(t1.c.col1 == 5),
617        ).alias()
618
619        assert sql_util.ClauseAdapter(u).traverse(t1) is u
620
621    def test_binds(self):
622        """test that unique bindparams change their name upon clone()
623        to prevent conflicts"""
624
625        s = select([t1], t1.c.col1 == bindparam(None, unique=True)).alias()
626        s2 = CloningVisitor().traverse(s).alias()
627        s3 = select([s], s.c.col2 == s2.c.col2)
628
629        self.assert_compile(
630            s3,
631            "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM "
632            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
633            "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :param_1) "
634            "AS anon_1, "
635            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 "
636            "AS col3 FROM table1 WHERE table1.col1 = :param_2) AS anon_2 "
637            "WHERE anon_1.col2 = anon_2.col2",
638        )
639
640        s = select([t1], t1.c.col1 == 4).alias()
641        s2 = CloningVisitor().traverse(s).alias()
642        s3 = select([s], s.c.col2 == s2.c.col2)
643        self.assert_compile(
644            s3,
645            "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM "
646            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
647            "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) "
648            "AS anon_1, "
649            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 "
650            "AS col3 FROM table1 WHERE table1.col1 = :col1_2) AS anon_2 "
651            "WHERE anon_1.col2 = anon_2.col2",
652        )
653
654    def test_extract(self):
655        s = select([extract("foo", t1.c.col1).label("col1")])
656        self.assert_compile(
657            s, "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1"
658        )
659
660        s2 = CloningVisitor().traverse(s).alias()
661        s3 = select([s2.c.col1])
662        self.assert_compile(
663            s, "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1"
664        )
665        self.assert_compile(
666            s3,
667            "SELECT anon_1.col1 FROM (SELECT EXTRACT(foo FROM "
668            "table1.col1) AS col1 FROM table1) AS anon_1",
669        )
670
671    @testing.emits_warning(".*replaced by another column with the same key")
672    def test_alias(self):
673        subq = t2.select().alias("subq")
674        s = select(
675            [t1.c.col1, subq.c.col1],
676            from_obj=[t1, subq, t1.join(subq, t1.c.col1 == subq.c.col2)],
677        )
678        orig = str(s)
679        s2 = CloningVisitor().traverse(s)
680        assert orig == str(s) == str(s2)
681
682        s4 = CloningVisitor().traverse(s2)
683        assert orig == str(s) == str(s2) == str(s4)
684
685        s3 = sql_util.ClauseAdapter(table("foo")).traverse(s)
686        assert orig == str(s) == str(s3)
687
688        s4 = sql_util.ClauseAdapter(table("foo")).traverse(s3)
689        assert orig == str(s) == str(s3) == str(s4)
690
691        subq = subq.alias("subq")
692        s = select(
693            [t1.c.col1, subq.c.col1],
694            from_obj=[t1, subq, t1.join(subq, t1.c.col1 == subq.c.col2)],
695        )
696        s5 = CloningVisitor().traverse(s)
697        assert orig == str(s) == str(s5)
698
699    def test_correlated_select(self):
700        s = select(
701            [literal_column("*")], t1.c.col1 == t2.c.col1, from_obj=[t1, t2]
702        ).correlate(t2)
703
704        class Vis(CloningVisitor):
705            def visit_select(self, select):
706                select.append_whereclause(t1.c.col2 == 7)
707
708        self.assert_compile(
709            select([t2]).where(t2.c.col1 == Vis().traverse(s)),
710            "SELECT table2.col1, table2.col2, table2.col3 "
711            "FROM table2 WHERE table2.col1 = "
712            "(SELECT * FROM table1 WHERE table1.col1 = table2.col1 "
713            "AND table1.col2 = :col2_1)",
714        )
715
716    def test_this_thing(self):
717        s = select([t1]).where(t1.c.col1 == "foo").alias()
718        s2 = select([s.c.col1])
719
720        self.assert_compile(
721            s2,
722            "SELECT anon_1.col1 FROM (SELECT "
723            "table1.col1 AS col1, table1.col2 AS col2, "
724            "table1.col3 AS col3 FROM table1 WHERE "
725            "table1.col1 = :col1_1) AS anon_1",
726        )
727        t1a = t1.alias()
728        s2 = sql_util.ClauseAdapter(t1a).traverse(s2)
729        self.assert_compile(
730            s2,
731            "SELECT anon_1.col1 FROM (SELECT "
732            "table1_1.col1 AS col1, table1_1.col2 AS "
733            "col2, table1_1.col3 AS col3 FROM table1 "
734            "AS table1_1 WHERE table1_1.col1 = "
735            ":col1_1) AS anon_1",
736        )
737
738    def test_select_fromtwice_one(self):
739        t1a = t1.alias()
740
741        s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1a)
742        s = select([t1]).where(t1.c.col1 == s)
743        self.assert_compile(
744            s,
745            "SELECT table1.col1, table1.col2, table1.col3 FROM table1 "
746            "WHERE table1.col1 = "
747            "(SELECT 1 FROM table1, table1 AS table1_1 "
748            "WHERE table1.col1 = table1_1.col1)",
749        )
750        s = CloningVisitor().traverse(s)
751        self.assert_compile(
752            s,
753            "SELECT table1.col1, table1.col2, table1.col3 FROM table1 "
754            "WHERE table1.col1 = "
755            "(SELECT 1 FROM table1, table1 AS table1_1 "
756            "WHERE table1.col1 = table1_1.col1)",
757        )
758
759    def test_select_fromtwice_two(self):
760        s = select([t1]).where(t1.c.col1 == "foo").alias()
761
762        s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1)
763        s3 = select([t1]).where(t1.c.col1 == s2)
764        self.assert_compile(
765            s3,
766            "SELECT table1.col1, table1.col2, table1.col3 "
767            "FROM table1 WHERE table1.col1 = "
768            "(SELECT 1 FROM "
769            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
770            "table1.col3 AS col3 FROM table1 "
771            "WHERE table1.col1 = :col1_1) "
772            "AS anon_1 WHERE table1.col1 = anon_1.col1)",
773        )
774
775        s4 = ReplacingCloningVisitor().traverse(s3)
776        self.assert_compile(
777            s4,
778            "SELECT table1.col1, table1.col2, table1.col3 "
779            "FROM table1 WHERE table1.col1 = "
780            "(SELECT 1 FROM "
781            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
782            "table1.col3 AS col3 FROM table1 "
783            "WHERE table1.col1 = :col1_1) "
784            "AS anon_1 WHERE table1.col1 = anon_1.col1)",
785        )
786
787
788class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
789    __dialect__ = "default"
790
791    @classmethod
792    def setup_class(cls):
793        global t1, t2
794        t1 = table(
795            "table1",
796            column("col1"),
797            column("col2"),
798            column("col3"),
799            column("col4"),
800        )
801        t2 = table("table2", column("col1"), column("col2"), column("col3"))
802
803    def test_traverse_memoizes_w_columns(self):
804        t1a = t1.alias()
805        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
806
807        expr = select([t1a.c.col1]).label("x")
808        expr_adapted = adapter.traverse(expr)
809        is_not(expr, expr_adapted)
810        is_(adapter.columns[expr], expr_adapted)
811
812    def test_traverse_memoizes_w_itself(self):
813        t1a = t1.alias()
814        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
815
816        expr = select([t1a.c.col1]).label("x")
817        expr_adapted = adapter.traverse(expr)
818        is_not(expr, expr_adapted)
819        is_(adapter.traverse(expr), expr_adapted)
820
821    def test_columns_memoizes_w_itself(self):
822        t1a = t1.alias()
823        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
824
825        expr = select([t1a.c.col1]).label("x")
826        expr_adapted = adapter.columns[expr]
827        is_not(expr, expr_adapted)
828        is_(adapter.columns[expr], expr_adapted)
829
830    def test_wrapping_fallthrough(self):
831        t1a = t1.alias(name="t1a")
832        t2a = t2.alias(name="t2a")
833        a1 = sql_util.ColumnAdapter(t1a)
834
835        s1 = select([t1a.c.col1, t2a.c.col1]).apply_labels().alias()
836        a2 = sql_util.ColumnAdapter(s1)
837        a3 = a2.wrap(a1)
838        a4 = a1.wrap(a2)
839        a5 = a1.chain(a2)
840
841        # t1.c.col1 -> s1.c.t1a_col1
842
843        # adapted by a2
844        is_(a3.columns[t1.c.col1], s1.c.t1a_col1)
845        is_(a4.columns[t1.c.col1], s1.c.t1a_col1)
846
847        # chaining can't fall through because a1 grabs it
848        # first
849        is_(a5.columns[t1.c.col1], t1a.c.col1)
850
851        # t2.c.col1 -> s1.c.t2a_col1
852
853        # adapted by a2
854        is_(a3.columns[t2.c.col1], s1.c.t2a_col1)
855        is_(a4.columns[t2.c.col1], s1.c.t2a_col1)
856        # chaining, t2 hits s1
857        is_(a5.columns[t2.c.col1], s1.c.t2a_col1)
858
859        # t1.c.col2 -> t1a.c.col2
860
861        # fallthrough to a1
862        is_(a3.columns[t1.c.col2], t1a.c.col2)
863        is_(a4.columns[t1.c.col2], t1a.c.col2)
864
865        # chaining hits a1
866        is_(a5.columns[t1.c.col2], t1a.c.col2)
867
868        # t2.c.col2 -> t2.c.col2
869
870        # fallthrough to no adaption
871        is_(a3.columns[t2.c.col2], t2.c.col2)
872        is_(a4.columns[t2.c.col2], t2.c.col2)
873
874    def test_wrapping_ordering(self):
875        """illustrate an example where order of wrappers matters.
876
877        This test illustrates both the ordering being significant
878        as well as a scenario where multiple translations are needed
879        (e.g. wrapping vs. chaining).
880
881        """
882
883        stmt = select([t1.c.col1, t2.c.col1]).apply_labels()
884
885        sa = stmt.alias()
886        stmt2 = select([t2, sa])
887
888        a1 = sql_util.ColumnAdapter(stmt)
889        a2 = sql_util.ColumnAdapter(stmt2)
890
891        a2_to_a1 = a2.wrap(a1)
892        a1_to_a2 = a1.wrap(a2)
893
894        # when stmt2 and stmt represent the same column
895        # in different contexts, order of wrapping matters
896
897        # t2.c.col1 via a2 is stmt2.c.col1; then ignored by a1
898        is_(a2_to_a1.columns[t2.c.col1], stmt2.c.col1)
899        # t2.c.col1 via a1 is stmt.c.table2_col1; a2 then
900        # sends this to stmt2.c.table2_col1
901        is_(a1_to_a2.columns[t2.c.col1], stmt2.c.table2_col1)
902
903        # for mutually exclusive columns, order doesn't matter
904        is_(a2_to_a1.columns[t1.c.col1], stmt2.c.table1_col1)
905        is_(a1_to_a2.columns[t1.c.col1], stmt2.c.table1_col1)
906        is_(a2_to_a1.columns[t2.c.col2], stmt2.c.col2)
907
908    def test_wrapping_multiple(self):
909        """illustrate that wrapping runs both adapters"""
910
911        t1a = t1.alias(name="t1a")
912        t2a = t2.alias(name="t2a")
913        a1 = sql_util.ColumnAdapter(t1a)
914        a2 = sql_util.ColumnAdapter(t2a)
915        a3 = a2.wrap(a1)
916
917        stmt = select([t1.c.col1, t2.c.col2])
918
919        self.assert_compile(
920            a3.traverse(stmt),
921            "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a",
922        )
923
924        # chaining does too because these adapters don't share any
925        # columns
926        a4 = a2.chain(a1)
927        self.assert_compile(
928            a4.traverse(stmt),
929            "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a",
930        )
931
932    def test_wrapping_inclusions(self):
933        """test wrapping and inclusion rules together,
934        taking into account multiple objects with equivalent hash identity."""
935
936        t1a = t1.alias(name="t1a")
937        t2a = t2.alias(name="t2a")
938        a1 = sql_util.ColumnAdapter(
939            t1a, include_fn=lambda col: "a1" in col._annotations
940        )
941
942        s1 = select([t1a, t2a]).apply_labels().alias()
943        a2 = sql_util.ColumnAdapter(
944            s1, include_fn=lambda col: "a2" in col._annotations
945        )
946        a3 = a2.wrap(a1)
947
948        c1a1 = t1.c.col1._annotate(dict(a1=True))
949        c1a2 = t1.c.col1._annotate(dict(a2=True))
950        c1aa = t1.c.col1._annotate(dict(a1=True, a2=True))
951
952        c2a1 = t2.c.col1._annotate(dict(a1=True))
953        c2a2 = t2.c.col1._annotate(dict(a2=True))
954        c2aa = t2.c.col1._annotate(dict(a1=True, a2=True))
955
956        is_(a3.columns[c1a1], t1a.c.col1)
957        is_(a3.columns[c1a2], s1.c.t1a_col1)
958        is_(a3.columns[c1aa], s1.c.t1a_col1)
959
960        # not covered by a1, accepted by a2
961        is_(a3.columns[c2aa], s1.c.t2a_col1)
962
963        # not covered by a1, accepted by a2
964        is_(a3.columns[c2a2], s1.c.t2a_col1)
965        # not covered by a1, rejected by a2
966        is_(a3.columns[c2a1], c2a1)
967
968
969class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
970    __dialect__ = "default"
971
972    @classmethod
973    def setup_class(cls):
974        global t1, t2
975        t1 = table("table1", column("col1"), column("col2"), column("col3"))
976        t2 = table("table2", column("col1"), column("col2"), column("col3"))
977
978    def test_correlation_on_clone(self):
979        t1alias = t1.alias("t1alias")
980        t2alias = t2.alias("t2alias")
981        vis = sql_util.ClauseAdapter(t1alias)
982
983        s = select(
984            [literal_column("*")], from_obj=[t1alias, t2alias]
985        ).as_scalar()
986        assert t2alias in s._froms
987        assert t1alias in s._froms
988
989        self.assert_compile(
990            select([literal_column("*")], t2alias.c.col1 == s),
991            "SELECT * FROM table2 AS t2alias WHERE "
992            "t2alias.col1 = (SELECT * FROM table1 AS "
993            "t1alias)",
994        )
995        s = vis.traverse(s)
996
997        assert t2alias not in s._froms  # not present because it's been
998        # cloned
999        assert t1alias in s._froms  # present because the adapter placed
1000        # it there
1001
1002        # correlate list on "s" needs to take into account the full
1003        # _cloned_set for each element in _froms when correlating
1004
1005        self.assert_compile(
1006            select([literal_column("*")], t2alias.c.col1 == s),
1007            "SELECT * FROM table2 AS t2alias WHERE "
1008            "t2alias.col1 = (SELECT * FROM table1 AS "
1009            "t1alias)",
1010        )
1011        s = (
1012            select([literal_column("*")], from_obj=[t1alias, t2alias])
1013            .correlate(t2alias)
1014            .as_scalar()
1015        )
1016        self.assert_compile(
1017            select([literal_column("*")], t2alias.c.col1 == s),
1018            "SELECT * FROM table2 AS t2alias WHERE "
1019            "t2alias.col1 = (SELECT * FROM table1 AS "
1020            "t1alias)",
1021        )
1022        s = vis.traverse(s)
1023        self.assert_compile(
1024            select([literal_column("*")], t2alias.c.col1 == s),
1025            "SELECT * FROM table2 AS t2alias WHERE "
1026            "t2alias.col1 = (SELECT * FROM table1 AS "
1027            "t1alias)",
1028        )
1029        s = CloningVisitor().traverse(s)
1030        self.assert_compile(
1031            select([literal_column("*")], t2alias.c.col1 == s),
1032            "SELECT * FROM table2 AS t2alias WHERE "
1033            "t2alias.col1 = (SELECT * FROM table1 AS "
1034            "t1alias)",
1035        )
1036
1037        s = (
1038            select([literal_column("*")])
1039            .where(t1.c.col1 == t2.c.col1)
1040            .as_scalar()
1041        )
1042        self.assert_compile(
1043            select([t1.c.col1, s]),
1044            "SELECT table1.col1, (SELECT * FROM table2 "
1045            "WHERE table1.col1 = table2.col1) AS "
1046            "anon_1 FROM table1",
1047        )
1048        vis = sql_util.ClauseAdapter(t1alias)
1049        s = vis.traverse(s)
1050        self.assert_compile(
1051            select([t1alias.c.col1, s]),
1052            "SELECT t1alias.col1, (SELECT * FROM "
1053            "table2 WHERE t1alias.col1 = table2.col1) "
1054            "AS anon_1 FROM table1 AS t1alias",
1055        )
1056        s = CloningVisitor().traverse(s)
1057        self.assert_compile(
1058            select([t1alias.c.col1, s]),
1059            "SELECT t1alias.col1, (SELECT * FROM "
1060            "table2 WHERE t1alias.col1 = table2.col1) "
1061            "AS anon_1 FROM table1 AS t1alias",
1062        )
1063        s = (
1064            select([literal_column("*")])
1065            .where(t1.c.col1 == t2.c.col1)
1066            .correlate(t1)
1067            .as_scalar()
1068        )
1069        self.assert_compile(
1070            select([t1.c.col1, s]),
1071            "SELECT table1.col1, (SELECT * FROM table2 "
1072            "WHERE table1.col1 = table2.col1) AS "
1073            "anon_1 FROM table1",
1074        )
1075        vis = sql_util.ClauseAdapter(t1alias)
1076        s = vis.traverse(s)
1077        self.assert_compile(
1078            select([t1alias.c.col1, s]),
1079            "SELECT t1alias.col1, (SELECT * FROM "
1080            "table2 WHERE t1alias.col1 = table2.col1) "
1081            "AS anon_1 FROM table1 AS t1alias",
1082        )
1083        s = CloningVisitor().traverse(s)
1084        self.assert_compile(
1085            select([t1alias.c.col1, s]),
1086            "SELECT t1alias.col1, (SELECT * FROM "
1087            "table2 WHERE t1alias.col1 = table2.col1) "
1088            "AS anon_1 FROM table1 AS t1alias",
1089        )
1090
1091    def test_correlate_except_on_clone(self):
1092        # test [ticket:4537]'s issue
1093
1094        t1alias = t1.alias("t1alias")
1095        j = t1.join(t1alias, t1.c.col1 == t1alias.c.col2)
1096
1097        vis = sql_util.ClauseAdapter(j)
1098
1099        # "control" subquery - uses correlate which has worked w/ adaption
1100        # for a long time
1101        control_s = (
1102            select([t2.c.col1])
1103            .where(t2.c.col1 == t1.c.col1)
1104            .correlate(t2)
1105            .as_scalar()
1106        )
1107
1108        # test subquery - given only t1 and t2 in the enclosing selectable,
1109        # will do the same thing as the "control" query since the correlation
1110        # works out the same
1111        s = (
1112            select([t2.c.col1])
1113            .where(t2.c.col1 == t1.c.col1)
1114            .correlate_except(t1)
1115            .as_scalar()
1116        )
1117
1118        # use both subqueries in statements
1119        control_stmt = select([control_s, t1.c.col1, t2.c.col1]).select_from(
1120            t1.join(t2, t1.c.col1 == t2.c.col1)
1121        )
1122
1123        stmt = select([s, t1.c.col1, t2.c.col1]).select_from(
1124            t1.join(t2, t1.c.col1 == t2.c.col1)
1125        )
1126        # they are the same
1127        self.assert_compile(
1128            control_stmt,
1129            "SELECT "
1130            "(SELECT table2.col1 FROM table1 "
1131            "WHERE table2.col1 = table1.col1) AS anon_1, "
1132            "table1.col1, table2.col1 "
1133            "FROM table1 "
1134            "JOIN table2 ON table1.col1 = table2.col1",
1135        )
1136        self.assert_compile(
1137            stmt,
1138            "SELECT "
1139            "(SELECT table2.col1 FROM table1 "
1140            "WHERE table2.col1 = table1.col1) AS anon_1, "
1141            "table1.col1, table2.col1 "
1142            "FROM table1 "
1143            "JOIN table2 ON table1.col1 = table2.col1",
1144        )
1145
1146        # now test against the adaption of "t1" into "t1 JOIN t1alias".
1147        # note in the control case, we aren't actually testing that
1148        # Select is processing the "correlate" list during the adaption
1149        # since we aren't adapting the "correlate"
1150        self.assert_compile(
1151            vis.traverse(control_stmt),
1152            "SELECT "
1153            "(SELECT table2.col1 FROM "
1154            "table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1155            "WHERE table2.col1 = table1.col1) AS anon_1, "
1156            "table1.col1, table2.col1 "
1157            "FROM table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1158            "JOIN table2 ON table1.col1 = table2.col1",
1159        )
1160
1161        # but here, correlate_except() does have the thing we're adapting
1162        # so whatever is in there has to be expanded out to include
1163        # the adaptation target, in this case "t1 JOIN t1alias".
1164        self.assert_compile(
1165            vis.traverse(stmt),
1166            "SELECT "
1167            "(SELECT table2.col1 FROM "
1168            "table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1169            "WHERE table2.col1 = table1.col1) AS anon_1, "
1170            "table1.col1, table2.col1 "
1171            "FROM table1 JOIN table1 AS t1alias ON table1.col1 = t1alias.col2 "
1172            "JOIN table2 ON table1.col1 = table2.col1",
1173        )
1174
1175    @testing.fails_on_everything_except()
1176    def test_joins_dont_adapt(self):
1177        # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't
1178        # make much sense. ClauseAdapter doesn't make any changes if
1179        # it's against a straight join.
1180
1181        users = table("users", column("id"))
1182        addresses = table("addresses", column("id"), column("user_id"))
1183
1184        ualias = users.alias()
1185
1186        s = select(
1187            [func.count(addresses.c.id)], users.c.id == addresses.c.user_id
1188        ).correlate(users)
1189        s = sql_util.ClauseAdapter(ualias).traverse(s)
1190
1191        j1 = addresses.join(ualias, addresses.c.user_id == ualias.c.id)
1192
1193        self.assert_compile(
1194            sql_util.ClauseAdapter(j1).traverse(s),
1195            "SELECT count(addresses.id) AS count_1 "
1196            "FROM addresses WHERE users_1.id = "
1197            "addresses.user_id",
1198        )
1199
1200    def test_table_to_alias_1(self):
1201        t1alias = t1.alias("t1alias")
1202
1203        vis = sql_util.ClauseAdapter(t1alias)
1204        ff = vis.traverse(func.count(t1.c.col1).label("foo"))
1205        assert list(_from_objects(ff)) == [t1alias]
1206
1207    def test_table_to_alias_2(self):
1208        t1alias = t1.alias("t1alias")
1209        vis = sql_util.ClauseAdapter(t1alias)
1210        self.assert_compile(
1211            vis.traverse(select([literal_column("*")], from_obj=[t1])),
1212            "SELECT * FROM table1 AS t1alias",
1213        )
1214
1215    def test_table_to_alias_3(self):
1216        t1alias = t1.alias("t1alias")
1217        vis = sql_util.ClauseAdapter(t1alias)
1218        self.assert_compile(
1219            vis.traverse(
1220                select([literal_column("*")], t1.c.col1 == t2.c.col2)
1221            ),
1222            "SELECT * FROM table1 AS t1alias, table2 "
1223            "WHERE t1alias.col1 = table2.col2",
1224        )
1225
1226    def test_table_to_alias_4(self):
1227        t1alias = t1.alias("t1alias")
1228        vis = sql_util.ClauseAdapter(t1alias)
1229        self.assert_compile(
1230            vis.traverse(
1231                select(
1232                    [literal_column("*")],
1233                    t1.c.col1 == t2.c.col2,
1234                    from_obj=[t1, t2],
1235                )
1236            ),
1237            "SELECT * FROM table1 AS t1alias, table2 "
1238            "WHERE t1alias.col1 = table2.col2",
1239        )
1240
1241    def test_table_to_alias_5(self):
1242        t1alias = t1.alias("t1alias")
1243        vis = sql_util.ClauseAdapter(t1alias)
1244        self.assert_compile(
1245            select([t1alias, t2]).where(
1246                t1alias.c.col1
1247                == vis.traverse(
1248                    select(
1249                        [literal_column("*")],
1250                        t1.c.col1 == t2.c.col2,
1251                        from_obj=[t1, t2],
1252                    ).correlate(t1)
1253                )
1254            ),
1255            "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
1256            "table2.col1, table2.col2, table2.col3 "
1257            "FROM table1 AS t1alias, table2 WHERE t1alias.col1 = "
1258            "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)",
1259        )
1260
1261    def test_table_to_alias_6(self):
1262        t1alias = t1.alias("t1alias")
1263        vis = sql_util.ClauseAdapter(t1alias)
1264        self.assert_compile(
1265            select([t1alias, t2]).where(
1266                t1alias.c.col1
1267                == vis.traverse(
1268                    select(
1269                        [literal_column("*")],
1270                        t1.c.col1 == t2.c.col2,
1271                        from_obj=[t1, t2],
1272                    ).correlate(t2)
1273                )
1274            ),
1275            "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
1276            "table2.col1, table2.col2, table2.col3 "
1277            "FROM table1 AS t1alias, table2 "
1278            "WHERE t1alias.col1 = "
1279            "(SELECT * FROM table1 AS t1alias "
1280            "WHERE t1alias.col1 = table2.col2)",
1281        )
1282
1283    def test_table_to_alias_7(self):
1284        t1alias = t1.alias("t1alias")
1285        vis = sql_util.ClauseAdapter(t1alias)
1286        self.assert_compile(
1287            vis.traverse(case([(t1.c.col1 == 5, t1.c.col2)], else_=t1.c.col1)),
1288            "CASE WHEN (t1alias.col1 = :col1_1) THEN "
1289            "t1alias.col2 ELSE t1alias.col1 END",
1290        )
1291
1292    def test_table_to_alias_8(self):
1293        t1alias = t1.alias("t1alias")
1294        vis = sql_util.ClauseAdapter(t1alias)
1295        self.assert_compile(
1296            vis.traverse(
1297                case([(5, t1.c.col2)], value=t1.c.col1, else_=t1.c.col1)
1298            ),
1299            "CASE t1alias.col1 WHEN :param_1 THEN "
1300            "t1alias.col2 ELSE t1alias.col1 END",
1301        )
1302
1303    def test_table_to_alias_9(self):
1304        s = select([literal_column("*")], from_obj=[t1]).alias("foo")
1305        self.assert_compile(
1306            s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo"
1307        )
1308
1309    def test_table_to_alias_10(self):
1310        s = select([literal_column("*")], from_obj=[t1]).alias("foo")
1311        t1alias = t1.alias("t1alias")
1312        vis = sql_util.ClauseAdapter(t1alias)
1313        self.assert_compile(
1314            vis.traverse(s.select()),
1315            "SELECT foo.* FROM (SELECT * FROM table1 " "AS t1alias) AS foo",
1316        )
1317
1318    def test_table_to_alias_11(self):
1319        s = select([literal_column("*")], from_obj=[t1]).alias("foo")
1320        self.assert_compile(
1321            s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo"
1322        )
1323
1324    def test_table_to_alias_12(self):
1325        t1alias = t1.alias("t1alias")
1326        vis = sql_util.ClauseAdapter(t1alias)
1327        ff = vis.traverse(func.count(t1.c.col1).label("foo"))
1328        self.assert_compile(
1329            select([ff]),
1330            "SELECT count(t1alias.col1) AS foo FROM " "table1 AS t1alias",
1331        )
1332        assert list(_from_objects(ff)) == [t1alias]
1333
1334    # def test_table_to_alias_2(self):
1335    # TODO: self.assert_compile(vis.traverse(select([func.count(t1.c
1336    # .col1).l abel('foo')]), clone=True), "SELECT
1337    # count(t1alias.col1) AS foo FROM table1 AS t1alias")
1338
1339    def test_table_to_alias_13(self):
1340        t1alias = t1.alias("t1alias")
1341        vis = sql_util.ClauseAdapter(t1alias)
1342        t2alias = t2.alias("t2alias")
1343        vis.chain(sql_util.ClauseAdapter(t2alias))
1344        self.assert_compile(
1345            vis.traverse(
1346                select([literal_column("*")], t1.c.col1 == t2.c.col2)
1347            ),
1348            "SELECT * FROM table1 AS t1alias, table2 "
1349            "AS t2alias WHERE t1alias.col1 = "
1350            "t2alias.col2",
1351        )
1352
1353    def test_table_to_alias_14(self):
1354        t1alias = t1.alias("t1alias")
1355        vis = sql_util.ClauseAdapter(t1alias)
1356        t2alias = t2.alias("t2alias")
1357        vis.chain(sql_util.ClauseAdapter(t2alias))
1358        self.assert_compile(
1359            vis.traverse(
1360                select(["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2])
1361            ),
1362            "SELECT * FROM table1 AS t1alias, table2 "
1363            "AS t2alias WHERE t1alias.col1 = "
1364            "t2alias.col2",
1365        )
1366
1367    def test_table_to_alias_15(self):
1368        t1alias = t1.alias("t1alias")
1369        vis = sql_util.ClauseAdapter(t1alias)
1370        t2alias = t2.alias("t2alias")
1371        vis.chain(sql_util.ClauseAdapter(t2alias))
1372        self.assert_compile(
1373            select([t1alias, t2alias]).where(
1374                t1alias.c.col1
1375                == vis.traverse(
1376                    select(
1377                        ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]
1378                    ).correlate(t1)
1379                )
1380            ),
1381            "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
1382            "t2alias.col1, t2alias.col2, t2alias.col3 "
1383            "FROM table1 AS t1alias, table2 AS t2alias "
1384            "WHERE t1alias.col1 = "
1385            "(SELECT * FROM table2 AS t2alias "
1386            "WHERE t1alias.col1 = t2alias.col2)",
1387        )
1388
1389    def test_table_to_alias_16(self):
1390        t1alias = t1.alias("t1alias")
1391        vis = sql_util.ClauseAdapter(t1alias)
1392        t2alias = t2.alias("t2alias")
1393        vis.chain(sql_util.ClauseAdapter(t2alias))
1394        self.assert_compile(
1395            t2alias.select().where(
1396                t2alias.c.col2
1397                == vis.traverse(
1398                    select(
1399                        ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]
1400                    ).correlate(t2)
1401                )
1402            ),
1403            "SELECT t2alias.col1, t2alias.col2, t2alias.col3 "
1404            "FROM table2 AS t2alias WHERE t2alias.col2 = "
1405            "(SELECT * FROM table1 AS t1alias WHERE "
1406            "t1alias.col1 = t2alias.col2)",
1407        )
1408
1409    def test_include_exclude(self):
1410        m = MetaData()
1411        a = Table(
1412            "a",
1413            m,
1414            Column("id", Integer, primary_key=True),
1415            Column(
1416                "xxx_id",
1417                Integer,
1418                ForeignKey("a.id", name="adf", use_alter=True),
1419            ),
1420        )
1421
1422        e = a.c.id == a.c.xxx_id
1423        assert str(e) == "a.id = a.xxx_id"
1424        b = a.alias()
1425
1426        e = sql_util.ClauseAdapter(
1427            b,
1428            include_fn=lambda x: x in set([a.c.id]),
1429            equivalents={a.c.id: set([a.c.id])},
1430        ).traverse(e)
1431
1432        assert str(e) == "a_1.id = a.xxx_id"
1433
1434    def test_recursive_equivalents(self):
1435        m = MetaData()
1436        a = Table("a", m, Column("x", Integer), Column("y", Integer))
1437        b = Table("b", m, Column("x", Integer), Column("y", Integer))
1438        c = Table("c", m, Column("x", Integer), Column("y", Integer))
1439
1440        # force a recursion overflow, by linking a.c.x<->c.c.x, and
1441        # asking for a nonexistent col.  corresponding_column should prevent
1442        # endless depth.
1443        adapt = sql_util.ClauseAdapter(
1444            b, equivalents={a.c.x: set([c.c.x]), c.c.x: set([a.c.x])}
1445        )
1446        assert adapt._corresponding_column(a.c.x, False) is None
1447
1448    def test_multilevel_equivalents(self):
1449        m = MetaData()
1450        a = Table("a", m, Column("x", Integer), Column("y", Integer))
1451        b = Table("b", m, Column("x", Integer), Column("y", Integer))
1452        c = Table("c", m, Column("x", Integer), Column("y", Integer))
1453
1454        alias = select([a]).select_from(a.join(b, a.c.x == b.c.x)).alias()
1455
1456        # two levels of indirection from c.x->b.x->a.x, requires recursive
1457        # corresponding_column call
1458        adapt = sql_util.ClauseAdapter(
1459            alias, equivalents={b.c.x: set([a.c.x]), c.c.x: set([b.c.x])}
1460        )
1461        assert adapt._corresponding_column(a.c.x, False) is alias.c.x
1462        assert adapt._corresponding_column(c.c.x, False) is alias.c.x
1463
1464    def test_join_to_alias(self):
1465        metadata = MetaData()
1466        a = Table("a", metadata, Column("id", Integer, primary_key=True))
1467        b = Table(
1468            "b",
1469            metadata,
1470            Column("id", Integer, primary_key=True),
1471            Column("aid", Integer, ForeignKey("a.id")),
1472        )
1473        c = Table(
1474            "c",
1475            metadata,
1476            Column("id", Integer, primary_key=True),
1477            Column("bid", Integer, ForeignKey("b.id")),
1478        )
1479
1480        d = Table(
1481            "d",
1482            metadata,
1483            Column("id", Integer, primary_key=True),
1484            Column("aid", Integer, ForeignKey("a.id")),
1485        )
1486
1487        j1 = a.outerjoin(b)
1488        j2 = select([j1], use_labels=True)
1489
1490        j3 = c.join(j2, j2.c.b_id == c.c.bid)
1491
1492        j4 = j3.outerjoin(d)
1493        self.assert_compile(
1494            j4,
1495            "c JOIN (SELECT a.id AS a_id, b.id AS "
1496            "b_id, b.aid AS b_aid FROM a LEFT OUTER "
1497            "JOIN b ON a.id = b.aid) ON b_id = c.bid "
1498            "LEFT OUTER JOIN d ON a_id = d.aid",
1499        )
1500        j5 = j3.alias("foo")
1501        j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0]
1502
1503        # this statement takes c join(a join b), wraps it inside an
1504        # aliased "select * from c join(a join b) AS foo". the outermost
1505        # right side "left outer join d" stays the same, except "d"
1506        # joins against foo.a_id instead of plain "a_id"
1507
1508        self.assert_compile(
1509            j6,
1510            "(SELECT c.id AS c_id, c.bid AS c_bid, "
1511            "a_id AS a_id, b_id AS b_id, b_aid AS "
1512            "b_aid FROM c JOIN (SELECT a.id AS a_id, "
1513            "b.id AS b_id, b.aid AS b_aid FROM a LEFT "
1514            "OUTER JOIN b ON a.id = b.aid) ON b_id = "
1515            "c.bid) AS foo LEFT OUTER JOIN d ON "
1516            "foo.a_id = d.aid",
1517        )
1518
1519    def test_derived_from(self):
1520        assert select([t1]).is_derived_from(t1)
1521        assert not select([t2]).is_derived_from(t1)
1522        assert not t1.is_derived_from(select([t1]))
1523        assert t1.alias().is_derived_from(t1)
1524
1525        s1 = select([t1, t2]).alias("foo")
1526        s2 = select([s1]).limit(5).offset(10).alias()
1527        assert s2.is_derived_from(s1)
1528        s2 = s2._clone()
1529        assert s2.is_derived_from(s1)
1530
1531    def test_aliasedselect_to_aliasedselect_straight(self):
1532
1533        # original issue from ticket #904
1534
1535        s1 = select([t1]).alias("foo")
1536        s2 = select([s1]).limit(5).offset(10).alias()
1537        self.assert_compile(
1538            sql_util.ClauseAdapter(s2).traverse(s1),
1539            "SELECT foo.col1, foo.col2, foo.col3 FROM "
1540            "(SELECT table1.col1 AS col1, table1.col2 "
1541            "AS col2, table1.col3 AS col3 FROM table1) "
1542            "AS foo LIMIT :param_1 OFFSET :param_2",
1543            {"param_1": 5, "param_2": 10},
1544        )
1545
1546    def test_aliasedselect_to_aliasedselect_join(self):
1547        s1 = select([t1]).alias("foo")
1548        s2 = select([s1]).limit(5).offset(10).alias()
1549        j = s1.outerjoin(t2, s1.c.col1 == t2.c.col1)
1550        self.assert_compile(
1551            sql_util.ClauseAdapter(s2).traverse(j).select(),
1552            "SELECT anon_1.col1, anon_1.col2, "
1553            "anon_1.col3, table2.col1, table2.col2, "
1554            "table2.col3 FROM (SELECT foo.col1 AS "
1555            "col1, foo.col2 AS col2, foo.col3 AS col3 "
1556            "FROM (SELECT table1.col1 AS col1, "
1557            "table1.col2 AS col2, table1.col3 AS col3 "
1558            "FROM table1) AS foo LIMIT :param_1 OFFSET "
1559            ":param_2) AS anon_1 LEFT OUTER JOIN "
1560            "table2 ON anon_1.col1 = table2.col1",
1561            {"param_1": 5, "param_2": 10},
1562        )
1563
1564    def test_aliasedselect_to_aliasedselect_join_nested_table(self):
1565        s1 = select([t1]).alias("foo")
1566        s2 = select([s1]).limit(5).offset(10).alias()
1567        talias = t1.alias("bar")
1568
1569        assert not s2.is_derived_from(talias)
1570        j = s1.outerjoin(talias, s1.c.col1 == talias.c.col1)
1571
1572        self.assert_compile(
1573            sql_util.ClauseAdapter(s2).traverse(j).select(),
1574            "SELECT anon_1.col1, anon_1.col2, "
1575            "anon_1.col3, bar.col1, bar.col2, bar.col3 "
1576            "FROM (SELECT foo.col1 AS col1, foo.col2 "
1577            "AS col2, foo.col3 AS col3 FROM (SELECT "
1578            "table1.col1 AS col1, table1.col2 AS col2, "
1579            "table1.col3 AS col3 FROM table1) AS foo "
1580            "LIMIT :param_1 OFFSET :param_2) AS anon_1 "
1581            "LEFT OUTER JOIN table1 AS bar ON "
1582            "anon_1.col1 = bar.col1",
1583            {"param_1": 5, "param_2": 10},
1584        )
1585
1586    def test_functions(self):
1587        self.assert_compile(
1588            sql_util.ClauseAdapter(t1.alias()).traverse(func.count(t1.c.col1)),
1589            "count(table1_1.col1)",
1590        )
1591        s = select([func.count(t1.c.col1)])
1592        self.assert_compile(
1593            sql_util.ClauseAdapter(t1.alias()).traverse(s),
1594            "SELECT count(table1_1.col1) AS count_1 "
1595            "FROM table1 AS table1_1",
1596        )
1597
1598    def test_recursive(self):
1599        metadata = MetaData()
1600        a = Table("a", metadata, Column("id", Integer, primary_key=True))
1601        b = Table(
1602            "b",
1603            metadata,
1604            Column("id", Integer, primary_key=True),
1605            Column("aid", Integer, ForeignKey("a.id")),
1606        )
1607        c = Table(
1608            "c",
1609            metadata,
1610            Column("id", Integer, primary_key=True),
1611            Column("bid", Integer, ForeignKey("b.id")),
1612        )
1613
1614        d = Table(
1615            "d",
1616            metadata,
1617            Column("id", Integer, primary_key=True),
1618            Column("aid", Integer, ForeignKey("a.id")),
1619        )
1620
1621        u = union(
1622            a.join(b).select().apply_labels(),
1623            a.join(d).select().apply_labels(),
1624        ).alias()
1625
1626        self.assert_compile(
1627            sql_util.ClauseAdapter(u).traverse(
1628                select([c.c.bid]).where(c.c.bid == u.c.b_aid)
1629            ),
1630            "SELECT c.bid "
1631            "FROM c, (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid "
1632            "FROM a JOIN b ON a.id = b.aid UNION SELECT a.id AS a_id, d.id "
1633            "AS d_id, d.aid AS d_aid "
1634            "FROM a JOIN d ON a.id = d.aid) AS anon_1 "
1635            "WHERE c.bid = anon_1.b_aid",
1636        )
1637
1638    def test_label_anonymize_one(self):
1639        t1a = t1.alias()
1640        adapter = sql_util.ClauseAdapter(t1a, anonymize_labels=True)
1641
1642        expr = select([t1.c.col2]).where(t1.c.col3 == 5).label("expr")
1643        expr_adapted = adapter.traverse(expr)
1644
1645        stmt = select([expr, expr_adapted]).order_by(expr, expr_adapted)
1646        self.assert_compile(
1647            stmt,
1648            "SELECT "
1649            "(SELECT table1.col2 FROM table1 WHERE table1.col3 = :col3_1) "
1650            "AS expr, "
1651            "(SELECT table1_1.col2 FROM table1 AS table1_1 "
1652            "WHERE table1_1.col3 = :col3_2) AS anon_1 "
1653            "ORDER BY expr, anon_1",
1654        )
1655
1656    def test_label_anonymize_two(self):
1657        t1a = t1.alias()
1658        adapter = sql_util.ClauseAdapter(t1a, anonymize_labels=True)
1659
1660        expr = select([t1.c.col2]).where(t1.c.col3 == 5).label(None)
1661        expr_adapted = adapter.traverse(expr)
1662
1663        stmt = select([expr, expr_adapted]).order_by(expr, expr_adapted)
1664        self.assert_compile(
1665            stmt,
1666            "SELECT "
1667            "(SELECT table1.col2 FROM table1 WHERE table1.col3 = :col3_1) "
1668            "AS anon_1, "
1669            "(SELECT table1_1.col2 FROM table1 AS table1_1 "
1670            "WHERE table1_1.col3 = :col3_2) AS anon_2 "
1671            "ORDER BY anon_1, anon_2",
1672        )
1673
1674    def test_label_anonymize_three(self):
1675        t1a = t1.alias()
1676        adapter = sql_util.ColumnAdapter(
1677            t1a, anonymize_labels=True, allow_label_resolve=False
1678        )
1679
1680        expr = select([t1.c.col2]).where(t1.c.col3 == 5).label(None)
1681        l1 = expr
1682        is_(l1._order_by_label_element, l1)
1683        eq_(l1._allow_label_resolve, True)
1684
1685        expr_adapted = adapter.traverse(expr)
1686        l2 = expr_adapted
1687        is_(l2._order_by_label_element, l2)
1688        eq_(l2._allow_label_resolve, False)
1689
1690        l3 = adapter.traverse(expr)
1691        is_(l3._order_by_label_element, l3)
1692        eq_(l3._allow_label_resolve, False)
1693
1694
1695class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
1696    __dialect__ = "default"
1697
1698    @classmethod
1699    def setup_class(cls):
1700        global table1, table2, table3, table4
1701
1702        def _table(name):
1703            return table(name, column("col1"), column("col2"), column("col3"))
1704
1705        table1, table2, table3, table4 = [
1706            _table(name) for name in ("table1", "table2", "table3", "table4")
1707        ]
1708
1709    def test_splice(self):
1710        t1, t2, t3, t4 = table1, table2, table1.alias(), table2.alias()
1711        j = (
1712            t1.join(t2, t1.c.col1 == t2.c.col1)
1713            .join(t3, t2.c.col1 == t3.c.col1)
1714            .join(t4, t4.c.col1 == t1.c.col1)
1715        )
1716        s = select([t1]).where(t1.c.col2 < 5).alias()
1717        self.assert_compile(
1718            sql_util.splice_joins(s, j),
1719            "(SELECT table1.col1 AS col1, table1.col2 "
1720            "AS col2, table1.col3 AS col3 FROM table1 "
1721            "WHERE table1.col2 < :col2_1) AS anon_1 "
1722            "JOIN table2 ON anon_1.col1 = table2.col1 "
1723            "JOIN table1 AS table1_1 ON table2.col1 = "
1724            "table1_1.col1 JOIN table2 AS table2_1 ON "
1725            "table2_1.col1 = anon_1.col1",
1726        )
1727
1728    def test_stop_on(self):
1729        t1, t2, t3 = table1, table2, table3
1730        j1 = t1.join(t2, t1.c.col1 == t2.c.col1)
1731        j2 = j1.join(t3, t2.c.col1 == t3.c.col1)
1732        s = select([t1]).select_from(j1).alias()
1733        self.assert_compile(
1734            sql_util.splice_joins(s, j2),
1735            "(SELECT table1.col1 AS col1, table1.col2 "
1736            "AS col2, table1.col3 AS col3 FROM table1 "
1737            "JOIN table2 ON table1.col1 = table2.col1) "
1738            "AS anon_1 JOIN table2 ON anon_1.col1 = "
1739            "table2.col1 JOIN table3 ON table2.col1 = "
1740            "table3.col1",
1741        )
1742        self.assert_compile(
1743            sql_util.splice_joins(s, j2, j1),
1744            "(SELECT table1.col1 AS col1, table1.col2 "
1745            "AS col2, table1.col3 AS col3 FROM table1 "
1746            "JOIN table2 ON table1.col1 = table2.col1) "
1747            "AS anon_1 JOIN table3 ON table2.col1 = "
1748            "table3.col1",
1749        )
1750
1751    def test_splice_2(self):
1752        t2a = table2.alias()
1753        t3a = table3.alias()
1754        j1 = table1.join(t2a, table1.c.col1 == t2a.c.col1).join(
1755            t3a, t2a.c.col2 == t3a.c.col2
1756        )
1757        t2b = table4.alias()
1758        j2 = table1.join(t2b, table1.c.col3 == t2b.c.col3)
1759        self.assert_compile(
1760            sql_util.splice_joins(table1, j1),
1761            "table1 JOIN table2 AS table2_1 ON "
1762            "table1.col1 = table2_1.col1 JOIN table3 "
1763            "AS table3_1 ON table2_1.col2 = "
1764            "table3_1.col2",
1765        )
1766        self.assert_compile(
1767            sql_util.splice_joins(table1, j2),
1768            "table1 JOIN table4 AS table4_1 ON " "table1.col3 = table4_1.col3",
1769        )
1770        self.assert_compile(
1771            sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2),
1772            "table1 JOIN table2 AS table2_1 ON "
1773            "table1.col1 = table2_1.col1 JOIN table3 "
1774            "AS table3_1 ON table2_1.col2 = "
1775            "table3_1.col2 JOIN table4 AS table4_1 ON "
1776            "table1.col3 = table4_1.col3",
1777        )
1778
1779
1780class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
1781
1782    """tests the generative capability of Select"""
1783
1784    __dialect__ = "default"
1785
1786    @classmethod
1787    def setup_class(cls):
1788        global t1, t2
1789        t1 = table("table1", column("col1"), column("col2"), column("col3"))
1790        t2 = table("table2", column("col1"), column("col2"), column("col3"))
1791
1792    def test_columns(self):
1793        s = t1.select()
1794        self.assert_compile(
1795            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
1796        )
1797        select_copy = s.column(column("yyy"))
1798        self.assert_compile(
1799            select_copy,
1800            "SELECT table1.col1, table1.col2, " "table1.col3, yyy FROM table1",
1801        )
1802        is_not(s.columns, select_copy.columns)
1803        is_not(s._columns, select_copy._columns)
1804        is_not(s._raw_columns, select_copy._raw_columns)
1805        self.assert_compile(
1806            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
1807        )
1808
1809    def test_froms(self):
1810        s = t1.select()
1811        self.assert_compile(
1812            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
1813        )
1814        select_copy = s.select_from(t2)
1815        self.assert_compile(
1816            select_copy,
1817            "SELECT table1.col1, table1.col2, "
1818            "table1.col3 FROM table1, table2",
1819        )
1820        assert s._froms is not select_copy._froms
1821        self.assert_compile(
1822            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
1823        )
1824
1825    def test_prefixes(self):
1826        s = t1.select()
1827        self.assert_compile(
1828            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
1829        )
1830        select_copy = s.prefix_with("FOOBER")
1831        self.assert_compile(
1832            select_copy,
1833            "SELECT FOOBER table1.col1, table1.col2, "
1834            "table1.col3 FROM table1",
1835        )
1836        self.assert_compile(
1837            s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1"
1838        )
1839
1840    def test_execution_options(self):
1841        s = select().execution_options(foo="bar")
1842        s2 = s.execution_options(bar="baz")
1843        s3 = s.execution_options(foo="not bar")
1844        # The original select should not be modified.
1845        eq_(s.get_execution_options(), dict(foo="bar"))
1846        # s2 should have its execution_options based on s, though.
1847        eq_(s2.get_execution_options(), dict(foo="bar", bar="baz"))
1848        eq_(s3.get_execution_options(), dict(foo="not bar"))
1849
1850    def test_invalid_options(self):
1851        assert_raises(
1852            exc.ArgumentError, select().execution_options, compiled_cache={}
1853        )
1854
1855        assert_raises(
1856            exc.ArgumentError,
1857            select().execution_options,
1858            isolation_level="READ_COMMITTED",
1859        )
1860
1861    # this feature not available yet
1862    def _NOTYET_test_execution_options_in_kwargs(self):
1863        s = select(execution_options=dict(foo="bar"))
1864        s2 = s.execution_options(bar="baz")
1865        # The original select should not be modified.
1866        assert s._execution_options == dict(foo="bar")
1867        # s2 should have its execution_options based on s, though.
1868        assert s2._execution_options == dict(foo="bar", bar="baz")
1869
1870    # this feature not available yet
1871    def _NOTYET_test_execution_options_in_text(self):
1872        s = text("select 42", execution_options=dict(foo="bar"))
1873        assert s._execution_options == dict(foo="bar")
1874
1875
1876class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
1877
1878    """Tests the generative capability of Insert, Update"""
1879
1880    __dialect__ = "default"
1881
1882    # fixme: consolidate converage from elsewhere here and expand
1883
1884    @classmethod
1885    def setup_class(cls):
1886        global t1, t2
1887        t1 = table("table1", column("col1"), column("col2"), column("col3"))
1888        t2 = table("table2", column("col1"), column("col2"), column("col3"))
1889
1890    def test_prefixes(self):
1891        i = t1.insert()
1892        self.assert_compile(
1893            i,
1894            "INSERT INTO table1 (col1, col2, col3) "
1895            "VALUES (:col1, :col2, :col3)",
1896        )
1897
1898        gen = i.prefix_with("foober")
1899        self.assert_compile(
1900            gen,
1901            "INSERT foober INTO table1 (col1, col2, col3) "
1902            "VALUES (:col1, :col2, :col3)",
1903        )
1904
1905        self.assert_compile(
1906            i,
1907            "INSERT INTO table1 (col1, col2, col3) "
1908            "VALUES (:col1, :col2, :col3)",
1909        )
1910
1911        i2 = t1.insert(prefixes=["squiznart"])
1912        self.assert_compile(
1913            i2,
1914            "INSERT squiznart INTO table1 (col1, col2, col3) "
1915            "VALUES (:col1, :col2, :col3)",
1916        )
1917
1918        gen2 = i2.prefix_with("quux")
1919        self.assert_compile(
1920            gen2,
1921            "INSERT squiznart quux INTO "
1922            "table1 (col1, col2, col3) "
1923            "VALUES (:col1, :col2, :col3)",
1924        )
1925
1926    def test_add_kwarg(self):
1927        i = t1.insert()
1928        eq_(i.parameters, None)
1929        i = i.values(col1=5)
1930        eq_(i.parameters, {"col1": 5})
1931        i = i.values(col2=7)
1932        eq_(i.parameters, {"col1": 5, "col2": 7})
1933
1934    def test_via_tuple_single(self):
1935        i = t1.insert()
1936        eq_(i.parameters, None)
1937        i = i.values((5, 6, 7))
1938        eq_(i.parameters, {"col1": 5, "col2": 6, "col3": 7})
1939
1940    def test_kw_and_dict_simultaneously_single(self):
1941        i = t1.insert()
1942        i = i.values({"col1": 5}, col2=7)
1943        eq_(i.parameters, {"col1": 5, "col2": 7})
1944
1945    def test_via_tuple_multi(self):
1946        i = t1.insert()
1947        eq_(i.parameters, None)
1948        i = i.values([(5, 6, 7), (8, 9, 10)])
1949        eq_(
1950            i.parameters,
1951            [
1952                {"col1": 5, "col2": 6, "col3": 7},
1953                {"col1": 8, "col2": 9, "col3": 10},
1954            ],
1955        )
1956
1957    def test_inline_values_single(self):
1958        i = t1.insert(values={"col1": 5})
1959        eq_(i.parameters, {"col1": 5})
1960        is_(i._has_multi_parameters, False)
1961
1962    def test_inline_values_multi(self):
1963        i = t1.insert(values=[{"col1": 5}, {"col1": 6}])
1964        eq_(i.parameters, [{"col1": 5}, {"col1": 6}])
1965        is_(i._has_multi_parameters, True)
1966
1967    def test_add_dictionary(self):
1968        i = t1.insert()
1969        eq_(i.parameters, None)
1970        i = i.values({"col1": 5})
1971        eq_(i.parameters, {"col1": 5})
1972        is_(i._has_multi_parameters, False)
1973
1974        i = i.values({"col1": 6})
1975        # note replaces
1976        eq_(i.parameters, {"col1": 6})
1977        is_(i._has_multi_parameters, False)
1978
1979        i = i.values({"col2": 7})
1980        eq_(i.parameters, {"col1": 6, "col2": 7})
1981        is_(i._has_multi_parameters, False)
1982
1983    def test_add_kwarg_disallowed_multi(self):
1984        i = t1.insert()
1985        i = i.values([{"col1": 5}, {"col1": 7}])
1986        assert_raises_message(
1987            exc.InvalidRequestError,
1988            "This construct already has multiple parameter sets.",
1989            i.values,
1990            col2=7,
1991        )
1992
1993    def test_cant_mix_single_multi_formats_dict_to_list(self):
1994        i = t1.insert().values(col1=5)
1995        assert_raises_message(
1996            exc.ArgumentError,
1997            "Can't mix single-values and multiple values "
1998            "formats in one statement",
1999            i.values,
2000            [{"col1": 6}],
2001        )
2002
2003    def test_cant_mix_single_multi_formats_list_to_dict(self):
2004        i = t1.insert().values([{"col1": 6}])
2005        assert_raises_message(
2006            exc.ArgumentError,
2007            "Can't mix single-values and multiple values "
2008            "formats in one statement",
2009            i.values,
2010            {"col1": 5},
2011        )
2012
2013    def test_erroneous_multi_args_dicts(self):
2014        i = t1.insert()
2015        assert_raises_message(
2016            exc.ArgumentError,
2017            "Only a single dictionary/tuple or list of "
2018            "dictionaries/tuples is accepted positionally.",
2019            i.values,
2020            {"col1": 5},
2021            {"col1": 7},
2022        )
2023
2024    def test_erroneous_multi_args_tuples(self):
2025        i = t1.insert()
2026        assert_raises_message(
2027            exc.ArgumentError,
2028            "Only a single dictionary/tuple or list of "
2029            "dictionaries/tuples is accepted positionally.",
2030            i.values,
2031            (5, 6, 7),
2032            (8, 9, 10),
2033        )
2034
2035    def test_erroneous_multi_args_plus_kw(self):
2036        i = t1.insert()
2037        assert_raises_message(
2038            exc.ArgumentError,
2039            "Can't pass kwargs and multiple parameter sets simultaneously",
2040            i.values,
2041            [{"col1": 5}],
2042            col2=7,
2043        )
2044
2045    def test_update_no_support_multi_values(self):
2046        u = t1.update()
2047        assert_raises_message(
2048            exc.InvalidRequestError,
2049            "This construct does not support multiple parameter sets.",
2050            u.values,
2051            [{"col1": 5}, {"col1": 7}],
2052        )
2053
2054    def test_update_no_support_multi_constructor(self):
2055        assert_raises_message(
2056            exc.InvalidRequestError,
2057            "This construct does not support multiple parameter sets.",
2058            t1.update,
2059            values=[{"col1": 5}, {"col1": 7}],
2060        )
2061