1# -*- encoding: utf-8
2
3from decimal import Decimal
4
5from sqlalchemy import Column
6from sqlalchemy import event
7from sqlalchemy import exc
8from sqlalchemy import Integer
9from sqlalchemy import Numeric
10from sqlalchemy import select
11from sqlalchemy import String
12from sqlalchemy import Table
13from sqlalchemy import testing
14from sqlalchemy.dialects.mssql import base
15from sqlalchemy.dialects.mssql import pymssql
16from sqlalchemy.dialects.mssql import pyodbc
17from sqlalchemy.engine import url
18from sqlalchemy.exc import DBAPIError
19from sqlalchemy.exc import IntegrityError
20from sqlalchemy.testing import assert_raises
21from sqlalchemy.testing import assert_raises_message
22from sqlalchemy.testing import assert_warnings
23from sqlalchemy.testing import engines
24from sqlalchemy.testing import eq_
25from sqlalchemy.testing import expect_raises
26from sqlalchemy.testing import expect_warnings
27from sqlalchemy.testing import fixtures
28from sqlalchemy.testing import mock
29from sqlalchemy.testing.mock import Mock
30
31
32class ParseConnectTest(fixtures.TestBase):
33    def test_pyodbc_connect_dsn_trusted(self):
34        dialect = pyodbc.dialect()
35        u = url.make_url("mssql://mydsn")
36        connection = dialect.create_connect_args(u)
37        eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection)
38
39    def test_pyodbc_connect_old_style_dsn_trusted(self):
40        dialect = pyodbc.dialect()
41        u = url.make_url("mssql:///?dsn=mydsn")
42        connection = dialect.create_connect_args(u)
43        eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection)
44
45    def test_pyodbc_connect_dsn_non_trusted(self):
46        dialect = pyodbc.dialect()
47        u = url.make_url("mssql://username:password@mydsn")
48        connection = dialect.create_connect_args(u)
49        eq_([["dsn=mydsn;UID=username;PWD=password"], {}], connection)
50
51    def test_pyodbc_connect_dsn_extra(self):
52        dialect = pyodbc.dialect()
53        u = url.make_url(
54            "mssql://username:password@mydsn/?LANGUAGE=us_" "english&foo=bar"
55        )
56        connection = dialect.create_connect_args(u)
57        dsn_string = connection[0][0]
58        assert ";LANGUAGE=us_english" in dsn_string
59        assert ";foo=bar" in dsn_string
60
61    def test_pyodbc_hostname(self):
62        dialect = pyodbc.dialect()
63        u = url.make_url(
64            "mssql://username:password@hostspec/database?driver=SQL+Server"
65        )
66        connection = dialect.create_connect_args(u)
67        eq_(
68            [
69                [
70                    "DRIVER={SQL Server};Server=hostspec;Database=database;UI"
71                    "D=username;PWD=password"
72                ],
73                {},
74            ],
75            connection,
76        )
77
78    def test_pyodbc_empty_url_no_warning(self):
79        dialect = pyodbc.dialect()
80        u = url.make_url("mssql+pyodbc://")
81
82        # no warning is emitted
83        dialect.create_connect_args(u)
84
85    def test_pyodbc_host_no_driver(self):
86        dialect = pyodbc.dialect()
87        u = url.make_url("mssql://username:password@hostspec/database")
88
89        def go():
90            return dialect.create_connect_args(u)
91
92        connection = assert_warnings(
93            go,
94            [
95                "No driver name specified; this is expected by "
96                "PyODBC when using DSN-less connections"
97            ],
98        )
99
100        eq_(
101            [
102                [
103                    "Server=hostspec;Database=database;UI"
104                    "D=username;PWD=password"
105                ],
106                {},
107            ],
108            connection,
109        )
110
111    def test_pyodbc_connect_comma_port(self):
112        dialect = pyodbc.dialect()
113        u = url.make_url(
114            "mssql://username:password@hostspec:12345/data"
115            "base?driver=SQL Server"
116        )
117        connection = dialect.create_connect_args(u)
118        eq_(
119            [
120                [
121                    "DRIVER={SQL Server};Server=hostspec,12345;Database=datab"
122                    "ase;UID=username;PWD=password"
123                ],
124                {},
125            ],
126            connection,
127        )
128
129    def test_pyodbc_connect_config_port(self):
130        dialect = pyodbc.dialect()
131        u = url.make_url(
132            "mssql://username:password@hostspec/database?p"
133            "ort=12345&driver=SQL+Server"
134        )
135        connection = dialect.create_connect_args(u)
136        eq_(
137            [
138                [
139                    "DRIVER={SQL Server};Server=hostspec;Database=database;UI"
140                    "D=username;PWD=password;port=12345"
141                ],
142                {},
143            ],
144            connection,
145        )
146
147    def test_pyodbc_extra_connect(self):
148        dialect = pyodbc.dialect()
149        u = url.make_url(
150            "mssql://username:password@hostspec/database?L"
151            "ANGUAGE=us_english&foo=bar&driver=SQL+Server"
152        )
153        connection = dialect.create_connect_args(u)
154        eq_(connection[1], {})
155        eq_(
156            connection[0][0]
157            in (
158                "DRIVER={SQL Server};Server=hostspec;Database=database;"
159                "UID=username;PWD=password;foo=bar;LANGUAGE=us_english",
160                "DRIVER={SQL Server};Server=hostspec;Database=database;UID="
161                "username;PWD=password;LANGUAGE=us_english;foo=bar",
162            ),
163            True,
164        )
165
166    def test_pyodbc_extra_connect_azure(self):
167        # issue #5592
168        dialect = pyodbc.dialect()
169        u = url.make_url(
170            "mssql+pyodbc://@server_name/db_name?"
171            "driver=ODBC+Driver+17+for+SQL+Server&"
172            "authentication=ActiveDirectoryIntegrated"
173        )
174        connection = dialect.create_connect_args(u)
175        eq_(connection[1], {})
176        eq_(
177            connection[0][0]
178            in (
179                "DRIVER={ODBC Driver 17 for SQL Server};"
180                "Server=server_name;Database=db_name;"
181                "Authentication=ActiveDirectoryIntegrated",
182            ),
183            True,
184        )
185
186    def test_pyodbc_odbc_connect(self):
187        dialect = pyodbc.dialect()
188        u = url.make_url(
189            "mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server"
190            "%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase"
191            "%3BUID%3Dusername%3BPWD%3Dpassword"
192        )
193        connection = dialect.create_connect_args(u)
194        eq_(
195            [
196                [
197                    "DRIVER={SQL Server};Server=hostspec;Database=database;UI"
198                    "D=username;PWD=password"
199                ],
200                {},
201            ],
202            connection,
203        )
204
205    def test_pyodbc_odbc_connect_with_dsn(self):
206        dialect = pyodbc.dialect()
207        u = url.make_url(
208            "mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase"
209            "%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword"
210        )
211        connection = dialect.create_connect_args(u)
212        eq_(
213            [["dsn=mydsn;Database=database;UID=username;PWD=password"], {}],
214            connection,
215        )
216
217    def test_pyodbc_odbc_connect_ignores_other_values(self):
218        dialect = pyodbc.dialect()
219        u = url.make_url(
220            "mssql://userdiff:passdiff@localhost/dbdiff?od"
221            "bc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer"
222            "%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Duse"
223            "rname%3BPWD%3Dpassword"
224        )
225        connection = dialect.create_connect_args(u)
226        eq_(
227            [
228                [
229                    "DRIVER={SQL Server};Server=hostspec;Database=database;UI"
230                    "D=username;PWD=password"
231                ],
232                {},
233            ],
234            connection,
235        )
236
237    def test_pyodbc_token_injection(self):
238        token1 = "someuser%3BPORT%3D50001"
239        token2 = "some{strange}pw%3BPORT%3D50001"
240        token3 = "somehost%3BPORT%3D50001"
241        token4 = "somedb%3BPORT%3D50001"
242
243        u = url.make_url(
244            "mssql+pyodbc://%s:%s@%s/%s?driver=foob"
245            % (token1, token2, token3, token4)
246        )
247        dialect = pyodbc.dialect()
248        connection = dialect.create_connect_args(u)
249        eq_(
250            [
251                [
252                    "DRIVER={foob};Server=somehost%3BPORT%3D50001;"
253                    "Database=somedb%3BPORT%3D50001;UID={someuser;PORT=50001};"
254                    "PWD={some{strange}}pw;PORT=50001}"
255                ],
256                {},
257            ],
258            connection,
259        )
260
261    def test_pymssql_port_setting(self):
262        dialect = pymssql.dialect()
263
264        u = url.make_url("mssql+pymssql://scott:tiger@somehost/test")
265        connection = dialect.create_connect_args(u)
266        eq_(
267            [
268                [],
269                {
270                    "host": "somehost",
271                    "password": "tiger",
272                    "user": "scott",
273                    "database": "test",
274                },
275            ],
276            connection,
277        )
278
279        u = url.make_url("mssql+pymssql://scott:tiger@somehost:5000/test")
280        connection = dialect.create_connect_args(u)
281        eq_(
282            [
283                [],
284                {
285                    "host": "somehost:5000",
286                    "password": "tiger",
287                    "user": "scott",
288                    "database": "test",
289                },
290            ],
291            connection,
292        )
293
294    def test_pymssql_disconnect(self):
295        dialect = pymssql.dialect()
296
297        for error in [
298            "Adaptive Server connection timed out",
299            "Net-Lib error during Connection reset by peer",
300            "message 20003",
301            "Error 10054",
302            "Not connected to any MS SQL server",
303            "Connection is closed",
304            "message 20006",  # Write to the server failed
305            "message 20017",  # Unexpected EOF from the server
306            "message 20047",  # DBPROCESS is dead or not enabled
307        ]:
308            eq_(dialect.is_disconnect(error, None, None), True)
309
310        eq_(dialect.is_disconnect("not an error", None, None), False)
311
312    def test_pyodbc_disconnect(self):
313        dialect = pyodbc.dialect()
314
315        class MockDBAPIError(Exception):
316            pass
317
318        class MockProgrammingError(MockDBAPIError):
319            pass
320
321        dialect.dbapi = Mock(
322            Error=MockDBAPIError, ProgrammingError=MockProgrammingError
323        )
324
325        for error in [
326            MockDBAPIError(code, "[%s] some pyodbc message" % code)
327            for code in [
328                "08S01",
329                "01002",
330                "08003",
331                "08007",
332                "08S02",
333                "08001",
334                "HYT00",
335                "HY010",
336            ]
337        ] + [
338            MockProgrammingError(message)
339            for message in [
340                "(some pyodbc stuff) The cursor's connection has been closed.",
341                "(some pyodbc stuff) Attempt to use a closed connection.",
342            ]
343        ]:
344            eq_(dialect.is_disconnect(error, None, None), True)
345
346        eq_(
347            dialect.is_disconnect(
348                MockProgrammingError("Query with abc08007def failed"),
349                None,
350                None,
351            ),
352            False,
353        )
354
355    @testing.requires.mssql_freetds
356    def test_bad_freetds_warning(self):
357        engine = engines.testing_engine()
358
359        def _bad_version(connection):
360            return 95, 10, 255
361
362        engine.dialect._get_server_version_info = _bad_version
363        assert_raises_message(
364            exc.SAWarning, "Unrecognized server version info", engine.connect
365        )
366
367
368class FastExecutemanyTest(fixtures.TestBase):
369    __only_on__ = "mssql"
370    __backend__ = True
371    __requires__ = ("pyodbc_fast_executemany",)
372
373    def test_flag_on(self, metadata):
374        t = Table(
375            "t",
376            metadata,
377            Column("id", Integer, primary_key=True),
378            Column("data", String(50)),
379        )
380        t.create(testing.db)
381
382        eng = engines.testing_engine(options={"fast_executemany": True})
383
384        @event.listens_for(eng, "after_cursor_execute")
385        def after_cursor_execute(
386            conn, cursor, statement, parameters, context, executemany
387        ):
388            if executemany:
389                assert cursor.fast_executemany
390
391        with eng.begin() as conn:
392            conn.execute(
393                t.insert(),
394                [{"id": i, "data": "data_%d" % i} for i in range(100)],
395            )
396
397            conn.execute(t.insert(), {"id": 200, "data": "data_200"})
398
399    @testing.fixture
400    def fe_engine(self, testing_engine):
401        def go(use_fastexecutemany, apply_setinputsizes_flag):
402            engine = testing_engine(
403                options={
404                    "fast_executemany": use_fastexecutemany,
405                    "use_setinputsizes": apply_setinputsizes_flag,
406                }
407            )
408            return engine
409
410        return go
411
412    @testing.combinations(
413        (
414            "setinputsizeshook",
415            True,
416        ),
417        (
418            "nosetinputsizeshook",
419            False,
420        ),
421        argnames="include_setinputsizes",
422        id_="ia",
423    )
424    @testing.combinations(
425        (
426            "setinputsizesflag",
427            True,
428        ),
429        (
430            "nosetinputsizesflag",
431            False,
432        ),
433        argnames="apply_setinputsizes_flag",
434        id_="ia",
435    )
436    @testing.combinations(
437        (
438            "fastexecutemany",
439            True,
440        ),
441        (
442            "nofastexecutemany",
443            False,
444        ),
445        argnames="use_fastexecutemany",
446        id_="ia",
447    )
448    def test_insert_floats(
449        self,
450        metadata,
451        fe_engine,
452        include_setinputsizes,
453        use_fastexecutemany,
454        apply_setinputsizes_flag,
455    ):
456        expect_failure = (
457            apply_setinputsizes_flag
458            and not include_setinputsizes
459            and use_fastexecutemany
460        )
461
462        engine = fe_engine(use_fastexecutemany, apply_setinputsizes_flag)
463
464        observations = Table(
465            "Observations",
466            metadata,
467            Column("id", Integer, nullable=False, primary_key=True),
468            Column("obs1", Numeric(19, 15), nullable=True),
469            Column("obs2", Numeric(19, 15), nullable=True),
470            schema="test_schema",
471        )
472        with engine.begin() as conn:
473            metadata.create_all(conn)
474
475        records = [
476            {
477                "id": 1,
478                "obs1": Decimal("60.1722066045792"),
479                "obs2": Decimal("24.929289808227466"),
480            },
481            {
482                "id": 2,
483                "obs1": Decimal("60.16325715615476"),
484                "obs2": Decimal("24.93886459535008"),
485            },
486            {
487                "id": 3,
488                "obs1": Decimal("60.16445165123469"),
489                "obs2": Decimal("24.949856300109516"),
490            },
491        ]
492
493        if include_setinputsizes:
494            canary = mock.Mock()
495
496            @event.listens_for(engine, "do_setinputsizes")
497            def do_setinputsizes(
498                inputsizes, cursor, statement, parameters, context
499            ):
500                canary(list(inputsizes.values()))
501
502                for key in inputsizes:
503                    if isinstance(key.type, Numeric):
504                        inputsizes[key] = (
505                            engine.dialect.dbapi.SQL_DECIMAL,
506                            19,
507                            15,
508                        )
509
510        with engine.begin() as conn:
511
512            if expect_failure:
513                with expect_raises(DBAPIError):
514                    conn.execute(observations.insert(), records)
515            else:
516                conn.execute(observations.insert(), records)
517
518                eq_(
519                    conn.execute(
520                        select(observations).order_by(observations.c.id)
521                    )
522                    .mappings()
523                    .all(),
524                    records,
525                )
526
527        if include_setinputsizes:
528            if apply_setinputsizes_flag:
529                eq_(
530                    canary.mock_calls,
531                    [
532                        # float for int?  this seems wrong
533                        mock.call([float, float, float]),
534                        mock.call([]),
535                    ],
536                )
537            else:
538                eq_(canary.mock_calls, [])
539
540
541class VersionDetectionTest(fixtures.TestBase):
542    @testing.fixture
543    def mock_conn_scalar(self):
544        return lambda text: Mock(
545            exec_driver_sql=Mock(
546                return_value=Mock(scalar=Mock(return_value=text))
547            )
548        )
549
550    def test_pymssql_version(self, mock_conn_scalar):
551        dialect = pymssql.MSDialect_pymssql()
552
553        for vers in [
554            "Microsoft SQL Server Blah - 11.0.9216.62",
555            "Microsoft SQL Server (XYZ) - 11.0.9216.62 \n"
556            "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation",
557            "Microsoft SQL Azure (RTM) - 11.0.9216.62 \n"
558            "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation",
559        ]:
560            conn = mock_conn_scalar(vers)
561            eq_(dialect._get_server_version_info(conn), (11, 0, 9216, 62))
562
563    def test_pyodbc_version_productversion(self, mock_conn_scalar):
564        dialect = pyodbc.MSDialect_pyodbc()
565
566        conn = mock_conn_scalar("11.0.9216.62")
567        eq_(dialect._get_server_version_info(conn), (11, 0, 9216, 62))
568
569    def test_pyodbc_version_fallback(self):
570        dialect = pyodbc.MSDialect_pyodbc()
571        dialect.dbapi = Mock()
572
573        for vers, expected in [
574            ("11.0.9216.62", (11, 0, 9216, 62)),
575            ("notsqlserver.11.foo.0.9216.BAR.62", (11, 0, 9216, 62)),
576            ("Not SQL Server Version 10.5", (5,)),
577        ]:
578            conn = Mock(
579                exec_driver_sql=Mock(
580                    return_value=Mock(
581                        scalar=Mock(
582                            side_effect=exc.DBAPIError("stmt", "params", None)
583                        )
584                    )
585                ),
586                connection=Mock(getinfo=Mock(return_value=vers)),
587            )
588
589            eq_(dialect._get_server_version_info(conn), expected)
590
591
592class RealIsolationLevelTest(fixtures.TestBase):
593    __only_on__ = "mssql"
594    __backend__ = True
595
596    def test_isolation_level(self, metadata):
597        Table("test", metadata, Column("id", Integer)).create(
598            testing.db, checkfirst=True
599        )
600
601        with testing.db.connect() as c:
602            default = testing.db.dialect.get_isolation_level(c.connection)
603
604        values = [
605            "READ UNCOMMITTED",
606            "READ COMMITTED",
607            "REPEATABLE READ",
608            "SERIALIZABLE",
609            "SNAPSHOT",
610        ]
611        for value in values:
612            with testing.db.connect() as c:
613                c.execution_options(isolation_level=value)
614
615                c.exec_driver_sql("SELECT TOP 10 * FROM test")
616
617                eq_(
618                    testing.db.dialect.get_isolation_level(c.connection), value
619                )
620
621        with testing.db.connect() as c:
622            eq_(testing.db.dialect.get_isolation_level(c.connection), default)
623
624
625class IsolationLevelDetectTest(fixtures.TestBase):
626    def _fixture(self, view):
627        class Error(Exception):
628            pass
629
630        dialect = pyodbc.MSDialect_pyodbc()
631        dialect.dbapi = Mock(Error=Error)
632        dialect.server_version_info = base.MS_2012_VERSION
633
634        result = []
635
636        def fail_on_exec(
637            stmt,
638        ):
639            if view is not None and view in stmt:
640                result.append(("SERIALIZABLE",))
641            else:
642                raise Error("that didn't work")
643
644        connection = Mock(
645            cursor=Mock(
646                return_value=Mock(
647                    execute=fail_on_exec, fetchone=lambda: result[0]
648                )
649            )
650        )
651
652        return dialect, connection
653
654    def test_dm_pdw_nodes(self):
655        dialect, connection = self._fixture("dm_pdw_nodes_exec_sessions")
656
657        eq_(dialect.get_isolation_level(connection), "SERIALIZABLE")
658
659    def test_exec_sessions(self):
660        dialect, connection = self._fixture("exec_sessions")
661
662        eq_(dialect.get_isolation_level(connection), "SERIALIZABLE")
663
664    def test_not_supported(self):
665        dialect, connection = self._fixture(None)
666
667        with expect_warnings("Could not fetch transaction isolation level"):
668            assert_raises_message(
669                NotImplementedError,
670                "Can't fetch isolation",
671                dialect.get_isolation_level,
672                connection,
673            )
674
675
676class InvalidTransactionFalsePositiveTest(fixtures.TablesTest):
677    __only_on__ = "mssql"
678    __backend__ = True
679
680    @classmethod
681    def define_tables(cls, metadata):
682        Table(
683            "error_t",
684            metadata,
685            Column("error_code", String(50), primary_key=True),
686        )
687
688    @classmethod
689    def insert_data(cls, connection):
690        connection.execute(
691            cls.tables.error_t.insert(),
692            [{"error_code": "01002"}],
693        )
694
695    def test_invalid_transaction_detection(self, connection):
696        # issue #5359
697        t = self.tables.error_t
698
699        # force duplicate PK error
700        assert_raises(
701            IntegrityError,
702            connection.execute,
703            t.insert(),
704            {"error_code": "01002"},
705        )
706
707        # this should not fail with
708        # "Can't reconnect until invalid transaction is rolled back."
709        result = connection.execute(t.select()).fetchall()
710        eq_(len(result), 1)
711