1#!/usr/bin/env python
2
3# test_connection.py - unit test for connection attributes
4#
5# Copyright (C) 2008-2019 James Henstridge  <james@jamesh.id.au>
6# Copyright (C) 2020-2021 The Psycopg Team
7#
8# psycopg2 is free software: you can redistribute it and/or modify it
9# under the terms of the GNU Lesser General Public License as published
10# by the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# In addition, as a special exception, the copyright holders give
14# permission to link this program with the OpenSSL library (or with
15# modified versions of OpenSSL that use the same license as OpenSSL),
16# and distribute linked combinations including the two.
17#
18# You must obey the GNU Lesser General Public License in all respects for
19# all of the code used other than OpenSSL.
20#
21# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
22# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
24# License for more details.
25
26import gc
27import os
28import re
29import sys
30import time
31import ctypes
32import shutil
33import tempfile
34import threading
35import subprocess as sp
36from collections import deque
37from operator import attrgetter
38from weakref import ref
39
40import psycopg2
41import psycopg2.extras
42from psycopg2 import extensions as ext
43
44from .testutils import (
45    unittest, skip_if_no_superuser, skip_before_postgres,
46    skip_after_postgres, skip_before_libpq, skip_after_libpq,
47    ConnectingTestCase, skip_if_tpc_disabled, skip_if_windows, slow,
48    skip_if_crdb, crdb_version)
49
50from .testconfig import dbhost, dsn, dbname
51
52
53class ConnectionTests(ConnectingTestCase):
54    def test_closed_attribute(self):
55        conn = self.conn
56        self.assertEqual(conn.closed, False)
57        conn.close()
58        self.assertEqual(conn.closed, True)
59
60    def test_close_idempotent(self):
61        conn = self.conn
62        conn.close()
63        conn.close()
64        self.assert_(conn.closed)
65
66    def test_cursor_closed_attribute(self):
67        conn = self.conn
68        curs = conn.cursor()
69        self.assertEqual(curs.closed, False)
70        curs.close()
71        self.assertEqual(curs.closed, True)
72
73        # Closing the connection closes the cursor:
74        curs = conn.cursor()
75        conn.close()
76        self.assertEqual(curs.closed, True)
77
78    @skip_if_crdb("backend pid")
79    @skip_before_postgres(8, 4)
80    @skip_if_no_superuser
81    @skip_if_windows
82    def test_cleanup_on_badconn_close(self):
83        # ticket #148
84        conn = self.conn
85        cur = conn.cursor()
86        self.assertRaises(psycopg2.OperationalError,
87            cur.execute, "select pg_terminate_backend(pg_backend_pid())")
88
89        self.assertEqual(conn.closed, 2)
90        conn.close()
91        self.assertEqual(conn.closed, 1)
92
93    @skip_if_crdb("isolation level")
94    def test_reset(self):
95        conn = self.conn
96        # switch session characteristics
97        conn.autocommit = True
98        conn.isolation_level = 'serializable'
99        conn.readonly = True
100        if self.conn.info.server_version >= 90100:
101            conn.deferrable = False
102
103        self.assert_(conn.autocommit)
104        self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
105        self.assert_(conn.readonly is True)
106        if self.conn.info.server_version >= 90100:
107            self.assert_(conn.deferrable is False)
108
109        conn.reset()
110        # now the session characteristics should be reverted
111        self.assert_(not conn.autocommit)
112        self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
113        self.assert_(conn.readonly is None)
114        if self.conn.info.server_version >= 90100:
115            self.assert_(conn.deferrable is None)
116
117    @skip_if_crdb("notice")
118    def test_notices(self):
119        conn = self.conn
120        cur = conn.cursor()
121        if self.conn.info.server_version >= 90300:
122            cur.execute("set client_min_messages=debug1")
123        cur.execute("create temp table chatty (id serial primary key);")
124        self.assertEqual("CREATE TABLE", cur.statusmessage)
125        self.assert_(conn.notices)
126
127    @skip_if_crdb("notice")
128    def test_notices_consistent_order(self):
129        conn = self.conn
130        cur = conn.cursor()
131        if self.conn.info.server_version >= 90300:
132            cur.execute("set client_min_messages=debug1")
133        cur.execute("""
134            create temp table table1 (id serial);
135            create temp table table2 (id serial);
136            """)
137        cur.execute("""
138            create temp table table3 (id serial);
139            create temp table table4 (id serial);
140            """)
141        self.assertEqual(4, len(conn.notices))
142        self.assert_('table1' in conn.notices[0])
143        self.assert_('table2' in conn.notices[1])
144        self.assert_('table3' in conn.notices[2])
145        self.assert_('table4' in conn.notices[3])
146
147    @slow
148    @skip_if_crdb("notice")
149    def test_notices_limited(self):
150        conn = self.conn
151        cur = conn.cursor()
152        if self.conn.info.server_version >= 90300:
153            cur.execute("set client_min_messages=debug1")
154        for i in range(0, 100, 10):
155            sql = " ".join(["create temp table table%d (id serial);" % j
156                            for j in range(i, i + 10)])
157            cur.execute(sql)
158
159        self.assertEqual(50, len(conn.notices))
160        self.assert_('table99' in conn.notices[-1], conn.notices[-1])
161
162    @slow
163    @skip_if_crdb("notice")
164    def test_notices_deque(self):
165        conn = self.conn
166        self.conn.notices = deque()
167        cur = conn.cursor()
168        if self.conn.info.server_version >= 90300:
169            cur.execute("set client_min_messages=debug1")
170
171        cur.execute("""
172            create temp table table1 (id serial);
173            create temp table table2 (id serial);
174            """)
175        cur.execute("""
176            create temp table table3 (id serial);
177            create temp table table4 (id serial);""")
178        self.assertEqual(len(conn.notices), 4)
179        self.assert_('table1' in conn.notices.popleft())
180        self.assert_('table2' in conn.notices.popleft())
181        self.assert_('table3' in conn.notices.popleft())
182        self.assert_('table4' in conn.notices.popleft())
183        self.assertEqual(len(conn.notices), 0)
184
185        # not limited, but no error
186        for i in range(0, 100, 10):
187            sql = " ".join(["create temp table table2_%d (id serial);" % j
188                            for j in range(i, i + 10)])
189            cur.execute(sql)
190
191        self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]),
192            100)
193
194    @skip_if_crdb("notice")
195    def test_notices_noappend(self):
196        conn = self.conn
197        self.conn.notices = None    # will make an error swallowes ok
198        cur = conn.cursor()
199        if self.conn.info.server_version >= 90300:
200            cur.execute("set client_min_messages=debug1")
201
202        cur.execute("create temp table table1 (id serial);")
203
204        self.assertEqual(self.conn.notices, None)
205
206    def test_server_version(self):
207        self.assert_(self.conn.server_version)
208
209    def test_protocol_version(self):
210        self.assert_(self.conn.protocol_version in (2, 3),
211            self.conn.protocol_version)
212
213    def test_tpc_unsupported(self):
214        cnn = self.conn
215        if cnn.info.server_version >= 80100:
216            return self.skipTest("tpc is supported")
217
218        self.assertRaises(psycopg2.NotSupportedError,
219            cnn.xid, 42, "foo", "bar")
220
221    @slow
222    @skip_before_postgres(8, 2)
223    def test_concurrent_execution(self):
224        def slave():
225            cnn = self.connect()
226            cur = cnn.cursor()
227            cur.execute("select pg_sleep(4)")
228            cur.close()
229            cnn.close()
230
231        t1 = threading.Thread(target=slave)
232        t2 = threading.Thread(target=slave)
233        t0 = time.time()
234        t1.start()
235        t2.start()
236        t1.join()
237        t2.join()
238        self.assert_(time.time() - t0 < 7,
239            "something broken in concurrency")
240
241    @skip_if_crdb("encoding")
242    def test_encoding_name(self):
243        self.conn.set_client_encoding("EUC_JP")
244        # conn.encoding is 'EUCJP' now.
245        cur = self.conn.cursor()
246        ext.register_type(ext.UNICODE, cur)
247        cur.execute("select 'foo'::text;")
248        self.assertEqual(cur.fetchone()[0], 'foo')
249
250    def test_connect_nonnormal_envvar(self):
251        # We must perform encoding normalization at connection time
252        self.conn.close()
253        oldenc = os.environ.get('PGCLIENTENCODING')
254        os.environ['PGCLIENTENCODING'] = 'utf-8'    # malformed spelling
255        try:
256            self.conn = self.connect()
257        finally:
258            if oldenc is not None:
259                os.environ['PGCLIENTENCODING'] = oldenc
260            else:
261                del os.environ['PGCLIENTENCODING']
262
263    def test_connect_no_string(self):
264        class MyString(str):
265            pass
266
267        conn = psycopg2.connect(MyString(dsn))
268        conn.close()
269
270    def test_weakref(self):
271        conn = psycopg2.connect(dsn)
272        w = ref(conn)
273        conn.close()
274        del conn
275        gc.collect()
276        self.assert_(w() is None)
277
278    @slow
279    def test_commit_concurrency(self):
280        # The problem is the one reported in ticket #103. Because of bad
281        # status check, we commit even when a commit is already on its way.
282        # We can detect this condition by the warnings.
283        conn = self.conn
284        notices = []
285        stop = []
286
287        def committer():
288            while not stop:
289                conn.commit()
290                while conn.notices:
291                    notices.append((2, conn.notices.pop()))
292
293        cur = conn.cursor()
294        t1 = threading.Thread(target=committer)
295        t1.start()
296        for i in range(1000):
297            cur.execute("select %s;", (i,))
298            conn.commit()
299            while conn.notices:
300                notices.append((1, conn.notices.pop()))
301
302        # Stop the committer thread
303        stop.append(True)
304
305        self.assert_(not notices, f"{len(notices)} notices raised")
306
307    def test_connect_cursor_factory(self):
308        conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
309        cur = conn.cursor()
310        cur.execute("select 1 as a")
311        self.assertEqual(cur.fetchone()['a'], 1)
312
313    def test_cursor_factory(self):
314        self.assertEqual(self.conn.cursor_factory, None)
315        cur = self.conn.cursor()
316        cur.execute("select 1 as a")
317        self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())
318
319        self.conn.cursor_factory = psycopg2.extras.DictCursor
320        self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor)
321        cur = self.conn.cursor()
322        cur.execute("select 1 as a")
323        self.assertEqual(cur.fetchone()['a'], 1)
324
325        self.conn.cursor_factory = None
326        self.assertEqual(self.conn.cursor_factory, None)
327        cur = self.conn.cursor()
328        cur.execute("select 1 as a")
329        self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())
330
331    def test_cursor_factory_none(self):
332        # issue #210
333        conn = self.connect()
334        cur = conn.cursor(cursor_factory=None)
335        self.assertEqual(type(cur), ext.cursor)
336
337        conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
338        cur = conn.cursor(cursor_factory=None)
339        self.assertEqual(type(cur), psycopg2.extras.DictCursor)
340
341    @skip_if_crdb("connect any db")
342    def test_failed_init_status(self):
343        class SubConnection(ext.connection):
344            def __init__(self, dsn):
345                try:
346                    super().__init__(dsn)
347                except Exception:
348                    pass
349
350        c = SubConnection("dbname=thereisnosuchdatabasemate password=foobar")
351        self.assert_(c.closed, "connection failed so it must be closed")
352        self.assert_('foobar' not in c.dsn, "password was not obscured")
353
354    def test_get_native_connection(self):
355        conn = self.connect()
356        capsule = conn.get_native_connection()
357        # we can't do anything else in Python
358        self.assertIsNotNone(capsule)
359
360    def test_pgconn_ptr(self):
361        conn = self.connect()
362        self.assert_(conn.pgconn_ptr is not None)
363
364        try:
365            f = self.libpq.PQserverVersion
366        except AttributeError:
367            pass
368        else:
369            f.argtypes = [ctypes.c_void_p]
370            f.restype = ctypes.c_int
371            ver = f(conn.pgconn_ptr)
372            if ver == 0 and sys.platform == 'darwin':
373                return self.skipTest(
374                    "I don't know why this func returns 0 on OSX")
375
376            self.assertEqual(ver, conn.server_version)
377
378        conn.close()
379        self.assert_(conn.pgconn_ptr is None)
380
381    @slow
382    def test_multiprocess_close(self):
383        dir = tempfile.mkdtemp()
384        try:
385            with open(os.path.join(dir, "mptest.py"), 'w') as f:
386                f.write(f"""import time
387import psycopg2
388
389def thread():
390    conn = psycopg2.connect({dsn!r})
391    curs = conn.cursor()
392    for i in range(10):
393        curs.execute("select 1")
394        time.sleep(0.1)
395
396def process():
397    time.sleep(0.2)
398""")
399
400            script = ("""\
401import sys
402sys.path.insert(0, {dir!r})
403import time
404import threading
405import multiprocessing
406import mptest
407
408t = threading.Thread(target=mptest.thread, name='mythread')
409t.start()
410time.sleep(0.2)
411multiprocessing.Process(target=mptest.process, name='myprocess').start()
412t.join()
413""".format(dir=dir))
414
415            out = sp.check_output(
416                [sys.executable, '-c', script], stderr=sp.STDOUT)
417            self.assertEqual(out, b'', out)
418        finally:
419            shutil.rmtree(dir, ignore_errors=True)
420
421
422class ParseDsnTestCase(ConnectingTestCase):
423    def test_parse_dsn(self):
424        self.assertEqual(
425            ext.parse_dsn('dbname=test user=tester password=secret'),
426            dict(user='tester', password='secret', dbname='test'),
427            "simple DSN parsed")
428
429        self.assertRaises(psycopg2.ProgrammingError, ext.parse_dsn,
430                          "dbname=test 2 user=tester password=secret")
431
432        self.assertEqual(
433            ext.parse_dsn("dbname='test 2' user=tester password=secret"),
434            dict(user='tester', password='secret', dbname='test 2'),
435            "DSN with quoting parsed")
436
437        # Can't really use assertRaisesRegexp() here since we need to
438        # make sure that secret is *not* exposed in the error message.
439        raised = False
440        try:
441            # unterminated quote after dbname:
442            ext.parse_dsn("dbname='test 2 user=tester password=secret")
443        except psycopg2.ProgrammingError as e:
444            raised = True
445            self.assertTrue(str(e).find('secret') < 0,
446                            "DSN was not exposed in error message")
447        self.assertTrue(raised, "ProgrammingError raised due to invalid DSN")
448
449    @skip_before_libpq(9, 2)
450    def test_parse_dsn_uri(self):
451        self.assertEqual(ext.parse_dsn('postgresql://tester:secret@/test'),
452                         dict(user='tester', password='secret', dbname='test'),
453                         "valid URI dsn parsed")
454
455        raised = False
456        try:
457            # extra '=' after port value
458            ext.parse_dsn(dsn='postgresql://tester:secret@/test?port=1111=x')
459        except psycopg2.ProgrammingError as e:
460            raised = True
461            self.assertTrue(str(e).find('secret') < 0,
462                            "URI was not exposed in error message")
463        self.assertTrue(raised, "ProgrammingError raised due to invalid URI")
464
465    def test_unicode_value(self):
466        snowman = "\u2603"
467        d = ext.parse_dsn('dbname=' + snowman)
468        self.assertEqual(d['dbname'], snowman)
469
470    def test_unicode_key(self):
471        snowman = "\u2603"
472        self.assertRaises(psycopg2.ProgrammingError, ext.parse_dsn,
473            snowman + '=' + snowman)
474
475    def test_bad_param(self):
476        self.assertRaises(TypeError, ext.parse_dsn, None)
477        self.assertRaises(TypeError, ext.parse_dsn, 42)
478
479    def test_str_subclass(self):
480        class MyString(str):
481            pass
482
483        res = ext.parse_dsn(MyString("dbname=test"))
484        self.assertEqual(res, {'dbname': 'test'})
485
486
487class MakeDsnTestCase(ConnectingTestCase):
488    def test_empty_arguments(self):
489        self.assertEqual(ext.make_dsn(), '')
490
491    def test_empty_string(self):
492        dsn = ext.make_dsn('')
493        self.assertEqual(dsn, '')
494
495    def test_params_validation(self):
496        self.assertRaises(psycopg2.ProgrammingError,
497            ext.make_dsn, 'dbnamo=a')
498        self.assertRaises(psycopg2.ProgrammingError,
499            ext.make_dsn, dbnamo='a')
500        self.assertRaises(psycopg2.ProgrammingError,
501            ext.make_dsn, 'dbname=a', nosuchparam='b')
502
503    def test_empty_param(self):
504        dsn = ext.make_dsn(dbname='sony', password='')
505        self.assertDsnEqual(dsn, "dbname=sony password=''")
506
507    def test_escape(self):
508        dsn = ext.make_dsn(dbname='hello world')
509        self.assertEqual(dsn, "dbname='hello world'")
510
511        dsn = ext.make_dsn(dbname=r'back\slash')
512        self.assertEqual(dsn, r"dbname=back\\slash")
513
514        dsn = ext.make_dsn(dbname="quo'te")
515        self.assertEqual(dsn, r"dbname=quo\'te")
516
517        dsn = ext.make_dsn(dbname="with\ttab")
518        self.assertEqual(dsn, "dbname='with\ttab'")
519
520        dsn = ext.make_dsn(dbname=r"\every thing'")
521        self.assertEqual(dsn, r"dbname='\\every thing\''")
522
523    def test_database_is_a_keyword(self):
524        self.assertEqual(ext.make_dsn(database='sigh'), "dbname=sigh")
525
526    def test_params_merging(self):
527        dsn = ext.make_dsn('dbname=foo host=bar', host='baz')
528        self.assertDsnEqual(dsn, 'dbname=foo host=baz')
529
530        dsn = ext.make_dsn('dbname=foo', user='postgres')
531        self.assertDsnEqual(dsn, 'dbname=foo user=postgres')
532
533    def test_no_dsn_munging(self):
534        dsnin = 'dbname=a host=b user=c password=d'
535        dsn = ext.make_dsn(dsnin)
536        self.assertEqual(dsn, dsnin)
537
538    def test_null_args(self):
539        dsn = ext.make_dsn("dbname=foo", user="bar", password=None)
540        self.assertDsnEqual(dsn, "dbname=foo user=bar")
541
542    @skip_before_libpq(9, 2)
543    def test_url_is_cool(self):
544        url = 'postgresql://tester:secret@/test?application_name=wat'
545        dsn = ext.make_dsn(url)
546        self.assertEqual(dsn, url)
547
548        dsn = ext.make_dsn(url, application_name='woot')
549        self.assertDsnEqual(dsn,
550            'dbname=test user=tester password=secret application_name=woot')
551
552        self.assertRaises(psycopg2.ProgrammingError,
553            ext.make_dsn, 'postgresql://tester:secret@/test?nosuch=param')
554        self.assertRaises(psycopg2.ProgrammingError,
555            ext.make_dsn, url, nosuch="param")
556
557    @skip_before_libpq(9, 3)
558    def test_get_dsn_parameters(self):
559        conn = self.connect()
560        d = conn.get_dsn_parameters()
561        self.assertEqual(d['dbname'], dbname)  # the only param we can check reliably
562        self.assert_('password' not in d, d)
563
564
565class IsolationLevelsTestCase(ConnectingTestCase):
566
567    def setUp(self):
568        ConnectingTestCase.setUp(self)
569
570        conn = self.connect()
571        cur = conn.cursor()
572        if crdb_version(conn) is not None:
573            cur.execute("create table if not exists isolevel (id integer)")
574            cur.execute("truncate isolevel")
575            conn.commit()
576            return
577
578        try:
579            cur.execute("drop table isolevel;")
580        except psycopg2.ProgrammingError:
581            conn.rollback()
582        try:
583            cur.execute("create table isolevel (id integer);")
584            conn.commit()
585        finally:
586            conn.close()
587
588    def test_isolation_level(self):
589        conn = self.connect()
590        self.assertEqual(
591            conn.isolation_level,
592            ext.ISOLATION_LEVEL_DEFAULT)
593
594    def test_encoding(self):
595        conn = self.connect()
596        self.assert_(conn.encoding in ext.encodings)
597
598    @skip_if_crdb("isolation level")
599    def test_set_isolation_level(self):
600        conn = self.connect()
601        curs = conn.cursor()
602
603        levels = [
604            ('read uncommitted',
605                ext.ISOLATION_LEVEL_READ_UNCOMMITTED),
606            ('read committed', ext.ISOLATION_LEVEL_READ_COMMITTED),
607            ('repeatable read', ext.ISOLATION_LEVEL_REPEATABLE_READ),
608            ('serializable', ext.ISOLATION_LEVEL_SERIALIZABLE),
609        ]
610        for name, level in levels:
611            conn.set_isolation_level(level)
612
613            # the only values available on prehistoric PG versions
614            if conn.info.server_version < 80000:
615                if level in (
616                        ext.ISOLATION_LEVEL_READ_UNCOMMITTED,
617                        ext.ISOLATION_LEVEL_REPEATABLE_READ):
618                    name, level = levels[levels.index((name, level)) + 1]
619
620            self.assertEqual(conn.isolation_level, level)
621
622            curs.execute('show transaction_isolation;')
623            got_name = curs.fetchone()[0]
624
625            self.assertEqual(name, got_name)
626            conn.commit()
627
628        self.assertRaises(ValueError, conn.set_isolation_level, -1)
629        self.assertRaises(ValueError, conn.set_isolation_level, 5)
630
631    def test_set_isolation_level_autocommit(self):
632        conn = self.connect()
633        curs = conn.cursor()
634
635        conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
636        self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
637        self.assert_(conn.autocommit)
638
639        conn.isolation_level = 'serializable'
640        self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
641        self.assert_(conn.autocommit)
642
643        curs.execute('show transaction_isolation;')
644        self.assertEqual(curs.fetchone()[0], 'serializable')
645
646    @skip_if_crdb("isolation level")
647    def test_set_isolation_level_default(self):
648        conn = self.connect()
649        curs = conn.cursor()
650
651        conn.autocommit = True
652        curs.execute("set default_transaction_isolation to 'read committed'")
653
654        conn.autocommit = False
655        conn.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE)
656        self.assertEqual(conn.isolation_level,
657            ext.ISOLATION_LEVEL_SERIALIZABLE)
658        curs.execute("show transaction_isolation")
659        self.assertEqual(curs.fetchone()[0], "serializable")
660
661        conn.rollback()
662        conn.set_isolation_level(ext.ISOLATION_LEVEL_DEFAULT)
663        curs.execute("show transaction_isolation")
664        self.assertEqual(curs.fetchone()[0], "read committed")
665
666    def test_set_isolation_level_abort(self):
667        conn = self.connect()
668        cur = conn.cursor()
669
670        self.assertEqual(ext.TRANSACTION_STATUS_IDLE,
671            conn.info.transaction_status)
672        cur.execute("insert into isolevel values (10);")
673        self.assertEqual(ext.TRANSACTION_STATUS_INTRANS,
674            conn.info.transaction_status)
675
676        conn.set_isolation_level(
677            psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE)
678        self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
679            conn.info.transaction_status)
680        cur.execute("select count(*) from isolevel;")
681        self.assertEqual(0, cur.fetchone()[0])
682
683        cur.execute("insert into isolevel values (10);")
684        self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_INTRANS,
685            conn.info.transaction_status)
686        conn.set_isolation_level(
687            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
688        self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
689            conn.info.transaction_status)
690        cur.execute("select count(*) from isolevel;")
691        self.assertEqual(0, cur.fetchone()[0])
692
693        cur.execute("insert into isolevel values (10);")
694        self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
695            conn.info.transaction_status)
696        conn.set_isolation_level(
697            psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
698        self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
699            conn.info.transaction_status)
700        cur.execute("select count(*) from isolevel;")
701        self.assertEqual(1, cur.fetchone()[0])
702        self.assertEqual(conn.isolation_level,
703            psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
704
705    def test_isolation_level_autocommit(self):
706        cnn1 = self.connect()
707        cnn2 = self.connect()
708        cnn2.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
709
710        cur1 = cnn1.cursor()
711        cur1.execute("select count(*) from isolevel;")
712        self.assertEqual(0, cur1.fetchone()[0])
713        cnn1.commit()
714
715        cur2 = cnn2.cursor()
716        cur2.execute("insert into isolevel values (10);")
717
718        cur1.execute("select count(*) from isolevel;")
719        self.assertEqual(1, cur1.fetchone()[0])
720
721    @skip_if_crdb("isolation level")
722    def test_isolation_level_read_committed(self):
723        cnn1 = self.connect()
724        cnn2 = self.connect()
725        cnn2.set_isolation_level(ext.ISOLATION_LEVEL_READ_COMMITTED)
726
727        cur1 = cnn1.cursor()
728        cur1.execute("select count(*) from isolevel;")
729        self.assertEqual(0, cur1.fetchone()[0])
730        cnn1.commit()
731
732        cur2 = cnn2.cursor()
733        cur2.execute("insert into isolevel values (10);")
734        cur1.execute("insert into isolevel values (20);")
735
736        cur2.execute("select count(*) from isolevel;")
737        self.assertEqual(1, cur2.fetchone()[0])
738        cnn1.commit()
739        cur2.execute("select count(*) from isolevel;")
740        self.assertEqual(2, cur2.fetchone()[0])
741
742        cur1.execute("select count(*) from isolevel;")
743        self.assertEqual(1, cur1.fetchone()[0])
744        cnn2.commit()
745        cur1.execute("select count(*) from isolevel;")
746        self.assertEqual(2, cur1.fetchone()[0])
747
748    @skip_if_crdb("isolation level")
749    def test_isolation_level_serializable(self):
750        cnn1 = self.connect()
751        cnn2 = self.connect()
752        cnn2.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE)
753
754        cur1 = cnn1.cursor()
755        cur1.execute("select count(*) from isolevel;")
756        self.assertEqual(0, cur1.fetchone()[0])
757        cnn1.commit()
758
759        cur2 = cnn2.cursor()
760        cur2.execute("insert into isolevel values (10);")
761        cur1.execute("insert into isolevel values (20);")
762
763        cur2.execute("select count(*) from isolevel;")
764        self.assertEqual(1, cur2.fetchone()[0])
765        cnn1.commit()
766        cur2.execute("select count(*) from isolevel;")
767        self.assertEqual(1, cur2.fetchone()[0])
768
769        cur1.execute("select count(*) from isolevel;")
770        self.assertEqual(1, cur1.fetchone()[0])
771        cnn2.commit()
772        cur1.execute("select count(*) from isolevel;")
773        self.assertEqual(2, cur1.fetchone()[0])
774
775        cur2.execute("select count(*) from isolevel;")
776        self.assertEqual(2, cur2.fetchone()[0])
777
778    def test_isolation_level_closed(self):
779        cnn = self.connect()
780        cnn.close()
781        self.assertRaises(psycopg2.InterfaceError,
782            cnn.set_isolation_level, 0)
783        self.assertRaises(psycopg2.InterfaceError,
784            cnn.set_isolation_level, 1)
785
786    @skip_if_crdb("isolation level")
787    def test_setattr_isolation_level_int(self):
788        cur = self.conn.cursor()
789        self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE
790        self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
791
792        cur.execute("SHOW transaction_isolation;")
793        self.assertEqual(cur.fetchone()[0], 'serializable')
794        self.conn.rollback()
795
796        self.conn.isolation_level = ext.ISOLATION_LEVEL_REPEATABLE_READ
797        cur.execute("SHOW transaction_isolation;")
798        if self.conn.info.server_version > 80000:
799            self.assertEqual(self.conn.isolation_level,
800                ext.ISOLATION_LEVEL_REPEATABLE_READ)
801            self.assertEqual(cur.fetchone()[0], 'repeatable read')
802        else:
803            self.assertEqual(self.conn.isolation_level,
804                ext.ISOLATION_LEVEL_SERIALIZABLE)
805            self.assertEqual(cur.fetchone()[0], 'serializable')
806        self.conn.rollback()
807
808        self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_COMMITTED
809        self.assertEqual(self.conn.isolation_level,
810            ext.ISOLATION_LEVEL_READ_COMMITTED)
811        cur.execute("SHOW transaction_isolation;")
812        self.assertEqual(cur.fetchone()[0], 'read committed')
813        self.conn.rollback()
814
815        self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_UNCOMMITTED
816        cur.execute("SHOW transaction_isolation;")
817        if self.conn.info.server_version > 80000:
818            self.assertEqual(self.conn.isolation_level,
819                ext.ISOLATION_LEVEL_READ_UNCOMMITTED)
820            self.assertEqual(cur.fetchone()[0], 'read uncommitted')
821        else:
822            self.assertEqual(self.conn.isolation_level,
823                ext.ISOLATION_LEVEL_READ_COMMITTED)
824            self.assertEqual(cur.fetchone()[0], 'read committed')
825        self.conn.rollback()
826
827        self.assertEqual(ext.ISOLATION_LEVEL_DEFAULT, None)
828        self.conn.isolation_level = ext.ISOLATION_LEVEL_DEFAULT
829        self.assertEqual(self.conn.isolation_level, None)
830        cur.execute("SHOW transaction_isolation;")
831        isol = cur.fetchone()[0]
832        cur.execute("SHOW default_transaction_isolation;")
833        self.assertEqual(cur.fetchone()[0], isol)
834
835    @skip_if_crdb("isolation level")
836    def test_setattr_isolation_level_str(self):
837        cur = self.conn.cursor()
838        self.conn.isolation_level = "serializable"
839        self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
840
841        cur.execute("SHOW transaction_isolation;")
842        self.assertEqual(cur.fetchone()[0], 'serializable')
843        self.conn.rollback()
844
845        self.conn.isolation_level = "repeatable read"
846        cur.execute("SHOW transaction_isolation;")
847        if self.conn.info.server_version > 80000:
848            self.assertEqual(self.conn.isolation_level,
849                ext.ISOLATION_LEVEL_REPEATABLE_READ)
850            self.assertEqual(cur.fetchone()[0], 'repeatable read')
851        else:
852            self.assertEqual(self.conn.isolation_level,
853                ext.ISOLATION_LEVEL_SERIALIZABLE)
854            self.assertEqual(cur.fetchone()[0], 'serializable')
855        self.conn.rollback()
856
857        self.conn.isolation_level = "read committed"
858        self.assertEqual(self.conn.isolation_level,
859            ext.ISOLATION_LEVEL_READ_COMMITTED)
860        cur.execute("SHOW transaction_isolation;")
861        self.assertEqual(cur.fetchone()[0], 'read committed')
862        self.conn.rollback()
863
864        self.conn.isolation_level = "read uncommitted"
865        cur.execute("SHOW transaction_isolation;")
866        if self.conn.info.server_version > 80000:
867            self.assertEqual(self.conn.isolation_level,
868                ext.ISOLATION_LEVEL_READ_UNCOMMITTED)
869            self.assertEqual(cur.fetchone()[0], 'read uncommitted')
870        else:
871            self.assertEqual(self.conn.isolation_level,
872                ext.ISOLATION_LEVEL_READ_COMMITTED)
873            self.assertEqual(cur.fetchone()[0], 'read committed')
874        self.conn.rollback()
875
876        self.conn.isolation_level = "default"
877        self.assertEqual(self.conn.isolation_level, None)
878        cur.execute("SHOW transaction_isolation;")
879        isol = cur.fetchone()[0]
880        cur.execute("SHOW default_transaction_isolation;")
881        self.assertEqual(cur.fetchone()[0], isol)
882
883    def test_setattr_isolation_level_invalid(self):
884        self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 0)
885        self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', -1)
886        self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 5)
887        self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 'bah')
888
889    def test_attribs_segfault(self):
890        # bug #790
891        for i in range(10000):
892            self.conn.autocommit
893            self.conn.readonly
894            self.conn.deferrable
895            self.conn.isolation_level
896
897
898@skip_if_tpc_disabled
899class ConnectionTwoPhaseTests(ConnectingTestCase):
900    def setUp(self):
901        ConnectingTestCase.setUp(self)
902
903        self.make_test_table()
904        self.clear_test_xacts()
905
906    def tearDown(self):
907        self.clear_test_xacts()
908        ConnectingTestCase.tearDown(self)
909
910    def clear_test_xacts(self):
911        """Rollback all the prepared transaction in the testing db."""
912        cnn = self.connect()
913        cnn.set_isolation_level(0)
914        cur = cnn.cursor()
915        try:
916            cur.execute(
917                "select gid from pg_prepared_xacts where database = %s",
918                (dbname,))
919        except psycopg2.ProgrammingError:
920            cnn.rollback()
921            cnn.close()
922            return
923
924        gids = [r[0] for r in cur]
925        for gid in gids:
926            cur.execute("rollback prepared %s;", (gid,))
927        cnn.close()
928
929    def make_test_table(self):
930        cnn = self.connect()
931        cur = cnn.cursor()
932        if crdb_version(cnn) is not None:
933            cur.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
934            cur.execute("TRUNCATE test_tpc")
935            cnn.commit()
936            cnn.close()
937            return
938
939        try:
940            cur.execute("DROP TABLE test_tpc;")
941        except psycopg2.ProgrammingError:
942            cnn.rollback()
943        try:
944            cur.execute("CREATE TABLE test_tpc (data text);")
945            cnn.commit()
946        finally:
947            cnn.close()
948
949    def count_xacts(self):
950        """Return the number of prepared xacts currently in the test db."""
951        cnn = self.connect()
952        cur = cnn.cursor()
953        cur.execute("""
954            select count(*) from pg_prepared_xacts
955            where database = %s;""",
956            (dbname,))
957        rv = cur.fetchone()[0]
958        cnn.close()
959        return rv
960
961    def count_test_records(self):
962        """Return the number of records in the test table."""
963        cnn = self.connect()
964        cur = cnn.cursor()
965        cur.execute("select count(*) from test_tpc;")
966        rv = cur.fetchone()[0]
967        cnn.close()
968        return rv
969
970    def test_tpc_commit(self):
971        cnn = self.connect()
972        xid = cnn.xid(1, "gtrid", "bqual")
973        self.assertEqual(cnn.status, ext.STATUS_READY)
974
975        cnn.tpc_begin(xid)
976        self.assertEqual(cnn.status, ext.STATUS_BEGIN)
977
978        cur = cnn.cursor()
979        cur.execute("insert into test_tpc values ('test_tpc_commit');")
980        self.assertEqual(0, self.count_xacts())
981        self.assertEqual(0, self.count_test_records())
982
983        cnn.tpc_prepare()
984        self.assertEqual(cnn.status, ext.STATUS_PREPARED)
985        self.assertEqual(1, self.count_xacts())
986        self.assertEqual(0, self.count_test_records())
987
988        cnn.tpc_commit()
989        self.assertEqual(cnn.status, ext.STATUS_READY)
990        self.assertEqual(0, self.count_xacts())
991        self.assertEqual(1, self.count_test_records())
992
993    def test_tpc_commit_one_phase(self):
994        cnn = self.connect()
995        xid = cnn.xid(1, "gtrid", "bqual")
996        self.assertEqual(cnn.status, ext.STATUS_READY)
997
998        cnn.tpc_begin(xid)
999        self.assertEqual(cnn.status, ext.STATUS_BEGIN)
1000
1001        cur = cnn.cursor()
1002        cur.execute("insert into test_tpc values ('test_tpc_commit_1p');")
1003        self.assertEqual(0, self.count_xacts())
1004        self.assertEqual(0, self.count_test_records())
1005
1006        cnn.tpc_commit()
1007        self.assertEqual(cnn.status, ext.STATUS_READY)
1008        self.assertEqual(0, self.count_xacts())
1009        self.assertEqual(1, self.count_test_records())
1010
1011    def test_tpc_commit_recovered(self):
1012        cnn = self.connect()
1013        xid = cnn.xid(1, "gtrid", "bqual")
1014        self.assertEqual(cnn.status, ext.STATUS_READY)
1015
1016        cnn.tpc_begin(xid)
1017        self.assertEqual(cnn.status, ext.STATUS_BEGIN)
1018
1019        cur = cnn.cursor()
1020        cur.execute("insert into test_tpc values ('test_tpc_commit_rec');")
1021        self.assertEqual(0, self.count_xacts())
1022        self.assertEqual(0, self.count_test_records())
1023
1024        cnn.tpc_prepare()
1025        cnn.close()
1026        self.assertEqual(1, self.count_xacts())
1027        self.assertEqual(0, self.count_test_records())
1028
1029        cnn = self.connect()
1030        xid = cnn.xid(1, "gtrid", "bqual")
1031        cnn.tpc_commit(xid)
1032
1033        self.assertEqual(cnn.status, ext.STATUS_READY)
1034        self.assertEqual(0, self.count_xacts())
1035        self.assertEqual(1, self.count_test_records())
1036
1037    def test_tpc_rollback(self):
1038        cnn = self.connect()
1039        xid = cnn.xid(1, "gtrid", "bqual")
1040        self.assertEqual(cnn.status, ext.STATUS_READY)
1041
1042        cnn.tpc_begin(xid)
1043        self.assertEqual(cnn.status, ext.STATUS_BEGIN)
1044
1045        cur = cnn.cursor()
1046        cur.execute("insert into test_tpc values ('test_tpc_rollback');")
1047        self.assertEqual(0, self.count_xacts())
1048        self.assertEqual(0, self.count_test_records())
1049
1050        cnn.tpc_prepare()
1051        self.assertEqual(cnn.status, ext.STATUS_PREPARED)
1052        self.assertEqual(1, self.count_xacts())
1053        self.assertEqual(0, self.count_test_records())
1054
1055        cnn.tpc_rollback()
1056        self.assertEqual(cnn.status, ext.STATUS_READY)
1057        self.assertEqual(0, self.count_xacts())
1058        self.assertEqual(0, self.count_test_records())
1059
1060    def test_tpc_rollback_one_phase(self):
1061        cnn = self.connect()
1062        xid = cnn.xid(1, "gtrid", "bqual")
1063        self.assertEqual(cnn.status, ext.STATUS_READY)
1064
1065        cnn.tpc_begin(xid)
1066        self.assertEqual(cnn.status, ext.STATUS_BEGIN)
1067
1068        cur = cnn.cursor()
1069        cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');")
1070        self.assertEqual(0, self.count_xacts())
1071        self.assertEqual(0, self.count_test_records())
1072
1073        cnn.tpc_rollback()
1074        self.assertEqual(cnn.status, ext.STATUS_READY)
1075        self.assertEqual(0, self.count_xacts())
1076        self.assertEqual(0, self.count_test_records())
1077
1078    def test_tpc_rollback_recovered(self):
1079        cnn = self.connect()
1080        xid = cnn.xid(1, "gtrid", "bqual")
1081        self.assertEqual(cnn.status, ext.STATUS_READY)
1082
1083        cnn.tpc_begin(xid)
1084        self.assertEqual(cnn.status, ext.STATUS_BEGIN)
1085
1086        cur = cnn.cursor()
1087        cur.execute("insert into test_tpc values ('test_tpc_commit_rec');")
1088        self.assertEqual(0, self.count_xacts())
1089        self.assertEqual(0, self.count_test_records())
1090
1091        cnn.tpc_prepare()
1092        cnn.close()
1093        self.assertEqual(1, self.count_xacts())
1094        self.assertEqual(0, self.count_test_records())
1095
1096        cnn = self.connect()
1097        xid = cnn.xid(1, "gtrid", "bqual")
1098        cnn.tpc_rollback(xid)
1099
1100        self.assertEqual(cnn.status, ext.STATUS_READY)
1101        self.assertEqual(0, self.count_xacts())
1102        self.assertEqual(0, self.count_test_records())
1103
1104    def test_status_after_recover(self):
1105        cnn = self.connect()
1106        self.assertEqual(ext.STATUS_READY, cnn.status)
1107        cnn.tpc_recover()
1108        self.assertEqual(ext.STATUS_READY, cnn.status)
1109
1110        cur = cnn.cursor()
1111        cur.execute("select 1")
1112        self.assertEqual(ext.STATUS_BEGIN, cnn.status)
1113        cnn.tpc_recover()
1114        self.assertEqual(ext.STATUS_BEGIN, cnn.status)
1115
1116    def test_recovered_xids(self):
1117        # insert a few test xns
1118        cnn = self.connect()
1119        cnn.set_isolation_level(0)
1120        cur = cnn.cursor()
1121        cur.execute("begin; prepare transaction '1-foo';")
1122        cur.execute("begin; prepare transaction '2-bar';")
1123
1124        # read the values to return
1125        cur.execute("""
1126            select gid, prepared, owner, database
1127            from pg_prepared_xacts
1128            where database = %s;""",
1129            (dbname,))
1130        okvals = cur.fetchall()
1131        okvals.sort()
1132
1133        cnn = self.connect()
1134        xids = cnn.tpc_recover()
1135        xids = [xid for xid in xids if xid.database == dbname]
1136        xids.sort(key=attrgetter('gtrid'))
1137
1138        # check the values returned
1139        self.assertEqual(len(okvals), len(xids))
1140        for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
1141            self.assertEqual(xid.gtrid, gid)
1142            self.assertEqual(xid.prepared, prepared)
1143            self.assertEqual(xid.owner, owner)
1144            self.assertEqual(xid.database, database)
1145
1146    def test_xid_encoding(self):
1147        cnn = self.connect()
1148        xid = cnn.xid(42, "gtrid", "bqual")
1149        cnn.tpc_begin(xid)
1150        cnn.tpc_prepare()
1151
1152        cnn = self.connect()
1153        cur = cnn.cursor()
1154        cur.execute("select gid from pg_prepared_xacts where database = %s;",
1155            (dbname,))
1156        self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0])
1157
1158    @slow
1159    def test_xid_roundtrip(self):
1160        for fid, gtrid, bqual in [
1161            (0, "", ""),
1162            (42, "gtrid", "bqual"),
1163            (0x7fffffff, "x" * 64, "y" * 64),
1164        ]:
1165            cnn = self.connect()
1166            xid = cnn.xid(fid, gtrid, bqual)
1167            cnn.tpc_begin(xid)
1168            cnn.tpc_prepare()
1169            cnn.close()
1170
1171            cnn = self.connect()
1172            xids = [x for x in cnn.tpc_recover() if x.database == dbname]
1173            self.assertEqual(1, len(xids))
1174            xid = xids[0]
1175            self.assertEqual(xid.format_id, fid)
1176            self.assertEqual(xid.gtrid, gtrid)
1177            self.assertEqual(xid.bqual, bqual)
1178
1179            cnn.tpc_rollback(xid)
1180
1181    @slow
1182    def test_unparsed_roundtrip(self):
1183        for tid in [
1184            '',
1185            'hello, world!',
1186            'x' * 199,  # PostgreSQL's limit in transaction id length
1187        ]:
1188            cnn = self.connect()
1189            cnn.tpc_begin(tid)
1190            cnn.tpc_prepare()
1191            cnn.close()
1192
1193            cnn = self.connect()
1194            xids = [x for x in cnn.tpc_recover() if x.database == dbname]
1195            self.assertEqual(1, len(xids))
1196            xid = xids[0]
1197            self.assertEqual(xid.format_id, None)
1198            self.assertEqual(xid.gtrid, tid)
1199            self.assertEqual(xid.bqual, None)
1200
1201            cnn.tpc_rollback(xid)
1202
1203    def test_xid_construction(self):
1204        x1 = ext.Xid(74, 'foo', 'bar')
1205        self.assertEqual(74, x1.format_id)
1206        self.assertEqual('foo', x1.gtrid)
1207        self.assertEqual('bar', x1.bqual)
1208
1209    def test_xid_from_string(self):
1210        x2 = ext.Xid.from_string('42_Z3RyaWQ=_YnF1YWw=')
1211        self.assertEqual(42, x2.format_id)
1212        self.assertEqual('gtrid', x2.gtrid)
1213        self.assertEqual('bqual', x2.bqual)
1214
1215        x3 = ext.Xid.from_string('99_xxx_yyy')
1216        self.assertEqual(None, x3.format_id)
1217        self.assertEqual('99_xxx_yyy', x3.gtrid)
1218        self.assertEqual(None, x3.bqual)
1219
1220    def test_xid_to_string(self):
1221        x1 = ext.Xid.from_string('42_Z3RyaWQ=_YnF1YWw=')
1222        self.assertEqual(str(x1), '42_Z3RyaWQ=_YnF1YWw=')
1223
1224        x2 = ext.Xid.from_string('99_xxx_yyy')
1225        self.assertEqual(str(x2), '99_xxx_yyy')
1226
1227    def test_xid_unicode(self):
1228        cnn = self.connect()
1229        x1 = cnn.xid(10, 'uni', 'code')
1230        cnn.tpc_begin(x1)
1231        cnn.tpc_prepare()
1232        cnn.reset()
1233        xid = [x for x in cnn.tpc_recover() if x.database == dbname][0]
1234        self.assertEqual(10, xid.format_id)
1235        self.assertEqual('uni', xid.gtrid)
1236        self.assertEqual('code', xid.bqual)
1237
1238    def test_xid_unicode_unparsed(self):
1239        # We don't expect people shooting snowmen as transaction ids,
1240        # so if something explodes in an encode error I don't mind.
1241        # Let's just check unicode is accepted as type.
1242        cnn = self.connect()
1243        cnn.set_client_encoding('utf8')
1244        cnn.tpc_begin("transaction-id")
1245        cnn.tpc_prepare()
1246        cnn.reset()
1247
1248        xid = [x for x in cnn.tpc_recover() if x.database == dbname][0]
1249        self.assertEqual(None, xid.format_id)
1250        self.assertEqual('transaction-id', xid.gtrid)
1251        self.assertEqual(None, xid.bqual)
1252
1253    def test_cancel_fails_prepared(self):
1254        cnn = self.connect()
1255        cnn.tpc_begin('cancel')
1256        cnn.tpc_prepare()
1257        self.assertRaises(psycopg2.ProgrammingError, cnn.cancel)
1258
1259    def test_tpc_recover_non_dbapi_connection(self):
1260        cnn = self.connect(connection_factory=psycopg2.extras.RealDictConnection)
1261        cnn.tpc_begin('dict-connection')
1262        cnn.tpc_prepare()
1263        cnn.reset()
1264
1265        xids = cnn.tpc_recover()
1266        xid = [x for x in xids if x.database == dbname][0]
1267        self.assertEqual(None, xid.format_id)
1268        self.assertEqual('dict-connection', xid.gtrid)
1269        self.assertEqual(None, xid.bqual)
1270
1271
1272@skip_if_crdb("isolation level")
1273class TransactionControlTests(ConnectingTestCase):
1274    def test_closed(self):
1275        self.conn.close()
1276        self.assertRaises(psycopg2.InterfaceError,
1277            self.conn.set_session,
1278            ext.ISOLATION_LEVEL_SERIALIZABLE)
1279
1280    def test_not_in_transaction(self):
1281        cur = self.conn.cursor()
1282        cur.execute("select 1")
1283        self.assertRaises(psycopg2.ProgrammingError,
1284            self.conn.set_session,
1285            ext.ISOLATION_LEVEL_SERIALIZABLE)
1286
1287    def test_set_isolation_level(self):
1288        cur = self.conn.cursor()
1289        self.conn.set_session(
1290            ext.ISOLATION_LEVEL_SERIALIZABLE)
1291        cur.execute("SHOW transaction_isolation;")
1292        self.assertEqual(cur.fetchone()[0], 'serializable')
1293        self.conn.rollback()
1294
1295        self.conn.set_session(
1296            ext.ISOLATION_LEVEL_REPEATABLE_READ)
1297        cur.execute("SHOW transaction_isolation;")
1298        if self.conn.info.server_version > 80000:
1299            self.assertEqual(cur.fetchone()[0], 'repeatable read')
1300        else:
1301            self.assertEqual(cur.fetchone()[0], 'serializable')
1302        self.conn.rollback()
1303
1304        self.conn.set_session(
1305            isolation_level=ext.ISOLATION_LEVEL_READ_COMMITTED)
1306        cur.execute("SHOW transaction_isolation;")
1307        self.assertEqual(cur.fetchone()[0], 'read committed')
1308        self.conn.rollback()
1309
1310        self.conn.set_session(
1311            isolation_level=ext.ISOLATION_LEVEL_READ_UNCOMMITTED)
1312        cur.execute("SHOW transaction_isolation;")
1313        if self.conn.info.server_version > 80000:
1314            self.assertEqual(cur.fetchone()[0], 'read uncommitted')
1315        else:
1316            self.assertEqual(cur.fetchone()[0], 'read committed')
1317        self.conn.rollback()
1318
1319    def test_set_isolation_level_str(self):
1320        cur = self.conn.cursor()
1321        self.conn.set_session("serializable")
1322        cur.execute("SHOW transaction_isolation;")
1323        self.assertEqual(cur.fetchone()[0], 'serializable')
1324        self.conn.rollback()
1325
1326        self.conn.set_session("repeatable read")
1327        cur.execute("SHOW transaction_isolation;")
1328        if self.conn.info.server_version > 80000:
1329            self.assertEqual(cur.fetchone()[0], 'repeatable read')
1330        else:
1331            self.assertEqual(cur.fetchone()[0], 'serializable')
1332        self.conn.rollback()
1333
1334        self.conn.set_session("read committed")
1335        cur.execute("SHOW transaction_isolation;")
1336        self.assertEqual(cur.fetchone()[0], 'read committed')
1337        self.conn.rollback()
1338
1339        self.conn.set_session("read uncommitted")
1340        cur.execute("SHOW transaction_isolation;")
1341        if self.conn.info.server_version > 80000:
1342            self.assertEqual(cur.fetchone()[0], 'read uncommitted')
1343        else:
1344            self.assertEqual(cur.fetchone()[0], 'read committed')
1345        self.conn.rollback()
1346
1347    def test_bad_isolation_level(self):
1348        self.assertRaises(ValueError, self.conn.set_session, 0)
1349        self.assertRaises(ValueError, self.conn.set_session, 5)
1350        self.assertRaises(ValueError, self.conn.set_session, 'whatever')
1351
1352    def test_set_read_only(self):
1353        self.assert_(self.conn.readonly is None)
1354
1355        cur = self.conn.cursor()
1356        self.conn.set_session(readonly=True)
1357        self.assert_(self.conn.readonly is True)
1358        cur.execute("SHOW transaction_read_only;")
1359        self.assertEqual(cur.fetchone()[0], 'on')
1360        self.conn.rollback()
1361        cur.execute("SHOW transaction_read_only;")
1362        self.assertEqual(cur.fetchone()[0], 'on')
1363        self.conn.rollback()
1364
1365        self.conn.set_session(readonly=False)
1366        self.assert_(self.conn.readonly is False)
1367        cur.execute("SHOW transaction_read_only;")
1368        self.assertEqual(cur.fetchone()[0], 'off')
1369        self.conn.rollback()
1370
1371    def test_setattr_read_only(self):
1372        cur = self.conn.cursor()
1373        self.conn.readonly = True
1374        self.assert_(self.conn.readonly is True)
1375        cur.execute("SHOW transaction_read_only;")
1376        self.assertEqual(cur.fetchone()[0], 'on')
1377        self.assertRaises(self.conn.ProgrammingError,
1378            setattr, self.conn, 'readonly', False)
1379        self.assert_(self.conn.readonly is True)
1380        self.conn.rollback()
1381        cur.execute("SHOW transaction_read_only;")
1382        self.assertEqual(cur.fetchone()[0], 'on')
1383        self.conn.rollback()
1384
1385        cur = self.conn.cursor()
1386        self.conn.readonly = None
1387        self.assert_(self.conn.readonly is None)
1388        cur.execute("SHOW transaction_read_only;")
1389        self.assertEqual(cur.fetchone()[0], 'off')  # assume defined by server
1390        self.conn.rollback()
1391
1392        self.conn.readonly = False
1393        self.assert_(self.conn.readonly is False)
1394        cur.execute("SHOW transaction_read_only;")
1395        self.assertEqual(cur.fetchone()[0], 'off')
1396        self.conn.rollback()
1397
1398    def test_set_default(self):
1399        cur = self.conn.cursor()
1400        cur.execute("SHOW transaction_isolation;")
1401        isolevel = cur.fetchone()[0]
1402        cur.execute("SHOW transaction_read_only;")
1403        readonly = cur.fetchone()[0]
1404        self.conn.rollback()
1405
1406        self.conn.set_session(isolation_level='serializable', readonly=True)
1407        self.conn.set_session(isolation_level='default', readonly='default')
1408
1409        cur.execute("SHOW transaction_isolation;")
1410        self.assertEqual(cur.fetchone()[0], isolevel)
1411        cur.execute("SHOW transaction_read_only;")
1412        self.assertEqual(cur.fetchone()[0], readonly)
1413
1414    @skip_before_postgres(9, 1)
1415    def test_set_deferrable(self):
1416        self.assert_(self.conn.deferrable is None)
1417        cur = self.conn.cursor()
1418        self.conn.set_session(readonly=True, deferrable=True)
1419        self.assert_(self.conn.deferrable is True)
1420        cur.execute("SHOW transaction_read_only;")
1421        self.assertEqual(cur.fetchone()[0], 'on')
1422        cur.execute("SHOW transaction_deferrable;")
1423        self.assertEqual(cur.fetchone()[0], 'on')
1424        self.conn.rollback()
1425        cur.execute("SHOW transaction_deferrable;")
1426        self.assertEqual(cur.fetchone()[0], 'on')
1427        self.conn.rollback()
1428
1429        self.conn.set_session(deferrable=False)
1430        self.assert_(self.conn.deferrable is False)
1431        cur.execute("SHOW transaction_read_only;")
1432        self.assertEqual(cur.fetchone()[0], 'on')
1433        cur.execute("SHOW transaction_deferrable;")
1434        self.assertEqual(cur.fetchone()[0], 'off')
1435        self.conn.rollback()
1436
1437    @skip_after_postgres(9, 1)
1438    def test_set_deferrable_error(self):
1439        self.assertRaises(psycopg2.ProgrammingError,
1440            self.conn.set_session, readonly=True, deferrable=True)
1441        self.assertRaises(psycopg2.ProgrammingError,
1442            setattr, self.conn, 'deferrable', True)
1443
1444    @skip_before_postgres(9, 1)
1445    def test_setattr_deferrable(self):
1446        cur = self.conn.cursor()
1447        self.conn.deferrable = True
1448        self.assert_(self.conn.deferrable is True)
1449        cur.execute("SHOW transaction_deferrable;")
1450        self.assertEqual(cur.fetchone()[0], 'on')
1451        self.assertRaises(self.conn.ProgrammingError,
1452            setattr, self.conn, 'deferrable', False)
1453        self.assert_(self.conn.deferrable is True)
1454        self.conn.rollback()
1455        cur.execute("SHOW transaction_deferrable;")
1456        self.assertEqual(cur.fetchone()[0], 'on')
1457        self.conn.rollback()
1458
1459        cur = self.conn.cursor()
1460        self.conn.deferrable = None
1461        self.assert_(self.conn.deferrable is None)
1462        cur.execute("SHOW transaction_deferrable;")
1463        self.assertEqual(cur.fetchone()[0], 'off')  # assume defined by server
1464        self.conn.rollback()
1465
1466        self.conn.deferrable = False
1467        self.assert_(self.conn.deferrable is False)
1468        cur.execute("SHOW transaction_deferrable;")
1469        self.assertEqual(cur.fetchone()[0], 'off')
1470        self.conn.rollback()
1471
1472    def test_mixing_session_attribs(self):
1473        cur = self.conn.cursor()
1474        self.conn.autocommit = True
1475        self.conn.readonly = True
1476
1477        cur.execute("SHOW transaction_read_only;")
1478        self.assertEqual(cur.fetchone()[0], 'on')
1479
1480        cur.execute("SHOW default_transaction_read_only;")
1481        self.assertEqual(cur.fetchone()[0], 'on')
1482
1483        self.conn.autocommit = False
1484        cur.execute("SHOW transaction_read_only;")
1485        self.assertEqual(cur.fetchone()[0], 'on')
1486
1487        cur.execute("SHOW default_transaction_read_only;")
1488        self.assertEqual(cur.fetchone()[0], 'off')
1489
1490    def test_idempotence_check(self):
1491        self.conn.autocommit = False
1492        self.conn.readonly = True
1493        self.conn.autocommit = True
1494        self.conn.readonly = True
1495
1496        cur = self.conn.cursor()
1497        cur.execute("SHOW transaction_read_only")
1498        self.assertEqual(cur.fetchone()[0], 'on')
1499
1500
1501class TestEncryptPassword(ConnectingTestCase):
1502    @skip_before_postgres(10)
1503    def test_encrypt_password_post_9_6(self):
1504        # MD5 algorithm
1505        self.assertEqual(
1506            ext.encrypt_password('psycopg2', 'ashesh', self.conn, 'md5'),
1507            'md594839d658c28a357126f105b9cb14cfc')
1508
1509        # keywords
1510        self.assertEqual(
1511            ext.encrypt_password(
1512                password='psycopg2', user='ashesh',
1513                scope=self.conn, algorithm='md5'),
1514            'md594839d658c28a357126f105b9cb14cfc')
1515
1516    @skip_if_crdb("password_encryption")
1517    @skip_before_libpq(10)
1518    @skip_before_postgres(10)
1519    def test_encrypt_server(self):
1520        cur = self.conn.cursor()
1521        cur.execute("SHOW password_encryption;")
1522        server_encryption_algorithm = cur.fetchone()[0]
1523
1524        enc_password = ext.encrypt_password(
1525            'psycopg2', 'ashesh', self.conn)
1526
1527        if server_encryption_algorithm == 'md5':
1528            self.assertEqual(
1529                enc_password, 'md594839d658c28a357126f105b9cb14cfc')
1530        elif server_encryption_algorithm == 'scram-sha-256':
1531            self.assertEqual(enc_password[:14], 'SCRAM-SHA-256$')
1532
1533        self.assertEqual(
1534            ext.encrypt_password(
1535                'psycopg2', 'ashesh', self.conn, 'scram-sha-256'
1536            )[:14], 'SCRAM-SHA-256$')
1537
1538        self.assertRaises(psycopg2.ProgrammingError,
1539            ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
1540
1541    def test_encrypt_md5(self):
1542        self.assertEqual(
1543            ext.encrypt_password('psycopg2', 'ashesh', algorithm='md5'),
1544            'md594839d658c28a357126f105b9cb14cfc')
1545
1546    @skip_before_libpq(10)
1547    def test_encrypt_bad_libpq_10(self):
1548        self.assertRaises(psycopg2.ProgrammingError,
1549            ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
1550
1551    @skip_after_libpq(10)
1552    def test_encrypt_bad_before_libpq_10(self):
1553        self.assertRaises(psycopg2.NotSupportedError,
1554            ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
1555
1556    @skip_before_libpq(10)
1557    def test_encrypt_scram(self):
1558        self.assert_(
1559            ext.encrypt_password(
1560                'psycopg2', 'ashesh', self.conn, 'scram-sha-256')
1561            .startswith('SCRAM-SHA-256$'))
1562
1563    @skip_after_libpq(10)
1564    def test_encrypt_scram_pre_10(self):
1565        self.assertRaises(psycopg2.NotSupportedError,
1566            ext.encrypt_password,
1567            password='psycopg2', user='ashesh',
1568            scope=self.conn, algorithm='scram-sha-256')
1569
1570    def test_bad_types(self):
1571        self.assertRaises(TypeError, ext.encrypt_password)
1572        self.assertRaises(TypeError, ext.encrypt_password,
1573            'password', 42, self.conn, 'md5')
1574        self.assertRaises(TypeError, ext.encrypt_password,
1575            42, 'user', self.conn, 'md5')
1576        self.assertRaises(TypeError, ext.encrypt_password,
1577            42, 'user', 'wat', 'abc')
1578        self.assertRaises(TypeError, ext.encrypt_password,
1579            'password', 'user', 'wat', 42)
1580
1581
1582class AutocommitTests(ConnectingTestCase):
1583    def test_closed(self):
1584        self.conn.close()
1585        self.assertRaises(psycopg2.InterfaceError,
1586            setattr, self.conn, 'autocommit', True)
1587
1588        # The getter doesn't have a guard. We may change this in future
1589        # to make it consistent with other methods; meanwhile let's just check
1590        # it doesn't explode.
1591        try:
1592            self.assert_(self.conn.autocommit in (True, False))
1593        except psycopg2.InterfaceError:
1594            pass
1595
1596    def test_default_no_autocommit(self):
1597        self.assert_(not self.conn.autocommit)
1598        self.assertEqual(self.conn.status, ext.STATUS_READY)
1599        self.assertEqual(self.conn.info.transaction_status,
1600            ext.TRANSACTION_STATUS_IDLE)
1601
1602        cur = self.conn.cursor()
1603        cur.execute('select 1;')
1604        self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
1605        self.assertEqual(self.conn.info.transaction_status,
1606            ext.TRANSACTION_STATUS_INTRANS)
1607
1608        self.conn.rollback()
1609        self.assertEqual(self.conn.status, ext.STATUS_READY)
1610        self.assertEqual(self.conn.info.transaction_status,
1611            ext.TRANSACTION_STATUS_IDLE)
1612
1613    def test_set_autocommit(self):
1614        self.conn.autocommit = True
1615        self.assert_(self.conn.autocommit)
1616        self.assertEqual(self.conn.status, ext.STATUS_READY)
1617        self.assertEqual(self.conn.info.transaction_status,
1618            ext.TRANSACTION_STATUS_IDLE)
1619
1620        cur = self.conn.cursor()
1621        cur.execute('select 1;')
1622        self.assertEqual(self.conn.status, ext.STATUS_READY)
1623        self.assertEqual(self.conn.info.transaction_status,
1624            ext.TRANSACTION_STATUS_IDLE)
1625
1626        self.conn.autocommit = False
1627        self.assert_(not self.conn.autocommit)
1628        self.assertEqual(self.conn.status, ext.STATUS_READY)
1629        self.assertEqual(self.conn.info.transaction_status,
1630            ext.TRANSACTION_STATUS_IDLE)
1631
1632        cur.execute('select 1;')
1633        self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
1634        self.assertEqual(self.conn.info.transaction_status,
1635            ext.TRANSACTION_STATUS_INTRANS)
1636
1637    def test_set_intrans_error(self):
1638        cur = self.conn.cursor()
1639        cur.execute('select 1;')
1640        self.assertRaises(psycopg2.ProgrammingError,
1641            setattr, self.conn, 'autocommit', True)
1642
1643    def test_set_session_autocommit(self):
1644        self.conn.set_session(autocommit=True)
1645        self.assert_(self.conn.autocommit)
1646        self.assertEqual(self.conn.status, ext.STATUS_READY)
1647        self.assertEqual(self.conn.info.transaction_status,
1648            ext.TRANSACTION_STATUS_IDLE)
1649
1650        cur = self.conn.cursor()
1651        cur.execute('select 1;')
1652        self.assertEqual(self.conn.status, ext.STATUS_READY)
1653        self.assertEqual(self.conn.info.transaction_status,
1654            ext.TRANSACTION_STATUS_IDLE)
1655
1656        self.conn.set_session(autocommit=False)
1657        self.assert_(not self.conn.autocommit)
1658        self.assertEqual(self.conn.status, ext.STATUS_READY)
1659        self.assertEqual(self.conn.info.transaction_status,
1660            ext.TRANSACTION_STATUS_IDLE)
1661
1662        cur.execute('select 1;')
1663        self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
1664        self.assertEqual(self.conn.info.transaction_status,
1665            ext.TRANSACTION_STATUS_INTRANS)
1666        self.conn.rollback()
1667
1668        self.conn.set_session('serializable', readonly=True, autocommit=True)
1669        self.assert_(self.conn.autocommit)
1670        cur.execute('select 1;')
1671        self.assertEqual(self.conn.status, ext.STATUS_READY)
1672        self.assertEqual(self.conn.info.transaction_status,
1673            ext.TRANSACTION_STATUS_IDLE)
1674        cur.execute("SHOW transaction_isolation;")
1675        self.assertEqual(cur.fetchone()[0], 'serializable')
1676        cur.execute("SHOW transaction_read_only;")
1677        self.assertEqual(cur.fetchone()[0], 'on')
1678
1679
1680class PasswordLeakTestCase(ConnectingTestCase):
1681    def setUp(self):
1682        super().setUp()
1683        PasswordLeakTestCase.dsn = None
1684
1685    class GrassingConnection(ext.connection):
1686        """A connection snitching the dsn away.
1687
1688        This connection passes the dsn to the test case class even if init
1689        fails (e.g. connection error). Test that we mangle the dsn ok anyway.
1690        """
1691
1692        def __init__(self, *args, **kwargs):
1693            try:
1694                super(PasswordLeakTestCase.GrassingConnection, self).__init__(
1695                    *args, **kwargs)
1696            finally:
1697                # The connection is not initialized entirely, however the C
1698                # code should have set the dsn, and it should have scrubbed
1699                # the password away
1700                PasswordLeakTestCase.dsn = self.dsn
1701
1702    @skip_if_crdb("connect any db")
1703    def test_leak(self):
1704        self.assertRaises(psycopg2.DatabaseError,
1705            self.GrassingConnection, "dbname=nosuch password=whateva")
1706        self.assertDsnEqual(self.dsn, "dbname=nosuch password=xxx")
1707
1708    @skip_before_libpq(9, 2)
1709    def test_url_leak(self):
1710        self.assertRaises(psycopg2.DatabaseError,
1711            self.GrassingConnection,
1712            "postgres://someone:whateva@localhost/nosuch")
1713
1714        self.assertDsnEqual(self.dsn,
1715            "user=someone password=xxx host=localhost dbname=nosuch")
1716
1717
1718class SignalTestCase(ConnectingTestCase):
1719    @slow
1720    @skip_before_postgres(8, 2)
1721    def test_bug_551_returning(self):
1722        # Raise an exception trying to decode 'id'
1723        self._test_bug_551(query="""
1724            INSERT INTO test551 (num) VALUES (%s) RETURNING id
1725            """)
1726
1727    @slow
1728    def test_bug_551_no_returning(self):
1729        # Raise an exception trying to decode 'INSERT 0 1'
1730        self._test_bug_551(query="""
1731            INSERT INTO test551 (num) VALUES (%s)
1732            """)
1733
1734    def _test_bug_551(self, query):
1735        script = f"""import os
1736import sys
1737import time
1738import signal
1739import warnings
1740import threading
1741
1742# ignore wheel deprecation warning
1743with warnings.catch_warnings():
1744    warnings.simplefilter('ignore')
1745    import psycopg2
1746
1747def handle_sigabort(sig, frame):
1748    sys.exit(1)
1749
1750def killer():
1751    time.sleep(0.5)
1752    os.kill(os.getpid(), signal.SIGABRT)
1753
1754signal.signal(signal.SIGABRT, handle_sigabort)
1755
1756conn = psycopg2.connect({dsn!r})
1757
1758cur = conn.cursor()
1759
1760cur.execute("create table test551 (id serial, num varchar(50))")
1761
1762t = threading.Thread(target=killer)
1763t.daemon = True
1764t.start()
1765
1766while True:
1767    cur.execute({query!r}, ("Hello, world!",))
1768"""
1769
1770        proc = sp.Popen([sys.executable, '-c', script],
1771            stdout=sp.PIPE, stderr=sp.PIPE)
1772        (out, err) = proc.communicate()
1773        self.assertNotEqual(proc.returncode, 0)
1774        # Strip [NNN refs] from output
1775        err = re.sub(br'\[[^\]]+\]', b'', err).strip()
1776        self.assert_(not err, err)
1777
1778
1779class TestConnectionInfo(ConnectingTestCase):
1780    def setUp(self):
1781        ConnectingTestCase.setUp(self)
1782
1783        class BrokenConn(psycopg2.extensions.connection):
1784            def __init__(self, *args, **kwargs):
1785                # don't call superclass
1786                pass
1787
1788        # A "broken" connection
1789        self.bconn = self.connect(connection_factory=BrokenConn)
1790
1791    def test_dbname(self):
1792        self.assert_(isinstance(self.conn.info.dbname, str))
1793        self.assert_(self.bconn.info.dbname is None)
1794
1795    def test_user(self):
1796        cur = self.conn.cursor()
1797        cur.execute("select user")
1798        self.assertEqual(self.conn.info.user, cur.fetchone()[0])
1799        self.assert_(self.bconn.info.user is None)
1800
1801    def test_password(self):
1802        self.assert_(isinstance(self.conn.info.password, str))
1803        self.assert_(self.bconn.info.password is None)
1804
1805    def test_host(self):
1806        expected = dbhost if dbhost else "/"
1807        self.assertIn(expected, self.conn.info.host)
1808        self.assert_(self.bconn.info.host is None)
1809
1810    def test_host_readonly(self):
1811        with self.assertRaises(AttributeError):
1812            self.conn.info.host = 'override'
1813
1814    def test_port(self):
1815        self.assert_(isinstance(self.conn.info.port, int))
1816        self.assert_(self.bconn.info.port is None)
1817
1818    def test_options(self):
1819        self.assert_(isinstance(self.conn.info.options, str))
1820        self.assert_(self.bconn.info.options is None)
1821
1822    @skip_before_libpq(9, 3)
1823    def test_dsn_parameters(self):
1824        d = self.conn.info.dsn_parameters
1825        self.assert_(isinstance(d, dict))
1826        self.assertEqual(d['dbname'], dbname)  # the only param we can check reliably
1827        self.assert_('password' not in d, d)
1828
1829    def test_status(self):
1830        self.assertEqual(self.conn.info.status, 0)
1831        self.assertEqual(self.bconn.info.status, 1)
1832
1833    def test_transaction_status(self):
1834        self.assertEqual(self.conn.info.transaction_status, 0)
1835        cur = self.conn.cursor()
1836        cur.execute("select 1")
1837        self.assertEqual(self.conn.info.transaction_status, 2)
1838        self.assertEqual(self.bconn.info.transaction_status, 4)
1839
1840    def test_parameter_status(self):
1841        cur = self.conn.cursor()
1842        try:
1843            cur.execute("show server_version")
1844        except psycopg2.DatabaseError:
1845            self.assertIsInstance(
1846                self.conn.info.parameter_status('server_version'), str)
1847        else:
1848            self.assertEqual(
1849                self.conn.info.parameter_status('server_version'),
1850                cur.fetchone()[0])
1851
1852        self.assertIsNone(self.conn.info.parameter_status('wat'))
1853        self.assertIsNone(self.bconn.info.parameter_status('server_version'))
1854
1855    def test_protocol_version(self):
1856        self.assertEqual(self.conn.info.protocol_version, 3)
1857        self.assertEqual(self.bconn.info.protocol_version, 0)
1858
1859    def test_server_version(self):
1860        cur = self.conn.cursor()
1861        try:
1862            cur.execute("show server_version_num")
1863        except psycopg2.DatabaseError:
1864            self.assert_(isinstance(self.conn.info.server_version, int))
1865        else:
1866            self.assertEqual(
1867                self.conn.info.server_version, int(cur.fetchone()[0]))
1868
1869        self.assertEqual(self.bconn.info.server_version, 0)
1870
1871    def test_error_message(self):
1872        self.assertIsNone(self.conn.info.error_message)
1873        self.assertIsNotNone(self.bconn.info.error_message)
1874
1875        cur = self.conn.cursor()
1876        try:
1877            cur.execute("select 1 from nosuchtable")
1878        except psycopg2.DatabaseError:
1879            pass
1880
1881        self.assert_('nosuchtable' in self.conn.info.error_message)
1882
1883    def test_socket(self):
1884        self.assert_(self.conn.info.socket >= 0)
1885        self.assert_(self.bconn.info.socket < 0)
1886
1887    @skip_if_crdb("backend pid")
1888    def test_backend_pid(self):
1889        cur = self.conn.cursor()
1890        try:
1891            cur.execute("select pg_backend_pid()")
1892        except psycopg2.DatabaseError:
1893            self.assert_(self.conn.info.backend_pid > 0)
1894        else:
1895            self.assertEqual(
1896                self.conn.info.backend_pid, int(cur.fetchone()[0]))
1897
1898        self.assert_(self.bconn.info.backend_pid == 0)
1899
1900    def test_needs_password(self):
1901        self.assertIs(self.conn.info.needs_password, False)
1902        self.assertIs(self.bconn.info.needs_password, False)
1903
1904    def test_used_password(self):
1905        self.assertIsInstance(self.conn.info.used_password, bool)
1906        self.assertIs(self.bconn.info.used_password, False)
1907
1908    @skip_before_libpq(9, 5)
1909    def test_ssl_in_use(self):
1910        self.assertIsInstance(self.conn.info.ssl_in_use, bool)
1911        self.assertIs(self.bconn.info.ssl_in_use, False)
1912
1913    @skip_after_libpq(9, 5)
1914    def test_ssl_not_supported(self):
1915        with self.assertRaises(psycopg2.NotSupportedError):
1916            self.conn.info.ssl_in_use
1917        with self.assertRaises(psycopg2.NotSupportedError):
1918            self.conn.info.ssl_attribute_names
1919        with self.assertRaises(psycopg2.NotSupportedError):
1920            self.conn.info.ssl_attribute('wat')
1921
1922    @skip_before_libpq(9, 5)
1923    def test_ssl_attribute(self):
1924        attribs = self.conn.info.ssl_attribute_names
1925        self.assert_(attribs)
1926        if self.conn.info.ssl_in_use:
1927            for attrib in attribs:
1928                self.assertIsInstance(self.conn.info.ssl_attribute(attrib), str)
1929        else:
1930            for attrib in attribs:
1931                self.assertIsNone(self.conn.info.ssl_attribute(attrib))
1932
1933        self.assertIsNone(self.conn.info.ssl_attribute('wat'))
1934
1935        for attrib in attribs:
1936            self.assertIsNone(self.bconn.info.ssl_attribute(attrib))
1937
1938
1939def test_suite():
1940    return unittest.TestLoader().loadTestsFromName(__name__)
1941
1942
1943if __name__ == "__main__":
1944    unittest.main()
1945