1from sqlalchemy.dialects import mssql
2from sqlalchemy.engine import default
3from sqlalchemy.exc import CompileError
4from sqlalchemy.sql import and_
5from sqlalchemy.sql import bindparam
6from sqlalchemy.sql import column
7from sqlalchemy.sql import exists
8from sqlalchemy.sql import func
9from sqlalchemy.sql import literal
10from sqlalchemy.sql import select
11from sqlalchemy.sql import table
12from sqlalchemy.sql.elements import quoted_name
13from sqlalchemy.sql.visitors import cloned_traverse
14from sqlalchemy.testing import assert_raises_message
15from sqlalchemy.testing import AssertsCompiledSQL
16from sqlalchemy.testing import eq_
17from sqlalchemy.testing import fixtures
18
19
20class CTETest(fixtures.TestBase, AssertsCompiledSQL):
21
22    __dialect__ = "default_enhanced"
23
24    def test_nonrecursive(self):
25        orders = table(
26            "orders",
27            column("region"),
28            column("amount"),
29            column("product"),
30            column("quantity"),
31        )
32
33        regional_sales = (
34            select(
35                [
36                    orders.c.region,
37                    func.sum(orders.c.amount).label("total_sales"),
38                ]
39            )
40            .group_by(orders.c.region)
41            .cte("regional_sales")
42        )
43
44        top_regions = (
45            select([regional_sales.c.region])
46            .where(
47                regional_sales.c.total_sales
48                > select([func.sum(regional_sales.c.total_sales) / 10])
49            )
50            .cte("top_regions")
51        )
52
53        s = (
54            select(
55                [
56                    orders.c.region,
57                    orders.c.product,
58                    func.sum(orders.c.quantity).label("product_units"),
59                    func.sum(orders.c.amount).label("product_sales"),
60                ]
61            )
62            .where(orders.c.region.in_(select([top_regions.c.region])))
63            .group_by(orders.c.region, orders.c.product)
64        )
65
66        # needs to render regional_sales first as top_regions
67        # refers to it
68        self.assert_compile(
69            s,
70            "WITH regional_sales AS (SELECT orders.region AS region, "
71            "sum(orders.amount) AS total_sales FROM orders "
72            "GROUP BY orders.region), "
73            "top_regions AS (SELECT "
74            "regional_sales.region AS region FROM regional_sales "
75            "WHERE regional_sales.total_sales > "
76            "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
77            "anon_1 FROM regional_sales)) "
78            "SELECT orders.region, orders.product, "
79            "sum(orders.quantity) AS product_units, "
80            "sum(orders.amount) AS product_sales "
81            "FROM orders WHERE orders.region "
82            "IN (SELECT top_regions.region FROM top_regions) "
83            "GROUP BY orders.region, orders.product",
84        )
85
86    def test_recursive(self):
87        parts = table(
88            "parts", column("part"), column("sub_part"), column("quantity")
89        )
90
91        included_parts = (
92            select([parts.c.sub_part, parts.c.part, parts.c.quantity])
93            .where(parts.c.part == "our part")
94            .cte(recursive=True)
95        )
96
97        incl_alias = included_parts.alias()
98        parts_alias = parts.alias()
99        included_parts = included_parts.union(
100            select(
101                [
102                    parts_alias.c.sub_part,
103                    parts_alias.c.part,
104                    parts_alias.c.quantity,
105                ]
106            ).where(parts_alias.c.part == incl_alias.c.sub_part)
107        )
108
109        s = (
110            select(
111                [
112                    included_parts.c.sub_part,
113                    func.sum(included_parts.c.quantity).label(
114                        "total_quantity"
115                    ),
116                ]
117            )
118            .select_from(
119                included_parts.join(
120                    parts, included_parts.c.part == parts.c.part
121                )
122            )
123            .group_by(included_parts.c.sub_part)
124        )
125        self.assert_compile(
126            s,
127            "WITH RECURSIVE anon_1(sub_part, part, quantity) "
128            "AS (SELECT parts.sub_part AS sub_part, parts.part "
129            "AS part, parts.quantity AS quantity FROM parts "
130            "WHERE parts.part = :part_1 UNION "
131            "SELECT parts_1.sub_part AS sub_part, "
132            "parts_1.part AS part, parts_1.quantity "
133            "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
134            "WHERE parts_1.part = anon_2.sub_part) "
135            "SELECT anon_1.sub_part, "
136            "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
137            "JOIN parts ON anon_1.part = parts.part "
138            "GROUP BY anon_1.sub_part",
139        )
140
141        # quick check that the "WITH RECURSIVE" varies per
142        # dialect
143        self.assert_compile(
144            s,
145            "WITH anon_1(sub_part, part, quantity) "
146            "AS (SELECT parts.sub_part AS sub_part, parts.part "
147            "AS part, parts.quantity AS quantity FROM parts "
148            "WHERE parts.part = :part_1 UNION "
149            "SELECT parts_1.sub_part AS sub_part, "
150            "parts_1.part AS part, parts_1.quantity "
151            "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
152            "WHERE parts_1.part = anon_2.sub_part) "
153            "SELECT anon_1.sub_part, "
154            "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
155            "JOIN parts ON anon_1.part = parts.part "
156            "GROUP BY anon_1.sub_part",
157            dialect=mssql.dialect(),
158        )
159
160    def test_recursive_inner_cte_unioned_to_alias(self):
161        parts = table(
162            "parts", column("part"), column("sub_part"), column("quantity")
163        )
164
165        included_parts = (
166            select([parts.c.sub_part, parts.c.part, parts.c.quantity])
167            .where(parts.c.part == "our part")
168            .cte(recursive=True)
169        )
170
171        incl_alias = included_parts.alias("incl")
172        parts_alias = parts.alias()
173        included_parts = incl_alias.union(
174            select(
175                [
176                    parts_alias.c.sub_part,
177                    parts_alias.c.part,
178                    parts_alias.c.quantity,
179                ]
180            ).where(parts_alias.c.part == incl_alias.c.sub_part)
181        )
182
183        s = (
184            select(
185                [
186                    included_parts.c.sub_part,
187                    func.sum(included_parts.c.quantity).label(
188                        "total_quantity"
189                    ),
190                ]
191            )
192            .select_from(
193                included_parts.join(
194                    parts, included_parts.c.part == parts.c.part
195                )
196            )
197            .group_by(included_parts.c.sub_part)
198        )
199        self.assert_compile(
200            s,
201            "WITH RECURSIVE incl(sub_part, part, quantity) "
202            "AS (SELECT parts.sub_part AS sub_part, parts.part "
203            "AS part, parts.quantity AS quantity FROM parts "
204            "WHERE parts.part = :part_1 UNION "
205            "SELECT parts_1.sub_part AS sub_part, "
206            "parts_1.part AS part, parts_1.quantity "
207            "AS quantity FROM parts AS parts_1, incl "
208            "WHERE parts_1.part = incl.sub_part) "
209            "SELECT incl.sub_part, "
210            "sum(incl.quantity) AS total_quantity FROM incl "
211            "JOIN parts ON incl.part = parts.part "
212            "GROUP BY incl.sub_part",
213        )
214
215    def test_recursive_union_no_alias_one(self):
216        s1 = select([literal(0).label("x")])
217        cte = s1.cte(name="cte", recursive=True)
218        cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10))
219        s2 = select([cte])
220        self.assert_compile(
221            s2,
222            "WITH RECURSIVE cte(x) AS "
223            "(SELECT :param_1 AS x UNION ALL "
224            "SELECT cte.x + :x_1 AS anon_1 "
225            "FROM cte WHERE cte.x < :x_2) "
226            "SELECT cte.x FROM cte",
227        )
228
229    def test_recursive_union_alias_one(self):
230        s1 = select([literal(0).label("x")])
231        cte = s1.cte(name="cte", recursive=True)
232        cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)).alias(
233            "cr1"
234        )
235        s2 = select([cte])
236        self.assert_compile(
237            s2,
238            "WITH RECURSIVE cte(x) AS "
239            "(SELECT :param_1 AS x UNION ALL "
240            "SELECT cte.x + :x_1 AS anon_1 "
241            "FROM cte WHERE cte.x < :x_2) "
242            "SELECT cr1.x FROM cte AS cr1",
243        )
244
245    def test_recursive_union_no_alias_two(self):
246        """
247
248        pg's example::
249
250            WITH RECURSIVE t(n) AS (
251                VALUES (1)
252              UNION ALL
253                SELECT n+1 FROM t WHERE n < 100
254            )
255            SELECT sum(n) FROM t;
256
257        """
258
259        # I know, this is the PG VALUES keyword,
260        # we're cheating here.  also yes we need the SELECT,
261        # sorry PG.
262        t = select([func.values(1).label("n")]).cte("t", recursive=True)
263        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
264        s = select([func.sum(t.c.n)])
265        self.assert_compile(
266            s,
267            "WITH RECURSIVE t(n) AS "
268            "(SELECT values(:values_1) AS n "
269            "UNION ALL SELECT t.n + :n_1 AS anon_1 "
270            "FROM t "
271            "WHERE t.n < :n_2) "
272            "SELECT sum(t.n) AS sum_1 FROM t",
273        )
274
275    def test_recursive_union_alias_two(self):
276        """
277
278        """
279
280        # I know, this is the PG VALUES keyword,
281        # we're cheating here.  also yes we need the SELECT,
282        # sorry PG.
283        t = select([func.values(1).label("n")]).cte("t", recursive=True)
284        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)).alias("ta")
285        s = select([func.sum(t.c.n)])
286        self.assert_compile(
287            s,
288            "WITH RECURSIVE t(n) AS "
289            "(SELECT values(:values_1) AS n "
290            "UNION ALL SELECT t.n + :n_1 AS anon_1 "
291            "FROM t "
292            "WHERE t.n < :n_2) "
293            "SELECT sum(ta.n) AS sum_1 FROM t AS ta",
294        )
295
296    def test_recursive_union_no_alias_three(self):
297        # like test one, but let's refer to the CTE
298        # in a sibling CTE.
299
300        s1 = select([literal(0).label("x")])
301        cte = s1.cte(name="cte", recursive=True)
302
303        # can't do it here...
304        # bar = select([cte]).cte('bar')
305        cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10))
306        bar = select([cte]).cte("bar")
307
308        s2 = select([cte, bar])
309        self.assert_compile(
310            s2,
311            "WITH RECURSIVE cte(x) AS "
312            "(SELECT :param_1 AS x UNION ALL "
313            "SELECT cte.x + :x_1 AS anon_1 "
314            "FROM cte WHERE cte.x < :x_2), "
315            "bar AS (SELECT cte.x AS x FROM cte) "
316            "SELECT cte.x, bar.x FROM cte, bar",
317        )
318
319    def test_recursive_union_alias_three(self):
320        # like test one, but let's refer to the CTE
321        # in a sibling CTE.
322
323        s1 = select([literal(0).label("x")])
324        cte = s1.cte(name="cte", recursive=True)
325
326        # can't do it here...
327        # bar = select([cte]).cte('bar')
328        cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)).alias(
329            "cs1"
330        )
331        bar = select([cte]).cte("bar").alias("cs2")
332
333        s2 = select([cte, bar])
334        self.assert_compile(
335            s2,
336            "WITH RECURSIVE cte(x) AS "
337            "(SELECT :param_1 AS x UNION ALL "
338            "SELECT cte.x + :x_1 AS anon_1 "
339            "FROM cte WHERE cte.x < :x_2), "
340            "bar AS (SELECT cs1.x AS x FROM cte AS cs1) "
341            "SELECT cs1.x, cs2.x FROM cte AS cs1, bar AS cs2",
342        )
343
344    def test_recursive_union_no_alias_four(self):
345        # like test one and three, but let's refer
346        # previous version of "cte".  here we test
347        # how the compiler resolves multiple instances
348        # of "cte".
349
350        s1 = select([literal(0).label("x")])
351        cte = s1.cte(name="cte", recursive=True)
352
353        bar = select([cte]).cte("bar")
354        cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10))
355
356        # outer cte rendered first, then bar, which
357        # includes "inner" cte
358        s2 = select([cte, bar])
359        self.assert_compile(
360            s2,
361            "WITH RECURSIVE cte(x) AS "
362            "(SELECT :param_1 AS x UNION ALL "
363            "SELECT cte.x + :x_1 AS anon_1 "
364            "FROM cte WHERE cte.x < :x_2), "
365            "bar AS (SELECT cte.x AS x FROM cte) "
366            "SELECT cte.x, bar.x FROM cte, bar",
367        )
368
369        # bar rendered, only includes "inner" cte,
370        # "outer" cte isn't present
371        s2 = select([bar])
372        self.assert_compile(
373            s2,
374            "WITH RECURSIVE cte(x) AS "
375            "(SELECT :param_1 AS x), "
376            "bar AS (SELECT cte.x AS x FROM cte) "
377            "SELECT bar.x FROM bar",
378        )
379
380        # bar rendered, but then the "outer"
381        # cte is rendered.
382        s2 = select([bar, cte])
383        self.assert_compile(
384            s2,
385            "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "
386            "cte(x) AS "
387            "(SELECT :param_1 AS x UNION ALL "
388            "SELECT cte.x + :x_1 AS anon_1 "
389            "FROM cte WHERE cte.x < :x_2) "
390            "SELECT bar.x, cte.x FROM bar, cte",
391        )
392
393    def test_recursive_union_alias_four(self):
394        # like test one and three, but let's refer
395        # previous version of "cte".  here we test
396        # how the compiler resolves multiple instances
397        # of "cte".
398
399        s1 = select([literal(0).label("x")])
400        cte = s1.cte(name="cte", recursive=True)
401
402        bar = select([cte]).cte("bar").alias("cs1")
403        cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)).alias(
404            "cs2"
405        )
406
407        # outer cte rendered first, then bar, which
408        # includes "inner" cte
409        s2 = select([cte, bar])
410        self.assert_compile(
411            s2,
412            "WITH RECURSIVE cte(x) AS "
413            "(SELECT :param_1 AS x UNION ALL "
414            "SELECT cte.x + :x_1 AS anon_1 "
415            "FROM cte WHERE cte.x < :x_2), "
416            "bar AS (SELECT cte.x AS x FROM cte) "
417            "SELECT cs2.x, cs1.x FROM cte AS cs2, bar AS cs1",
418        )
419
420        # bar rendered, only includes "inner" cte,
421        # "outer" cte isn't present
422        s2 = select([bar])
423        self.assert_compile(
424            s2,
425            "WITH RECURSIVE cte(x) AS "
426            "(SELECT :param_1 AS x), "
427            "bar AS (SELECT cte.x AS x FROM cte) "
428            "SELECT cs1.x FROM bar AS cs1",
429        )
430
431        # bar rendered, but then the "outer"
432        # cte is rendered.
433        s2 = select([bar, cte])
434        self.assert_compile(
435            s2,
436            "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "
437            "cte(x) AS "
438            "(SELECT :param_1 AS x UNION ALL "
439            "SELECT cte.x + :x_1 AS anon_1 "
440            "FROM cte WHERE cte.x < :x_2) "
441            "SELECT cs1.x, cs2.x FROM bar AS cs1, cte AS cs2",
442        )
443
444    def test_conflicting_names(self):
445        """test a flat out name conflict."""
446
447        s1 = select([1])
448        c1 = s1.cte(name="cte1", recursive=True)
449        s2 = select([1])
450        c2 = s2.cte(name="cte1", recursive=True)
451
452        s = select([c1, c2])
453        assert_raises_message(
454            CompileError,
455            "Multiple, unrelated CTEs found " "with the same name: 'cte1'",
456            s.compile,
457        )
458
459    def test_union(self):
460        orders = table("orders", column("region"), column("amount"))
461
462        regional_sales = select([orders.c.region, orders.c.amount]).cte(
463            "regional_sales"
464        )
465
466        s = select([regional_sales.c.region]).where(
467            regional_sales.c.amount > 500
468        )
469
470        self.assert_compile(
471            s,
472            "WITH regional_sales AS "
473            "(SELECT orders.region AS region, "
474            "orders.amount AS amount FROM orders) "
475            "SELECT regional_sales.region "
476            "FROM regional_sales WHERE "
477            "regional_sales.amount > :amount_1",
478        )
479
480        s = s.union_all(
481            select([regional_sales.c.region]).where(
482                regional_sales.c.amount < 300
483            )
484        )
485        self.assert_compile(
486            s,
487            "WITH regional_sales AS "
488            "(SELECT orders.region AS region, "
489            "orders.amount AS amount FROM orders) "
490            "SELECT regional_sales.region FROM regional_sales "
491            "WHERE regional_sales.amount > :amount_1 "
492            "UNION ALL SELECT regional_sales.region "
493            "FROM regional_sales WHERE "
494            "regional_sales.amount < :amount_2",
495        )
496
497    def test_union_cte_aliases(self):
498        orders = table("orders", column("region"), column("amount"))
499
500        regional_sales = (
501            select([orders.c.region, orders.c.amount])
502            .cte("regional_sales")
503            .alias("rs")
504        )
505
506        s = select([regional_sales.c.region]).where(
507            regional_sales.c.amount > 500
508        )
509
510        self.assert_compile(
511            s,
512            "WITH regional_sales AS "
513            "(SELECT orders.region AS region, "
514            "orders.amount AS amount FROM orders) "
515            "SELECT rs.region "
516            "FROM regional_sales AS rs WHERE "
517            "rs.amount > :amount_1",
518        )
519
520        s = s.union_all(
521            select([regional_sales.c.region]).where(
522                regional_sales.c.amount < 300
523            )
524        )
525        self.assert_compile(
526            s,
527            "WITH regional_sales AS "
528            "(SELECT orders.region AS region, "
529            "orders.amount AS amount FROM orders) "
530            "SELECT rs.region FROM regional_sales AS rs "
531            "WHERE rs.amount > :amount_1 "
532            "UNION ALL SELECT rs.region "
533            "FROM regional_sales AS rs WHERE "
534            "rs.amount < :amount_2",
535        )
536
537        cloned = cloned_traverse(s, {}, {})
538        self.assert_compile(
539            cloned,
540            "WITH regional_sales AS "
541            "(SELECT orders.region AS region, "
542            "orders.amount AS amount FROM orders) "
543            "SELECT rs.region FROM regional_sales AS rs "
544            "WHERE rs.amount > :amount_1 "
545            "UNION ALL SELECT rs.region "
546            "FROM regional_sales AS rs WHERE "
547            "rs.amount < :amount_2",
548        )
549
550    def test_cloned_alias(self):
551        entity = table(
552            "entity", column("id"), column("employer_id"), column("name")
553        )
554        tag = table("tag", column("tag"), column("entity_id"))
555
556        tags = (
557            select([tag.c.entity_id, func.array_agg(tag.c.tag).label("tags")])
558            .group_by(tag.c.entity_id)
559            .cte("unaliased_tags")
560        )
561
562        entity_tags = tags.alias(name="entity_tags")
563        employer_tags = tags.alias(name="employer_tags")
564
565        q = (
566            select([entity.c.name])
567            .select_from(
568                entity.outerjoin(
569                    entity_tags, tags.c.entity_id == entity.c.id
570                ).outerjoin(
571                    employer_tags, tags.c.entity_id == entity.c.employer_id
572                )
573            )
574            .where(entity_tags.c.tags.op("@>")(bindparam("tags")))
575            .where(employer_tags.c.tags.op("@>")(bindparam("tags")))
576        )
577
578        self.assert_compile(
579            q,
580            "WITH unaliased_tags AS "
581            "(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags "
582            "FROM tag GROUP BY tag.entity_id)"
583            " SELECT entity.name "
584            "FROM entity "
585            "LEFT OUTER JOIN unaliased_tags AS entity_tags ON "
586            "unaliased_tags.entity_id = entity.id "
587            "LEFT OUTER JOIN unaliased_tags AS employer_tags ON "
588            "unaliased_tags.entity_id = entity.employer_id "
589            "WHERE (entity_tags.tags @> :tags) AND "
590            "(employer_tags.tags @> :tags)",
591        )
592
593        cloned = q.params(tags=["tag1", "tag2"])
594        self.assert_compile(
595            cloned,
596            "WITH unaliased_tags AS "
597            "(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags "
598            "FROM tag GROUP BY tag.entity_id)"
599            " SELECT entity.name "
600            "FROM entity "
601            "LEFT OUTER JOIN unaliased_tags AS entity_tags ON "
602            "unaliased_tags.entity_id = entity.id "
603            "LEFT OUTER JOIN unaliased_tags AS employer_tags ON "
604            "unaliased_tags.entity_id = entity.employer_id "
605            "WHERE (entity_tags.tags @> :tags) AND "
606            "(employer_tags.tags @> :tags)",
607        )
608
609    def test_reserved_quote(self):
610        orders = table("orders", column("order"))
611        s = select([orders.c.order]).cte("regional_sales", recursive=True)
612        s = select([s.c.order])
613        self.assert_compile(
614            s,
615            'WITH RECURSIVE regional_sales("order") AS '
616            '(SELECT orders."order" AS "order" '
617            "FROM orders)"
618            ' SELECT regional_sales."order" '
619            "FROM regional_sales",
620        )
621
622    def test_multi_subq_quote(self):
623        cte = select([literal(1).label("id")]).cte(name="CTE")
624
625        s1 = select([cte.c.id]).alias()
626        s2 = select([cte.c.id]).alias()
627
628        s = select([s1, s2])
629        self.assert_compile(
630            s,
631            'WITH "CTE" AS (SELECT :param_1 AS id) '
632            "SELECT anon_1.id, anon_2.id FROM "
633            '(SELECT "CTE".id AS id FROM "CTE") AS anon_1, '
634            '(SELECT "CTE".id AS id FROM "CTE") AS anon_2',
635        )
636
637    def test_multi_subq_alias(self):
638        cte = select([literal(1).label("id")]).cte(name="cte1").alias("aa")
639
640        s1 = select([cte.c.id]).alias()
641        s2 = select([cte.c.id]).alias()
642
643        s = select([s1, s2])
644        self.assert_compile(
645            s,
646            "WITH cte1 AS (SELECT :param_1 AS id) "
647            "SELECT anon_1.id, anon_2.id FROM "
648            "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_1, "
649            "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_2",
650        )
651
652    def test_cte_refers_to_aliased_cte_twice(self):
653        # test issue #4204
654        a = table("a", column("id"))
655        b = table("b", column("id"), column("fid"))
656        c = table("c", column("id"), column("fid"))
657
658        cte1 = select([a.c.id]).cte(name="cte1")
659
660        aa = cte1.alias("aa")
661
662        cte2 = (
663            select([b.c.id])
664            .select_from(b.join(aa, b.c.fid == aa.c.id))
665            .cte(name="cte2")
666        )
667
668        cte3 = (
669            select([c.c.id])
670            .select_from(c.join(aa, c.c.fid == aa.c.id))
671            .cte(name="cte3")
672        )
673
674        stmt = select([cte3.c.id, cte2.c.id]).select_from(
675            cte2.join(cte3, cte2.c.id == cte3.c.id)
676        )
677        self.assert_compile(
678            stmt,
679            "WITH cte1 AS (SELECT a.id AS id FROM a), "
680            "cte2 AS (SELECT b.id AS id FROM b "
681            "JOIN cte1 AS aa ON b.fid = aa.id), "
682            "cte3 AS (SELECT c.id AS id FROM c "
683            "JOIN cte1 AS aa ON c.fid = aa.id) "
684            "SELECT cte3.id, cte2.id FROM cte2 JOIN cte3 ON cte2.id = cte3.id",
685        )
686
687    def test_named_alias_no_quote(self):
688        cte = select([literal(1).label("id")]).cte(name="CTE")
689
690        s1 = select([cte.c.id]).alias(name="no_quotes")
691
692        s = select([s1])
693        self.assert_compile(
694            s,
695            'WITH "CTE" AS (SELECT :param_1 AS id) '
696            "SELECT no_quotes.id FROM "
697            '(SELECT "CTE".id AS id FROM "CTE") AS no_quotes',
698        )
699
700    def test_named_alias_quote(self):
701        cte = select([literal(1).label("id")]).cte(name="CTE")
702
703        s1 = select([cte.c.id]).alias(name="Quotes Required")
704
705        s = select([s1])
706        self.assert_compile(
707            s,
708            'WITH "CTE" AS (SELECT :param_1 AS id) '
709            'SELECT "Quotes Required".id FROM '
710            '(SELECT "CTE".id AS id FROM "CTE") AS "Quotes Required"',
711        )
712
713    def test_named_alias_disable_quote(self):
714        cte = select([literal(1).label("id")]).cte(
715            name=quoted_name("CTE", quote=False)
716        )
717
718        s1 = select([cte.c.id]).alias(
719            name=quoted_name("DontQuote", quote=False)
720        )
721
722        s = select([s1])
723        self.assert_compile(
724            s,
725            "WITH CTE AS (SELECT :param_1 AS id) "
726            "SELECT DontQuote.id FROM "
727            "(SELECT CTE.id AS id FROM CTE) AS DontQuote",
728        )
729
730    def test_positional_binds(self):
731        orders = table("orders", column("order"))
732        s = select([orders.c.order, literal("x")]).cte("regional_sales")
733        s = select([s.c.order, literal("y")])
734        dialect = default.DefaultDialect()
735        dialect.positional = True
736        dialect.paramstyle = "numeric"
737        self.assert_compile(
738            s,
739            'WITH regional_sales AS (SELECT orders."order" '
740            'AS "order", :1 AS anon_2 FROM orders) SELECT '
741            'regional_sales."order", :2 AS anon_1 FROM regional_sales',
742            checkpositional=("x", "y"),
743            dialect=dialect,
744        )
745
746        self.assert_compile(
747            s.union(s),
748            'WITH regional_sales AS (SELECT orders."order" '
749            'AS "order", :1 AS anon_2 FROM orders) SELECT '
750            'regional_sales."order", :2 AS anon_1 FROM regional_sales '
751            'UNION SELECT regional_sales."order", :3 AS anon_1 '
752            "FROM regional_sales",
753            checkpositional=("x", "y", "y"),
754            dialect=dialect,
755        )
756
757        s = (
758            select([orders.c.order])
759            .where(orders.c.order == "x")
760            .cte("regional_sales")
761        )
762        s = select([s.c.order]).where(s.c.order == "y")
763        self.assert_compile(
764            s,
765            'WITH regional_sales AS (SELECT orders."order" AS '
766            '"order" FROM orders WHERE orders."order" = :1) '
767            'SELECT regional_sales."order" FROM regional_sales '
768            'WHERE regional_sales."order" = :2',
769            checkpositional=("x", "y"),
770            dialect=dialect,
771        )
772
773    def test_positional_binds_2(self):
774        orders = table("orders", column("order"))
775        s = select([orders.c.order, literal("x")]).cte("regional_sales")
776        s = select([s.c.order, literal("y")])
777        dialect = default.DefaultDialect()
778        dialect.positional = True
779        dialect.paramstyle = "numeric"
780        s1 = (
781            select([orders.c.order])
782            .where(orders.c.order == "x")
783            .cte("regional_sales_1")
784        )
785
786        s1a = s1.alias()
787
788        s2 = (
789            select(
790                [
791                    orders.c.order == "y",
792                    s1a.c.order,
793                    orders.c.order,
794                    s1.c.order,
795                ]
796            )
797            .where(orders.c.order == "z")
798            .cte("regional_sales_2")
799        )
800
801        s3 = select([s2])
802
803        self.assert_compile(
804            s3,
805            'WITH regional_sales_1 AS (SELECT orders."order" AS "order" '
806            'FROM orders WHERE orders."order" = :1), regional_sales_2 AS '
807            '(SELECT orders."order" = :2 AS anon_1, '
808            'anon_2."order" AS "order", '
809            'orders."order" AS "order", '
810            'regional_sales_1."order" AS "order" FROM orders, '
811            "regional_sales_1 "
812            "AS anon_2, regional_sales_1 "
813            'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, '
814            'regional_sales_2."order" FROM regional_sales_2',
815            checkpositional=("x", "y", "z"),
816            dialect=dialect,
817        )
818
819    def test_positional_binds_2_asliteral(self):
820        orders = table("orders", column("order"))
821        s = select([orders.c.order, literal("x")]).cte("regional_sales")
822        s = select([s.c.order, literal("y")])
823        dialect = default.DefaultDialect()
824        dialect.positional = True
825        dialect.paramstyle = "numeric"
826        s1 = (
827            select([orders.c.order])
828            .where(orders.c.order == "x")
829            .cte("regional_sales_1")
830        )
831
832        s1a = s1.alias()
833
834        s2 = (
835            select(
836                [
837                    orders.c.order == "y",
838                    s1a.c.order,
839                    orders.c.order,
840                    s1.c.order,
841                ]
842            )
843            .where(orders.c.order == "z")
844            .cte("regional_sales_2")
845        )
846
847        s3 = select([s2])
848
849        self.assert_compile(
850            s3,
851            "WITH regional_sales_1 AS "
852            '(SELECT orders."order" AS "order" '
853            "FROM orders "
854            "WHERE orders.\"order\" = 'x'), "
855            "regional_sales_2 AS "
856            "(SELECT orders.\"order\" = 'y' AS anon_1, "
857            'anon_2."order" AS "order", orders."order" AS "order", '
858            'regional_sales_1."order" AS "order" '
859            "FROM orders, regional_sales_1 AS anon_2, regional_sales_1 "
860            "WHERE orders.\"order\" = 'z') "
861            'SELECT regional_sales_2.anon_1, regional_sales_2."order" '
862            "FROM regional_sales_2",
863            checkpositional=(),
864            dialect=dialect,
865            literal_binds=True,
866        )
867
868    def test_all_aliases(self):
869        orders = table("order", column("order"))
870        s = select([orders.c.order]).cte("regional_sales")
871
872        r1 = s.alias()
873        r2 = s.alias()
874
875        s2 = select([r1, r2]).where(r1.c.order > r2.c.order)
876
877        self.assert_compile(
878            s2,
879            'WITH regional_sales AS (SELECT "order"."order" '
880            'AS "order" FROM "order") '
881            'SELECT anon_1."order", anon_2."order" '
882            "FROM regional_sales AS anon_1, "
883            'regional_sales AS anon_2 WHERE anon_1."order" > anon_2."order"',
884        )
885
886        s3 = select([orders]).select_from(
887            orders.join(r1, r1.c.order == orders.c.order)
888        )
889
890        self.assert_compile(
891            s3,
892            "WITH regional_sales AS "
893            '(SELECT "order"."order" AS "order" '
894            'FROM "order")'
895            ' SELECT "order"."order" '
896            'FROM "order" JOIN regional_sales AS anon_1 '
897            'ON anon_1."order" = "order"."order"',
898        )
899
900    def test_suffixes(self):
901        orders = table("order", column("order"))
902        s = select([orders.c.order]).cte("regional_sales")
903        s = s.suffix_with("pg suffix", dialect="postgresql")
904        s = s.suffix_with("oracle suffix", dialect="oracle")
905        stmt = select([orders]).where(orders.c.order > s.c.order)
906
907        self.assert_compile(
908            stmt,
909            'WITH regional_sales AS (SELECT "order"."order" AS "order" '
910            'FROM "order")  SELECT "order"."order" FROM "order", '
911            'regional_sales WHERE "order"."order" > regional_sales."order"',
912        )
913
914        self.assert_compile(
915            stmt,
916            'WITH regional_sales AS (SELECT "order"."order" AS "order" '
917            'FROM "order") oracle suffix  '
918            'SELECT "order"."order" FROM "order", '
919            'regional_sales WHERE "order"."order" > regional_sales."order"',
920            dialect="oracle",
921        )
922
923        self.assert_compile(
924            stmt,
925            'WITH regional_sales AS (SELECT "order"."order" AS "order" '
926            'FROM "order") pg suffix  SELECT "order"."order" FROM "order", '
927            'regional_sales WHERE "order"."order" > regional_sales."order"',
928            dialect="postgresql",
929        )
930
931    def test_upsert_from_select(self):
932        orders = table(
933            "orders",
934            column("region"),
935            column("amount"),
936            column("product"),
937            column("quantity"),
938        )
939
940        upsert = (
941            orders.update()
942            .where(orders.c.region == "Region1")
943            .values(amount=1.0, product="Product1", quantity=1)
944            .returning(*(orders.c._all_columns))
945            .cte("upsert")
946        )
947
948        insert = orders.insert().from_select(
949            orders.c.keys(),
950            select(
951                [
952                    literal("Region1"),
953                    literal(1.0),
954                    literal("Product1"),
955                    literal(1),
956                ]
957            ).where(~exists(upsert.select())),
958        )
959
960        self.assert_compile(
961            insert,
962            "WITH upsert AS (UPDATE orders SET amount=:amount, "
963            "product=:product, quantity=:quantity "
964            "WHERE orders.region = :region_1 "
965            "RETURNING orders.region, orders.amount, "
966            "orders.product, orders.quantity) "
967            "INSERT INTO orders (region, amount, product, quantity) "
968            "SELECT :param_1 AS anon_1, :param_2 AS anon_2, "
969            ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS "
970            "(SELECT upsert.region, upsert.amount, upsert.product, "
971            "upsert.quantity FROM upsert))",
972        )
973
974    def test_anon_update_cte(self):
975        orders = table("orders", column("region"))
976        stmt = (
977            orders.update()
978            .where(orders.c.region == "x")
979            .values(region="y")
980            .returning(orders.c.region)
981            .cte()
982        )
983
984        self.assert_compile(
985            stmt.select(),
986            "WITH anon_1 AS (UPDATE orders SET region=:region "
987            "WHERE orders.region = :region_1 RETURNING orders.region) "
988            "SELECT anon_1.region FROM anon_1",
989        )
990
991    def test_anon_insert_cte(self):
992        orders = table("orders", column("region"))
993        stmt = (
994            orders.insert().values(region="y").returning(orders.c.region).cte()
995        )
996
997        self.assert_compile(
998            stmt.select(),
999            "WITH anon_1 AS (INSERT INTO orders (region) "
1000            "VALUES (:region) RETURNING orders.region) "
1001            "SELECT anon_1.region FROM anon_1",
1002        )
1003
1004    def test_pg_example_one(self):
1005        products = table("products", column("id"), column("date"))
1006        products_log = table("products_log", column("id"), column("date"))
1007
1008        moved_rows = (
1009            products.delete()
1010            .where(
1011                and_(products.c.date >= "dateone", products.c.date < "datetwo")
1012            )
1013            .returning(*products.c)
1014            .cte("moved_rows")
1015        )
1016
1017        stmt = products_log.insert().from_select(
1018            products_log.c, moved_rows.select()
1019        )
1020        self.assert_compile(
1021            stmt,
1022            "WITH moved_rows AS "
1023            "(DELETE FROM products WHERE products.date >= :date_1 "
1024            "AND products.date < :date_2 "
1025            "RETURNING products.id, products.date) "
1026            "INSERT INTO products_log (id, date) "
1027            "SELECT moved_rows.id, moved_rows.date FROM moved_rows",
1028        )
1029
1030    def test_pg_example_two(self):
1031        products = table("products", column("id"), column("price"))
1032
1033        t = (
1034            products.update()
1035            .values(price="someprice")
1036            .returning(*products.c)
1037            .cte("t")
1038        )
1039        stmt = t.select()
1040        assert "autocommit" not in stmt._execution_options
1041        eq_(stmt.compile().execution_options["autocommit"], True)
1042
1043        self.assert_compile(
1044            stmt,
1045            "WITH t AS "
1046            "(UPDATE products SET price=:price "
1047            "RETURNING products.id, products.price) "
1048            "SELECT t.id, t.price "
1049            "FROM t",
1050        )
1051
1052    def test_pg_example_three(self):
1053
1054        parts = table("parts", column("part"), column("sub_part"))
1055
1056        included_parts = (
1057            select([parts.c.sub_part, parts.c.part])
1058            .where(parts.c.part == "our part")
1059            .cte("included_parts", recursive=True)
1060        )
1061
1062        pr = included_parts.alias("pr")
1063        p = parts.alias("p")
1064        included_parts = included_parts.union_all(
1065            select([p.c.sub_part, p.c.part]).where(p.c.part == pr.c.sub_part)
1066        )
1067        stmt = (
1068            parts.delete()
1069            .where(parts.c.part.in_(select([included_parts.c.part])))
1070            .returning(parts.c.part)
1071        )
1072
1073        # the outer RETURNING is a bonus over what PG's docs have
1074        self.assert_compile(
1075            stmt,
1076            "WITH RECURSIVE included_parts(sub_part, part) AS "
1077            "(SELECT parts.sub_part AS sub_part, parts.part AS part "
1078            "FROM parts "
1079            "WHERE parts.part = :part_1 "
1080            "UNION ALL SELECT p.sub_part AS sub_part, p.part AS part "
1081            "FROM parts AS p, included_parts AS pr "
1082            "WHERE p.part = pr.sub_part) "
1083            "DELETE FROM parts WHERE parts.part IN "
1084            "(SELECT included_parts.part FROM included_parts) "
1085            "RETURNING parts.part",
1086        )
1087
1088    def test_insert_in_the_cte(self):
1089        products = table("products", column("id"), column("price"))
1090
1091        cte = (
1092            products.insert()
1093            .values(id=1, price=27.0)
1094            .returning(*products.c)
1095            .cte("pd")
1096        )
1097
1098        stmt = select([cte])
1099
1100        assert "autocommit" not in stmt._execution_options
1101        eq_(stmt.compile().execution_options["autocommit"], True)
1102
1103        self.assert_compile(
1104            stmt,
1105            "WITH pd AS "
1106            "(INSERT INTO products (id, price) VALUES (:id, :price) "
1107            "RETURNING products.id, products.price) "
1108            "SELECT pd.id, pd.price "
1109            "FROM pd",
1110        )
1111
1112    def test_update_pulls_from_cte(self):
1113        products = table("products", column("id"), column("price"))
1114
1115        cte = products.select().cte("pd")
1116        assert "autocommit" not in cte._execution_options
1117
1118        stmt = products.update().where(products.c.price == cte.c.price)
1119        eq_(stmt.compile().execution_options["autocommit"], True)
1120
1121        self.assert_compile(
1122            stmt,
1123            "WITH pd AS "
1124            "(SELECT products.id AS id, products.price AS price "
1125            "FROM products) "
1126            "UPDATE products SET id=:id, price=:price FROM pd "
1127            "WHERE products.price = pd.price",
1128        )
1129