1# -*- encoding: utf-8
2from sqlalchemy import bindparam
3from sqlalchemy import Column
4from sqlalchemy import Computed
5from sqlalchemy import delete
6from sqlalchemy import exc
7from sqlalchemy import extract
8from sqlalchemy import func
9from sqlalchemy import Identity
10from sqlalchemy import Index
11from sqlalchemy import insert
12from sqlalchemy import Integer
13from sqlalchemy import literal
14from sqlalchemy import literal_column
15from sqlalchemy import MetaData
16from sqlalchemy import PrimaryKeyConstraint
17from sqlalchemy import schema
18from sqlalchemy import select
19from sqlalchemy import sql
20from sqlalchemy import String
21from sqlalchemy import Table
22from sqlalchemy import testing
23from sqlalchemy import text
24from sqlalchemy import union
25from sqlalchemy import UniqueConstraint
26from sqlalchemy import update
27from sqlalchemy.dialects import mssql
28from sqlalchemy.dialects.mssql import base as mssql_base
29from sqlalchemy.dialects.mssql import mxodbc
30from sqlalchemy.dialects.mssql.base import try_cast
31from sqlalchemy.sql import column
32from sqlalchemy.sql import quoted_name
33from sqlalchemy.sql import table
34from sqlalchemy.testing import assert_raises_message
35from sqlalchemy.testing import AssertsCompiledSQL
36from sqlalchemy.testing import eq_
37from sqlalchemy.testing import fixtures
38from sqlalchemy.testing import is_
39from sqlalchemy.testing.assertions import eq_ignore_whitespace
40
41tbl = table("t", column("a"))
42
43
44class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
45    __dialect__ = mssql.dialect()
46
47    @testing.fixture
48    def dialect_2012(self):
49        dialect = mssql.dialect()
50        dialect._supports_offset_fetch = True
51        return dialect
52
53    def test_true_false(self):
54        self.assert_compile(sql.false(), "0")
55        self.assert_compile(sql.true(), "1")
56
57    @testing.combinations(
58        ("plain", "sometable", "sometable"),
59        ("matched_square_brackets", "colo[u]r", "[colo[u]]r]"),
60        ("unmatched_left_square_bracket", "colo[ur", "[colo[ur]"),
61        ("unmatched_right_square_bracket", "colou]r", "[colou]]r]"),
62        ("double quotes", 'Edwin "Buzz" Aldrin', '[Edwin "Buzz" Aldrin]'),
63        ("dash", "Dash-8", "[Dash-8]"),
64        ("slash", "tl/dr", "[tl/dr]"),
65        ("space", "Red Deer", "[Red Deer]"),
66        ("question mark", "OK?", "[OK?]"),
67        ("percent", "GST%", "[GST%]"),
68        id_="iaa",
69    )
70    def test_identifier_rendering(self, table_name, rendered_name):
71        t = table(table_name, column("somecolumn"))
72        self.assert_compile(
73            t.select(), "SELECT {0}.somecolumn FROM {0}".format(rendered_name)
74        )
75
76    def test_select_with_nolock(self):
77        t = table("sometable", column("somecolumn"))
78        self.assert_compile(
79            t.select().with_hint(t, "WITH (NOLOCK)"),
80            "SELECT sometable.somecolumn FROM sometable WITH (NOLOCK)",
81        )
82
83    def test_select_with_nolock_schema(self):
84        m = MetaData()
85        t = Table(
86            "sometable", m, Column("somecolumn", Integer), schema="test_schema"
87        )
88        self.assert_compile(
89            t.select().with_hint(t, "WITH (NOLOCK)"),
90            "SELECT test_schema.sometable.somecolumn "
91            "FROM test_schema.sometable WITH (NOLOCK)",
92        )
93
94    def test_select_w_order_by_collate(self):
95        m = MetaData()
96        t = Table("sometable", m, Column("somecolumn", String))
97
98        self.assert_compile(
99            select(t).order_by(
100                t.c.somecolumn.collate("Latin1_General_CS_AS_KS_WS_CI").asc()
101            ),
102            "SELECT sometable.somecolumn FROM sometable "
103            "ORDER BY sometable.somecolumn COLLATE "
104            "Latin1_General_CS_AS_KS_WS_CI ASC",
105        )
106
107    def test_join_with_hint(self):
108        t1 = table(
109            "t1",
110            column("a", Integer),
111            column("b", String),
112            column("c", String),
113        )
114        t2 = table(
115            "t2",
116            column("a", Integer),
117            column("b", Integer),
118            column("c", Integer),
119        )
120        join = (
121            t1.join(t2, t1.c.a == t2.c.a)
122            .select()
123            .with_hint(t1, "WITH (NOLOCK)")
124        )
125        self.assert_compile(
126            join,
127            "SELECT t1.a, t1.b, t1.c, t2.a AS a_1, t2.b AS b_1, t2.c AS c_1 "
128            "FROM t1 WITH (NOLOCK) JOIN t2 ON t1.a = t2.a",
129        )
130
131    def test_insert(self):
132        t = table("sometable", column("somecolumn"))
133        self.assert_compile(
134            t.insert(),
135            "INSERT INTO sometable (somecolumn) VALUES " "(:somecolumn)",
136        )
137
138    def test_update(self):
139        t = table("sometable", column("somecolumn"))
140        self.assert_compile(
141            t.update().where(t.c.somecolumn == 7),
142            "UPDATE sometable SET somecolumn=:somecolum"
143            "n WHERE sometable.somecolumn = "
144            ":somecolumn_1",
145            dict(somecolumn=10),
146        )
147
148    def test_insert_hint(self):
149        t = table("sometable", column("somecolumn"))
150        for targ in (None, t):
151            for darg in ("*", "mssql"):
152                self.assert_compile(
153                    t.insert()
154                    .values(somecolumn="x")
155                    .with_hint(
156                        "WITH (PAGLOCK)", selectable=targ, dialect_name=darg
157                    ),
158                    "INSERT INTO sometable WITH (PAGLOCK) "
159                    "(somecolumn) VALUES (:somecolumn)",
160                )
161
162    def test_update_hint(self):
163        t = table("sometable", column("somecolumn"))
164        for targ in (None, t):
165            for darg in ("*", "mssql"):
166                self.assert_compile(
167                    t.update()
168                    .where(t.c.somecolumn == "q")
169                    .values(somecolumn="x")
170                    .with_hint(
171                        "WITH (PAGLOCK)", selectable=targ, dialect_name=darg
172                    ),
173                    "UPDATE sometable WITH (PAGLOCK) "
174                    "SET somecolumn=:somecolumn "
175                    "WHERE sometable.somecolumn = :somecolumn_1",
176                )
177
178    def test_update_exclude_hint(self):
179        t = table("sometable", column("somecolumn"))
180        self.assert_compile(
181            t.update()
182            .where(t.c.somecolumn == "q")
183            .values(somecolumn="x")
184            .with_hint("XYZ", "mysql"),
185            "UPDATE sometable SET somecolumn=:somecolumn "
186            "WHERE sometable.somecolumn = :somecolumn_1",
187        )
188
189    def test_delete_hint(self):
190        t = table("sometable", column("somecolumn"))
191        for targ in (None, t):
192            for darg in ("*", "mssql"):
193                self.assert_compile(
194                    t.delete()
195                    .where(t.c.somecolumn == "q")
196                    .with_hint(
197                        "WITH (PAGLOCK)", selectable=targ, dialect_name=darg
198                    ),
199                    "DELETE FROM sometable WITH (PAGLOCK) "
200                    "WHERE sometable.somecolumn = :somecolumn_1",
201                )
202
203    def test_delete_exclude_hint(self):
204        t = table("sometable", column("somecolumn"))
205        self.assert_compile(
206            t.delete()
207            .where(t.c.somecolumn == "q")
208            .with_hint("XYZ", dialect_name="mysql"),
209            "DELETE FROM sometable WHERE "
210            "sometable.somecolumn = :somecolumn_1",
211        )
212
213    def test_delete_extra_froms(self):
214        t1 = table("t1", column("c1"))
215        t2 = table("t2", column("c1"))
216        q = sql.delete(t1).where(t1.c.c1 == t2.c.c1)
217        self.assert_compile(
218            q, "DELETE FROM t1 FROM t1, t2 WHERE t1.c1 = t2.c1"
219        )
220
221    def test_delete_extra_froms_alias(self):
222        a1 = table("t1", column("c1")).alias("a1")
223        t2 = table("t2", column("c1"))
224        q = sql.delete(a1).where(a1.c.c1 == t2.c.c1)
225        self.assert_compile(
226            q, "DELETE FROM a1 FROM t1 AS a1, t2 WHERE a1.c1 = t2.c1"
227        )
228        self.assert_compile(sql.delete(a1), "DELETE FROM t1 AS a1")
229
230    def test_update_from(self):
231        metadata = MetaData()
232        table1 = Table(
233            "mytable",
234            metadata,
235            Column("myid", Integer),
236            Column("name", String(30)),
237            Column("description", String(50)),
238        )
239        table2 = Table(
240            "myothertable",
241            metadata,
242            Column("otherid", Integer),
243            Column("othername", String(30)),
244        )
245
246        mt = table1.alias()
247
248        u = (
249            table1.update()
250            .values(name="foo")
251            .where(table2.c.otherid == table1.c.myid)
252        )
253
254        # testing mssql.base.MSSQLCompiler.update_from_clause
255        self.assert_compile(
256            u,
257            "UPDATE mytable SET name=:name "
258            "FROM mytable, myothertable WHERE "
259            "myothertable.otherid = mytable.myid",
260        )
261
262        self.assert_compile(
263            u.where(table2.c.othername == mt.c.name),
264            "UPDATE mytable SET name=:name "
265            "FROM mytable, myothertable, mytable AS mytable_1 "
266            "WHERE myothertable.otherid = mytable.myid "
267            "AND myothertable.othername = mytable_1.name",
268        )
269
270    def test_update_from_hint(self):
271        t = table("sometable", column("somecolumn"))
272        t2 = table("othertable", column("somecolumn"))
273        for darg in ("*", "mssql"):
274            self.assert_compile(
275                t.update()
276                .where(t.c.somecolumn == t2.c.somecolumn)
277                .values(somecolumn="x")
278                .with_hint("WITH (PAGLOCK)", selectable=t2, dialect_name=darg),
279                "UPDATE sometable SET somecolumn=:somecolumn "
280                "FROM sometable, othertable WITH (PAGLOCK) "
281                "WHERE sometable.somecolumn = othertable.somecolumn",
282            )
283
284    def test_update_to_select_schema(self):
285        meta = MetaData()
286        table = Table(
287            "sometable",
288            meta,
289            Column("sym", String),
290            Column("val", Integer),
291            schema="schema",
292        )
293        other = Table(
294            "#other", meta, Column("sym", String), Column("newval", Integer)
295        )
296        stmt = table.update().values(
297            val=select(other.c.newval)
298            .where(table.c.sym == other.c.sym)
299            .scalar_subquery()
300        )
301
302        self.assert_compile(
303            stmt,
304            "UPDATE [schema].sometable SET val="
305            "(SELECT [#other].newval FROM [#other] "
306            "WHERE [schema].sometable.sym = [#other].sym)",
307        )
308
309        stmt = (
310            table.update()
311            .values(val=other.c.newval)
312            .where(table.c.sym == other.c.sym)
313        )
314        self.assert_compile(
315            stmt,
316            "UPDATE [schema].sometable SET val="
317            "[#other].newval FROM [schema].sometable, "
318            "[#other] WHERE [schema].sometable.sym = [#other].sym",
319        )
320
321    # TODO: not supported yet.
322    # def test_delete_from_hint(self):
323    #    t = table('sometable', column('somecolumn'))
324    #    t2 = table('othertable', column('somecolumn'))
325    #    for darg in ("*", "mssql"):
326    #        self.assert_compile(
327    #            t.delete().where(t.c.somecolumn==t2.c.somecolumn).
328    #                    with_hint("WITH (PAGLOCK)",
329    #                            selectable=t2,
330    #                            dialect_name=darg),
331    #            ""
332    #        )
333
334    @testing.combinations(
335        (
336            lambda: select(literal("x"), literal("y")),
337            "SELECT __[POSTCOMPILE_param_1] AS anon_1, "
338            "__[POSTCOMPILE_param_2] AS anon_2",
339            {
340                "check_literal_execute": {"param_1": "x", "param_2": "y"},
341                "check_post_param": {},
342            },
343        ),
344        (
345            lambda t: select(t).where(t.c.foo.in_(["x", "y", "z"])),
346            "SELECT sometable.foo FROM sometable WHERE sometable.foo "
347            "IN (__[POSTCOMPILE_foo_1])",
348            {
349                "check_literal_execute": {"foo_1": ["x", "y", "z"]},
350                "check_post_param": {},
351            },
352        ),
353        (lambda t: t.c.foo.in_([None]), "sometable.foo IN (NULL)", {}),
354    )
355    def test_strict_binds(self, expr, compiled, kw):
356        """test the 'strict' compiler binds."""
357
358        from sqlalchemy.dialects.mssql.base import MSSQLStrictCompiler
359
360        mxodbc_dialect = mxodbc.dialect()
361        mxodbc_dialect.statement_compiler = MSSQLStrictCompiler
362
363        t = table("sometable", column("foo"))
364
365        expr = testing.resolve_lambda(expr, t=t)
366        self.assert_compile(expr, compiled, dialect=mxodbc_dialect, **kw)
367
368    def test_in_with_subqueries(self):
369        """Test removal of legacy behavior that converted "x==subquery"
370        to use IN.
371
372        """
373
374        t = table("sometable", column("somecolumn"))
375        self.assert_compile(
376            t.select().where(t.c.somecolumn == t.select().scalar_subquery()),
377            "SELECT sometable.somecolumn FROM "
378            "sometable WHERE sometable.somecolumn = "
379            "(SELECT sometable.somecolumn FROM "
380            "sometable)",
381        )
382        self.assert_compile(
383            t.select().where(t.c.somecolumn != t.select().scalar_subquery()),
384            "SELECT sometable.somecolumn FROM "
385            "sometable WHERE sometable.somecolumn != "
386            "(SELECT sometable.somecolumn FROM "
387            "sometable)",
388        )
389
390    @testing.uses_deprecated
391    def test_count(self):
392        t = table("sometable", column("somecolumn"))
393        self.assert_compile(
394            t.count(),
395            "SELECT count(sometable.somecolumn) AS "
396            "tbl_row_count FROM sometable",
397        )
398
399    def test_noorderby_insubquery(self):
400        """test "no ORDER BY in subqueries unless TOP / LIMIT / OFFSET"
401        present"""
402
403        table1 = table(
404            "mytable",
405            column("myid", Integer),
406            column("name", String),
407            column("description", String),
408        )
409
410        q = select(table1.c.myid).order_by(table1.c.myid).alias("foo")
411        crit = q.c.myid == table1.c.myid
412        self.assert_compile(
413            select("*").where(crit),
414            "SELECT * FROM (SELECT mytable.myid AS "
415            "myid FROM mytable) AS foo, mytable WHERE "
416            "foo.myid = mytable.myid",
417        )
418
419    def test_noorderby_insubquery_limit(self):
420        """test "no ORDER BY in subqueries unless TOP / LIMIT / OFFSET"
421        present"""
422
423        table1 = table(
424            "mytable",
425            column("myid", Integer),
426            column("name", String),
427            column("description", String),
428        )
429
430        q = (
431            select(table1.c.myid)
432            .order_by(table1.c.myid)
433            .limit(10)
434            .alias("foo")
435        )
436        crit = q.c.myid == table1.c.myid
437        self.assert_compile(
438            select("*").where(crit),
439            "SELECT * FROM (SELECT TOP __[POSTCOMPILE_param_1] "
440            "mytable.myid AS "
441            "myid FROM mytable ORDER BY mytable.myid) AS foo, mytable WHERE "
442            "foo.myid = mytable.myid",
443        )
444
445    @testing.combinations(10, 0)
446    def test_noorderby_insubquery_offset_oldstyle(self, offset):
447        """test "no ORDER BY in subqueries unless TOP / LIMIT / OFFSET"
448        present"""
449
450        table1 = table(
451            "mytable",
452            column("myid", Integer),
453            column("name", String),
454            column("description", String),
455        )
456
457        q = (
458            select(table1.c.myid)
459            .order_by(table1.c.myid)
460            .offset(offset)
461            .alias("foo")
462        )
463        crit = q.c.myid == table1.c.myid
464        self.assert_compile(
465            select("*").where(crit),
466            "SELECT * FROM (SELECT anon_1.myid AS myid FROM "
467            "(SELECT mytable.myid AS myid, ROW_NUMBER() OVER (ORDER BY "
468            "mytable.myid) AS mssql_rn FROM mytable) AS anon_1 "
469            "WHERE mssql_rn > :param_1) AS foo, mytable WHERE "
470            "foo.myid = mytable.myid",
471        )
472
473    @testing.combinations(10, 0, argnames="offset")
474    def test_noorderby_insubquery_offset_newstyle(self, dialect_2012, offset):
475        """test "no ORDER BY in subqueries unless TOP / LIMIT / OFFSET"
476        present"""
477
478        table1 = table(
479            "mytable",
480            column("myid", Integer),
481            column("name", String),
482            column("description", String),
483        )
484
485        q = (
486            select(table1.c.myid)
487            .order_by(table1.c.myid)
488            .offset(offset)
489            .alias("foo")
490        )
491        crit = q.c.myid == table1.c.myid
492        self.assert_compile(
493            select("*").where(crit),
494            "SELECT * FROM (SELECT mytable.myid AS myid FROM mytable "
495            "ORDER BY mytable.myid OFFSET :param_1 ROWS) AS foo, "
496            "mytable WHERE foo.myid = mytable.myid",
497            dialect=dialect_2012,
498        )
499
500    def test_noorderby_insubquery_limit_offset_newstyle(self, dialect_2012):
501        """test "no ORDER BY in subqueries unless TOP / LIMIT / OFFSET"
502        present"""
503
504        table1 = table(
505            "mytable",
506            column("myid", Integer),
507            column("name", String),
508            column("description", String),
509        )
510
511        q = (
512            select(table1.c.myid)
513            .order_by(table1.c.myid)
514            .limit(10)
515            .offset(10)
516            .alias("foo")
517        )
518        crit = q.c.myid == table1.c.myid
519        self.assert_compile(
520            select("*").where(crit),
521            "SELECT * FROM (SELECT mytable.myid AS myid FROM mytable "
522            "ORDER BY mytable.myid OFFSET :param_1 ROWS "
523            "FETCH FIRST :param_2 ROWS ONLY) AS foo, "
524            "mytable WHERE foo.myid = mytable.myid",
525            dialect=dialect_2012,
526        )
527
528    def test_noorderby_parameters_insubquery(self):
529        """test that the ms-sql dialect does not include ORDER BY
530        positional parameters in subqueries"""
531
532        table1 = table(
533            "mytable",
534            column("myid", Integer),
535            column("name", String),
536            column("description", String),
537        )
538
539        q = (
540            select(table1.c.myid, sql.literal("bar").label("c1"))
541            .order_by(table1.c.name + "-")
542            .alias("foo")
543        )
544        crit = q.c.myid == table1.c.myid
545        dialect = mssql.dialect()
546        dialect.paramstyle = "qmark"
547        dialect.positional = True
548        self.assert_compile(
549            select("*").where(crit),
550            "SELECT * FROM (SELECT mytable.myid AS "
551            "myid, ? AS c1 FROM mytable) AS foo, mytable WHERE "
552            "foo.myid = mytable.myid",
553            dialect=dialect,
554            checkparams={"param_1": "bar"},
555            # if name_1 is included, too many parameters are passed to dbapi
556            checkpositional=("bar",),
557        )
558
559    def test_schema_many_tokens_one(self):
560        metadata = MetaData()
561        tbl = Table(
562            "test",
563            metadata,
564            Column("id", Integer, primary_key=True),
565            schema="abc.def.efg.hij",
566        )
567
568        # for now, we don't really know what the above means, at least
569        # don't lose the dot
570        self.assert_compile(
571            select(tbl),
572            "SELECT [abc.def.efg].hij.test.id FROM [abc.def.efg].hij.test",
573        )
574
575        dbname, owner = mssql_base._schema_elements("abc.def.efg.hij")
576        eq_(dbname, "abc.def.efg")
577        assert not isinstance(dbname, quoted_name)
578        eq_(owner, "hij")
579
580    def test_schema_many_tokens_two(self):
581        metadata = MetaData()
582        tbl = Table(
583            "test",
584            metadata,
585            Column("id", Integer, primary_key=True),
586            schema="[abc].[def].[efg].[hij]",
587        )
588
589        self.assert_compile(
590            select(tbl),
591            "SELECT [abc].[def].[efg].hij.test.id "
592            "FROM [abc].[def].[efg].hij.test",
593        )
594
595    def test_force_schema_quoted_name_w_dot_case_insensitive(self):
596        metadata = MetaData()
597        tbl = Table(
598            "test",
599            metadata,
600            Column("id", Integer, primary_key=True),
601            schema=quoted_name("foo.dbo", True),
602        )
603        self.assert_compile(
604            select(tbl), "SELECT [foo.dbo].test.id FROM [foo.dbo].test"
605        )
606
607    def test_force_schema_quoted_w_dot_case_insensitive(self):
608        metadata = MetaData()
609        tbl = Table(
610            "test",
611            metadata,
612            Column("id", Integer, primary_key=True),
613            schema=quoted_name("foo.dbo", True),
614        )
615        self.assert_compile(
616            select(tbl), "SELECT [foo.dbo].test.id FROM [foo.dbo].test"
617        )
618
619    @testing.combinations((True,), (False,), argnames="use_schema_translate")
620    def test_force_schema_quoted_name_w_dot_case_sensitive(
621        self, use_schema_translate
622    ):
623        metadata = MetaData()
624        tbl = Table(
625            "test",
626            metadata,
627            Column("id", Integer, primary_key=True),
628            schema=quoted_name("Foo.dbo", True)
629            if not use_schema_translate
630            else None,
631        )
632        self.assert_compile(
633            select(tbl),
634            "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test",
635            schema_translate_map={None: quoted_name("Foo.dbo", True)}
636            if use_schema_translate
637            else None,
638            render_schema_translate=True if use_schema_translate else False,
639        )
640
641    @testing.combinations((True,), (False,), argnames="use_schema_translate")
642    def test_force_schema_quoted_w_dot_case_sensitive(
643        self, use_schema_translate
644    ):
645        metadata = MetaData()
646        tbl = Table(
647            "test",
648            metadata,
649            Column("id", Integer, primary_key=True),
650            schema="[Foo.dbo]" if not use_schema_translate else None,
651        )
652        self.assert_compile(
653            select(tbl),
654            "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test",
655            schema_translate_map={None: "[Foo.dbo]"}
656            if use_schema_translate
657            else None,
658            render_schema_translate=True if use_schema_translate else False,
659        )
660
661    @testing.combinations((True,), (False,), argnames="use_schema_translate")
662    def test_schema_autosplit_w_dot_case_insensitive(
663        self, use_schema_translate
664    ):
665        metadata = MetaData()
666        tbl = Table(
667            "test",
668            metadata,
669            Column("id", Integer, primary_key=True),
670            schema="foo.dbo" if not use_schema_translate else None,
671        )
672        self.assert_compile(
673            select(tbl),
674            "SELECT foo.dbo.test.id FROM foo.dbo.test",
675            schema_translate_map={None: "foo.dbo"}
676            if use_schema_translate
677            else None,
678            render_schema_translate=True if use_schema_translate else False,
679        )
680
681    @testing.combinations((True,), (False,), argnames="use_schema_translate")
682    def test_schema_autosplit_w_dot_case_sensitive(self, use_schema_translate):
683        metadata = MetaData()
684        tbl = Table(
685            "test",
686            metadata,
687            Column("id", Integer, primary_key=True),
688            schema="Foo.dbo" if not use_schema_translate else None,
689        )
690        self.assert_compile(
691            select(tbl),
692            "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test",
693            schema_translate_map={None: "Foo.dbo"}
694            if use_schema_translate
695            else None,
696            render_schema_translate=True if use_schema_translate else False,
697        )
698
699    def test_delete_schema(self):
700        metadata = MetaData()
701        tbl = Table(
702            "test",
703            metadata,
704            Column("id", Integer, primary_key=True),
705            schema="paj",
706        )
707        self.assert_compile(
708            tbl.delete().where(tbl.c.id == 1),
709            "DELETE FROM paj.test WHERE paj.test.id = " ":id_1",
710        )
711        s = select(tbl.c.id).where(tbl.c.id == 1)
712        self.assert_compile(
713            tbl.delete().where(tbl.c.id.in_(s)),
714            "DELETE FROM paj.test WHERE paj.test.id IN "
715            "(SELECT paj.test.id FROM paj.test "
716            "WHERE paj.test.id = :id_1)",
717        )
718
719    def test_delete_schema_multipart(self):
720        metadata = MetaData()
721        tbl = Table(
722            "test",
723            metadata,
724            Column("id", Integer, primary_key=True),
725            schema="banana.paj",
726        )
727        self.assert_compile(
728            tbl.delete().where(tbl.c.id == 1),
729            "DELETE FROM banana.paj.test WHERE " "banana.paj.test.id = :id_1",
730        )
731        s = select(tbl.c.id).where(tbl.c.id == 1)
732        self.assert_compile(
733            tbl.delete().where(tbl.c.id.in_(s)),
734            "DELETE FROM banana.paj.test WHERE "
735            "banana.paj.test.id IN (SELECT banana.paj.test.id "
736            "FROM banana.paj.test WHERE "
737            "banana.paj.test.id = :id_1)",
738        )
739
740    def test_delete_schema_multipart_needs_quoting(self):
741        metadata = MetaData()
742        tbl = Table(
743            "test",
744            metadata,
745            Column("id", Integer, primary_key=True),
746            schema="banana split.paj",
747        )
748        self.assert_compile(
749            tbl.delete().where(tbl.c.id == 1),
750            "DELETE FROM [banana split].paj.test WHERE "
751            "[banana split].paj.test.id = :id_1",
752        )
753        s = select(tbl.c.id).where(tbl.c.id == 1)
754        self.assert_compile(
755            tbl.delete().where(tbl.c.id.in_(s)),
756            "DELETE FROM [banana split].paj.test WHERE "
757            "[banana split].paj.test.id IN ("
758            "SELECT [banana split].paj.test.id FROM "
759            "[banana split].paj.test WHERE "
760            "[banana split].paj.test.id = :id_1)",
761        )
762
763    def test_delete_schema_multipart_both_need_quoting(self):
764        metadata = MetaData()
765        tbl = Table(
766            "test",
767            metadata,
768            Column("id", Integer, primary_key=True),
769            schema="banana split.paj with a space",
770        )
771        self.assert_compile(
772            tbl.delete().where(tbl.c.id == 1),
773            "DELETE FROM [banana split].[paj with a "
774            "space].test WHERE [banana split].[paj "
775            "with a space].test.id = :id_1",
776        )
777        s = select(tbl.c.id).where(tbl.c.id == 1)
778        self.assert_compile(
779            tbl.delete().where(tbl.c.id.in_(s)),
780            "DELETE FROM [banana split].[paj with a space].test "
781            "WHERE [banana split].[paj with a space].test.id IN "
782            "(SELECT [banana split].[paj with a space].test.id "
783            "FROM [banana split].[paj with a space].test "
784            "WHERE [banana split].[paj with a space].test.id = :id_1)",
785        )
786
787    def test_union(self):
788        t1 = table(
789            "t1",
790            column("col1"),
791            column("col2"),
792            column("col3"),
793            column("col4"),
794        )
795        t2 = table(
796            "t2",
797            column("col1"),
798            column("col2"),
799            column("col3"),
800            column("col4"),
801        )
802        s1, s2 = (
803            select(t1.c.col3.label("col3"), t1.c.col4.label("col4")).where(
804                t1.c.col2.in_(["t1col2r1", "t1col2r2"]),
805            ),
806            select(t2.c.col3.label("col3"), t2.c.col4.label("col4")).where(
807                t2.c.col2.in_(["t2col2r2", "t2col2r3"]),
808            ),
809        )
810        u = union(s1, s2).order_by("col3", "col4")
811        self.assert_compile(
812            u,
813            "SELECT t1.col3 AS col3, t1.col4 AS col4 "
814            "FROM t1 WHERE t1.col2 IN (__[POSTCOMPILE_col2_1]) "
815            "UNION SELECT t2.col3 AS col3, "
816            "t2.col4 AS col4 FROM t2 WHERE t2.col2 IN "
817            "(__[POSTCOMPILE_col2_2]) ORDER BY col3, col4",
818            checkparams={
819                "col2_1": ["t1col2r1", "t1col2r2"],
820                "col2_2": ["t2col2r2", "t2col2r3"],
821            },
822        )
823        self.assert_compile(
824            u.alias("bar").select(),
825            "SELECT bar.col3, bar.col4 FROM (SELECT "
826            "t1.col3 AS col3, t1.col4 AS col4 FROM t1 "
827            "WHERE t1.col2 IN (__[POSTCOMPILE_col2_1]) UNION "
828            "SELECT t2.col3 AS col3, t2.col4 AS col4 "
829            "FROM t2 WHERE t2.col2 IN (__[POSTCOMPILE_col2_2])) AS bar",
830            checkparams={
831                "col2_1": ["t1col2r1", "t1col2r2"],
832                "col2_2": ["t2col2r2", "t2col2r3"],
833            },
834        )
835
836    def test_function(self):
837        self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)")
838        self.assert_compile(func.current_time(), "CURRENT_TIME")
839        self.assert_compile(func.foo(), "foo()")
840        m = MetaData()
841        t = Table(
842            "sometable", m, Column("col1", Integer), Column("col2", Integer)
843        )
844        self.assert_compile(
845            select(func.max(t.c.col1)),
846            "SELECT max(sometable.col1) AS max_1 FROM " "sometable",
847        )
848
849    def test_function_overrides(self):
850        self.assert_compile(func.current_date(), "GETDATE()")
851        self.assert_compile(func.length(3), "LEN(:length_1)")
852
853    def test_extract(self):
854        t = table("t", column("col1"))
855
856        for field in "day", "month", "year":
857            self.assert_compile(
858                select(extract(field, t.c.col1)),
859                "SELECT DATEPART(%s, t.col1) AS anon_1 FROM t" % field,
860            )
861
862    def test_update_returning(self):
863        table1 = table(
864            "mytable",
865            column("myid", Integer),
866            column("name", String(128)),
867            column("description", String(128)),
868        )
869        u = (
870            update(table1)
871            .values(dict(name="foo"))
872            .returning(table1.c.myid, table1.c.name)
873        )
874        self.assert_compile(
875            u,
876            "UPDATE mytable SET name=:name OUTPUT "
877            "inserted.myid, inserted.name",
878        )
879        u = update(table1).values(dict(name="foo")).returning(table1)
880        self.assert_compile(
881            u,
882            "UPDATE mytable SET name=:name OUTPUT "
883            "inserted.myid, inserted.name, "
884            "inserted.description",
885        )
886        u = (
887            update(table1)
888            .values(dict(name="foo"))
889            .returning(table1)
890            .where(table1.c.name == "bar")
891        )
892        self.assert_compile(
893            u,
894            "UPDATE mytable SET name=:name OUTPUT "
895            "inserted.myid, inserted.name, "
896            "inserted.description WHERE mytable.name = "
897            ":name_1",
898        )
899        u = (
900            update(table1)
901            .values(dict(name="foo"))
902            .returning(func.length(table1.c.name))
903        )
904        self.assert_compile(
905            u,
906            "UPDATE mytable SET name=:name OUTPUT "
907            "LEN(inserted.name) AS length_1",
908        )
909
910    def test_delete_returning(self):
911        table1 = table(
912            "mytable",
913            column("myid", Integer),
914            column("name", String(128)),
915            column("description", String(128)),
916        )
917        d = delete(table1).returning(table1.c.myid, table1.c.name)
918        self.assert_compile(
919            d, "DELETE FROM mytable OUTPUT deleted.myid, " "deleted.name"
920        )
921        d = (
922            delete(table1)
923            .where(table1.c.name == "bar")
924            .returning(table1.c.myid, table1.c.name)
925        )
926        self.assert_compile(
927            d,
928            "DELETE FROM mytable OUTPUT deleted.myid, "
929            "deleted.name WHERE mytable.name = :name_1",
930        )
931
932    def test_insert_returning(self):
933        table1 = table(
934            "mytable",
935            column("myid", Integer),
936            column("name", String(128)),
937            column("description", String(128)),
938        )
939        i = (
940            insert(table1)
941            .values(dict(name="foo"))
942            .returning(table1.c.myid, table1.c.name)
943        )
944        self.assert_compile(
945            i,
946            "INSERT INTO mytable (name) OUTPUT "
947            "inserted.myid, inserted.name VALUES "
948            "(:name)",
949        )
950        i = insert(table1).values(dict(name="foo")).returning(table1)
951        self.assert_compile(
952            i,
953            "INSERT INTO mytable (name) OUTPUT "
954            "inserted.myid, inserted.name, "
955            "inserted.description VALUES (:name)",
956        )
957        i = (
958            insert(table1)
959            .values(dict(name="foo"))
960            .returning(func.length(table1.c.name))
961        )
962        self.assert_compile(
963            i,
964            "INSERT INTO mytable (name) OUTPUT "
965            "LEN(inserted.name) AS length_1 VALUES "
966            "(:name)",
967        )
968
969    def test_limit_using_top(self):
970        t = table("t", column("x", Integer), column("y", Integer))
971
972        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(10)
973
974        self.assert_compile(
975            s,
976            "SELECT TOP __[POSTCOMPILE_param_1] t.x, t.y FROM t "
977            "WHERE t.x = :x_1 ORDER BY t.y",
978            checkparams={"x_1": 5, "param_1": 10},
979        )
980
981    def test_limit_using_top_literal_binds(self):
982        """test #6863"""
983        t = table("t", column("x", Integer), column("y", Integer))
984
985        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(10)
986
987        eq_ignore_whitespace(
988            str(
989                s.compile(
990                    dialect=mssql.dialect(),
991                    compile_kwargs={"literal_binds": True},
992                )
993            ),
994            "SELECT TOP 10 t.x, t.y FROM t WHERE t.x = 5 ORDER BY t.y",
995        )
996
997    def test_limit_zero_using_top(self):
998        t = table("t", column("x", Integer), column("y", Integer))
999
1000        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(0)
1001
1002        self.assert_compile(
1003            s,
1004            "SELECT TOP __[POSTCOMPILE_param_1] t.x, t.y FROM t "
1005            "WHERE t.x = :x_1 ORDER BY t.y",
1006            checkparams={"x_1": 5, "param_1": 0},
1007        )
1008        c = s.compile(dialect=mssql.dialect())
1009        eq_(len(c._result_columns), 2)
1010        assert t.c.x in set(c._create_result_map()["x"][1])
1011
1012    def test_offset_using_window(self):
1013        t = table("t", column("x", Integer), column("y", Integer))
1014
1015        s = select(t).where(t.c.x == 5).order_by(t.c.y).offset(20)
1016
1017        # test that the select is not altered with subsequent compile
1018        # calls
1019        for i in range(2):
1020            self.assert_compile(
1021                s,
1022                "SELECT anon_1.x, anon_1.y FROM (SELECT t.x AS x, t.y "
1023                "AS y, ROW_NUMBER() OVER (ORDER BY t.y) AS "
1024                "mssql_rn FROM t WHERE t.x = :x_1) AS "
1025                "anon_1 WHERE mssql_rn > :param_1",
1026                checkparams={"param_1": 20, "x_1": 5},
1027            )
1028
1029            c = s.compile(dialect=mssql.dialect())
1030            eq_(len(c._result_columns), 2)
1031            assert t.c.x in set(c._create_result_map()["x"][1])
1032
1033    def test_simple_limit_expression_offset_using_window(self):
1034        t = table("t", column("x", Integer), column("y", Integer))
1035
1036        s = (
1037            select(t)
1038            .where(t.c.x == 5)
1039            .order_by(t.c.y)
1040            .limit(10)
1041            .offset(literal_column("20"))
1042        )
1043
1044        self.assert_compile(
1045            s,
1046            "SELECT anon_1.x, anon_1.y "
1047            "FROM (SELECT t.x AS x, t.y AS y, "
1048            "ROW_NUMBER() OVER (ORDER BY t.y) AS mssql_rn "
1049            "FROM t "
1050            "WHERE t.x = :x_1) AS anon_1 "
1051            "WHERE mssql_rn > 20 AND mssql_rn <= :param_1 + 20",
1052            checkparams={"param_1": 10, "x_1": 5},
1053        )
1054
1055    def test_limit_offset_using_window(self):
1056        t = table("t", column("x", Integer), column("y", Integer))
1057
1058        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(10).offset(20)
1059
1060        self.assert_compile(
1061            s,
1062            "SELECT anon_1.x, anon_1.y "
1063            "FROM (SELECT t.x AS x, t.y AS y, "
1064            "ROW_NUMBER() OVER (ORDER BY t.y) AS mssql_rn "
1065            "FROM t "
1066            "WHERE t.x = :x_1) AS anon_1 "
1067            "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
1068            checkparams={"param_1": 20, "param_2": 10, "x_1": 5},
1069        )
1070        c = s.compile(dialect=mssql.dialect())
1071        eq_(len(c._result_columns), 2)
1072        assert t.c.x in set(c._create_result_map()["x"][1])
1073        assert t.c.y in set(c._create_result_map()["y"][1])
1074
1075    def test_limit_offset_using_offset_fetch(self, dialect_2012):
1076        t = table("t", column("x", Integer), column("y", Integer))
1077        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(10).offset(20)
1078
1079        self.assert_compile(
1080            s,
1081            "SELECT t.x, t.y "
1082            "FROM t "
1083            "WHERE t.x = :x_1 ORDER BY t.y "
1084            "OFFSET :param_1 ROWS "
1085            "FETCH FIRST :param_2 ROWS ONLY",
1086            checkparams={"param_1": 20, "param_2": 10, "x_1": 5},
1087            dialect=dialect_2012,
1088        )
1089
1090        c = s.compile(dialect=dialect_2012)
1091        eq_(len(c._result_columns), 2)
1092        assert t.c.x in set(c._create_result_map()["x"][1])
1093        assert t.c.y in set(c._create_result_map()["y"][1])
1094
1095    def test_limit_offset_w_ambiguous_cols(self):
1096        t = table("t", column("x", Integer), column("y", Integer))
1097
1098        cols = [t.c.x, t.c.x.label("q"), t.c.x.label("p"), t.c.y]
1099        s = (
1100            select(*cols)
1101            .where(t.c.x == 5)
1102            .order_by(t.c.y)
1103            .limit(10)
1104            .offset(20)
1105        )
1106
1107        self.assert_compile(
1108            s,
1109            "SELECT anon_1.x, anon_1.q, anon_1.p, anon_1.y "
1110            "FROM (SELECT t.x AS x, t.x AS q, t.x AS p, t.y AS y, "
1111            "ROW_NUMBER() OVER (ORDER BY t.y) AS mssql_rn "
1112            "FROM t "
1113            "WHERE t.x = :x_1) AS anon_1 "
1114            "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
1115            checkparams={"param_1": 20, "param_2": 10, "x_1": 5},
1116        )
1117        c = s.compile(dialect=mssql.dialect())
1118        eq_(len(c._result_columns), 4)
1119
1120        result_map = c._create_result_map()
1121
1122        for col in cols:
1123            is_(result_map[col.key][1][0], col)
1124
1125    def test_limit_offset_with_correlated_order_by(self):
1126        t1 = table("t1", column("x", Integer), column("y", Integer))
1127        t2 = table("t2", column("x", Integer), column("y", Integer))
1128
1129        order_by = select(t2.c.y).where(t1.c.x == t2.c.x).scalar_subquery()
1130        s = (
1131            select(t1)
1132            .where(t1.c.x == 5)
1133            .order_by(order_by)
1134            .limit(10)
1135            .offset(20)
1136        )
1137
1138        self.assert_compile(
1139            s,
1140            "SELECT anon_1.x, anon_1.y "
1141            "FROM (SELECT t1.x AS x, t1.y AS y, "
1142            "ROW_NUMBER() OVER (ORDER BY "
1143            "(SELECT t2.y FROM t2 WHERE t1.x = t2.x)"
1144            ") AS mssql_rn "
1145            "FROM t1 "
1146            "WHERE t1.x = :x_1) AS anon_1 "
1147            "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
1148            checkparams={"param_1": 20, "param_2": 10, "x_1": 5},
1149        )
1150
1151        c = s.compile(dialect=mssql.dialect())
1152        eq_(len(c._result_columns), 2)
1153        assert t1.c.x in set(c._create_result_map()["x"][1])
1154        assert t1.c.y in set(c._create_result_map()["y"][1])
1155
1156    def test_offset_dont_misapply_labelreference(self):
1157        m = MetaData()
1158
1159        t = Table("t", m, Column("x", Integer))
1160
1161        expr1 = func.foo(t.c.x).label("x")
1162        expr2 = func.foo(t.c.x).label("y")
1163
1164        stmt1 = select(expr1).order_by(expr1.desc()).offset(1)
1165        stmt2 = select(expr2).order_by(expr2.desc()).offset(1)
1166
1167        self.assert_compile(
1168            stmt1,
1169            "SELECT anon_1.x FROM (SELECT foo(t.x) AS x, "
1170            "ROW_NUMBER() OVER (ORDER BY foo(t.x) DESC) AS mssql_rn FROM t) "
1171            "AS anon_1 WHERE mssql_rn > :param_1",
1172        )
1173
1174        self.assert_compile(
1175            stmt2,
1176            "SELECT anon_1.y FROM (SELECT foo(t.x) AS y, "
1177            "ROW_NUMBER() OVER (ORDER BY foo(t.x) DESC) AS mssql_rn FROM t) "
1178            "AS anon_1 WHERE mssql_rn > :param_1",
1179        )
1180
1181    def test_limit_zero_offset_using_window(self):
1182        t = table("t", column("x", Integer), column("y", Integer))
1183
1184        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(0).offset(0)
1185
1186        # offset is zero but we need to cache a compatible statement
1187        self.assert_compile(
1188            s,
1189            "SELECT anon_1.x, anon_1.y FROM (SELECT t.x AS x, t.y AS y, "
1190            "ROW_NUMBER() OVER (ORDER BY t.y) AS mssql_rn FROM t "
1191            "WHERE t.x = :x_1) AS anon_1 WHERE mssql_rn > :param_1 "
1192            "AND mssql_rn <= :param_2 + :param_1",
1193            checkparams={"x_1": 5, "param_1": 0, "param_2": 0},
1194        )
1195
1196    def test_limit_zero_using_window(self):
1197        t = table("t", column("x", Integer), column("y", Integer))
1198
1199        s = select(t).where(t.c.x == 5).order_by(t.c.y).limit(0)
1200
1201        # render the LIMIT of zero, but not the OFFSET
1202        # of zero, so produces TOP 0
1203        self.assert_compile(
1204            s,
1205            "SELECT TOP __[POSTCOMPILE_param_1] t.x, t.y FROM t "
1206            "WHERE t.x = :x_1 ORDER BY t.y",
1207            checkparams={"x_1": 5, "param_1": 0},
1208        )
1209
1210    def test_table_pkc_clustering(self):
1211        metadata = MetaData()
1212        tbl = Table(
1213            "test",
1214            metadata,
1215            Column("x", Integer, autoincrement=False),
1216            Column("y", Integer, autoincrement=False),
1217            PrimaryKeyConstraint("x", "y", mssql_clustered=True),
1218        )
1219        self.assert_compile(
1220            schema.CreateTable(tbl),
1221            "CREATE TABLE test (x INTEGER NOT NULL, y INTEGER NOT NULL, "
1222            "PRIMARY KEY CLUSTERED (x, y))",
1223        )
1224
1225    def test_table_pkc_explicit_nonclustered(self):
1226        metadata = MetaData()
1227        tbl = Table(
1228            "test",
1229            metadata,
1230            Column("x", Integer, autoincrement=False),
1231            Column("y", Integer, autoincrement=False),
1232            PrimaryKeyConstraint("x", "y", mssql_clustered=False),
1233        )
1234        self.assert_compile(
1235            schema.CreateTable(tbl),
1236            "CREATE TABLE test (x INTEGER NOT NULL, y INTEGER NOT NULL, "
1237            "PRIMARY KEY NONCLUSTERED (x, y))",
1238        )
1239
1240    def test_table_idx_explicit_nonclustered(self):
1241        metadata = MetaData()
1242        tbl = Table(
1243            "test",
1244            metadata,
1245            Column("x", Integer, autoincrement=False),
1246            Column("y", Integer, autoincrement=False),
1247        )
1248
1249        idx = Index("myidx", tbl.c.x, tbl.c.y, mssql_clustered=False)
1250        self.assert_compile(
1251            schema.CreateIndex(idx),
1252            "CREATE NONCLUSTERED INDEX myidx ON test (x, y)",
1253        )
1254
1255    def test_table_uc_explicit_nonclustered(self):
1256        metadata = MetaData()
1257        tbl = Table(
1258            "test",
1259            metadata,
1260            Column("x", Integer, autoincrement=False),
1261            Column("y", Integer, autoincrement=False),
1262            UniqueConstraint("x", "y", mssql_clustered=False),
1263        )
1264        self.assert_compile(
1265            schema.CreateTable(tbl),
1266            "CREATE TABLE test (x INTEGER NULL, y INTEGER NULL, "
1267            "UNIQUE NONCLUSTERED (x, y))",
1268        )
1269
1270    def test_table_uc_clustering(self):
1271        metadata = MetaData()
1272        tbl = Table(
1273            "test",
1274            metadata,
1275            Column("x", Integer, autoincrement=False),
1276            Column("y", Integer, autoincrement=False),
1277            PrimaryKeyConstraint("x"),
1278            UniqueConstraint("y", mssql_clustered=True),
1279        )
1280        self.assert_compile(
1281            schema.CreateTable(tbl),
1282            "CREATE TABLE test (x INTEGER NOT NULL, y INTEGER NULL, "
1283            "PRIMARY KEY (x), UNIQUE CLUSTERED (y))",
1284        )
1285
1286    def test_index_clustering(self):
1287        metadata = MetaData()
1288        tbl = Table("test", metadata, Column("id", Integer))
1289        idx = Index("foo", tbl.c.id, mssql_clustered=True)
1290        self.assert_compile(
1291            schema.CreateIndex(idx), "CREATE CLUSTERED INDEX foo ON test (id)"
1292        )
1293
1294    def test_index_where(self):
1295        metadata = MetaData()
1296        tbl = Table("test", metadata, Column("data", Integer))
1297        idx = Index("test_idx_data_1", tbl.c.data, mssql_where=tbl.c.data > 1)
1298        self.assert_compile(
1299            schema.CreateIndex(idx),
1300            "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1",
1301        )
1302
1303        idx = Index("test_idx_data_1", tbl.c.data, mssql_where="data > 1")
1304        self.assert_compile(
1305            schema.CreateIndex(idx),
1306            "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1",
1307        )
1308
1309    def test_index_ordering(self):
1310        metadata = MetaData()
1311        tbl = Table(
1312            "test",
1313            metadata,
1314            Column("x", Integer),
1315            Column("y", Integer),
1316            Column("z", Integer),
1317        )
1318        idx = Index("foo", tbl.c.x.desc(), "y")
1319        self.assert_compile(
1320            schema.CreateIndex(idx), "CREATE INDEX foo ON test (x DESC, y)"
1321        )
1322
1323    def test_create_index_expr(self):
1324        m = MetaData()
1325        t1 = Table("foo", m, Column("x", Integer))
1326        self.assert_compile(
1327            schema.CreateIndex(Index("bar", t1.c.x > 5)),
1328            "CREATE INDEX bar ON foo (x > 5)",
1329        )
1330
1331    def test_drop_index_w_schema(self):
1332        m = MetaData()
1333        t1 = Table("foo", m, Column("x", Integer), schema="bar")
1334        self.assert_compile(
1335            schema.DropIndex(Index("idx_foo", t1.c.x)),
1336            "DROP INDEX idx_foo ON bar.foo",
1337        )
1338
1339    def test_index_extra_include_1(self):
1340        metadata = MetaData()
1341        tbl = Table(
1342            "test",
1343            metadata,
1344            Column("x", Integer),
1345            Column("y", Integer),
1346            Column("z", Integer),
1347        )
1348        idx = Index("foo", tbl.c.x, mssql_include=["y"])
1349        self.assert_compile(
1350            schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)"
1351        )
1352
1353    def test_index_extra_include_2(self):
1354        metadata = MetaData()
1355        tbl = Table(
1356            "test",
1357            metadata,
1358            Column("x", Integer),
1359            Column("y", Integer),
1360            Column("z", Integer),
1361        )
1362        idx = Index("foo", tbl.c.x, mssql_include=[tbl.c.y])
1363        self.assert_compile(
1364            schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)"
1365        )
1366
1367    def test_index_include_where(self):
1368        metadata = MetaData()
1369        tbl = Table(
1370            "test",
1371            metadata,
1372            Column("x", Integer),
1373            Column("y", Integer),
1374            Column("z", Integer),
1375        )
1376        idx = Index(
1377            "foo", tbl.c.x, mssql_include=[tbl.c.y], mssql_where=tbl.c.y > 1
1378        )
1379        self.assert_compile(
1380            schema.CreateIndex(idx),
1381            "CREATE INDEX foo ON test (x) INCLUDE (y) WHERE y > 1",
1382        )
1383
1384        idx = Index(
1385            "foo", tbl.c.x, mssql_include=[tbl.c.y], mssql_where=text("y > 1")
1386        )
1387        self.assert_compile(
1388            schema.CreateIndex(idx),
1389            "CREATE INDEX foo ON test (x) INCLUDE (y) WHERE y > 1",
1390        )
1391
1392    def test_try_cast(self):
1393        metadata = MetaData()
1394        t1 = Table("t1", metadata, Column("id", Integer, primary_key=True))
1395
1396        self.assert_compile(
1397            select(try_cast(t1.c.id, Integer)),
1398            "SELECT TRY_CAST (t1.id AS INTEGER) AS id FROM t1",
1399        )
1400
1401    @testing.combinations(
1402        ("no_persisted", "", "ignore"),
1403        ("persisted_none", "", None),
1404        ("persisted_true", " PERSISTED", True),
1405        ("persisted_false", "", False),
1406        id_="iaa",
1407    )
1408    def test_column_computed(self, text, persisted):
1409        m = MetaData()
1410        kwargs = {"persisted": persisted} if persisted != "ignore" else {}
1411        t = Table(
1412            "t",
1413            m,
1414            Column("x", Integer),
1415            Column("y", Integer, Computed("x + 2", **kwargs)),
1416        )
1417        self.assert_compile(
1418            schema.CreateTable(t),
1419            "CREATE TABLE t (x INTEGER NULL, y AS (x + 2)%s)" % text,
1420        )
1421
1422    @testing.combinations(
1423        (
1424            5,
1425            10,
1426            {},
1427            "OFFSET :param_1 ROWS FETCH FIRST :param_2 ROWS ONLY",
1428            {"param_1": 10, "param_2": 5},
1429        ),
1430        (None, 10, {}, "OFFSET :param_1 ROWS", {"param_1": 10}),
1431        (
1432            5,
1433            None,
1434            {},
1435            "OFFSET 0 ROWS FETCH FIRST :param_1 ROWS ONLY",
1436            {"param_1": 5},
1437        ),
1438        (
1439            0,
1440            0,
1441            {},
1442            "OFFSET :param_1 ROWS FETCH FIRST :param_2 ROWS ONLY",
1443            {"param_1": 0, "param_2": 0},
1444        ),
1445        (
1446            5,
1447            0,
1448            {"percent": True},
1449            "TOP __[POSTCOMPILE_param_1] PERCENT",
1450            {"param_1": 5},
1451        ),
1452        (
1453            5,
1454            None,
1455            {"percent": True, "with_ties": True},
1456            "TOP __[POSTCOMPILE_param_1] PERCENT WITH TIES",
1457            {"param_1": 5},
1458        ),
1459        (
1460            5,
1461            0,
1462            {"with_ties": True},
1463            "TOP __[POSTCOMPILE_param_1] WITH TIES",
1464            {"param_1": 5},
1465        ),
1466        (
1467            literal_column("Q"),
1468            literal_column("Y"),
1469            {},
1470            "OFFSET Y ROWS FETCH FIRST Q ROWS ONLY",
1471            {},
1472        ),
1473        (
1474            column("Q"),
1475            column("Y"),
1476            {},
1477            "OFFSET [Y] ROWS FETCH FIRST [Q] ROWS ONLY",
1478            {},
1479        ),
1480        (
1481            bindparam("Q", 3),
1482            bindparam("Y", 7),
1483            {},
1484            "OFFSET :Y ROWS FETCH FIRST :Q ROWS ONLY",
1485            {"Q": 3, "Y": 7},
1486        ),
1487        (
1488            literal_column("Q") + literal_column("Z"),
1489            literal_column("Y") + literal_column("W"),
1490            {},
1491            "OFFSET Y + W ROWS FETCH FIRST Q + Z ROWS ONLY",
1492            {},
1493        ),
1494        argnames="fetch, offset, fetch_kw, exp, params",
1495    )
1496    def test_fetch(self, dialect_2012, fetch, offset, fetch_kw, exp, params):
1497        t = table("t", column("a"))
1498        if "TOP" in exp:
1499            sel = "SELECT %s t.a FROM t ORDER BY t.a" % exp
1500        else:
1501            sel = "SELECT t.a FROM t ORDER BY t.a " + exp
1502
1503        stmt = select(t).order_by(t.c.a).fetch(fetch, **fetch_kw)
1504        if "with_ties" not in fetch_kw and "percent" not in fetch_kw:
1505            stmt = stmt.offset(offset)
1506
1507        self.assert_compile(
1508            stmt,
1509            sel,
1510            checkparams=params,
1511            dialect=dialect_2012,
1512        )
1513
1514    @testing.combinations(
1515        (
1516            5,
1517            10,
1518            {},
1519            "mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
1520            {"param_1": 10, "param_2": 5},
1521        ),
1522        (None, 10, {}, "mssql_rn > :param_1", {"param_1": 10}),
1523        (
1524            5,
1525            None,
1526            {},
1527            "mssql_rn <= :param_1",
1528            {"param_1": 5},
1529        ),
1530        (
1531            0,
1532            0,
1533            {},
1534            "mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
1535            {"param_1": 0, "param_2": 0},
1536        ),
1537        (
1538            5,
1539            0,
1540            {"percent": True},
1541            "TOP __[POSTCOMPILE_param_1] PERCENT",
1542            {"param_1": 5},
1543        ),
1544        (
1545            5,
1546            None,
1547            {"percent": True, "with_ties": True},
1548            "TOP __[POSTCOMPILE_param_1] PERCENT WITH TIES",
1549            {"param_1": 5},
1550        ),
1551        (
1552            5,
1553            0,
1554            {"with_ties": True},
1555            "TOP __[POSTCOMPILE_param_1] WITH TIES",
1556            {"param_1": 5},
1557        ),
1558        (
1559            literal_column("Q"),
1560            literal_column("Y"),
1561            {},
1562            "mssql_rn > Y AND mssql_rn <= Q + Y",
1563            {},
1564        ),
1565        (
1566            column("Q"),
1567            column("Y"),
1568            {},
1569            "mssql_rn > [Y] AND mssql_rn <= [Q] + [Y]",
1570            {},
1571        ),
1572        (
1573            bindparam("Q", 3),
1574            bindparam("Y", 7),
1575            {},
1576            "mssql_rn > :Y AND mssql_rn <= :Q + :Y",
1577            {"Q": 3, "Y": 7},
1578        ),
1579        (
1580            literal_column("Q") + literal_column("Z"),
1581            literal_column("Y") + literal_column("W"),
1582            {},
1583            "mssql_rn > Y + W AND mssql_rn <= Q + Z + Y + W",
1584            {},
1585        ),
1586        argnames="fetch, offset, fetch_kw, exp, params",
1587    )
1588    def test_fetch_old_version(self, fetch, offset, fetch_kw, exp, params):
1589        t = table("t", column("a"))
1590        if "TOP" in exp:
1591            sel = "SELECT %s t.a FROM t ORDER BY t.a" % exp
1592        else:
1593            sel = (
1594                "SELECT anon_1.a FROM (SELECT t.a AS a, ROW_NUMBER() "
1595                "OVER (ORDER BY t.a) AS mssql_rn FROM t) AS anon_1 WHERE "
1596                + exp
1597            )
1598
1599        stmt = select(t).order_by(t.c.a).fetch(fetch, **fetch_kw)
1600        if "with_ties" not in fetch_kw and "percent" not in fetch_kw:
1601            stmt = stmt.offset(offset)
1602
1603        self.assert_compile(
1604            stmt,
1605            sel,
1606            checkparams=params,
1607        )
1608
1609    _no_offset = (
1610        "MSSQL needs TOP to use PERCENT and/or WITH TIES. "
1611        "Only simple fetch without offset can be used."
1612    )
1613
1614    _order_by = (
1615        "MSSQL requires an order_by when using an OFFSET "
1616        "or a non-simple LIMIT clause"
1617    )
1618
1619    @testing.combinations(
1620        (
1621            select(tbl).order_by(tbl.c.a).fetch(5, percent=True).offset(3),
1622            _no_offset,
1623        ),
1624        (
1625            select(tbl).order_by(tbl.c.a).fetch(5, with_ties=True).offset(3),
1626            _no_offset,
1627        ),
1628        (
1629            select(tbl)
1630            .order_by(tbl.c.a)
1631            .fetch(5, percent=True, with_ties=True)
1632            .offset(3),
1633            _no_offset,
1634        ),
1635        (
1636            select(tbl)
1637            .order_by(tbl.c.a)
1638            .fetch(bindparam("x"), with_ties=True),
1639            _no_offset,
1640        ),
1641        (select(tbl).fetch(5).offset(3), _order_by),
1642        (select(tbl).fetch(5), _order_by),
1643        (select(tbl).offset(5), _order_by),
1644        argnames="stmt, error",
1645    )
1646    def test_row_limit_compile_error(self, dialect_2012, stmt, error):
1647        with testing.expect_raises_message(exc.CompileError, error):
1648            print(stmt.compile(dialect=dialect_2012))
1649        with testing.expect_raises_message(exc.CompileError, error):
1650            print(stmt.compile(dialect=self.__dialect__))
1651
1652
1653class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL):
1654    __dialect__ = mssql.dialect()
1655
1656    def assert_compile_with_warning(self, *args, **kwargs):
1657        with testing.expect_deprecated(
1658            "The dialect options 'mssql_identity_start' and "
1659            "'mssql_identity_increment' are deprecated. "
1660            "Use the 'Identity' object instead."
1661        ):
1662            return self.assert_compile(*args, **kwargs)
1663
1664    def test_primary_key_no_identity(self):
1665        metadata = MetaData()
1666        tbl = Table(
1667            "test",
1668            metadata,
1669            Column("id", Integer, autoincrement=False, primary_key=True),
1670        )
1671        self.assert_compile(
1672            schema.CreateTable(tbl),
1673            "CREATE TABLE test (id INTEGER NOT NULL, PRIMARY KEY (id))",
1674        )
1675
1676    def test_primary_key_defaults_to_identity(self):
1677        metadata = MetaData()
1678        tbl = Table("test", metadata, Column("id", Integer, primary_key=True))
1679        self.assert_compile(
1680            schema.CreateTable(tbl),
1681            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY, "
1682            "PRIMARY KEY (id))",
1683        )
1684
1685    def test_primary_key_with_identity_object(self):
1686        metadata = MetaData()
1687        tbl = Table(
1688            "test",
1689            metadata,
1690            Column(
1691                "id",
1692                Integer,
1693                Identity(start=3, increment=42),
1694                primary_key=True,
1695            ),
1696        )
1697        self.assert_compile(
1698            schema.CreateTable(tbl),
1699            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(3,42), "
1700            "PRIMARY KEY (id))",
1701        )
1702
1703    def test_identity_no_primary_key(self):
1704        metadata = MetaData()
1705        tbl = Table(
1706            "test", metadata, Column("id", Integer, autoincrement=True)
1707        )
1708        self.assert_compile(
1709            schema.CreateTable(tbl),
1710            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY)",
1711        )
1712
1713    def test_identity_object_no_primary_key(self):
1714        metadata = MetaData()
1715        tbl = Table(
1716            "test",
1717            metadata,
1718            Column("id", Integer, Identity(increment=42)),
1719        )
1720        self.assert_compile(
1721            schema.CreateTable(tbl),
1722            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,42))",
1723        )
1724
1725    def test_identity_object_1_1(self):
1726        metadata = MetaData()
1727        tbl = Table(
1728            "test",
1729            metadata,
1730            Column("id", Integer, Identity(start=1, increment=1)),
1731        )
1732        self.assert_compile(
1733            schema.CreateTable(tbl),
1734            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,1))",
1735        )
1736
1737    def test_identity_object_no_primary_key_non_nullable(self):
1738        metadata = MetaData()
1739        tbl = Table(
1740            "test",
1741            metadata,
1742            Column(
1743                "id",
1744                Integer,
1745                Identity(start=3),
1746                nullable=False,
1747            ),
1748        )
1749        self.assert_compile(
1750            schema.CreateTable(tbl),
1751            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(3,1)" ")",
1752        )
1753
1754    def test_identity_separate_from_primary_key(self):
1755        metadata = MetaData()
1756        tbl = Table(
1757            "test",
1758            metadata,
1759            Column("id", Integer, autoincrement=False, primary_key=True),
1760            Column("x", Integer, autoincrement=True),
1761        )
1762        self.assert_compile(
1763            schema.CreateTable(tbl),
1764            "CREATE TABLE test (id INTEGER NOT NULL, "
1765            "x INTEGER NOT NULL IDENTITY, "
1766            "PRIMARY KEY (id))",
1767        )
1768
1769    def test_identity_object_separate_from_primary_key(self):
1770        metadata = MetaData()
1771        tbl = Table(
1772            "test",
1773            metadata,
1774            Column("id", Integer, autoincrement=False, primary_key=True),
1775            Column(
1776                "x",
1777                Integer,
1778                Identity(start=3, increment=42),
1779            ),
1780        )
1781        self.assert_compile(
1782            schema.CreateTable(tbl),
1783            "CREATE TABLE test (id INTEGER NOT NULL, "
1784            "x INTEGER NOT NULL IDENTITY(3,42), "
1785            "PRIMARY KEY (id))",
1786        )
1787
1788    def test_identity_illegal_two_autoincrements(self):
1789        metadata = MetaData()
1790        tbl = Table(
1791            "test",
1792            metadata,
1793            Column("id", Integer, autoincrement=True),
1794            Column("id2", Integer, autoincrement=True),
1795        )
1796        # this will be rejected by the database, just asserting this is what
1797        # the two autoincrements will do right now
1798        self.assert_compile(
1799            schema.CreateTable(tbl),
1800            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY, "
1801            "id2 INTEGER NOT NULL IDENTITY)",
1802        )
1803
1804    def test_identity_object_illegal_two_autoincrements(self):
1805        metadata = MetaData()
1806        tbl = Table(
1807            "test",
1808            metadata,
1809            Column(
1810                "id",
1811                Integer,
1812                Identity(start=3, increment=42),
1813                autoincrement=True,
1814            ),
1815            Column(
1816                "id2",
1817                Integer,
1818                Identity(start=7, increment=2),
1819            ),
1820        )
1821        # this will be rejected by the database, just asserting this is what
1822        # the two autoincrements will do right now
1823        self.assert_compile(
1824            schema.CreateTable(tbl),
1825            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(3,42), "
1826            "id2 INTEGER NOT NULL IDENTITY(7,2))",
1827        )
1828
1829    def test_identity_start_0(self):
1830        metadata = MetaData()
1831        tbl = Table(
1832            "test",
1833            metadata,
1834            Column("id", Integer, mssql_identity_start=0, primary_key=True),
1835        )
1836        self.assert_compile_with_warning(
1837            schema.CreateTable(tbl),
1838            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(0,1), "
1839            "PRIMARY KEY (id))",
1840        )
1841
1842    def test_identity_increment_5(self):
1843        metadata = MetaData()
1844        tbl = Table(
1845            "test",
1846            metadata,
1847            Column(
1848                "id", Integer, mssql_identity_increment=5, primary_key=True
1849            ),
1850        )
1851        self.assert_compile_with_warning(
1852            schema.CreateTable(tbl),
1853            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,5), "
1854            "PRIMARY KEY (id))",
1855        )
1856
1857    @testing.combinations(
1858        schema.CreateTable(
1859            Table(
1860                "test",
1861                MetaData(),
1862                Column(
1863                    "id",
1864                    Integer,
1865                    Identity(start=2, increment=2),
1866                    mssql_identity_start=0,
1867                ),
1868            )
1869        ),
1870        schema.CreateTable(
1871            Table(
1872                "test1",
1873                MetaData(),
1874                Column(
1875                    "id2",
1876                    Integer,
1877                    Identity(start=3, increment=3),
1878                    mssql_identity_increment=5,
1879                ),
1880            )
1881        ),
1882    )
1883    def test_identity_options_ignored_with_identity_object(self, create_table):
1884        assert_raises_message(
1885            exc.CompileError,
1886            "Cannot specify options 'mssql_identity_start' and/or "
1887            "'mssql_identity_increment' while also using the "
1888            "'Identity' construct.",
1889            create_table.compile,
1890            dialect=self.__dialect__,
1891        )
1892
1893    def test_identity_object_no_options(self):
1894        metadata = MetaData()
1895        tbl = Table(
1896            "test",
1897            metadata,
1898            Column("id", Integer, Identity()),
1899        )
1900        self.assert_compile(
1901            schema.CreateTable(tbl),
1902            "CREATE TABLE test (id INTEGER NOT NULL IDENTITY)",
1903        )
1904
1905
1906class SchemaTest(fixtures.TestBase):
1907    def setup_test(self):
1908        t = Table(
1909            "sometable",
1910            MetaData(),
1911            Column("pk_column", Integer),
1912            Column("test_column", String),
1913        )
1914        self.column = t.c.test_column
1915
1916        dialect = mssql.dialect()
1917        self.ddl_compiler = dialect.ddl_compiler(
1918            dialect, schema.CreateTable(t)
1919        )
1920
1921    def _column_spec(self):
1922        return self.ddl_compiler.get_column_specification(self.column)
1923
1924    def test_that_mssql_default_nullability_emits_null(self):
1925        eq_("test_column VARCHAR(max) NULL", self._column_spec())
1926
1927    def test_that_mssql_none_nullability_does_not_emit_nullability(self):
1928        self.column.nullable = None
1929        eq_("test_column VARCHAR(max)", self._column_spec())
1930
1931    def test_that_mssql_specified_nullable_emits_null(self):
1932        self.column.nullable = True
1933        eq_("test_column VARCHAR(max) NULL", self._column_spec())
1934
1935    def test_that_mssql_specified_not_nullable_emits_not_null(self):
1936        self.column.nullable = False
1937        eq_("test_column VARCHAR(max) NOT NULL", self._column_spec())
1938