1from datetime import datetime as Datetime, timezone as Timezone
2from warnings import filterwarnings
3
4import pytest
5
6import pg8000
7from pg8000.converters import INET_ARRAY, INTEGER
8
9
10# Tests relating to the basic operation of the database driver, driven by the
11# pg8000 custom interface.
12
13
14@pytest.fixture
15def db_table(request, con):
16    filterwarnings("ignore", "DB-API extension cursor.next()")
17    filterwarnings("ignore", "DB-API extension cursor.__iter__()")
18    con.paramstyle = "format"
19    with con.cursor() as cursor:
20        cursor.execute(
21            "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
22            "f2 bigint not null, f3 varchar(50) null) "
23        )
24
25    def fin():
26        try:
27            with con.cursor() as cursor:
28                cursor.execute("drop table t1")
29        except pg8000.ProgrammingError:
30            pass
31
32    request.addfinalizer(fin)
33    return con
34
35
36def test_database_error(cursor):
37    with pytest.raises(pg8000.ProgrammingError):
38        cursor.execute("INSERT INTO t99 VALUES (1, 2, 3)")
39
40
41def test_parallel_queries(db_table):
42    with db_table.cursor() as cursor:
43        cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
44        cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
45        cursor.execute(
46            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)
47        )
48        cursor.execute(
49            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)
50        )
51        cursor.execute(
52            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)
53        )
54        with db_table.cursor() as c1, db_table.cursor() as c2:
55            c1.execute("SELECT f1, f2, f3 FROM t1")
56            for row in c1:
57                f1, f2, f3 = row
58                c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
59                for row in c2:
60                    f1, f2, f3 = row
61
62
63def test_parallel_open_portals(con):
64    with con.cursor() as c1, con.cursor() as c2:
65        c1count, c2count = 0, 0
66        q = "select * from generate_series(1, %s)"
67        params = (100,)
68        c1.execute(q, params)
69        c2.execute(q, params)
70        for c2row in c2:
71            c2count += 1
72        for c1row in c1:
73            c1count += 1
74
75    assert c1count == c2count
76
77
78# Run a query on a table, alter the structure of the table, then run the
79# original query again.
80
81
82def test_alter(db_table):
83    with db_table.cursor() as cursor:
84        cursor.execute("select * from t1")
85        cursor.execute("alter table t1 drop column f3")
86        cursor.execute("select * from t1")
87
88
89# Run a query on a table, drop then re-create the table, then run the
90# original query again.
91
92
93def test_create(db_table):
94    with db_table.cursor() as cursor:
95        cursor.execute("select * from t1")
96        cursor.execute("drop table t1")
97        cursor.execute("create temporary table t1 (f1 int primary key)")
98        cursor.execute("select * from t1")
99
100
101def test_insert_returning(db_table):
102    with db_table.cursor() as cursor:
103        cursor.execute("CREATE TABLE t2 (id serial, data text)")
104
105        # Test INSERT ... RETURNING with one row...
106        cursor.execute("INSERT INTO t2 (data) VALUES (%s) RETURNING id", ("test1",))
107        row_id = cursor.fetchone()[0]
108        cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,))
109        assert "test1" == cursor.fetchone()[0]
110
111        assert cursor.rowcount == 1
112
113        # Test with multiple rows...
114        cursor.execute(
115            "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " "RETURNING id",
116            ("test2", "test3", "test4"),
117        )
118        assert cursor.rowcount == 3
119        ids = tuple([x[0] for x in cursor])
120        assert len(ids) == 3
121
122
123def test_row_count(db_table):
124    with db_table.cursor() as cursor:
125        expected_count = 57
126        cursor.executemany(
127            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
128            tuple((i, i, None) for i in range(expected_count)),
129        )
130
131        # Check rowcount after executemany
132        assert expected_count == cursor.rowcount
133
134        cursor.execute("SELECT * FROM t1")
135
136        # Check row_count without doing any reading first...
137        assert expected_count == cursor.rowcount
138
139        # Check rowcount after reading some rows, make sure it still
140        # works...
141        for i in range(expected_count // 2):
142            cursor.fetchone()
143        assert expected_count == cursor.rowcount
144
145    with db_table.cursor() as cursor:
146        # Restart the cursor, read a few rows, and then check rowcount
147        # again...
148        cursor.execute("SELECT * FROM t1")
149        for i in range(expected_count // 3):
150            cursor.fetchone()
151        assert expected_count == cursor.rowcount
152
153        # Should be -1 for a command with no results
154        cursor.execute("DROP TABLE t1")
155        assert -1 == cursor.rowcount
156
157
158def test_row_count_update(db_table):
159    with db_table.cursor() as cursor:
160        cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
161        cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
162        cursor.execute(
163            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None)
164        )
165        cursor.execute(
166            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None)
167        )
168        cursor.execute(
169            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None)
170        )
171        cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
172        assert cursor.rowcount == 2
173
174
175def test_int_oid(cursor):
176    # https://bugs.launchpad.net/pg8000/+bug/230796
177    cursor.execute("SELECT typname FROM pg_type WHERE oid = %s", (100,))
178
179
180def test_unicode_query(cursor):
181    cursor.execute(
182        "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e "
183        "(\u0438\u043c\u044f VARCHAR(50), "
184        "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))"
185    )
186
187
188def test_executemany(db_table):
189    with db_table.cursor() as cursor:
190        cursor.executemany(
191            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
192            ((1, 1, "Avast ye!"), (2, 1, None)),
193        )
194
195        cursor.executemany(
196            "SELECT CAST(%s AS TIMESTAMP)",
197            ((Datetime(2014, 5, 7, tzinfo=Timezone.utc),), (Datetime(2014, 5, 7),)),
198        )
199
200
201def test_executemany_setinputsizes(cursor):
202    """Make sure that setinputsizes works for all the parameter sets"""
203
204    cursor.execute(
205        "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 inet[] not null) "
206    )
207
208    cursor.setinputsizes(INTEGER, INET_ARRAY)
209    cursor.executemany(
210        "INSERT INTO t1 (f1, f2) VALUES (%s, %s)", ((1, ["1.1.1.1"]), (2, ["0.0.0.0"]))
211    )
212
213
214def test_executemany_no_param_sets(cursor):
215    cursor.executemany("INSERT INTO t1 (f1, f2) VALUES (%s, %s)", [])
216    assert cursor.rowcount == -1
217
218
219# Check that autocommit stays off
220# We keep track of whether we're in a transaction or not by using the
221# READY_FOR_QUERY message.
222def test_transactions(db_table):
223    with db_table.cursor() as cursor:
224        cursor.execute("commit")
225        cursor.execute(
226            "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie")
227        )
228        cursor.execute("rollback")
229        cursor.execute("select * from t1")
230
231        assert cursor.rowcount == 0
232
233
234def test_in(cursor):
235    cursor.execute("SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],))
236    ret = cursor.fetchall()
237    assert ret[0][0] == "bool"
238
239
240def test_no_previous_tpc(con):
241    con.tpc_begin("Stacey")
242    with con.cursor() as cursor:
243        cursor.execute("SELECT * FROM pg_type")
244        con.tpc_commit()
245
246
247# Check that tpc_recover() doesn't start a transaction
248def test_tpc_recover(con):
249    con.tpc_recover()
250    with con.cursor() as cursor:
251        con.autocommit = True
252
253        # If tpc_recover() has started a transaction, this will fail
254        cursor.execute("VACUUM")
255
256
257def test_tpc_prepare(con):
258    xid = "Stacey"
259    con.tpc_begin(xid)
260    con.tpc_prepare()
261    con.tpc_rollback(xid)
262
263
264# An empty query should raise a ProgrammingError
265def test_empty_query(cursor):
266    with pytest.raises(pg8000.ProgrammingError):
267        cursor.execute("")
268
269
270# rolling back when not in a transaction doesn't generate a warning
271def test_rollback_no_transaction(con):
272    # Remove any existing notices
273    con.notices.clear()
274
275    # First, verify that a raw rollback does produce a notice
276    con.execute_unnamed("rollback")
277
278    assert 1 == len(con.notices)
279
280    # 25P01 is the code for no_active_sql_tronsaction. It has
281    # a message and severity name, but those might be
282    # localized/depend on the server version.
283    assert con.notices.pop().get(b"C") == b"25P01"
284
285    # Now going through the rollback method doesn't produce
286    # any notices because it knows we're not in a transaction.
287    con.rollback()
288
289    assert 0 == len(con.notices)
290
291
292def test_context_manager_class(con):
293    assert "__enter__" in pg8000.legacy.Cursor.__dict__
294    assert "__exit__" in pg8000.legacy.Cursor.__dict__
295
296    with con.cursor() as cursor:
297        cursor.execute("select 1")
298
299
300def test_close_prepared_statement(con):
301    ps = con.prepare("select 1")
302    ps.run()
303    res = con.run("select count(*) from pg_prepared_statements")
304    assert res[0][0] == 1  # Should have one prepared statement
305
306    ps.close()
307
308    res = con.run("select count(*) from pg_prepared_statements")
309    assert res[0][0] == 0  # Should have no prepared statements
310
311
312def test_setinputsizes(con):
313    cursor = con.cursor()
314    cursor.setinputsizes(20)
315    cursor.execute("select %s", (None,))
316    retval = cursor.fetchall()
317    assert retval[0][0] is None
318
319
320def test_setinputsizes_class(con):
321    cursor = con.cursor()
322    cursor.setinputsizes(bytes)
323    cursor.execute("select %s", (None,))
324    retval = cursor.fetchall()
325    assert retval[0][0] is None
326
327
328def test_unexecuted_cursor_rowcount(con):
329    cursor = con.cursor()
330    assert cursor.rowcount == -1
331
332
333def test_unexecuted_cursor_description(con):
334    cursor = con.cursor()
335    assert cursor.description is None
336
337
338def test_not_parsed_if_no_params(mocker, cursor):
339    mock_convert_paramstyle = mocker.patch("pg8000.legacy.convert_paramstyle")
340    cursor.execute("ROLLBACK")
341    mock_convert_paramstyle.assert_not_called()
342