1import datetime
2import warnings
3
4import pytest
5
6import pg8000
7
8
9def test_unix_socket_missing():
10    conn_params = {"unix_sock": "/file-does-not-exist", "user": "doesn't-matter"}
11
12    with pytest.raises(pg8000.dbapi.InterfaceError):
13        pg8000.dbapi.connect(**conn_params)
14
15
16def test_internet_socket_connection_refused():
17    conn_params = {"port": 0, "user": "doesn't-matter"}
18
19    with pytest.raises(
20        pg8000.dbapi.InterfaceError,
21        match="Can't create a connection to host localhost and port 0 "
22        "\\(timeout is None and source_address is None\\).",
23    ):
24        pg8000.dbapi.connect(**conn_params)
25
26
27def test_database_missing(db_kwargs):
28    db_kwargs["database"] = "missing-db"
29    with pytest.raises(pg8000.dbapi.DatabaseError):
30        pg8000.dbapi.connect(**db_kwargs)
31
32
33def test_database_name_unicode(db_kwargs):
34    db_kwargs["database"] = "pg8000_sn\uFF6Fw"
35
36    # Should only raise an exception saying db doesn't exist
37    with pytest.raises(pg8000.dbapi.DatabaseError, match="3D000"):
38        pg8000.dbapi.connect(**db_kwargs)
39
40
41def test_database_name_bytes(db_kwargs):
42    """Should only raise an exception saying db doesn't exist"""
43
44    db_kwargs["database"] = bytes("pg8000_sn\uFF6Fw", "utf8")
45    with pytest.raises(pg8000.dbapi.DatabaseError, match="3D000"):
46        pg8000.dbapi.connect(**db_kwargs)
47
48
49def test_password_bytes(con, db_kwargs):
50    # Create user
51    username = "boltzmann"
52    password = "cha\uFF6Fs"
53    cur = con.cursor()
54    cur.execute("create user " + username + " with password '" + password + "';")
55    con.commit()
56
57    db_kwargs["user"] = username
58    db_kwargs["password"] = password.encode("utf8")
59    db_kwargs["database"] = "pg8000_md5"
60    with pytest.raises(pg8000.dbapi.DatabaseError, match="3D000"):
61        pg8000.dbapi.connect(**db_kwargs)
62
63    cur.execute("drop role " + username)
64    con.commit()
65
66
67def test_application_name(db_kwargs):
68    app_name = "my test application name"
69    db_kwargs["application_name"] = app_name
70    with pg8000.dbapi.connect(**db_kwargs) as db:
71        cur = db.cursor()
72        cur.execute(
73            "select application_name from pg_stat_activity "
74            " where pid = pg_backend_pid()"
75        )
76
77        application_name = cur.fetchone()[0]
78        assert application_name == app_name
79
80
81def test_application_name_integer(db_kwargs):
82    db_kwargs["application_name"] = 1
83    with pytest.raises(
84        pg8000.dbapi.InterfaceError,
85        match="The parameter application_name can't be of type " "<class 'int'>.",
86    ):
87        pg8000.dbapi.connect(**db_kwargs)
88
89
90def test_application_name_bytearray(db_kwargs):
91    db_kwargs["application_name"] = bytearray(b"Philby")
92    pg8000.dbapi.connect(**db_kwargs)
93
94
95def test_notify(con):
96    cursor = con.cursor()
97    cursor.execute("select pg_backend_pid()")
98    backend_pid = cursor.fetchall()[0][0]
99    assert list(con.notifications) == []
100    cursor.execute("LISTEN test")
101    cursor.execute("NOTIFY test")
102    con.commit()
103
104    cursor.execute("VALUES (1, 2), (3, 4), (5, 6)")
105    assert len(con.notifications) == 1
106    assert con.notifications[0] == (backend_pid, "test", "")
107
108
109def test_notify_with_payload(con):
110    cursor = con.cursor()
111    cursor.execute("select pg_backend_pid()")
112    backend_pid = cursor.fetchall()[0][0]
113    assert list(con.notifications) == []
114    cursor.execute("LISTEN test")
115    cursor.execute("NOTIFY test, 'Parnham'")
116    con.commit()
117
118    cursor.execute("VALUES (1, 2), (3, 4), (5, 6)")
119    assert len(con.notifications) == 1
120    assert con.notifications[0] == (backend_pid, "test", "Parnham")
121
122
123def test_broken_pipe_read(con, db_kwargs):
124    db1 = pg8000.dbapi.connect(**db_kwargs)
125    cur1 = db1.cursor()
126    cur2 = con.cursor()
127    cur1.execute("select pg_backend_pid()")
128    pid1 = cur1.fetchone()[0]
129
130    cur2.execute("select pg_terminate_backend(%s)", (pid1,))
131    with pytest.raises(pg8000.dbapi.InterfaceError, match="network error on read"):
132        cur1.execute("select 1")
133
134
135def test_broken_pipe_flush(con, db_kwargs):
136    db1 = pg8000.dbapi.connect(**db_kwargs)
137    cur1 = db1.cursor()
138    cur2 = con.cursor()
139    cur1.execute("select pg_backend_pid()")
140    pid1 = cur1.fetchone()[0]
141
142    cur2.execute("select pg_terminate_backend(%s)", (pid1,))
143    try:
144        cur1.execute("select 1")
145    except BaseException:
146        pass
147
148    # Sometimes raises and sometime doesn't
149    try:
150        db1.close()
151    except pg8000.exceptions.InterfaceError as e:
152        assert str(e) == "network error on flush"
153
154
155def test_broken_pipe_unpack(con):
156    cur = con.cursor()
157    cur.execute("select pg_backend_pid()")
158    pid1 = cur.fetchone()[0]
159
160    with pytest.raises(pg8000.dbapi.InterfaceError, match="network error"):
161        cur.execute("select pg_terminate_backend(%s)", (pid1,))
162
163
164def test_py_value_fail(con, mocker):
165    # Ensure that if types.py_value throws an exception, the original
166    # exception is raised (PG8000TestException), and the connection is
167    # still usable after the error.
168
169    class PG8000TestException(Exception):
170        pass
171
172    def raise_exception(val):
173        raise PG8000TestException("oh noes!")
174
175    mocker.patch.object(con, "py_types")
176    con.py_types = {datetime.time: raise_exception}
177
178    with pytest.raises(PG8000TestException):
179        c = con.cursor()
180        c.execute("SELECT CAST(%s AS TIME) AS f1", (datetime.time(10, 30),))
181        c.fetchall()
182
183        # ensure that the connection is still usable for a new query
184        c.execute("VALUES ('hw3'::text)")
185        assert c.fetchone()[0] == "hw3"
186
187
188def test_no_data_error_recovery(con):
189    for i in range(1, 4):
190        with pytest.raises(pg8000.DatabaseError) as e:
191            c = con.cursor()
192            c.execute("DROP TABLE t1")
193        assert e.value.args[0]["C"] == "42P01"
194        con.rollback()
195
196
197def test_closed_connection(db_kwargs):
198    warnings.simplefilter("ignore")
199
200    my_db = pg8000.connect(**db_kwargs)
201    cursor = my_db.cursor()
202    my_db.close()
203    with pytest.raises(my_db.InterfaceError, match="connection is closed"):
204        cursor.execute("VALUES ('hw1'::text)")
205
206    warnings.resetwarnings()
207