1from sqlalchemy.testing import eq_, ne_, assert_raises, \
2    expect_warnings, assert_raises_message
3import time
4from sqlalchemy import (
5    select, MetaData, Integer, String, create_engine, pool, exc, util)
6from sqlalchemy.testing.schema import Table, Column
7import sqlalchemy as tsa
8from sqlalchemy import testing
9from sqlalchemy.testing import mock
10from sqlalchemy.testing import engines
11from sqlalchemy.testing import fixtures
12from sqlalchemy.testing.engines import testing_engine
13from sqlalchemy.testing.mock import Mock, call, patch
14from sqlalchemy import event
15from sqlalchemy.testing.util import gc_collect
16
17
18class MockError(Exception):
19    pass
20
21
22class MockDisconnect(MockError):
23    pass
24
25
26class MockExitIsh(BaseException):
27    pass
28
29
30def mock_connection():
31    def mock_cursor():
32        def execute(*args, **kwargs):
33            if conn.explode == 'execute':
34                raise MockDisconnect("Lost the DB connection on execute")
35            elif conn.explode == 'interrupt':
36                conn.explode = "explode_no_disconnect"
37                raise MockExitIsh("Keyboard / greenlet / etc interruption")
38            elif conn.explode == 'interrupt_dont_break':
39                conn.explode = None
40                raise MockExitIsh("Keyboard / greenlet / etc interruption")
41            elif conn.explode in ('execute_no_disconnect',
42                                  'explode_no_disconnect'):
43                raise MockError(
44                    "something broke on execute but we didn't lose the "
45                    "connection")
46            elif conn.explode in ('rollback', 'rollback_no_disconnect',
47                                  'explode_no_disconnect'):
48                raise MockError(
49                    "something broke on execute but we didn't lose the "
50                    "connection")
51            elif args and "SELECT" in args[0]:
52                cursor.description = [('foo', None, None, None, None, None)]
53            else:
54                return
55
56        def close():
57            cursor.fetchall = cursor.fetchone = \
58                Mock(side_effect=MockError("cursor closed"))
59        cursor = Mock(
60            execute=Mock(side_effect=execute),
61            close=Mock(side_effect=close))
62        return cursor
63
64    def cursor():
65        while True:
66            yield mock_cursor()
67
68    def rollback():
69        if conn.explode == 'rollback':
70            raise MockDisconnect("Lost the DB connection on rollback")
71        if conn.explode == 'rollback_no_disconnect':
72            raise MockError(
73                "something broke on rollback but we didn't lose the "
74                "connection")
75        else:
76            return
77
78    conn = Mock(
79        rollback=Mock(side_effect=rollback),
80        cursor=Mock(side_effect=cursor()))
81    return conn
82
83
84def MockDBAPI():
85    connections = []
86
87    def connect():
88        while True:
89            conn = mock_connection()
90            connections.append(conn)
91            yield conn
92
93    def shutdown(explode='execute'):
94        for c in connections:
95            c.explode = explode
96
97    def dispose():
98        for c in connections:
99            c.explode = None
100        connections[:] = []
101
102    return Mock(
103        connect=Mock(side_effect=connect()),
104        shutdown=Mock(side_effect=shutdown),
105        dispose=Mock(side_effect=dispose),
106        paramstyle='named',
107        connections=connections,
108        Error=MockError)
109
110
111class MockReconnectTest(fixtures.TestBase):
112    def setup(self):
113        self.dbapi = MockDBAPI()
114
115        self.db = testing_engine(
116            'postgresql://foo:bar@localhost/test',
117            options=dict(module=self.dbapi, _initialize=False))
118
119        self.mock_connect = call(
120            host='localhost', password='bar', user='foo', database='test')
121        # monkeypatch disconnect checker
122        self.db.dialect.is_disconnect = \
123            lambda e, conn, cursor: isinstance(e, MockDisconnect)
124
125    def teardown(self):
126        self.dbapi.dispose()
127
128    def test_reconnect(self):
129        """test that an 'is_disconnect' condition will invalidate the
130        connection, and additionally dispose the previous connection
131        pool and recreate."""
132
133        db_pool = self.db.pool
134
135        # make a connection
136
137        conn = self.db.connect()
138
139        # connection works
140
141        conn.execute(select([1]))
142
143        # create a second connection within the pool, which we'll ensure
144        # also goes away
145
146        conn2 = self.db.connect()
147        conn2.close()
148
149        # two connections opened total now
150
151        assert len(self.dbapi.connections) == 2
152
153        # set it to fail
154
155        self.dbapi.shutdown()
156        assert_raises(
157            tsa.exc.DBAPIError,
158            conn.execute, select([1])
159        )
160
161        # assert was invalidated
162
163        assert not conn.closed
164        assert conn.invalidated
165
166        # close shouldn't break
167
168        conn.close()
169
170        # ensure one connection closed...
171        eq_(
172            [c.close.mock_calls for c in self.dbapi.connections],
173            [[call()], []]
174        )
175
176        conn = self.db.connect()
177
178        eq_(
179            [c.close.mock_calls for c in self.dbapi.connections],
180            [[call()], [call()], []]
181        )
182
183        conn.execute(select([1]))
184        conn.close()
185
186        eq_(
187            [c.close.mock_calls for c in self.dbapi.connections],
188            [[call()], [call()], []]
189        )
190
191    def test_invalidate_trans(self):
192        conn = self.db.connect()
193        trans = conn.begin()
194        self.dbapi.shutdown()
195
196        assert_raises(
197            tsa.exc.DBAPIError,
198            conn.execute, select([1])
199        )
200
201        eq_(
202            [c.close.mock_calls for c in self.dbapi.connections],
203            [[call()]]
204        )
205        assert not conn.closed
206        assert conn.invalidated
207        assert trans.is_active
208        assert_raises_message(
209            tsa.exc.StatementError,
210            "Can't reconnect until invalid transaction is rolled back",
211            conn.execute, select([1])
212        )
213        assert trans.is_active
214
215        assert_raises_message(
216            tsa.exc.InvalidRequestError,
217            "Can't reconnect until invalid transaction is rolled back",
218            trans.commit)
219
220        assert trans.is_active
221        trans.rollback()
222        assert not trans.is_active
223        conn.execute(select([1]))
224        assert not conn.invalidated
225        eq_(
226            [c.close.mock_calls for c in self.dbapi.connections],
227            [[call()], []]
228        )
229
230    def test_invalidate_dont_call_finalizer(self):
231        conn = self.db.connect()
232        finalizer = mock.Mock()
233        conn.connection._connection_record.\
234            finalize_callback.append(finalizer)
235        conn.invalidate()
236        assert conn.invalidated
237        eq_(finalizer.call_count, 0)
238
239    def test_conn_reusable(self):
240        conn = self.db.connect()
241
242        conn.execute(select([1]))
243
244        eq_(
245            self.dbapi.connect.mock_calls,
246            [self.mock_connect]
247        )
248
249        self.dbapi.shutdown()
250
251        assert_raises(
252            tsa.exc.DBAPIError,
253            conn.execute, select([1])
254        )
255
256        assert not conn.closed
257        assert conn.invalidated
258
259        eq_(
260            [c.close.mock_calls for c in self.dbapi.connections],
261            [[call()]]
262        )
263
264        # test reconnects
265        conn.execute(select([1]))
266        assert not conn.invalidated
267
268        eq_(
269            [c.close.mock_calls for c in self.dbapi.connections],
270            [[call()], []]
271        )
272
273    def test_invalidated_close(self):
274        conn = self.db.connect()
275
276        self.dbapi.shutdown()
277
278        assert_raises(
279            tsa.exc.DBAPIError,
280            conn.execute, select([1])
281        )
282
283        conn.close()
284        assert conn.closed
285        assert conn.invalidated
286        assert_raises_message(
287            tsa.exc.StatementError,
288            "This Connection is closed",
289            conn.execute, select([1])
290        )
291
292    def test_noreconnect_execute_plus_closewresult(self):
293        conn = self.db.connect(close_with_result=True)
294
295        self.dbapi.shutdown("execute_no_disconnect")
296
297        # raises error
298        assert_raises_message(
299            tsa.exc.DBAPIError,
300            "something broke on execute but we didn't lose the connection",
301            conn.execute, select([1])
302        )
303
304        assert conn.closed
305        assert not conn.invalidated
306
307    def test_noreconnect_rollback_plus_closewresult(self):
308        conn = self.db.connect(close_with_result=True)
309
310        self.dbapi.shutdown("rollback_no_disconnect")
311
312        # raises error
313        with expect_warnings(
314            "An exception has occurred during handling .*"
315            "something broke on execute but we didn't lose the connection",
316            py2konly=True
317        ):
318            assert_raises_message(
319                tsa.exc.DBAPIError,
320                "something broke on rollback but we didn't "
321                "lose the connection",
322                conn.execute, select([1])
323            )
324
325        assert conn.closed
326        assert not conn.invalidated
327
328        assert_raises_message(
329            tsa.exc.StatementError,
330            "This Connection is closed",
331            conn.execute, select([1])
332        )
333
334    def test_reconnect_on_reentrant(self):
335        conn = self.db.connect()
336
337        conn.execute(select([1]))
338
339        assert len(self.dbapi.connections) == 1
340
341        self.dbapi.shutdown("rollback")
342
343        # raises error
344        with expect_warnings(
345            "An exception has occurred during handling .*"
346            "something broke on execute but we didn't lose the connection",
347            py2konly=True
348        ):
349            assert_raises_message(
350                tsa.exc.DBAPIError,
351                "Lost the DB connection on rollback",
352                conn.execute, select([1])
353            )
354
355        assert not conn.closed
356        assert conn.invalidated
357
358    def test_reconnect_on_reentrant_plus_closewresult(self):
359        conn = self.db.connect(close_with_result=True)
360
361        self.dbapi.shutdown("rollback")
362
363        # raises error
364        with expect_warnings(
365            "An exception has occurred during handling .*"
366            "something broke on execute but we didn't lose the connection",
367            py2konly=True
368        ):
369            assert_raises_message(
370                tsa.exc.DBAPIError,
371                "Lost the DB connection on rollback",
372                conn.execute, select([1])
373            )
374
375        assert conn.closed
376        assert conn.invalidated
377
378        assert_raises_message(
379            tsa.exc.StatementError,
380            "This Connection is closed",
381            conn.execute, select([1])
382        )
383
384    def test_check_disconnect_no_cursor(self):
385        conn = self.db.connect()
386        result = conn.execute(select([1]))
387        result.cursor.close()
388        conn.close()
389
390        assert_raises_message(
391            tsa.exc.DBAPIError,
392            "cursor closed",
393            list, result
394        )
395
396    def test_dialect_initialize_once(self):
397        from sqlalchemy.engine.url import URL
398        from sqlalchemy.engine.default import DefaultDialect
399        dbapi = self.dbapi
400
401        mock_dialect = Mock()
402
403        class MyURL(URL):
404            def _get_entrypoint(self):
405                return Dialect
406
407            def get_dialect(self):
408                return Dialect
409
410        class Dialect(DefaultDialect):
411            initialize = Mock()
412
413        engine = create_engine(MyURL("foo://"), module=dbapi)
414        c1 = engine.connect()
415        engine.dispose()
416        c2 = engine.connect()
417        eq_(Dialect.initialize.call_count, 1)
418
419    def test_invalidate_conn_w_contextmanager_interrupt(self):
420        # test [ticket:3803]
421        pool = self.db.pool
422
423        conn = self.db.connect()
424        self.dbapi.shutdown("interrupt")
425
426        def go():
427            with conn.begin():
428                conn.execute(select([1]))
429
430        assert_raises(
431            MockExitIsh,
432            go
433        )
434
435        assert conn.invalidated
436
437        eq_(pool._invalidate_time, 0)  # pool not invalidated
438
439        conn.execute(select([1]))
440        assert not conn.invalidated
441
442    def test_invalidate_conn_interrupt_nodisconnect_workaround(self):
443        # test [ticket:3803] workaround for no disconnect on keyboard interrupt
444
445        @event.listens_for(self.db, "handle_error")
446        def cancel_disconnect(ctx):
447            ctx.is_disconnect = False
448
449        pool = self.db.pool
450
451        conn = self.db.connect()
452        self.dbapi.shutdown("interrupt_dont_break")
453
454        def go():
455            with conn.begin():
456                conn.execute(select([1]))
457
458        assert_raises(
459            MockExitIsh,
460            go
461        )
462
463        assert not conn.invalidated
464
465        eq_(pool._invalidate_time, 0)  # pool not invalidated
466
467        conn.execute(select([1]))
468        assert not conn.invalidated
469
470    def test_invalidate_conn_w_contextmanager_disconnect(self):
471        # test [ticket:3803] change maintains old behavior
472
473        pool = self.db.pool
474
475        conn = self.db.connect()
476        self.dbapi.shutdown("execute")
477
478        def go():
479            with conn.begin():
480                conn.execute(select([1]))
481
482        assert_raises(
483            exc.DBAPIError,  # wraps a MockDisconnect
484            go
485        )
486
487        assert conn.invalidated
488
489        ne_(pool._invalidate_time, 0)  # pool is invalidated
490
491        conn.execute(select([1]))
492        assert not conn.invalidated
493
494
495class CursorErrTest(fixtures.TestBase):
496    # this isn't really a "reconnect" test, it's more of
497    # a generic "recovery".   maybe this test suite should have been
498    # named "test_error_recovery".
499    def _fixture(self, explode_on_exec, initialize):
500        class DBAPIError(Exception):
501            pass
502
503        def MockDBAPI():
504            def cursor():
505                while True:
506                    if explode_on_exec:
507                        yield Mock(
508                            description=[],
509                            close=Mock(side_effect=DBAPIError("explode")),
510                            execute=Mock(side_effect=DBAPIError("explode"))
511                        )
512                    else:
513                        yield Mock(
514                            description=[],
515                            close=Mock(side_effect=Exception("explode")),
516                        )
517
518            def connect():
519                while True:
520                    yield Mock(
521                        spec=['cursor', 'commit', 'rollback', 'close'],
522                        cursor=Mock(side_effect=cursor()),)
523
524            return Mock(
525                Error=DBAPIError, paramstyle='qmark',
526                connect=Mock(side_effect=connect()))
527        dbapi = MockDBAPI()
528
529        from sqlalchemy.engine import default
530        url = Mock(
531            get_dialect=lambda: default.DefaultDialect,
532            _get_entrypoint=lambda: default.DefaultDialect,
533            _instantiate_plugins=lambda kwargs: (),
534            translate_connect_args=lambda: {}, query={},)
535        eng = testing_engine(
536            url, options=dict(module=dbapi, _initialize=initialize))
537        eng.pool.logger = Mock()
538        return eng
539
540    def test_cursor_explode(self):
541        db = self._fixture(False, False)
542        conn = db.connect()
543        result = conn.execute("select foo")
544        result.close()
545        conn.close()
546        eq_(
547            db.pool.logger.error.mock_calls,
548            [call('Error closing cursor', exc_info=True)]
549        )
550
551    def test_cursor_shutdown_in_initialize(self):
552        db = self._fixture(True, True)
553        assert_raises_message(
554            exc.SAWarning,
555            "Exception attempting to detect",
556            db.connect
557        )
558        eq_(
559            db.pool.logger.error.mock_calls,
560            [call('Error closing cursor', exc_info=True)]
561        )
562
563
564def _assert_invalidated(fn, *args):
565    try:
566        fn(*args)
567        assert False
568    except tsa.exc.DBAPIError as e:
569        if not e.connection_invalidated:
570            raise
571
572
573class RealReconnectTest(fixtures.TestBase):
574    __backend__ = True
575    __requires__ = 'graceful_disconnects',
576
577    def setup(self):
578        self.engine = engines.reconnecting_engine()
579
580    def teardown(self):
581        self.engine.dispose()
582
583    def test_reconnect(self):
584        conn = self.engine.connect()
585
586        eq_(conn.execute(select([1])).scalar(), 1)
587        assert not conn.closed
588
589        self.engine.test_shutdown()
590
591        _assert_invalidated(conn.execute, select([1]))
592
593        assert not conn.closed
594        assert conn.invalidated
595
596        assert conn.invalidated
597        eq_(conn.execute(select([1])).scalar(), 1)
598        assert not conn.invalidated
599
600        # one more time
601        self.engine.test_shutdown()
602        _assert_invalidated(conn.execute, select([1]))
603
604        assert conn.invalidated
605        eq_(conn.execute(select([1])).scalar(), 1)
606        assert not conn.invalidated
607
608        conn.close()
609
610    def test_multiple_invalidate(self):
611        c1 = self.engine.connect()
612        c2 = self.engine.connect()
613
614        eq_(c1.execute(select([1])).scalar(), 1)
615
616        p1 = self.engine.pool
617        self.engine.test_shutdown()
618
619        _assert_invalidated(c1.execute, select([1]))
620
621        p2 = self.engine.pool
622
623        _assert_invalidated(c2.execute, select([1]))
624
625        # pool isn't replaced
626        assert self.engine.pool is p2
627
628    def test_branched_invalidate_branch_to_parent(self):
629        c1 = self.engine.connect()
630
631        with patch.object(self.engine.pool, "logger") as logger:
632            c1_branch = c1.connect()
633            eq_(c1_branch.execute(select([1])).scalar(), 1)
634
635            self.engine.test_shutdown()
636
637            _assert_invalidated(c1_branch.execute, select([1]))
638            assert c1.invalidated
639            assert c1_branch.invalidated
640
641            c1_branch._revalidate_connection()
642            assert not c1.invalidated
643            assert not c1_branch.invalidated
644
645        assert "Invalidate connection" in logger.mock_calls[0][1][0]
646
647    def test_branched_invalidate_parent_to_branch(self):
648        c1 = self.engine.connect()
649
650        c1_branch = c1.connect()
651        eq_(c1_branch.execute(select([1])).scalar(), 1)
652
653        self.engine.test_shutdown()
654
655        _assert_invalidated(c1.execute, select([1]))
656        assert c1.invalidated
657        assert c1_branch.invalidated
658
659        c1._revalidate_connection()
660        assert not c1.invalidated
661        assert not c1_branch.invalidated
662
663    def test_branch_invalidate_state(self):
664        c1 = self.engine.connect()
665
666        c1_branch = c1.connect()
667
668        eq_(c1_branch.execute(select([1])).scalar(), 1)
669
670        self.engine.test_shutdown()
671
672        _assert_invalidated(c1_branch.execute, select([1]))
673        assert not c1_branch.closed
674        assert not c1_branch._connection_is_valid
675
676    def test_ensure_is_disconnect_gets_connection(self):
677        def is_disconnect(e, conn, cursor):
678            # connection is still present
679            assert conn.connection is not None
680            # the error usually occurs on connection.cursor(),
681            # though MySQLdb we get a non-working cursor.
682            # assert cursor is None
683
684        self.engine.dialect.is_disconnect = is_disconnect
685        conn = self.engine.connect()
686        self.engine.test_shutdown()
687        with expect_warnings(
688            "An exception has occurred during handling .*",
689            py2konly=True
690        ):
691            assert_raises(
692                tsa.exc.DBAPIError,
693                conn.execute, select([1])
694            )
695
696    def test_rollback_on_invalid_plain(self):
697        conn = self.engine.connect()
698        trans = conn.begin()
699        conn.invalidate()
700        trans.rollback()
701
702    @testing.requires.two_phase_transactions
703    def test_rollback_on_invalid_twophase(self):
704        conn = self.engine.connect()
705        trans = conn.begin_twophase()
706        conn.invalidate()
707        trans.rollback()
708
709    @testing.requires.savepoints
710    def test_rollback_on_invalid_savepoint(self):
711        conn = self.engine.connect()
712        trans = conn.begin()
713        trans2 = conn.begin_nested()
714        conn.invalidate()
715        trans2.rollback()
716
717    def test_invalidate_twice(self):
718        conn = self.engine.connect()
719        conn.invalidate()
720        conn.invalidate()
721
722    @testing.skip_if(
723        [lambda: util.py3k, "oracle+cx_oracle"],
724        "Crashes on py3k+cx_oracle")
725    def test_explode_in_initializer(self):
726        engine = engines.testing_engine()
727
728        def broken_initialize(connection):
729            connection.execute("select fake_stuff from _fake_table")
730
731        engine.dialect.initialize = broken_initialize
732
733        # raises a DBAPIError, not an AttributeError
734        assert_raises(exc.DBAPIError, engine.connect)
735
736    @testing.skip_if(
737        [lambda: util.py3k, "oracle+cx_oracle"],
738        "Crashes on py3k+cx_oracle")
739    def test_explode_in_initializer_disconnect(self):
740        engine = engines.testing_engine()
741
742        def broken_initialize(connection):
743            connection.execute("select fake_stuff from _fake_table")
744
745        engine.dialect.initialize = broken_initialize
746
747        p1 = engine.pool
748
749        def is_disconnect(e, conn, cursor):
750            return True
751
752        engine.dialect.is_disconnect = is_disconnect
753
754        # invalidate() also doesn't screw up
755        assert_raises(exc.DBAPIError, engine.connect)
756
757    def test_null_pool(self):
758        engine = \
759            engines.reconnecting_engine(options=dict(poolclass=pool.NullPool))
760        conn = engine.connect()
761        eq_(conn.execute(select([1])).scalar(), 1)
762        assert not conn.closed
763        engine.test_shutdown()
764        _assert_invalidated(conn.execute, select([1]))
765        assert not conn.closed
766        assert conn.invalidated
767        eq_(conn.execute(select([1])).scalar(), 1)
768        assert not conn.invalidated
769
770    def test_close(self):
771        conn = self.engine.connect()
772        eq_(conn.execute(select([1])).scalar(), 1)
773        assert not conn.closed
774
775        self.engine.test_shutdown()
776
777        _assert_invalidated(conn.execute, select([1]))
778
779        conn.close()
780        conn = self.engine.connect()
781        eq_(conn.execute(select([1])).scalar(), 1)
782
783    def test_with_transaction(self):
784        conn = self.engine.connect()
785        trans = conn.begin()
786        eq_(conn.execute(select([1])).scalar(), 1)
787        assert not conn.closed
788        self.engine.test_shutdown()
789        _assert_invalidated(conn.execute, select([1]))
790        assert not conn.closed
791        assert conn.invalidated
792        assert trans.is_active
793        assert_raises_message(
794            tsa.exc.StatementError,
795            "Can't reconnect until invalid transaction is rolled back",
796            conn.execute, select([1]))
797        assert trans.is_active
798        assert_raises_message(
799            tsa.exc.InvalidRequestError,
800            "Can't reconnect until invalid transaction is rolled back",
801            trans.commit
802        )
803        assert trans.is_active
804        trans.rollback()
805        assert not trans.is_active
806        assert conn.invalidated
807        eq_(conn.execute(select([1])).scalar(), 1)
808        assert not conn.invalidated
809
810
811class RecycleTest(fixtures.TestBase):
812    __backend__ = True
813
814    def test_basic(self):
815        for threadlocal in False, True:
816            engine = engines.reconnecting_engine(
817                options={'pool_threadlocal': threadlocal})
818
819            conn = engine.contextual_connect()
820            eq_(conn.execute(select([1])).scalar(), 1)
821            conn.close()
822
823            # set the pool recycle down to 1.
824            # we aren't doing this inline with the
825            # engine create since cx_oracle takes way
826            # too long to create the 1st connection and don't
827            # want to build a huge delay into this test.
828
829            engine.pool._recycle = 1
830
831            # kill the DB connection
832            engine.test_shutdown()
833
834            # wait until past the recycle period
835            time.sleep(2)
836
837            # can connect, no exception
838            conn = engine.contextual_connect()
839            eq_(conn.execute(select([1])).scalar(), 1)
840            conn.close()
841
842
843class InvalidateDuringResultTest(fixtures.TestBase):
844    __backend__ = True
845
846    def setup(self):
847        self.engine = engines.reconnecting_engine()
848        self.meta = MetaData(self.engine)
849        table = Table(
850            'sometable', self.meta,
851            Column('id', Integer, primary_key=True),
852            Column('name', String(50)))
853        self.meta.create_all()
854        table.insert().execute(
855            [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)]
856        )
857
858    def teardown(self):
859        self.meta.drop_all()
860        self.engine.dispose()
861
862    @testing.crashes(
863        "oracle",
864        "cx_oracle 6 doesn't allow a close like this due to open cursors")
865    @testing.fails_if([
866        '+mysqlconnector', '+mysqldb', '+cymysql', '+pymysql', '+pg8000'],
867        "Buffers the result set and doesn't check for connection close")
868    def test_invalidate_on_results(self):
869        conn = self.engine.connect()
870        result = conn.execute('select * from sometable')
871        for x in range(20):
872            result.fetchone()
873        self.engine.test_shutdown()
874        _assert_invalidated(result.fetchone)
875        assert conn.invalidated
876