1# -*- encoding: utf-8 2 3from decimal import Decimal 4 5from sqlalchemy import Column 6from sqlalchemy import event 7from sqlalchemy import exc 8from sqlalchemy import Integer 9from sqlalchemy import Numeric 10from sqlalchemy import select 11from sqlalchemy import String 12from sqlalchemy import Table 13from sqlalchemy import testing 14from sqlalchemy.dialects.mssql import base 15from sqlalchemy.dialects.mssql import pymssql 16from sqlalchemy.dialects.mssql import pyodbc 17from sqlalchemy.engine import url 18from sqlalchemy.exc import DBAPIError 19from sqlalchemy.exc import IntegrityError 20from sqlalchemy.testing import assert_raises 21from sqlalchemy.testing import assert_raises_message 22from sqlalchemy.testing import assert_warnings 23from sqlalchemy.testing import engines 24from sqlalchemy.testing import eq_ 25from sqlalchemy.testing import expect_raises 26from sqlalchemy.testing import expect_warnings 27from sqlalchemy.testing import fixtures 28from sqlalchemy.testing import mock 29from sqlalchemy.testing.mock import Mock 30 31 32class ParseConnectTest(fixtures.TestBase): 33 def test_pyodbc_connect_dsn_trusted(self): 34 dialect = pyodbc.dialect() 35 u = url.make_url("mssql://mydsn") 36 connection = dialect.create_connect_args(u) 37 eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) 38 39 def test_pyodbc_connect_old_style_dsn_trusted(self): 40 dialect = pyodbc.dialect() 41 u = url.make_url("mssql:///?dsn=mydsn") 42 connection = dialect.create_connect_args(u) 43 eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) 44 45 def test_pyodbc_connect_dsn_non_trusted(self): 46 dialect = pyodbc.dialect() 47 u = url.make_url("mssql://username:password@mydsn") 48 connection = dialect.create_connect_args(u) 49 eq_([["dsn=mydsn;UID=username;PWD=password"], {}], connection) 50 51 def test_pyodbc_connect_dsn_extra(self): 52 dialect = pyodbc.dialect() 53 u = url.make_url( 54 "mssql://username:password@mydsn/?LANGUAGE=us_" "english&foo=bar" 55 ) 56 connection = dialect.create_connect_args(u) 57 dsn_string = connection[0][0] 58 assert ";LANGUAGE=us_english" in dsn_string 59 assert ";foo=bar" in dsn_string 60 61 def test_pyodbc_hostname(self): 62 dialect = pyodbc.dialect() 63 u = url.make_url( 64 "mssql://username:password@hostspec/database?driver=SQL+Server" 65 ) 66 connection = dialect.create_connect_args(u) 67 eq_( 68 [ 69 [ 70 "DRIVER={SQL Server};Server=hostspec;Database=database;UI" 71 "D=username;PWD=password" 72 ], 73 {}, 74 ], 75 connection, 76 ) 77 78 def test_pyodbc_empty_url_no_warning(self): 79 dialect = pyodbc.dialect() 80 u = url.make_url("mssql+pyodbc://") 81 82 # no warning is emitted 83 dialect.create_connect_args(u) 84 85 def test_pyodbc_host_no_driver(self): 86 dialect = pyodbc.dialect() 87 u = url.make_url("mssql://username:password@hostspec/database") 88 89 def go(): 90 return dialect.create_connect_args(u) 91 92 connection = assert_warnings( 93 go, 94 [ 95 "No driver name specified; this is expected by " 96 "PyODBC when using DSN-less connections" 97 ], 98 ) 99 100 eq_( 101 [ 102 [ 103 "Server=hostspec;Database=database;UI" 104 "D=username;PWD=password" 105 ], 106 {}, 107 ], 108 connection, 109 ) 110 111 def test_pyodbc_connect_comma_port(self): 112 dialect = pyodbc.dialect() 113 u = url.make_url( 114 "mssql://username:password@hostspec:12345/data" 115 "base?driver=SQL Server" 116 ) 117 connection = dialect.create_connect_args(u) 118 eq_( 119 [ 120 [ 121 "DRIVER={SQL Server};Server=hostspec,12345;Database=datab" 122 "ase;UID=username;PWD=password" 123 ], 124 {}, 125 ], 126 connection, 127 ) 128 129 def test_pyodbc_connect_config_port(self): 130 dialect = pyodbc.dialect() 131 u = url.make_url( 132 "mssql://username:password@hostspec/database?p" 133 "ort=12345&driver=SQL+Server" 134 ) 135 connection = dialect.create_connect_args(u) 136 eq_( 137 [ 138 [ 139 "DRIVER={SQL Server};Server=hostspec;Database=database;UI" 140 "D=username;PWD=password;port=12345" 141 ], 142 {}, 143 ], 144 connection, 145 ) 146 147 def test_pyodbc_extra_connect(self): 148 dialect = pyodbc.dialect() 149 u = url.make_url( 150 "mssql://username:password@hostspec/database?L" 151 "ANGUAGE=us_english&foo=bar&driver=SQL+Server" 152 ) 153 connection = dialect.create_connect_args(u) 154 eq_(connection[1], {}) 155 eq_( 156 connection[0][0] 157 in ( 158 "DRIVER={SQL Server};Server=hostspec;Database=database;" 159 "UID=username;PWD=password;foo=bar;LANGUAGE=us_english", 160 "DRIVER={SQL Server};Server=hostspec;Database=database;UID=" 161 "username;PWD=password;LANGUAGE=us_english;foo=bar", 162 ), 163 True, 164 ) 165 166 def test_pyodbc_extra_connect_azure(self): 167 # issue #5592 168 dialect = pyodbc.dialect() 169 u = url.make_url( 170 "mssql+pyodbc://@server_name/db_name?" 171 "driver=ODBC+Driver+17+for+SQL+Server&" 172 "authentication=ActiveDirectoryIntegrated" 173 ) 174 connection = dialect.create_connect_args(u) 175 eq_(connection[1], {}) 176 eq_( 177 connection[0][0] 178 in ( 179 "DRIVER={ODBC Driver 17 for SQL Server};" 180 "Server=server_name;Database=db_name;" 181 "Authentication=ActiveDirectoryIntegrated", 182 ), 183 True, 184 ) 185 186 def test_pyodbc_odbc_connect(self): 187 dialect = pyodbc.dialect() 188 u = url.make_url( 189 "mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server" 190 "%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase" 191 "%3BUID%3Dusername%3BPWD%3Dpassword" 192 ) 193 connection = dialect.create_connect_args(u) 194 eq_( 195 [ 196 [ 197 "DRIVER={SQL Server};Server=hostspec;Database=database;UI" 198 "D=username;PWD=password" 199 ], 200 {}, 201 ], 202 connection, 203 ) 204 205 def test_pyodbc_odbc_connect_with_dsn(self): 206 dialect = pyodbc.dialect() 207 u = url.make_url( 208 "mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase" 209 "%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword" 210 ) 211 connection = dialect.create_connect_args(u) 212 eq_( 213 [["dsn=mydsn;Database=database;UID=username;PWD=password"], {}], 214 connection, 215 ) 216 217 def test_pyodbc_odbc_connect_ignores_other_values(self): 218 dialect = pyodbc.dialect() 219 u = url.make_url( 220 "mssql://userdiff:passdiff@localhost/dbdiff?od" 221 "bc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer" 222 "%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Duse" 223 "rname%3BPWD%3Dpassword" 224 ) 225 connection = dialect.create_connect_args(u) 226 eq_( 227 [ 228 [ 229 "DRIVER={SQL Server};Server=hostspec;Database=database;UI" 230 "D=username;PWD=password" 231 ], 232 {}, 233 ], 234 connection, 235 ) 236 237 def test_pyodbc_token_injection(self): 238 token1 = "someuser%3BPORT%3D50001" 239 token2 = "some{strange}pw%3BPORT%3D50001" 240 token3 = "somehost%3BPORT%3D50001" 241 token4 = "somedb%3BPORT%3D50001" 242 243 u = url.make_url( 244 "mssql+pyodbc://%s:%s@%s/%s?driver=foob" 245 % (token1, token2, token3, token4) 246 ) 247 dialect = pyodbc.dialect() 248 connection = dialect.create_connect_args(u) 249 eq_( 250 [ 251 [ 252 "DRIVER={foob};Server=somehost%3BPORT%3D50001;" 253 "Database=somedb%3BPORT%3D50001;UID={someuser;PORT=50001};" 254 "PWD={some{strange}}pw;PORT=50001}" 255 ], 256 {}, 257 ], 258 connection, 259 ) 260 261 def test_pymssql_port_setting(self): 262 dialect = pymssql.dialect() 263 264 u = url.make_url("mssql+pymssql://scott:tiger@somehost/test") 265 connection = dialect.create_connect_args(u) 266 eq_( 267 [ 268 [], 269 { 270 "host": "somehost", 271 "password": "tiger", 272 "user": "scott", 273 "database": "test", 274 }, 275 ], 276 connection, 277 ) 278 279 u = url.make_url("mssql+pymssql://scott:tiger@somehost:5000/test") 280 connection = dialect.create_connect_args(u) 281 eq_( 282 [ 283 [], 284 { 285 "host": "somehost:5000", 286 "password": "tiger", 287 "user": "scott", 288 "database": "test", 289 }, 290 ], 291 connection, 292 ) 293 294 def test_pymssql_disconnect(self): 295 dialect = pymssql.dialect() 296 297 for error in [ 298 "Adaptive Server connection timed out", 299 "Net-Lib error during Connection reset by peer", 300 "message 20003", 301 "Error 10054", 302 "Not connected to any MS SQL server", 303 "Connection is closed", 304 "message 20006", # Write to the server failed 305 "message 20017", # Unexpected EOF from the server 306 "message 20047", # DBPROCESS is dead or not enabled 307 ]: 308 eq_(dialect.is_disconnect(error, None, None), True) 309 310 eq_(dialect.is_disconnect("not an error", None, None), False) 311 312 def test_pyodbc_disconnect(self): 313 dialect = pyodbc.dialect() 314 315 class MockDBAPIError(Exception): 316 pass 317 318 class MockProgrammingError(MockDBAPIError): 319 pass 320 321 dialect.dbapi = Mock( 322 Error=MockDBAPIError, ProgrammingError=MockProgrammingError 323 ) 324 325 for error in [ 326 MockDBAPIError(code, "[%s] some pyodbc message" % code) 327 for code in [ 328 "08S01", 329 "01002", 330 "08003", 331 "08007", 332 "08S02", 333 "08001", 334 "HYT00", 335 "HY010", 336 ] 337 ] + [ 338 MockProgrammingError(message) 339 for message in [ 340 "(some pyodbc stuff) The cursor's connection has been closed.", 341 "(some pyodbc stuff) Attempt to use a closed connection.", 342 ] 343 ]: 344 eq_(dialect.is_disconnect(error, None, None), True) 345 346 eq_( 347 dialect.is_disconnect( 348 MockProgrammingError("Query with abc08007def failed"), 349 None, 350 None, 351 ), 352 False, 353 ) 354 355 @testing.requires.mssql_freetds 356 def test_bad_freetds_warning(self): 357 engine = engines.testing_engine() 358 359 def _bad_version(connection): 360 return 95, 10, 255 361 362 engine.dialect._get_server_version_info = _bad_version 363 assert_raises_message( 364 exc.SAWarning, "Unrecognized server version info", engine.connect 365 ) 366 367 368class FastExecutemanyTest(fixtures.TestBase): 369 __only_on__ = "mssql" 370 __backend__ = True 371 __requires__ = ("pyodbc_fast_executemany",) 372 373 def test_flag_on(self, metadata): 374 t = Table( 375 "t", 376 metadata, 377 Column("id", Integer, primary_key=True), 378 Column("data", String(50)), 379 ) 380 t.create(testing.db) 381 382 eng = engines.testing_engine(options={"fast_executemany": True}) 383 384 @event.listens_for(eng, "after_cursor_execute") 385 def after_cursor_execute( 386 conn, cursor, statement, parameters, context, executemany 387 ): 388 if executemany: 389 assert cursor.fast_executemany 390 391 with eng.begin() as conn: 392 conn.execute( 393 t.insert(), 394 [{"id": i, "data": "data_%d" % i} for i in range(100)], 395 ) 396 397 conn.execute(t.insert(), {"id": 200, "data": "data_200"}) 398 399 @testing.fixture 400 def fe_engine(self, testing_engine): 401 def go(use_fastexecutemany, apply_setinputsizes_flag): 402 engine = testing_engine( 403 options={ 404 "fast_executemany": use_fastexecutemany, 405 "use_setinputsizes": apply_setinputsizes_flag, 406 } 407 ) 408 return engine 409 410 return go 411 412 @testing.combinations( 413 ( 414 "setinputsizeshook", 415 True, 416 ), 417 ( 418 "nosetinputsizeshook", 419 False, 420 ), 421 argnames="include_setinputsizes", 422 id_="ia", 423 ) 424 @testing.combinations( 425 ( 426 "setinputsizesflag", 427 True, 428 ), 429 ( 430 "nosetinputsizesflag", 431 False, 432 ), 433 argnames="apply_setinputsizes_flag", 434 id_="ia", 435 ) 436 @testing.combinations( 437 ( 438 "fastexecutemany", 439 True, 440 ), 441 ( 442 "nofastexecutemany", 443 False, 444 ), 445 argnames="use_fastexecutemany", 446 id_="ia", 447 ) 448 def test_insert_floats( 449 self, 450 metadata, 451 fe_engine, 452 include_setinputsizes, 453 use_fastexecutemany, 454 apply_setinputsizes_flag, 455 ): 456 expect_failure = ( 457 apply_setinputsizes_flag 458 and not include_setinputsizes 459 and use_fastexecutemany 460 ) 461 462 engine = fe_engine(use_fastexecutemany, apply_setinputsizes_flag) 463 464 observations = Table( 465 "Observations", 466 metadata, 467 Column("id", Integer, nullable=False, primary_key=True), 468 Column("obs1", Numeric(19, 15), nullable=True), 469 Column("obs2", Numeric(19, 15), nullable=True), 470 schema="test_schema", 471 ) 472 with engine.begin() as conn: 473 metadata.create_all(conn) 474 475 records = [ 476 { 477 "id": 1, 478 "obs1": Decimal("60.1722066045792"), 479 "obs2": Decimal("24.929289808227466"), 480 }, 481 { 482 "id": 2, 483 "obs1": Decimal("60.16325715615476"), 484 "obs2": Decimal("24.93886459535008"), 485 }, 486 { 487 "id": 3, 488 "obs1": Decimal("60.16445165123469"), 489 "obs2": Decimal("24.949856300109516"), 490 }, 491 ] 492 493 if include_setinputsizes: 494 canary = mock.Mock() 495 496 @event.listens_for(engine, "do_setinputsizes") 497 def do_setinputsizes( 498 inputsizes, cursor, statement, parameters, context 499 ): 500 canary(list(inputsizes.values())) 501 502 for key in inputsizes: 503 if isinstance(key.type, Numeric): 504 inputsizes[key] = ( 505 engine.dialect.dbapi.SQL_DECIMAL, 506 19, 507 15, 508 ) 509 510 with engine.begin() as conn: 511 512 if expect_failure: 513 with expect_raises(DBAPIError): 514 conn.execute(observations.insert(), records) 515 else: 516 conn.execute(observations.insert(), records) 517 518 eq_( 519 conn.execute( 520 select(observations).order_by(observations.c.id) 521 ) 522 .mappings() 523 .all(), 524 records, 525 ) 526 527 if include_setinputsizes: 528 if apply_setinputsizes_flag: 529 eq_( 530 canary.mock_calls, 531 [ 532 # float for int? this seems wrong 533 mock.call([float, float, float]), 534 mock.call([]), 535 ], 536 ) 537 else: 538 eq_(canary.mock_calls, []) 539 540 541class VersionDetectionTest(fixtures.TestBase): 542 @testing.fixture 543 def mock_conn_scalar(self): 544 return lambda text: Mock( 545 exec_driver_sql=Mock( 546 return_value=Mock(scalar=Mock(return_value=text)) 547 ) 548 ) 549 550 def test_pymssql_version(self, mock_conn_scalar): 551 dialect = pymssql.MSDialect_pymssql() 552 553 for vers in [ 554 "Microsoft SQL Server Blah - 11.0.9216.62", 555 "Microsoft SQL Server (XYZ) - 11.0.9216.62 \n" 556 "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation", 557 "Microsoft SQL Azure (RTM) - 11.0.9216.62 \n" 558 "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation", 559 ]: 560 conn = mock_conn_scalar(vers) 561 eq_(dialect._get_server_version_info(conn), (11, 0, 9216, 62)) 562 563 def test_pyodbc_version_productversion(self, mock_conn_scalar): 564 dialect = pyodbc.MSDialect_pyodbc() 565 566 conn = mock_conn_scalar("11.0.9216.62") 567 eq_(dialect._get_server_version_info(conn), (11, 0, 9216, 62)) 568 569 def test_pyodbc_version_fallback(self): 570 dialect = pyodbc.MSDialect_pyodbc() 571 dialect.dbapi = Mock() 572 573 for vers, expected in [ 574 ("11.0.9216.62", (11, 0, 9216, 62)), 575 ("notsqlserver.11.foo.0.9216.BAR.62", (11, 0, 9216, 62)), 576 ("Not SQL Server Version 10.5", (5,)), 577 ]: 578 conn = Mock( 579 exec_driver_sql=Mock( 580 return_value=Mock( 581 scalar=Mock( 582 side_effect=exc.DBAPIError("stmt", "params", None) 583 ) 584 ) 585 ), 586 connection=Mock(getinfo=Mock(return_value=vers)), 587 ) 588 589 eq_(dialect._get_server_version_info(conn), expected) 590 591 592class RealIsolationLevelTest(fixtures.TestBase): 593 __only_on__ = "mssql" 594 __backend__ = True 595 596 def test_isolation_level(self, metadata): 597 Table("test", metadata, Column("id", Integer)).create( 598 testing.db, checkfirst=True 599 ) 600 601 with testing.db.connect() as c: 602 default = testing.db.dialect.get_isolation_level(c.connection) 603 604 values = [ 605 "READ UNCOMMITTED", 606 "READ COMMITTED", 607 "REPEATABLE READ", 608 "SERIALIZABLE", 609 "SNAPSHOT", 610 ] 611 for value in values: 612 with testing.db.connect() as c: 613 c.execution_options(isolation_level=value) 614 615 c.exec_driver_sql("SELECT TOP 10 * FROM test") 616 617 eq_( 618 testing.db.dialect.get_isolation_level(c.connection), value 619 ) 620 621 with testing.db.connect() as c: 622 eq_(testing.db.dialect.get_isolation_level(c.connection), default) 623 624 625class IsolationLevelDetectTest(fixtures.TestBase): 626 def _fixture(self, view): 627 class Error(Exception): 628 pass 629 630 dialect = pyodbc.MSDialect_pyodbc() 631 dialect.dbapi = Mock(Error=Error) 632 dialect.server_version_info = base.MS_2012_VERSION 633 634 result = [] 635 636 def fail_on_exec( 637 stmt, 638 ): 639 if view is not None and view in stmt: 640 result.append(("SERIALIZABLE",)) 641 else: 642 raise Error("that didn't work") 643 644 connection = Mock( 645 cursor=Mock( 646 return_value=Mock( 647 execute=fail_on_exec, fetchone=lambda: result[0] 648 ) 649 ) 650 ) 651 652 return dialect, connection 653 654 def test_dm_pdw_nodes(self): 655 dialect, connection = self._fixture("dm_pdw_nodes_exec_sessions") 656 657 eq_(dialect.get_isolation_level(connection), "SERIALIZABLE") 658 659 def test_exec_sessions(self): 660 dialect, connection = self._fixture("exec_sessions") 661 662 eq_(dialect.get_isolation_level(connection), "SERIALIZABLE") 663 664 def test_not_supported(self): 665 dialect, connection = self._fixture(None) 666 667 with expect_warnings("Could not fetch transaction isolation level"): 668 assert_raises_message( 669 NotImplementedError, 670 "Can't fetch isolation", 671 dialect.get_isolation_level, 672 connection, 673 ) 674 675 676class InvalidTransactionFalsePositiveTest(fixtures.TablesTest): 677 __only_on__ = "mssql" 678 __backend__ = True 679 680 @classmethod 681 def define_tables(cls, metadata): 682 Table( 683 "error_t", 684 metadata, 685 Column("error_code", String(50), primary_key=True), 686 ) 687 688 @classmethod 689 def insert_data(cls, connection): 690 connection.execute( 691 cls.tables.error_t.insert(), 692 [{"error_code": "01002"}], 693 ) 694 695 def test_invalid_transaction_detection(self, connection): 696 # issue #5359 697 t = self.tables.error_t 698 699 # force duplicate PK error 700 assert_raises( 701 IntegrityError, 702 connection.execute, 703 t.insert(), 704 {"error_code": "01002"}, 705 ) 706 707 # this should not fail with 708 # "Can't reconnect until invalid transaction is rolled back." 709 result = connection.execute(t.select()).fetchall() 710 eq_(len(result), 1) 711