1# -*- encoding: utf-8
2import codecs
3import datetime
4import decimal
5import os
6
7import sqlalchemy as sa
8from sqlalchemy import Boolean
9from sqlalchemy import Column
10from sqlalchemy import column
11from sqlalchemy import Date
12from sqlalchemy import DateTime
13from sqlalchemy import DefaultClause
14from sqlalchemy import Float
15from sqlalchemy import inspect
16from sqlalchemy import Integer
17from sqlalchemy import LargeBinary
18from sqlalchemy import literal
19from sqlalchemy import MetaData
20from sqlalchemy import Numeric
21from sqlalchemy import PickleType
22from sqlalchemy import schema
23from sqlalchemy import select
24from sqlalchemy import Sequence
25from sqlalchemy import String
26from sqlalchemy import Table
27from sqlalchemy import testing
28from sqlalchemy import Text
29from sqlalchemy import text
30from sqlalchemy import Time
31from sqlalchemy import types
32from sqlalchemy import Unicode
33from sqlalchemy import UnicodeText
34from sqlalchemy import util
35from sqlalchemy.dialects.mssql import base as mssql
36from sqlalchemy.dialects.mssql import ROWVERSION
37from sqlalchemy.dialects.mssql import TIMESTAMP
38from sqlalchemy.dialects.mssql.base import _MSDate
39from sqlalchemy.dialects.mssql.base import BIT
40from sqlalchemy.dialects.mssql.base import DATETIMEOFFSET
41from sqlalchemy.dialects.mssql.base import MS_2005_VERSION
42from sqlalchemy.dialects.mssql.base import MS_2008_VERSION
43from sqlalchemy.dialects.mssql.base import TIME
44from sqlalchemy.sql import sqltypes
45from sqlalchemy.testing import assert_raises
46from sqlalchemy.testing import assert_raises_message
47from sqlalchemy.testing import AssertsCompiledSQL
48from sqlalchemy.testing import AssertsExecutionResults
49from sqlalchemy.testing import ComparesTables
50from sqlalchemy.testing import emits_warning_on
51from sqlalchemy.testing import engines
52from sqlalchemy.testing import eq_
53from sqlalchemy.testing import fixtures
54from sqlalchemy.testing import is_
55from sqlalchemy.testing import is_not
56from sqlalchemy.testing import pickleable
57from sqlalchemy.util import b
58
59
60class TimeParameterTest(fixtures.TablesTest):
61    __only_on__ = "mssql"
62    __backend__ = True
63
64    @classmethod
65    def define_tables(cls, metadata):
66        Table(
67            "time_t",
68            metadata,
69            Column("id", Integer, primary_key=True, autoincrement=False),
70            Column("time_col", Time),
71        )
72
73    @classmethod
74    def insert_data(cls, connection):
75        time_t = cls.tables.time_t
76        connection.execute(
77            time_t.insert(),
78            [
79                {"id": 1, "time_col": datetime.time(1, 23, 45, 67)},
80                {"id": 2, "time_col": datetime.time(12, 0, 0)},
81                {"id": 3, "time_col": datetime.time(16, 19, 59, 999999)},
82                {"id": 4, "time_col": None},
83            ],
84        )
85
86    @testing.combinations(
87        ("not_null", datetime.time(1, 23, 45, 68), 2),
88        ("null", None, 1),
89        id_="iaa",
90        argnames="time_value, expected_row_count",
91    )
92    def test_time_as_parameter_to_where(
93        self, time_value, expected_row_count, connection
94    ):
95        # issue #5339
96        t = self.tables.time_t
97
98        if time_value is None:
99            qry = t.select().where(t.c.time_col.is_(time_value))
100        else:
101            qry = t.select().where(t.c.time_col >= time_value)
102        result = connection.execute(qry).fetchall()
103        eq_(len(result), expected_row_count)
104
105
106class TimeTypeTest(fixtures.TestBase):
107    def test_result_processor_no_microseconds(self):
108        expected = datetime.time(12, 34, 56)
109        self._assert_result_processor(expected, "12:34:56")
110
111    def test_result_processor_too_many_microseconds(self):
112        # microsecond must be in 0..999999, should truncate (6 vs 7 digits)
113        expected = datetime.time(12, 34, 56, 123456)
114        self._assert_result_processor(expected, "12:34:56.1234567")
115
116    def _assert_result_processor(self, expected, value):
117        mssql_time_type = TIME()
118        result_processor = mssql_time_type.result_processor(None, None)
119        eq_(expected, result_processor(value))
120
121    def test_result_processor_invalid(self):
122        mssql_time_type = TIME()
123        result_processor = mssql_time_type.result_processor(None, None)
124        assert_raises_message(
125            ValueError,
126            "could not parse 'abc' as a time value",
127            result_processor,
128            "abc",
129        )
130
131
132class MSDateTypeTest(fixtures.TestBase):
133    __only_on__ = "mssql"
134    __backend__ = True
135
136    def test_result_processor(self):
137        expected = datetime.date(2000, 1, 2)
138        self._assert_result_processor(expected, "2000-01-02")
139
140    def _assert_result_processor(self, expected, value):
141        mssql_date_type = _MSDate()
142        result_processor = mssql_date_type.result_processor(None, None)
143        eq_(expected, result_processor(value))
144
145    def test_result_processor_invalid(self):
146        mssql_date_type = _MSDate()
147        result_processor = mssql_date_type.result_processor(None, None)
148        assert_raises_message(
149            ValueError,
150            "could not parse 'abc' as a date value",
151            result_processor,
152            "abc",
153        )
154
155    def test_extract(self, connection):
156        from sqlalchemy import extract
157
158        fivedaysago = datetime.datetime.now() - datetime.timedelta(days=5)
159        for field, exp in (
160            ("year", fivedaysago.year),
161            ("month", fivedaysago.month),
162            ("day", fivedaysago.day),
163        ):
164            r = connection.execute(
165                select(extract(field, fivedaysago))
166            ).scalar()
167            eq_(r, exp)
168
169
170class RowVersionTest(fixtures.TablesTest):
171    __only_on__ = "mssql"
172    __backend__ = True
173
174    @classmethod
175    def define_tables(cls, metadata):
176        Table(
177            "rv_t",
178            metadata,
179            Column("data", String(50)),
180            Column("rv", ROWVERSION),
181        )
182
183        Table(
184            "ts_t",
185            metadata,
186            Column("data", String(50)),
187            Column("rv", TIMESTAMP),
188        )
189
190    def test_rowversion_reflection(self):
191        # ROWVERSION is only a synonym for TIMESTAMP
192        insp = inspect(testing.db)
193        assert isinstance(insp.get_columns("rv_t")[1]["type"], TIMESTAMP)
194
195    def test_timestamp_reflection(self):
196        insp = inspect(testing.db)
197        assert isinstance(insp.get_columns("ts_t")[1]["type"], TIMESTAMP)
198
199    def test_class_hierarchy(self):
200        """TIMESTAMP and ROWVERSION aren't datetime types, they're binary."""
201
202        assert issubclass(TIMESTAMP, sqltypes._Binary)
203        assert issubclass(ROWVERSION, sqltypes._Binary)
204
205    def test_round_trip_ts(self):
206        self._test_round_trip("ts_t", TIMESTAMP, False)
207
208    def test_round_trip_rv(self):
209        self._test_round_trip("rv_t", ROWVERSION, False)
210
211    def test_round_trip_ts_int(self):
212        self._test_round_trip("ts_t", TIMESTAMP, True)
213
214    def test_round_trip_rv_int(self):
215        self._test_round_trip("rv_t", ROWVERSION, True)
216
217    def _test_round_trip(self, tab, cls, convert_int):
218        t = Table(
219            tab,
220            MetaData(),
221            Column("data", String(50)),
222            Column("rv", cls(convert_int=convert_int)),
223        )
224
225        with testing.db.begin() as conn:
226            conn.execute(t.insert().values(data="foo"))
227            last_ts_1 = conn.exec_driver_sql("SELECT @@DBTS").scalar()
228
229            if convert_int:
230                last_ts_1 = int(codecs.encode(last_ts_1, "hex"), 16)
231
232            eq_(conn.scalar(select(t.c.rv)), last_ts_1)
233
234            conn.execute(
235                t.update().values(data="bar").where(t.c.data == "foo")
236            )
237            last_ts_2 = conn.exec_driver_sql("SELECT @@DBTS").scalar()
238            if convert_int:
239                last_ts_2 = int(codecs.encode(last_ts_2, "hex"), 16)
240
241            eq_(conn.scalar(select(t.c.rv)), last_ts_2)
242
243    def test_cant_insert_rowvalue(self):
244        self._test_cant_insert(self.tables.rv_t)
245
246    def test_cant_insert_timestamp(self):
247        self._test_cant_insert(self.tables.ts_t)
248
249    def _test_cant_insert(self, tab):
250        with testing.db.connect() as conn:
251            assert_raises_message(
252                sa.exc.DBAPIError,
253                r".*Cannot insert an explicit value into a timestamp column.",
254                conn.execute,
255                tab.insert().values(data="ins", rv=b"000"),
256            )
257
258
259class TypeDDLTest(fixtures.TestBase):
260    def test_boolean(self):
261        "Exercise type specification for boolean type."
262
263        columns = [
264            # column type, args, kwargs, expected ddl
265            (Boolean, [], {}, "BIT")
266        ]
267
268        metadata = MetaData()
269        table_args = ["test_mssql_boolean", metadata]
270        for index, spec in enumerate(columns):
271            type_, args, kw, res = spec
272            table_args.append(
273                Column("c%s" % index, type_(*args, **kw), nullable=None)
274            )
275
276        boolean_table = Table(*table_args)
277        dialect = mssql.dialect()
278        gen = dialect.ddl_compiler(dialect, schema.CreateTable(boolean_table))
279
280        for col in boolean_table.c:
281            index = int(col.name[1:])
282            testing.eq_(
283                gen.get_column_specification(col),
284                "%s %s" % (col.name, columns[index][3]),
285            )
286            self.assert_(repr(col))
287
288    def test_numeric(self):
289        "Exercise type specification and options for numeric types."
290
291        columns = [
292            # column type, args, kwargs, expected ddl
293            (types.NUMERIC, [], {}, "NUMERIC"),
294            (types.NUMERIC, [None], {}, "NUMERIC"),
295            (types.NUMERIC, [12, 4], {}, "NUMERIC(12, 4)"),
296            (types.Float, [], {}, "FLOAT"),
297            (types.Float, [None], {}, "FLOAT"),
298            (types.Float, [12], {}, "FLOAT(12)"),
299            (mssql.MSReal, [], {}, "REAL"),
300            (types.Integer, [], {}, "INTEGER"),
301            (types.BigInteger, [], {}, "BIGINT"),
302            (mssql.MSTinyInteger, [], {}, "TINYINT"),
303            (types.SmallInteger, [], {}, "SMALLINT"),
304        ]
305
306        metadata = MetaData()
307        table_args = ["test_mssql_numeric", metadata]
308        for index, spec in enumerate(columns):
309            type_, args, kw, res = spec
310            table_args.append(
311                Column("c%s" % index, type_(*args, **kw), nullable=None)
312            )
313
314        numeric_table = Table(*table_args)
315        dialect = mssql.dialect()
316        gen = dialect.ddl_compiler(dialect, schema.CreateTable(numeric_table))
317
318        for col in numeric_table.c:
319            index = int(col.name[1:])
320            testing.eq_(
321                gen.get_column_specification(col),
322                "%s %s" % (col.name, columns[index][3]),
323            )
324            self.assert_(repr(col))
325
326    def test_char(self):
327        """Exercise COLLATE-ish options on string types."""
328
329        columns = [
330            (mssql.MSChar, [], {}, "CHAR"),
331            (mssql.MSChar, [1], {}, "CHAR(1)"),
332            (
333                mssql.MSChar,
334                [1],
335                {"collation": "Latin1_General_CI_AS"},
336                "CHAR(1) COLLATE Latin1_General_CI_AS",
337            ),
338            (mssql.MSNChar, [], {}, "NCHAR"),
339            (mssql.MSNChar, [1], {}, "NCHAR(1)"),
340            (
341                mssql.MSNChar,
342                [1],
343                {"collation": "Latin1_General_CI_AS"},
344                "NCHAR(1) COLLATE Latin1_General_CI_AS",
345            ),
346            (mssql.MSString, [], {}, "VARCHAR(max)"),
347            (mssql.MSString, [1], {}, "VARCHAR(1)"),
348            (
349                mssql.MSString,
350                [1],
351                {"collation": "Latin1_General_CI_AS"},
352                "VARCHAR(1) COLLATE Latin1_General_CI_AS",
353            ),
354            (mssql.MSNVarchar, [], {}, "NVARCHAR(max)"),
355            (mssql.MSNVarchar, [1], {}, "NVARCHAR(1)"),
356            (
357                mssql.MSNVarchar,
358                [1],
359                {"collation": "Latin1_General_CI_AS"},
360                "NVARCHAR(1) COLLATE Latin1_General_CI_AS",
361            ),
362            (mssql.MSText, [], {}, "TEXT"),
363            (
364                mssql.MSText,
365                [],
366                {"collation": "Latin1_General_CI_AS"},
367                "TEXT COLLATE Latin1_General_CI_AS",
368            ),
369            (mssql.MSNText, [], {}, "NTEXT"),
370            (
371                mssql.MSNText,
372                [],
373                {"collation": "Latin1_General_CI_AS"},
374                "NTEXT COLLATE Latin1_General_CI_AS",
375            ),
376        ]
377
378        metadata = MetaData()
379        table_args = ["test_mssql_charset", metadata]
380        for index, spec in enumerate(columns):
381            type_, args, kw, res = spec
382            table_args.append(
383                Column("c%s" % index, type_(*args, **kw), nullable=None)
384            )
385
386        charset_table = Table(*table_args)
387        dialect = mssql.dialect()
388        gen = dialect.ddl_compiler(dialect, schema.CreateTable(charset_table))
389
390        for col in charset_table.c:
391            index = int(col.name[1:])
392            testing.eq_(
393                gen.get_column_specification(col),
394                "%s %s" % (col.name, columns[index][3]),
395            )
396            self.assert_(repr(col))
397
398    @testing.combinations(
399        # column type, args, kwargs, expected ddl
400        (mssql.MSDateTime, [], {}, "DATETIME", None),
401        (types.DATE, [], {}, "DATE", None),
402        (types.Date, [], {}, "DATE", None),
403        (types.Date, [], {}, "DATETIME", MS_2005_VERSION),
404        (mssql.MSDate, [], {}, "DATE", None),
405        (mssql.MSDate, [], {}, "DATETIME", MS_2005_VERSION),
406        (types.TIME, [], {}, "TIME", None),
407        (types.Time, [], {}, "TIME", None),
408        (mssql.MSTime, [], {}, "TIME", None),
409        (mssql.MSTime, [1], {}, "TIME(1)", None),
410        (types.Time, [], {}, "DATETIME", MS_2005_VERSION),
411        (mssql.MSTime, [], {}, "TIME", None),
412        (mssql.MSSmallDateTime, [], {}, "SMALLDATETIME", None),
413        (mssql.MSDateTimeOffset, [], {}, "DATETIMEOFFSET", None),
414        (mssql.MSDateTimeOffset, [1], {}, "DATETIMEOFFSET(1)", None),
415        (mssql.MSDateTime2, [], {}, "DATETIME2", None),
416        (mssql.MSDateTime2, [0], {}, "DATETIME2(0)", None),
417        (mssql.MSDateTime2, [1], {}, "DATETIME2(1)", None),
418        (mssql.MSTime, [0], {}, "TIME(0)", None),
419        (mssql.MSDateTimeOffset, [0], {}, "DATETIMEOFFSET(0)", None),
420        (types.DateTime, [], {"timezone": True}, "DATETIMEOFFSET", None),
421        (types.DateTime, [], {"timezone": False}, "DATETIME", None),
422        argnames="type_, args, kw, res, server_version",
423    )
424    @testing.combinations((True,), (False,), argnames="use_type_descriptor")
425    @testing.combinations(
426        ("base",), ("pyodbc",), ("pymssql",), argnames="driver"
427    )
428    def test_dates(
429        self, type_, args, kw, res, server_version, use_type_descriptor, driver
430    ):
431        "Exercise type specification for date types."
432
433        if driver == "base":
434            from sqlalchemy.dialects.mssql import base
435
436            dialect = base.MSDialect()
437        elif driver == "pyodbc":
438            from sqlalchemy.dialects.mssql import pyodbc
439
440            dialect = pyodbc.dialect()
441        elif driver == "pymssql":
442            from sqlalchemy.dialects.mssql import pymssql
443
444            dialect = pymssql.dialect()
445        else:
446            assert False
447
448        if server_version:
449            dialect.server_version_info = server_version
450        else:
451            dialect.server_version_info = MS_2008_VERSION
452
453        metadata = MetaData()
454
455        typ = type_(*args, **kw)
456
457        if use_type_descriptor:
458            typ = dialect.type_descriptor(typ)
459
460        col = Column("date_c", typ, nullable=None)
461
462        date_table = Table("test_mssql_dates", metadata, col)
463        gen = dialect.ddl_compiler(dialect, schema.CreateTable(date_table))
464
465        testing.eq_(
466            gen.get_column_specification(col),
467            "%s %s"
468            % (
469                col.name,
470                res,
471            ),
472        )
473
474        self.assert_(repr(col))
475
476    def test_large_type_deprecation(self):
477        d1 = mssql.dialect(deprecate_large_types=True)
478        d2 = mssql.dialect(deprecate_large_types=False)
479        d3 = mssql.dialect()
480        d3.server_version_info = (11, 0)
481        d3._setup_version_attributes()
482        d4 = mssql.dialect()
483        d4.server_version_info = (10, 0)
484        d4._setup_version_attributes()
485
486        for dialect in (d1, d3):
487            eq_(str(Text().compile(dialect=dialect)), "VARCHAR(max)")
488            eq_(str(UnicodeText().compile(dialect=dialect)), "NVARCHAR(max)")
489            eq_(str(LargeBinary().compile(dialect=dialect)), "VARBINARY(max)")
490
491        for dialect in (d2, d4):
492            eq_(str(Text().compile(dialect=dialect)), "TEXT")
493            eq_(str(UnicodeText().compile(dialect=dialect)), "NTEXT")
494            eq_(str(LargeBinary().compile(dialect=dialect)), "IMAGE")
495
496    def test_money(self):
497        """Exercise type specification for money types."""
498
499        columns = [
500            (mssql.MSMoney, [], {}, "MONEY"),
501            (mssql.MSSmallMoney, [], {}, "SMALLMONEY"),
502        ]
503        metadata = MetaData()
504        table_args = ["test_mssql_money", metadata]
505        for index, spec in enumerate(columns):
506            type_, args, kw, res = spec
507            table_args.append(
508                Column("c%s" % index, type_(*args, **kw), nullable=None)
509            )
510        money_table = Table(*table_args)
511        dialect = mssql.dialect()
512        gen = dialect.ddl_compiler(dialect, schema.CreateTable(money_table))
513        for col in money_table.c:
514            index = int(col.name[1:])
515            testing.eq_(
516                gen.get_column_specification(col),
517                "%s %s" % (col.name, columns[index][3]),
518            )
519            self.assert_(repr(col))
520
521    def test_binary(self):
522        "Exercise type specification for binary types."
523
524        columns = [
525            # column type, args, kwargs, expected ddl
526            (mssql.MSBinary, [], {}, "BINARY"),
527            (mssql.MSBinary, [10], {}, "BINARY(10)"),
528            (types.BINARY, [], {}, "BINARY"),
529            (types.BINARY, [10], {}, "BINARY(10)"),
530            (mssql.MSVarBinary, [], {}, "VARBINARY(max)"),
531            (mssql.MSVarBinary, [10], {}, "VARBINARY(10)"),
532            (types.VARBINARY, [10], {}, "VARBINARY(10)"),
533            (types.VARBINARY, [], {}, "VARBINARY(max)"),
534            (mssql.MSImage, [], {}, "IMAGE"),
535            (mssql.IMAGE, [], {}, "IMAGE"),
536            (types.LargeBinary, [], {}, "IMAGE"),
537        ]
538
539        metadata = MetaData()
540        table_args = ["test_mssql_binary", metadata]
541        for index, spec in enumerate(columns):
542            type_, args, kw, res = spec
543            table_args.append(
544                Column("c%s" % index, type_(*args, **kw), nullable=None)
545            )
546        binary_table = Table(*table_args)
547        dialect = mssql.dialect()
548        gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table))
549        for col in binary_table.c:
550            index = int(col.name[1:])
551            testing.eq_(
552                gen.get_column_specification(col),
553                "%s %s" % (col.name, columns[index][3]),
554            )
555            self.assert_(repr(col))
556
557
558class TypeRoundTripTest(
559    fixtures.TestBase, AssertsExecutionResults, ComparesTables
560):
561    __only_on__ = "mssql"
562
563    __backend__ = True
564
565    def test_decimal_notation(self, metadata, connection):
566        numeric_table = Table(
567            "numeric_table",
568            metadata,
569            Column(
570                "id",
571                Integer,
572                Sequence("numeric_id_seq", optional=True),
573                primary_key=True,
574            ),
575            Column(
576                "numericcol", Numeric(precision=38, scale=20, asdecimal=True)
577            ),
578        )
579        metadata.create_all(connection)
580        test_items = [
581            decimal.Decimal(d)
582            for d in (
583                "1500000.00000000000000000000",
584                "-1500000.00000000000000000000",
585                "1500000",
586                "0.0000000000000000002",
587                "0.2",
588                "-0.0000000000000000002",
589                "-2E-2",
590                "156666.458923543",
591                "-156666.458923543",
592                "1",
593                "-1",
594                "-1234",
595                "1234",
596                "2E-12",
597                "4E8",
598                "3E-6",
599                "3E-7",
600                "4.1",
601                "1E-1",
602                "1E-2",
603                "1E-3",
604                "1E-4",
605                "1E-5",
606                "1E-6",
607                "1E-7",
608                "1E-1",
609                "1E-8",
610                "0.2732E2",
611                "-0.2432E2",
612                "4.35656E2",
613                "-02452E-2",
614                "45125E-2",
615                "1234.58965E-2",
616                "1.521E+15",
617                # previously, these were at -1E-25, which were inserted
618                # cleanly however we only got back 20 digits of accuracy.
619                # pyodbc as of 4.0.22 now disallows the silent truncation.
620                "-1E-20",
621                "1E-20",
622                "1254E-20",
623                "-1203E-20",
624                "0",
625                "-0.00",
626                "-0",
627                "4585E12",
628                "000000000000000000012",
629                "000000000000.32E12",
630                "00000000000000.1E+12",
631                # these are no longer accepted by pyodbc 4.0.22 but it seems
632                # they were not actually round-tripping correctly before that
633                # in any case
634                # '-1E-25',
635                # '1E-25',
636                # '1254E-25',
637                # '-1203E-25',
638                # '000000000000.2E-32',
639            )
640        ]
641
642        for value in test_items:
643            result = connection.execute(
644                numeric_table.insert(), dict(numericcol=value)
645            )
646            primary_key = result.inserted_primary_key
647            returned = connection.scalar(
648                select(numeric_table.c.numericcol).where(
649                    numeric_table.c.id == primary_key[0]
650                )
651            )
652            eq_(value, returned)
653
654    def test_float(self, metadata, connection):
655
656        float_table = Table(
657            "float_table",
658            metadata,
659            Column(
660                "id",
661                Integer,
662                Sequence("numeric_id_seq", optional=True),
663                primary_key=True,
664            ),
665            Column("floatcol", Float()),
666        )
667
668        metadata.create_all(connection)
669        test_items = [
670            float(d)
671            for d in (
672                "1500000.00000000000000000000",
673                "-1500000.00000000000000000000",
674                "1500000",
675                "0.0000000000000000002",
676                "0.2",
677                "-0.0000000000000000002",
678                "156666.458923543",
679                "-156666.458923543",
680                "1",
681                "-1",
682                "1234",
683                "2E-12",
684                "4E8",
685                "3E-6",
686                "3E-7",
687                "4.1",
688                "1E-1",
689                "1E-2",
690                "1E-3",
691                "1E-4",
692                "1E-5",
693                "1E-6",
694                "1E-7",
695                "1E-8",
696            )
697        ]
698        for value in test_items:
699            result = connection.execute(
700                float_table.insert(), dict(floatcol=value)
701            )
702            primary_key = result.inserted_primary_key
703            returned = connection.scalar(
704                select(float_table.c.floatcol).where(
705                    float_table.c.id == primary_key[0]
706                )
707            )
708            eq_(value, returned)
709
710    @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*")
711    def test_dates(self, metadata, connection):
712        "Exercise type specification for date types."
713
714        columns = [
715            # column type, args, kwargs, expected ddl
716            (mssql.MSDateTime, [], {}, "DATETIME", []),
717            (types.DATE, [], {}, "DATE", [">=", (10,)]),
718            (types.Date, [], {}, "DATE", [">=", (10,)]),
719            (types.Date, [], {}, "DATETIME", ["<", (10,)], mssql.MSDateTime),
720            (mssql.MSDate, [], {}, "DATE", [">=", (10,)]),
721            (mssql.MSDate, [], {}, "DATETIME", ["<", (10,)], mssql.MSDateTime),
722            (types.TIME, [], {}, "TIME", [">=", (10,)]),
723            (types.Time, [], {}, "TIME", [">=", (10,)]),
724            (mssql.MSTime, [], {}, "TIME", [">=", (10,)]),
725            (mssql.MSTime, [1], {}, "TIME(1)", [">=", (10,)]),
726            (types.Time, [], {}, "DATETIME", ["<", (10,)], mssql.MSDateTime),
727            (mssql.MSTime, [], {}, "TIME", [">=", (10,)]),
728            (mssql.MSSmallDateTime, [], {}, "SMALLDATETIME", []),
729            (mssql.MSDateTimeOffset, [], {}, "DATETIMEOFFSET", [">=", (10,)]),
730            (
731                mssql.MSDateTimeOffset,
732                [1],
733                {},
734                "DATETIMEOFFSET(1)",
735                [">=", (10,)],
736            ),
737            (mssql.MSDateTime2, [], {}, "DATETIME2", [">=", (10,)]),
738            (mssql.MSDateTime2, [0], {}, "DATETIME2(0)", [">=", (10,)]),
739            (mssql.MSDateTime2, [1], {}, "DATETIME2(1)", [">=", (10,)]),
740        ]
741
742        table_args = ["test_mssql_dates", metadata]
743        for index, spec in enumerate(columns):
744            type_, args, kw, res, requires = spec[0:5]
745            if (
746                requires
747                and testing._is_excluded("mssql", *requires)
748                or not requires
749            ):
750                c = Column("c%s" % index, type_(*args, **kw), nullable=None)
751                connection.dialect.type_descriptor(c.type)
752                table_args.append(c)
753        dates_table = Table(*table_args)
754        gen = connection.dialect.ddl_compiler(
755            connection.dialect, schema.CreateTable(dates_table)
756        )
757        for col in dates_table.c:
758            index = int(col.name[1:])
759            testing.eq_(
760                gen.get_column_specification(col),
761                "%s %s" % (col.name, columns[index][3]),
762            )
763            self.assert_(repr(col))
764        dates_table.create(connection)
765        reflected_dates = Table(
766            "test_mssql_dates", MetaData(), autoload_with=connection
767        )
768        for col in reflected_dates.c:
769            self.assert_types_base(col, dates_table.c[col.key])
770
771    @testing.metadata_fixture()
772    def date_fixture(self, metadata):
773        t = Table(
774            "test_dates",
775            metadata,
776            Column("adate", Date),
777            Column("atime1", Time),
778            Column("atime2", Time),
779            Column("adatetime", DateTime),
780            Column("adatetimeoffset", DATETIMEOFFSET),
781            Column("adatetimewithtimezone", DateTime(timezone=True)),
782        )
783
784        d1 = datetime.date(2007, 10, 30)
785        t1 = datetime.time(11, 2, 32)
786        d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
787        d3 = datetime.datetime(
788            2007,
789            10,
790            30,
791            11,
792            2,
793            32,
794            123456,
795            util.timezone(datetime.timedelta(hours=-5)),
796        )
797        return t, (d1, t1, d2, d3)
798
799    def test_date_roundtrips(self, date_fixture, connection):
800        t, (d1, t1, d2, d3) = date_fixture
801        connection.execute(
802            t.insert(),
803            dict(
804                adate=d1,
805                adatetime=d2,
806                atime1=t1,
807                atime2=d2,
808                adatetimewithtimezone=d3,
809            ),
810        )
811
812        row = connection.execute(t.select()).first()
813        eq_(
814            (
815                row.adate,
816                row.adatetime,
817                row.atime1,
818                row.atime2,
819                row.adatetimewithtimezone,
820            ),
821            (d1, d2, t1, d2.time(), d3),
822        )
823
824    @testing.combinations(
825        (
826            datetime.datetime(
827                2007,
828                10,
829                30,
830                11,
831                2,
832                32,
833                tzinfo=util.timezone(datetime.timedelta(hours=-5)),
834            ),
835        ),
836        (datetime.datetime(2007, 10, 30, 11, 2, 32)),
837        argnames="date",
838    )
839    def test_tz_present_or_non_in_dates(self, date_fixture, connection, date):
840        t, (d1, t1, d2, d3) = date_fixture
841        connection.execute(
842            t.insert(),
843            dict(
844                adatetime=date,
845                adatetimewithtimezone=date,
846            ),
847        )
848
849        row = connection.execute(
850            select(t.c.adatetime, t.c.adatetimewithtimezone)
851        ).first()
852
853        if not date.tzinfo:
854            eq_(row, (date, date.replace(tzinfo=util.timezone.utc)))
855        else:
856            eq_(row, (date.replace(tzinfo=None), date))
857
858    @testing.metadata_fixture()
859    def datetimeoffset_fixture(self, metadata):
860        t = Table(
861            "test_dates",
862            metadata,
863            Column("adatetimeoffset", DATETIMEOFFSET),
864        )
865
866        return t
867
868    @testing.combinations(
869        ("dto_param_none", lambda: None, None, False),
870        (
871            "dto_param_datetime_aware_positive",
872            lambda: datetime.datetime(
873                2007,
874                10,
875                30,
876                11,
877                2,
878                32,
879                123456,
880                util.timezone(datetime.timedelta(hours=1)),
881            ),
882            1,
883            False,
884        ),
885        (
886            "dto_param_datetime_aware_negative",
887            lambda: datetime.datetime(
888                2007,
889                10,
890                30,
891                11,
892                2,
893                32,
894                123456,
895                util.timezone(datetime.timedelta(hours=-5)),
896            ),
897            -5,
898            False,
899        ),
900        (
901            "dto_param_datetime_aware_seconds_frac_fail",
902            lambda: datetime.datetime(
903                2007,
904                10,
905                30,
906                11,
907                2,
908                32,
909                123456,
910                util.timezone(datetime.timedelta(seconds=4000)),
911            ),
912            None,
913            True,
914            testing.requires.python37,
915        ),
916        (
917            "dto_param_datetime_naive",
918            lambda: datetime.datetime(2007, 10, 30, 11, 2, 32, 123456),
919            0,
920            False,
921        ),
922        (
923            "dto_param_string_one",
924            lambda: "2007-10-30 11:02:32.123456 +01:00",
925            1,
926            False,
927        ),
928        # wow
929        (
930            "dto_param_string_two",
931            lambda: "October 30, 2007 11:02:32.123456",
932            0,
933            False,
934        ),
935        ("dto_param_string_invalid", lambda: "this is not a date", 0, True),
936        id_="iaaa",
937        argnames="dto_param_value, expected_offset_hours, should_fail",
938    )
939    def test_datetime_offset(
940        self,
941        datetimeoffset_fixture,
942        dto_param_value,
943        expected_offset_hours,
944        should_fail,
945        connection,
946    ):
947        t = datetimeoffset_fixture
948        dto_param_value = dto_param_value()
949
950        if should_fail:
951            assert_raises(
952                sa.exc.DBAPIError,
953                connection.execute,
954                t.insert(),
955                dict(adatetimeoffset=dto_param_value),
956            )
957            return
958
959        connection.execute(
960            t.insert(),
961            dict(adatetimeoffset=dto_param_value),
962        )
963
964        row = connection.execute(t.select()).first()
965
966        if dto_param_value is None:
967            is_(row.adatetimeoffset, None)
968        else:
969            eq_(
970                row.adatetimeoffset,
971                datetime.datetime(
972                    2007,
973                    10,
974                    30,
975                    11,
976                    2,
977                    32,
978                    123456,
979                    util.timezone(
980                        datetime.timedelta(hours=expected_offset_hours)
981                    ),
982                ),
983            )
984
985    @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*")
986    @testing.combinations(
987        ("legacy_large_types", False),
988        ("sql2012_large_types", True, lambda: testing.only_on("mssql >= 11")),
989        id_="ia",
990        argnames="deprecate_large_types",
991    )
992    def test_binary_reflection(self, metadata, deprecate_large_types):
993        "Exercise type specification for binary types."
994
995        columns = [
996            # column type, args, kwargs, expected ddl from reflected
997            (mssql.MSBinary, [], {}, "BINARY(1)"),
998            (mssql.MSBinary, [10], {}, "BINARY(10)"),
999            (types.BINARY, [], {}, "BINARY(1)"),
1000            (types.BINARY, [10], {}, "BINARY(10)"),
1001            (mssql.MSVarBinary, [], {}, "VARBINARY(max)"),
1002            (mssql.MSVarBinary, [10], {}, "VARBINARY(10)"),
1003            (types.VARBINARY, [10], {}, "VARBINARY(10)"),
1004            (types.VARBINARY, [], {}, "VARBINARY(max)"),
1005            (mssql.MSImage, [], {}, "IMAGE"),
1006            (mssql.IMAGE, [], {}, "IMAGE"),
1007            (
1008                types.LargeBinary,
1009                [],
1010                {},
1011                "IMAGE" if not deprecate_large_types else "VARBINARY(max)",
1012            ),
1013        ]
1014
1015        engine = engines.testing_engine(
1016            options={"deprecate_large_types": deprecate_large_types}
1017        )
1018        with engine.begin() as conn:
1019            table_args = ["test_mssql_binary", metadata]
1020            for index, spec in enumerate(columns):
1021                type_, args, kw, res = spec
1022                table_args.append(
1023                    Column("c%s" % index, type_(*args, **kw), nullable=None)
1024                )
1025            binary_table = Table(*table_args)
1026            metadata.create_all(conn)
1027            reflected_binary = Table(
1028                "test_mssql_binary", MetaData(), autoload_with=conn
1029            )
1030            for col, spec in zip(reflected_binary.c, columns):
1031                eq_(
1032                    col.type.compile(dialect=mssql.dialect()),
1033                    spec[3],
1034                    "column %s %s != %s"
1035                    % (
1036                        col.key,
1037                        col.type.compile(dialect=conn.dialect),
1038                        spec[3],
1039                    ),
1040                )
1041                c1 = conn.dialect.type_descriptor(col.type).__class__
1042                c2 = conn.dialect.type_descriptor(
1043                    binary_table.c[col.name].type
1044                ).__class__
1045                assert issubclass(
1046                    c1, c2
1047                ), "column %s: %r is not a subclass of %r" % (col.key, c1, c2)
1048                if binary_table.c[col.name].type.length:
1049                    testing.eq_(
1050                        col.type.length, binary_table.c[col.name].type.length
1051                    )
1052
1053    def test_autoincrement(self, metadata, connection):
1054        Table(
1055            "ai_1",
1056            metadata,
1057            Column("int_y", Integer, primary_key=True, autoincrement=True),
1058            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
1059        )
1060        Table(
1061            "ai_2",
1062            metadata,
1063            Column("int_y", Integer, primary_key=True, autoincrement=True),
1064            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
1065        )
1066        Table(
1067            "ai_3",
1068            metadata,
1069            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
1070            Column("int_y", Integer, primary_key=True, autoincrement=True),
1071        )
1072
1073        Table(
1074            "ai_4",
1075            metadata,
1076            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
1077            Column("int_n2", Integer, DefaultClause("0"), primary_key=True),
1078        )
1079        Table(
1080            "ai_5",
1081            metadata,
1082            Column("int_y", Integer, primary_key=True, autoincrement=True),
1083            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
1084        )
1085        Table(
1086            "ai_6",
1087            metadata,
1088            Column("o1", String(1), DefaultClause("x"), primary_key=True),
1089            Column("int_y", Integer, primary_key=True, autoincrement=True),
1090        )
1091        Table(
1092            "ai_7",
1093            metadata,
1094            Column("o1", String(1), DefaultClause("x"), primary_key=True),
1095            Column("o2", String(1), DefaultClause("x"), primary_key=True),
1096            Column("int_y", Integer, autoincrement=True, primary_key=True),
1097        )
1098        Table(
1099            "ai_8",
1100            metadata,
1101            Column("o1", String(1), DefaultClause("x"), primary_key=True),
1102            Column("o2", String(1), DefaultClause("x"), primary_key=True),
1103        )
1104        metadata.create_all(connection)
1105
1106        table_names = [
1107            "ai_1",
1108            "ai_2",
1109            "ai_3",
1110            "ai_4",
1111            "ai_5",
1112            "ai_6",
1113            "ai_7",
1114            "ai_8",
1115        ]
1116        mr = MetaData()
1117
1118        for name in table_names:
1119            tbl = Table(name, mr, autoload_with=connection)
1120            tbl = metadata.tables[name]
1121
1122            # test that the flag itself reflects appropriately
1123            for col in tbl.c:
1124                if "int_y" in col.name:
1125                    is_(col.autoincrement, True)
1126                    is_(tbl._autoincrement_column, col)
1127                else:
1128                    eq_(col.autoincrement, "auto")
1129                    is_not(tbl._autoincrement_column, col)
1130
1131            # mxodbc can't handle scope_identity() with DEFAULT VALUES
1132
1133            if testing.db.driver == "mxodbc":
1134                eng = [
1135                    engines.testing_engine(
1136                        options={"implicit_returning": True}
1137                    )
1138                ]
1139            else:
1140                eng = [
1141                    engines.testing_engine(
1142                        options={"implicit_returning": False}
1143                    ),
1144                    engines.testing_engine(
1145                        options={"implicit_returning": True}
1146                    ),
1147                ]
1148
1149            for counter, engine in enumerate(eng):
1150                connection.execute(tbl.insert())
1151                if "int_y" in tbl.c:
1152                    eq_(
1153                        connection.execute(select(tbl.c.int_y)).scalar(),
1154                        counter + 1,
1155                    )
1156                    assert (
1157                        list(connection.execute(tbl.select()).first()).count(
1158                            counter + 1
1159                        )
1160                        == 1
1161                    )
1162                else:
1163                    assert 1 not in list(
1164                        connection.execute(tbl.select()).first()
1165                    )
1166                connection.execute(tbl.delete())
1167
1168
1169class StringTest(fixtures.TestBase, AssertsCompiledSQL):
1170    __dialect__ = mssql.dialect()
1171
1172    def test_unicode_literal_binds(self):
1173        self.assert_compile(
1174            column("x", Unicode()) == "foo", "x = N'foo'", literal_binds=True
1175        )
1176
1177    def test_unicode_text_literal_binds(self):
1178        self.assert_compile(
1179            column("x", UnicodeText()) == "foo",
1180            "x = N'foo'",
1181            literal_binds=True,
1182        )
1183
1184    def test_string_text_literal_binds(self):
1185        self.assert_compile(
1186            column("x", String()) == "foo", "x = 'foo'", literal_binds=True
1187        )
1188
1189    def test_string_text_literal_binds_explicit_unicode_right(self):
1190        self.assert_compile(
1191            column("x", String()) == util.u("foo"),
1192            "x = 'foo'",
1193            literal_binds=True,
1194        )
1195
1196    def test_string_text_explicit_literal_binds(self):
1197        # the literal expression here coerces the right side to
1198        # Unicode on Python 3 for plain string, test with unicode
1199        # string just to confirm literal is doing this
1200        self.assert_compile(
1201            column("x", String()) == literal(util.u("foo")),
1202            "x = N'foo'",
1203            literal_binds=True,
1204        )
1205
1206    def test_text_text_literal_binds(self):
1207        self.assert_compile(
1208            column("x", Text()) == "foo", "x = 'foo'", literal_binds=True
1209        )
1210
1211
1212class MyPickleType(types.TypeDecorator):
1213    impl = PickleType
1214    cache_ok = True
1215
1216    def process_bind_param(self, value, dialect):
1217        if value:
1218            value.stuff = "BIND" + value.stuff
1219        return value
1220
1221    def process_result_value(self, value, dialect):
1222        if value:
1223            value.stuff = value.stuff + "RESULT"
1224        return value
1225
1226
1227class BinaryTest(fixtures.TestBase):
1228    __only_on__ = "mssql"
1229    __requires__ = ("non_broken_binary",)
1230    __backend__ = True
1231
1232    @testing.combinations(
1233        (
1234            mssql.MSVarBinary(800),
1235            b("some normal data"),
1236            None,
1237            True,
1238            None,
1239            False,
1240        ),
1241        (
1242            mssql.VARBINARY("max"),
1243            "binary_data_one.dat",
1244            None,
1245            False,
1246            None,
1247            False,
1248        ),
1249        (
1250            mssql.VARBINARY("max"),
1251            "binary_data_one.dat",
1252            None,
1253            True,
1254            None,
1255            False,
1256        ),
1257        (
1258            sqltypes.LargeBinary,
1259            "binary_data_one.dat",
1260            None,
1261            False,
1262            None,
1263            False,
1264        ),
1265        (sqltypes.LargeBinary, "binary_data_one.dat", None, True, None, False),
1266        (mssql.MSImage, "binary_data_one.dat", None, True, None, False),
1267        (PickleType, pickleable.Foo("im foo 1"), None, True, None, False),
1268        (
1269            MyPickleType,
1270            pickleable.Foo("im foo 1"),
1271            pickleable.Foo("im foo 1", stuff="BINDim stuffRESULT"),
1272            True,
1273            None,
1274            False,
1275        ),
1276        (types.BINARY(100), "binary_data_one.dat", None, True, 100, False),
1277        (types.VARBINARY(100), "binary_data_one.dat", None, True, 100, False),
1278        (mssql.VARBINARY(100), "binary_data_one.dat", None, True, 100, False),
1279        (types.BINARY(100), "binary_data_two.dat", None, True, 99, True),
1280        (types.VARBINARY(100), "binary_data_two.dat", None, True, 99, False),
1281        (mssql.VARBINARY(100), "binary_data_two.dat", None, True, 99, False),
1282        argnames="type_, data, expected, deprecate_large_types, "
1283        "slice_, zeropad",
1284    )
1285    def test_round_trip(
1286        self,
1287        metadata,
1288        type_,
1289        data,
1290        expected,
1291        deprecate_large_types,
1292        slice_,
1293        zeropad,
1294    ):
1295        if (
1296            testing.db.dialect.deprecate_large_types
1297            is not deprecate_large_types
1298        ):
1299            engine = engines.testing_engine(
1300                options={"deprecate_large_types": deprecate_large_types}
1301            )
1302        else:
1303            engine = testing.db
1304
1305        binary_table = Table(
1306            "binary_table",
1307            metadata,
1308            Column("id", Integer, primary_key=True),
1309            Column("data", type_),
1310        )
1311        binary_table.create(engine)
1312
1313        if isinstance(data, str) and (
1314            data == "binary_data_one.dat" or data == "binary_data_two.dat"
1315        ):
1316            data = self._load_stream(data)
1317
1318        if slice_ is not None:
1319            data = data[0:slice_]
1320
1321        if expected is None:
1322            if zeropad:
1323                expected = data[0:slice_] + b"\x00"
1324            else:
1325                expected = data
1326
1327        with engine.begin() as conn:
1328            conn.execute(binary_table.insert(), dict(data=data))
1329
1330            eq_(conn.scalar(select(binary_table.c.data)), expected)
1331
1332            eq_(
1333                conn.scalar(
1334                    text("select data from binary_table").columns(
1335                        binary_table.c.data
1336                    )
1337                ),
1338                expected,
1339            )
1340
1341            conn.execute(binary_table.delete())
1342
1343            conn.execute(binary_table.insert(), dict(data=None))
1344            eq_(conn.scalar(select(binary_table.c.data)), None)
1345
1346            eq_(
1347                conn.scalar(
1348                    text("select data from binary_table").columns(
1349                        binary_table.c.data
1350                    )
1351                ),
1352                None,
1353            )
1354
1355    def _load_stream(self, name, len_=3000):
1356        fp = open(
1357            os.path.join(os.path.dirname(__file__), "..", "..", name), "rb"
1358        )
1359        stream = fp.read(len_)
1360        fp.close()
1361        return stream
1362
1363
1364class BooleanTest(fixtures.TestBase, AssertsCompiledSQL):
1365    __only_on__ = "mssql"
1366
1367    @testing.provide_metadata
1368    @testing.combinations(
1369        ("as_boolean_null", Boolean, True, "CREATE TABLE tbl (boo BIT NULL)"),
1370        ("as_bit_null", BIT, True, "CREATE TABLE tbl (boo BIT NULL)"),
1371        (
1372            "as_boolean_not_null",
1373            Boolean,
1374            False,
1375            "CREATE TABLE tbl (boo BIT NOT NULL)",
1376        ),
1377        ("as_bit_not_null", BIT, False, "CREATE TABLE tbl (boo BIT NOT NULL)"),
1378        id_="iaaa",
1379        argnames="col_type, is_nullable, ddl",
1380    )
1381    def test_boolean_as_bit(self, col_type, is_nullable, ddl):
1382        tbl = Table(
1383            "tbl", self.metadata, Column("boo", col_type, nullable=is_nullable)
1384        )
1385        self.assert_compile(
1386            schema.CreateTable(tbl),
1387            ddl,
1388        )
1389        assert isinstance(tbl.c.boo.type.as_generic(), Boolean)
1390