1# Copyright (C) 2016-present the asyncpg authors and contributors
2# <see AUTHORS file>
3#
4# This module is part of asyncpg and is released under
5# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6
7
8import asyncio
9import asyncpg
10import gc
11import unittest
12
13from asyncpg import _testbase as tb
14from asyncpg import exceptions
15
16
17class TestPrepare(tb.ConnectedTestCase):
18
19    async def test_prepare_01(self):
20        self.assertEqual(self.con._protocol.queries_count, 0)
21        st = await self.con.prepare('SELECT 1 = $1 AS test')
22        self.assertEqual(self.con._protocol.queries_count, 0)
23        self.assertEqual(st.get_query(), 'SELECT 1 = $1 AS test')
24
25        rec = await st.fetchrow(1)
26        self.assertEqual(self.con._protocol.queries_count, 1)
27        self.assertTrue(rec['test'])
28        self.assertEqual(len(rec), 1)
29
30        self.assertEqual(False, await st.fetchval(10))
31        self.assertEqual(self.con._protocol.queries_count, 2)
32
33    async def test_prepare_02(self):
34        with self.assertRaisesRegex(Exception, 'column "a" does not exist'):
35            await self.con.prepare('SELECT a')
36
37    async def test_prepare_03(self):
38        cases = [
39            ('text', ("'NULL'", 'NULL'), [
40                'aaa',
41                None
42            ]),
43
44            ('decimal', ('0', 0), [
45                123,
46                123.5,
47                None
48            ])
49        ]
50
51        for type, (none_name, none_val), vals in cases:
52            st = await self.con.prepare('''
53                    SELECT CASE WHEN $1::{type} IS NULL THEN {default}
54                    ELSE $1::{type} END'''.format(
55                type=type, default=none_name))
56
57            for val in vals:
58                with self.subTest(type=type, value=val):
59                    res = await st.fetchval(val)
60                    if val is None:
61                        self.assertEqual(res, none_val)
62                    else:
63                        self.assertEqual(res, val)
64
65    async def test_prepare_04(self):
66        s = await self.con.prepare('SELECT $1::smallint')
67        self.assertEqual(await s.fetchval(10), 10)
68
69        s = await self.con.prepare('SELECT $1::smallint * 2')
70        self.assertEqual(await s.fetchval(10), 20)
71
72        s = await self.con.prepare('SELECT generate_series(5,10)')
73        self.assertEqual(await s.fetchval(), 5)
74        # Since the "execute" message was sent with a limit=1,
75        # we will receive a PortalSuspended message, instead of
76        # CommandComplete.  Which means there will be no status
77        # message set.
78        self.assertIsNone(s.get_statusmsg())
79        # Repeat the same test for 'fetchrow()'.
80        self.assertEqual(await s.fetchrow(), (5,))
81        self.assertIsNone(s.get_statusmsg())
82
83    async def test_prepare_05_unknownoid(self):
84        s = await self.con.prepare("SELECT 'test'")
85        self.assertEqual(await s.fetchval(), 'test')
86
87    async def test_prepare_06_interrupted_close(self):
88        stmt = await self.con.prepare('''SELECT pg_sleep(10)''')
89        fut = self.loop.create_task(stmt.fetch())
90
91        await asyncio.sleep(0.2)
92
93        self.assertFalse(self.con.is_closed())
94        await self.con.close()
95        self.assertTrue(self.con.is_closed())
96
97        with self.assertRaises(asyncpg.QueryCanceledError):
98            await fut
99
100        # Test that it's OK to call close again
101        await self.con.close()
102
103    async def test_prepare_07_interrupted_terminate(self):
104        stmt = await self.con.prepare('''SELECT pg_sleep(10)''')
105        fut = self.loop.create_task(stmt.fetchval())
106
107        await asyncio.sleep(0.2)
108
109        self.assertFalse(self.con.is_closed())
110        self.con.terminate()
111        self.assertTrue(self.con.is_closed())
112
113        with self.assertRaisesRegex(asyncpg.ConnectionDoesNotExistError,
114                                    'closed in the middle'):
115            await fut
116
117        # Test that it's OK to call terminate again
118        self.con.terminate()
119
120    async def test_prepare_08_big_result(self):
121        stmt = await self.con.prepare('select generate_series(0,10000)')
122        result = await stmt.fetch()
123
124        self.assertEqual(len(result), 10001)
125        self.assertEqual(
126            [r[0] for r in result],
127            list(range(10001)))
128
129    async def test_prepare_09_raise_error(self):
130        # Stress test ReadBuffer.read_cstr()
131        msg = '0' * 1024 * 100
132        query = """
133        DO language plpgsql $$
134        BEGIN
135        RAISE EXCEPTION '{}';
136        END
137        $$;""".format(msg)
138
139        stmt = await self.con.prepare(query)
140        with self.assertRaisesRegex(asyncpg.RaiseError, msg):
141            with tb.silence_asyncio_long_exec_warning():
142                await stmt.fetchval()
143
144    async def test_prepare_10_stmt_lru(self):
145        cache = self.con._stmt_cache
146
147        query = 'select {}'
148        cache_max = cache.get_max_size()
149        iter_max = cache_max * 2 + 11
150
151        # First, we have no cached statements.
152        self.assertEqual(len(cache), 0)
153
154        stmts = []
155        for i in range(iter_max):
156            s = await self.con._prepare(query.format(i), use_cache=True)
157            self.assertEqual(await s.fetchval(), i)
158            stmts.append(s)
159
160        # At this point our cache should be full.
161        self.assertEqual(len(cache), cache_max)
162        self.assertTrue(all(not s.closed for s in cache.iter_statements()))
163
164        # Since there are references to the statements (`stmts` list),
165        # no statements are scheduled to be closed.
166        self.assertEqual(len(self.con._stmts_to_close), 0)
167
168        # Removing refs to statements and preparing a new statement
169        # will cause connection to cleanup any stale statements.
170        stmts.clear()
171        gc.collect()
172
173        # Now we have a bunch of statements that have no refs to them
174        # scheduled to be closed.
175        self.assertEqual(len(self.con._stmts_to_close), iter_max - cache_max)
176        self.assertTrue(all(s.closed for s in self.con._stmts_to_close))
177        self.assertTrue(all(not s.closed for s in cache.iter_statements()))
178
179        zero = await self.con.prepare(query.format(0))
180        # Hence, all stale statements should be closed now.
181        self.assertEqual(len(self.con._stmts_to_close), 0)
182
183        # The number of cached statements will stay the same though.
184        self.assertEqual(len(cache), cache_max)
185        self.assertTrue(all(not s.closed for s in cache.iter_statements()))
186
187        # After closing all statements will be closed.
188        await self.con.close()
189        self.assertEqual(len(self.con._stmts_to_close), 0)
190        self.assertEqual(len(cache), 0)
191
192        # An attempt to perform an operation on a closed statement
193        # will trigger an error.
194        with self.assertRaisesRegex(asyncpg.InterfaceError, 'is closed'):
195            await zero.fetchval()
196
197    async def test_prepare_11_stmt_gc(self):
198        # Test that prepared statements should stay in the cache after
199        # they are GCed.
200
201        cache = self.con._stmt_cache
202
203        # First, we have no cached statements.
204        self.assertEqual(len(cache), 0)
205        self.assertEqual(len(self.con._stmts_to_close), 0)
206
207        # The prepared statement that we'll create will be GCed
208        # right await.  However, its state should be still in
209        # in the statements LRU cache.
210        await self.con._prepare('select 1', use_cache=True)
211        gc.collect()
212
213        self.assertEqual(len(cache), 1)
214        self.assertEqual(len(self.con._stmts_to_close), 0)
215
216    async def test_prepare_12_stmt_gc(self):
217        # Test that prepared statements are closed when there is no space
218        # for them in the LRU cache and there are no references to them.
219
220        cache = self.con._stmt_cache
221        cache_max = cache.get_max_size()
222
223        # First, we have no cached statements.
224        self.assertEqual(len(cache), 0)
225        self.assertEqual(len(self.con._stmts_to_close), 0)
226
227        stmt = await self.con._prepare('select 100000000', use_cache=True)
228        self.assertEqual(len(cache), 1)
229        self.assertEqual(len(self.con._stmts_to_close), 0)
230
231        for i in range(cache_max):
232            await self.con._prepare('select {}'.format(i), use_cache=True)
233
234        self.assertEqual(len(cache), cache_max)
235        self.assertEqual(len(self.con._stmts_to_close), 0)
236
237        del stmt
238        gc.collect()
239
240        self.assertEqual(len(cache), cache_max)
241        self.assertEqual(len(self.con._stmts_to_close), 1)
242
243    async def test_prepare_13_connect(self):
244        v = await self.con.fetchval(
245            'SELECT $1::smallint AS foo', 10, column='foo')
246        self.assertEqual(v, 10)
247
248        r = await self.con.fetchrow('SELECT $1::smallint * 2 AS test', 10)
249        self.assertEqual(r['test'], 20)
250
251        rows = await self.con.fetch('SELECT generate_series(0,$1::int)', 3)
252        self.assertEqual([r[0] for r in rows], [0, 1, 2, 3])
253
254    async def test_prepare_14_explain(self):
255        # Test simple EXPLAIN.
256        stmt = await self.con.prepare('SELECT typname FROM pg_type')
257        plan = await stmt.explain()
258        self.assertEqual(plan[0]['Plan']['Relation Name'], 'pg_type')
259
260        # Test "EXPLAIN ANALYZE".
261        stmt = await self.con.prepare(
262            'SELECT typname, typlen FROM pg_type WHERE typlen > $1')
263        plan = await stmt.explain(2, analyze=True)
264        self.assertEqual(plan[0]['Plan']['Relation Name'], 'pg_type')
265        self.assertIn('Actual Total Time', plan[0]['Plan'])
266
267        # Test that 'EXPLAIN ANALYZE' is executed in a transaction
268        # that gets rollbacked.
269        tr = self.con.transaction()
270        await tr.start()
271        try:
272            await self.con.execute('CREATE TABLE mytab (a int)')
273            stmt = await self.con.prepare(
274                'INSERT INTO mytab (a) VALUES (1), (2)')
275            plan = await stmt.explain(analyze=True)
276            self.assertEqual(plan[0]['Plan']['Operation'], 'Insert')
277
278            # Check that no data was inserted
279            res = await self.con.fetch('SELECT * FROM mytab')
280            self.assertEqual(res, [])
281        finally:
282            await tr.rollback()
283
284    async def test_prepare_15_stmt_gc_cache_disabled(self):
285        # Test that even if the statements cache is off, we're still
286        # cleaning up GCed statements.
287
288        cache = self.con._stmt_cache
289
290        self.assertEqual(len(cache), 0)
291        self.assertEqual(len(self.con._stmts_to_close), 0)
292
293        # Disable cache
294        cache.set_max_size(0)
295
296        stmt = await self.con._prepare('select 100000000', use_cache=True)
297        self.assertEqual(len(cache), 0)
298        self.assertEqual(len(self.con._stmts_to_close), 0)
299
300        del stmt
301        gc.collect()
302
303        # After GC, _stmts_to_close should contain stmt's state
304        self.assertEqual(len(cache), 0)
305        self.assertEqual(len(self.con._stmts_to_close), 1)
306
307        # Next "prepare" call will trigger a cleanup
308        stmt = await self.con._prepare('select 1', use_cache=True)
309        self.assertEqual(len(cache), 0)
310        self.assertEqual(len(self.con._stmts_to_close), 0)
311
312        del stmt
313
314    async def test_prepare_16_command_result(self):
315        async def status(query):
316            stmt = await self.con.prepare(query)
317            await stmt.fetch()
318            return stmt.get_statusmsg()
319
320        try:
321            self.assertEqual(
322                await status('CREATE TABLE mytab (a int)'),
323                'CREATE TABLE')
324
325            self.assertEqual(
326                await status('INSERT INTO mytab (a) VALUES (1), (2)'),
327                'INSERT 0 2')
328
329            self.assertEqual(
330                await status('SELECT a FROM mytab'),
331                'SELECT 2')
332
333            self.assertEqual(
334                await status('UPDATE mytab SET a = 3 WHERE a = 1'),
335                'UPDATE 1')
336        finally:
337            self.assertEqual(
338                await status('DROP TABLE mytab'),
339                'DROP TABLE')
340
341    async def test_prepare_17_stmt_closed_lru(self):
342        st = await self.con.prepare('SELECT 1')
343        st._state.mark_closed()
344        with self.assertRaisesRegex(asyncpg.InterfaceError, 'is closed'):
345            await st.fetch()
346
347        st = await self.con.prepare('SELECT 1')
348        self.assertEqual(await st.fetchval(), 1)
349
350    async def test_prepare_18_empty_result(self):
351        # test EmptyQueryResponse protocol message
352        st = await self.con.prepare('')
353        self.assertEqual(await st.fetch(), [])
354        self.assertIsNone(await st.fetchval())
355        self.assertIsNone(await st.fetchrow())
356
357        self.assertEqual(await self.con.fetch(''), [])
358        self.assertIsNone(await self.con.fetchval(''))
359        self.assertIsNone(await self.con.fetchrow(''))
360
361    async def test_prepare_19_concurrent_calls(self):
362        st = self.loop.create_task(self.con.fetchval(
363            'SELECT ROW(pg_sleep(0.1), 1)'))
364
365        # Wait for some time to make sure the first query is fully
366        # prepared (!) and is now awaiting the results (!!).
367        await asyncio.sleep(0.01)
368
369        with self.assertRaisesRegex(asyncpg.InterfaceError,
370                                    'another operation'):
371            await self.con.execute('SELECT 2')
372
373        self.assertEqual(await st, (None, 1))
374
375    async def test_prepare_20_concurrent_calls(self):
376        expected = ((None, 1),)
377
378        for methname, val in [('fetch', [expected]),
379                              ('fetchval', expected[0]),
380                              ('fetchrow', expected)]:
381
382            with self.subTest(meth=methname):
383
384                meth = getattr(self.con, methname)
385
386                vf = self.loop.create_task(
387                    meth('SELECT ROW(pg_sleep(0.1), 1)'))
388
389                await asyncio.sleep(0.01)
390
391                with self.assertRaisesRegex(asyncpg.InterfaceError,
392                                            'another operation'):
393                    await meth('SELECT 2')
394
395                self.assertEqual(await vf, val)
396
397    async def test_prepare_21_errors(self):
398        stmt = await self.con.prepare('SELECT 10 / $1::int')
399
400        with self.assertRaises(asyncpg.DivisionByZeroError):
401            await stmt.fetchval(0)
402
403        self.assertEqual(await stmt.fetchval(5), 2)
404
405    async def test_prepare_22_empty(self):
406        # Support for empty target list was added in PostgreSQL 9.4
407        if self.server_version < (9, 4):
408            raise unittest.SkipTest(
409                'PostgreSQL servers < 9.4 do not support empty target list.')
410
411        result = await self.con.fetchrow('SELECT')
412        self.assertEqual(result, ())
413        self.assertEqual(repr(result), '<Record>')
414
415    async def test_prepare_statement_invalid(self):
416        await self.con.execute('CREATE TABLE tab1(a int, b int)')
417
418        try:
419            await self.con.execute('INSERT INTO tab1 VALUES (1, 2)')
420
421            stmt = await self.con.prepare('SELECT * FROM tab1')
422
423            await self.con.execute(
424                'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text')
425
426            with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError,
427                                        'cached statement plan is invalid'):
428                await stmt.fetchrow()
429
430        finally:
431            await self.con.execute('DROP TABLE tab1')
432
433    @tb.with_connection_options(statement_cache_size=0)
434    async def test_prepare_23_no_stmt_cache_seq(self):
435        self.assertEqual(self.con._stmt_cache.get_max_size(), 0)
436
437        async def check_simple():
438            # Run a simple query a few times.
439            self.assertEqual(await self.con.fetchval('SELECT 1'), 1)
440            self.assertEqual(await self.con.fetchval('SELECT 2'), 2)
441            self.assertEqual(await self.con.fetchval('SELECT 1'), 1)
442
443        await check_simple()
444
445        # Run a query that timeouts.
446        with self.assertRaises(asyncio.TimeoutError):
447            await self.con.fetchrow('select pg_sleep(10)', timeout=0.02)
448
449        # Check that we can run new queries after a timeout.
450        await check_simple()
451
452        # Try a cursor/timeout combination. Cursors should always use
453        # named prepared statements.
454        async with self.con.transaction():
455            with self.assertRaises(asyncio.TimeoutError):
456                async for _ in self.con.cursor(   # NOQA
457                        'select pg_sleep(10)', timeout=0.1):
458                    pass
459
460        # Check that we can run queries after a failed cursor
461        # operation.
462        await check_simple()
463
464    @tb.with_connection_options(max_cached_statement_lifetime=142)
465    async def test_prepare_24_max_lifetime(self):
466        cache = self.con._stmt_cache
467
468        self.assertEqual(cache.get_max_lifetime(), 142)
469        cache.set_max_lifetime(1)
470
471        s = await self.con._prepare('SELECT 1', use_cache=True)
472        state = s._state
473
474        s = await self.con._prepare('SELECT 1', use_cache=True)
475        self.assertIs(s._state, state)
476
477        s = await self.con._prepare('SELECT 1', use_cache=True)
478        self.assertIs(s._state, state)
479
480        await asyncio.sleep(1)
481
482        s = await self.con._prepare('SELECT 1', use_cache=True)
483        self.assertIsNot(s._state, state)
484
485    @tb.with_connection_options(max_cached_statement_lifetime=0.5)
486    async def test_prepare_25_max_lifetime_reset(self):
487        cache = self.con._stmt_cache
488
489        s = await self.con._prepare('SELECT 1', use_cache=True)
490        state = s._state
491
492        # Disable max_lifetime
493        cache.set_max_lifetime(0)
494
495        await asyncio.sleep(1)
496
497        # The statement should still be cached (as we disabled the timeout).
498        s = await self.con._prepare('SELECT 1', use_cache=True)
499        self.assertIs(s._state, state)
500
501    @tb.with_connection_options(max_cached_statement_lifetime=0.5)
502    async def test_prepare_26_max_lifetime_max_size(self):
503        cache = self.con._stmt_cache
504
505        s = await self.con._prepare('SELECT 1', use_cache=True)
506        state = s._state
507
508        # Disable max_lifetime
509        cache.set_max_size(0)
510
511        s = await self.con._prepare('SELECT 1', use_cache=True)
512        self.assertIsNot(s._state, state)
513
514        # Check that nothing crashes after the initial timeout
515        await asyncio.sleep(1)
516
517    @tb.with_connection_options(max_cacheable_statement_size=50)
518    async def test_prepare_27_max_cacheable_statement_size(self):
519        cache = self.con._stmt_cache
520
521        await self.con._prepare('SELECT 1', use_cache=True)
522        self.assertEqual(len(cache), 1)
523
524        # Test that long and explicitly created prepared statements
525        # are not cached.
526        await self.con._prepare("SELECT \'" + "a" * 50 + "\'", use_cache=True)
527        self.assertEqual(len(cache), 1)
528
529        # Test that implicitly created long prepared statements
530        # are not cached.
531        await self.con.fetchval("SELECT \'" + "a" * 50 + "\'")
532        self.assertEqual(len(cache), 1)
533
534        # Test that short prepared statements can still be cached.
535        await self.con._prepare('SELECT 2', use_cache=True)
536        self.assertEqual(len(cache), 2)
537
538    async def test_prepare_28_max_args(self):
539        N = 32768
540        args = ','.join('${}'.format(i) for i in range(1, N + 1))
541        query = 'SELECT ARRAY[{}]'.format(args)
542
543        with self.assertRaisesRegex(
544                exceptions.InterfaceError,
545                'the number of query arguments cannot exceed 32767'):
546            await self.con.fetchval(query, *range(1, N + 1))
547
548    async def test_prepare_29_duplicates(self):
549        # In addition to test_record.py, let's have a full functional
550        # test for records with duplicate keys.
551        r = await self.con.fetchrow('SELECT 1 as a, 2 as b, 3 as a')
552        self.assertEqual(list(r.items()), [('a', 1), ('b', 2), ('a', 3)])
553        self.assertEqual(list(r.keys()), ['a', 'b', 'a'])
554        self.assertEqual(list(r.values()), [1, 2, 3])
555        self.assertEqual(r['a'], 3)
556        self.assertEqual(r['b'], 2)
557        self.assertEqual(r[0], 1)
558        self.assertEqual(r[1], 2)
559        self.assertEqual(r[2], 3)
560
561    async def test_prepare_30_invalid_arg_count(self):
562        with self.assertRaisesRegex(
563                exceptions.InterfaceError,
564                'the server expects 1 argument for this query, 0 were passed'):
565            await self.con.fetchval('SELECT $1::int')
566
567        with self.assertRaisesRegex(
568                exceptions.InterfaceError,
569                'the server expects 0 arguments for this query, 1 was passed'):
570            await self.con.fetchval('SELECT 1', 1)
571
572    async def test_prepare_31_pgbouncer_note(self):
573        try:
574            await self.con.execute("""
575                DO $$ BEGIN
576                    RAISE EXCEPTION
577                        'duplicate statement' USING ERRCODE = '42P05';
578                END; $$ LANGUAGE plpgsql;
579            """)
580        except asyncpg.DuplicatePreparedStatementError as e:
581            self.assertTrue('pgbouncer' in e.hint)
582        else:
583            self.fail('DuplicatePreparedStatementError not raised')
584
585        try:
586            await self.con.execute("""
587                DO $$ BEGIN
588                    RAISE EXCEPTION
589                        'invalid statement' USING ERRCODE = '26000';
590                END; $$ LANGUAGE plpgsql;
591            """)
592        except asyncpg.InvalidSQLStatementNameError as e:
593            self.assertTrue('pgbouncer' in e.hint)
594        else:
595            self.fail('InvalidSQLStatementNameError not raised')
596
597    async def test_prepare_does_not_use_cache(self):
598        cache = self.con._stmt_cache
599
600        # prepare with disabled cache
601        await self.con.prepare('select 1')
602        self.assertEqual(len(cache), 0)
603
604    async def test_prepare_explicitly_named(self):
605        ps = await self.con.prepare('select 1', name='foobar')
606        self.assertEqual(ps.get_name(), 'foobar')
607        self.assertEqual(await self.con.fetchval('EXECUTE foobar'), 1)
608
609        with self.assertRaisesRegex(
610            exceptions.DuplicatePreparedStatementError,
611            'prepared statement "foobar" already exists',
612        ):
613            await self.con.prepare('select 1', name='foobar')
614