1from datetime import datetime as Datetime, timezone as Timezone
2
3import pytest
4
5import pg8000.dbapi
6from pg8000.converters import INET_ARRAY, INTEGER
7
8
9# Tests relating to the basic operation of the database driver, driven by the
10# pg8000 custom interface.
11
12
13@pytest.fixture
14def db_table(request, con):
15    con.paramstyle = "format"
16    cursor = con.cursor()
17    cursor.execute(
18        "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
19        "f2 bigint not null, f3 varchar(50) null) "
20    )
21
22    def fin():
23        try:
24            cursor = con.cursor()
25            cursor.execute("drop table t1")
26        except pg8000.dbapi.DatabaseError:
27            pass
28
29    request.addfinalizer(fin)
30    return con
31
32
33def test_database_error(cursor):
34    with pytest.raises(pg8000.dbapi.DatabaseError):
35        cursor.execute("INSERT INTO t99 VALUES (1, 2, 3)")
36
37
38def test_parallel_queries(db_table):
39    cursor = db_table.cursor()
40    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
41    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
42    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
43    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
44    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
45    c1 = db_table.cursor()
46    c2 = db_table.cursor()
47    c1.execute("SELECT f1, f2, f3 FROM t1")
48    for row in c1.fetchall():
49        f1, f2, f3 = row
50        c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
51        for row in c2.fetchall():
52            f1, f2, f3 = row
53
54
55def test_parallel_open_portals(con):
56    c1 = con.cursor()
57    c2 = con.cursor()
58    c1count, c2count = 0, 0
59    q = "select * from generate_series(1, %s)"
60    params = (100,)
61    c1.execute(q, params)
62    c2.execute(q, params)
63    for c2row in c2.fetchall():
64        c2count += 1
65    for c1row in c1.fetchall():
66        c1count += 1
67
68    assert c1count == c2count
69
70
71# Run a query on a table, alter the structure of the table, then run the
72# original query again.
73
74
75def test_alter(db_table):
76    cursor = db_table.cursor()
77    cursor.execute("select * from t1")
78    cursor.execute("alter table t1 drop column f3")
79    cursor.execute("select * from t1")
80
81
82# Run a query on a table, drop then re-create the table, then run the
83# original query again.
84
85
86def test_create(db_table):
87    cursor = db_table.cursor()
88    cursor.execute("select * from t1")
89    cursor.execute("drop table t1")
90    cursor.execute("create temporary table t1 (f1 int primary key)")
91    cursor.execute("select * from t1")
92
93
94def test_insert_returning(db_table):
95    cursor = db_table.cursor()
96    cursor.execute("CREATE TEMPORARY TABLE t2 (id serial, data text)")
97
98    # Test INSERT ... RETURNING with one row...
99    cursor.execute("INSERT INTO t2 (data) VALUES (%s) RETURNING id", ("test1",))
100    row_id = cursor.fetchone()[0]
101    cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,))
102    assert "test1" == cursor.fetchone()[0]
103
104    assert cursor.rowcount == 1
105
106    # Test with multiple rows...
107    cursor.execute(
108        "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " "RETURNING id",
109        ("test2", "test3", "test4"),
110    )
111    assert cursor.rowcount == 3
112    ids = tuple([x[0] for x in cursor.fetchall()])
113    assert len(ids) == 3
114
115
116def test_row_count(db_table):
117    cursor = db_table.cursor()
118    expected_count = 57
119    cursor.executemany(
120        "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
121        tuple((i, i, None) for i in range(expected_count)),
122    )
123
124    # Check rowcount after executemany
125    assert expected_count == cursor.rowcount
126
127    cursor.execute("SELECT * FROM t1")
128
129    # Check row_count without doing any reading first...
130    assert expected_count == cursor.rowcount
131
132    # Check rowcount after reading some rows, make sure it still
133    # works...
134    for i in range(expected_count // 2):
135        cursor.fetchone()
136    assert expected_count == cursor.rowcount
137
138    cursor = db_table.cursor()
139    # Restart the cursor, read a few rows, and then check rowcount
140    # again...
141    cursor.execute("SELECT * FROM t1")
142    for i in range(expected_count // 3):
143        cursor.fetchone()
144    assert expected_count == cursor.rowcount
145
146    # Should be -1 for a command with no results
147    cursor.execute("DROP TABLE t1")
148    assert -1 == cursor.rowcount
149
150
151def test_row_count_update(db_table):
152    cursor = db_table.cursor()
153    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
154    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
155    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
156    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
157    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
158    cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
159    assert cursor.rowcount == 2
160
161
162def test_int_oid(cursor):
163    # https://bugs.launchpad.net/pg8000/+bug/230796
164    cursor.execute("SELECT typname FROM pg_type WHERE oid = %s", (100,))
165
166
167def test_unicode_query(cursor):
168    cursor.execute(
169        "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e "
170        "(\u0438\u043c\u044f VARCHAR(50), "
171        "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))"
172    )
173
174
175def test_executemany(db_table):
176    cursor = db_table.cursor()
177    cursor.executemany(
178        "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
179        ((1, 1, "Avast ye!"), (2, 1, None)),
180    )
181
182    cursor.executemany(
183        "select CAST(%s AS TIMESTAMP)",
184        ((Datetime(2014, 5, 7, tzinfo=Timezone.utc),), (Datetime(2014, 5, 7),)),
185    )
186
187
188def test_executemany_setinputsizes(cursor):
189    """Make sure that setinputsizes works for all the parameter sets"""
190
191    cursor.execute(
192        "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 inet[] not null) "
193    )
194
195    cursor.setinputsizes(INTEGER, INET_ARRAY)
196    cursor.executemany(
197        "INSERT INTO t1 (f1, f2) VALUES (%s, %s)", ((1, ["1.1.1.1"]), (2, ["0.0.0.0"]))
198    )
199
200
201def test_executemany_no_param_sets(cursor):
202    cursor.executemany("INSERT INTO t1 (f1, f2) VALUES (%s, %s)", [])
203    assert cursor.rowcount == -1
204
205
206# Check that autocommit stays off
207# We keep track of whether we're in a transaction or not by using the
208# READY_FOR_QUERY message.
209def test_transactions(db_table):
210    cursor = db_table.cursor()
211    cursor.execute("commit")
212    cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie"))
213    cursor.execute("rollback")
214    cursor.execute("select * from t1")
215
216    assert cursor.rowcount == 0
217
218
219def test_in(cursor):
220    cursor.execute("SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],))
221    ret = cursor.fetchall()
222    assert ret[0][0] == "bool"
223
224
225def test_no_previous_tpc(con):
226    con.tpc_begin("Stacey")
227    cursor = con.cursor()
228    cursor.execute("SELECT * FROM pg_type")
229    con.tpc_commit()
230
231
232# Check that tpc_recover() doesn't start a transaction
233def test_tpc_recover(con):
234    con.tpc_recover()
235    cursor = con.cursor()
236    con.autocommit = True
237
238    # If tpc_recover() has started a transaction, this will fail
239    cursor.execute("VACUUM")
240
241
242def test_tpc_prepare(con):
243    xid = "Stacey"
244    con.tpc_begin(xid)
245    con.tpc_prepare()
246    con.tpc_rollback(xid)
247
248
249# An empty query should raise a ProgrammingError
250def test_empty_query(cursor):
251    with pytest.raises(pg8000.dbapi.DatabaseError):
252        cursor.execute("")
253
254
255# rolling back when not in a transaction doesn't generate a warning
256def test_rollback_no_transaction(con):
257    # Remove any existing notices
258    con.notices.clear()
259
260    # First, verify that a raw rollback does produce a notice
261    con.execute_unnamed("rollback")
262
263    assert 1 == len(con.notices)
264
265    # 25P01 is the code for no_active_sql_tronsaction. It has
266    # a message and severity name, but those might be
267    # localized/depend on the server version.
268    assert con.notices.pop().get(b"C") == b"25P01"
269
270    # Now going through the rollback method doesn't produce
271    # any notices because it knows we're not in a transaction.
272    con.rollback()
273
274    assert 0 == len(con.notices)
275
276
277def test_setinputsizes(con):
278    cursor = con.cursor()
279    cursor.setinputsizes(20)
280    cursor.execute("select %s", (None,))
281    retval = cursor.fetchall()
282    assert retval[0][0] is None
283
284
285def test_unexecuted_cursor_rowcount(con):
286    cursor = con.cursor()
287    assert cursor.rowcount == -1
288
289
290def test_unexecuted_cursor_description(con):
291    cursor = con.cursor()
292    assert cursor.description is None
293
294
295def test_callproc(cursor):
296    cursor.execute("select current_setting('server_version')")
297    version = cursor.fetchall()[0][0].split()[0]
298
299    if not (version.startswith("9") or version.startswith("10")):
300        cursor.execute(
301            """
302CREATE PROCEDURE echo(INOUT val text)
303  LANGUAGE plpgsql AS
304$proc$
305BEGIN
306END
307$proc$;
308"""
309        )
310
311        cursor.callproc("echo", ["hello"])
312        assert cursor.fetchall() == (["hello"],)
313
314
315def test_null_result(db_table):
316    cur = db_table.cursor()
317    cur.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "a"))
318    with pytest.raises(pg8000.dbapi.ProgrammingError):
319        cur.fetchall()
320
321
322def test_not_parsed_if_no_params(mocker, cursor):
323    mock_convert_paramstyle = mocker.patch("pg8000.dbapi.convert_paramstyle")
324    cursor.execute("ROLLBACK")
325    mock_convert_paramstyle.assert_not_called()
326