1# Copyright (C) 2016-present the asyncpg authors and contributors 2# <see AUTHORS file> 3# 4# This module is part of asyncpg and is released under 5# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 6 7 8import asyncio 9import asyncpg 10import gc 11import unittest 12 13from asyncpg import _testbase as tb 14from asyncpg import exceptions 15 16 17class TestPrepare(tb.ConnectedTestCase): 18 19 async def test_prepare_01(self): 20 self.assertEqual(self.con._protocol.queries_count, 0) 21 st = await self.con.prepare('SELECT 1 = $1 AS test') 22 self.assertEqual(self.con._protocol.queries_count, 0) 23 self.assertEqual(st.get_query(), 'SELECT 1 = $1 AS test') 24 25 rec = await st.fetchrow(1) 26 self.assertEqual(self.con._protocol.queries_count, 1) 27 self.assertTrue(rec['test']) 28 self.assertEqual(len(rec), 1) 29 30 self.assertEqual(False, await st.fetchval(10)) 31 self.assertEqual(self.con._protocol.queries_count, 2) 32 33 async def test_prepare_02(self): 34 with self.assertRaisesRegex(Exception, 'column "a" does not exist'): 35 await self.con.prepare('SELECT a') 36 37 async def test_prepare_03(self): 38 cases = [ 39 ('text', ("'NULL'", 'NULL'), [ 40 'aaa', 41 None 42 ]), 43 44 ('decimal', ('0', 0), [ 45 123, 46 123.5, 47 None 48 ]) 49 ] 50 51 for type, (none_name, none_val), vals in cases: 52 st = await self.con.prepare(''' 53 SELECT CASE WHEN $1::{type} IS NULL THEN {default} 54 ELSE $1::{type} END'''.format( 55 type=type, default=none_name)) 56 57 for val in vals: 58 with self.subTest(type=type, value=val): 59 res = await st.fetchval(val) 60 if val is None: 61 self.assertEqual(res, none_val) 62 else: 63 self.assertEqual(res, val) 64 65 async def test_prepare_04(self): 66 s = await self.con.prepare('SELECT $1::smallint') 67 self.assertEqual(await s.fetchval(10), 10) 68 69 s = await self.con.prepare('SELECT $1::smallint * 2') 70 self.assertEqual(await s.fetchval(10), 20) 71 72 s = await self.con.prepare('SELECT generate_series(5,10)') 73 self.assertEqual(await s.fetchval(), 5) 74 # Since the "execute" message was sent with a limit=1, 75 # we will receive a PortalSuspended message, instead of 76 # CommandComplete. Which means there will be no status 77 # message set. 78 self.assertIsNone(s.get_statusmsg()) 79 # Repeat the same test for 'fetchrow()'. 80 self.assertEqual(await s.fetchrow(), (5,)) 81 self.assertIsNone(s.get_statusmsg()) 82 83 async def test_prepare_05_unknownoid(self): 84 s = await self.con.prepare("SELECT 'test'") 85 self.assertEqual(await s.fetchval(), 'test') 86 87 async def test_prepare_06_interrupted_close(self): 88 stmt = await self.con.prepare('''SELECT pg_sleep(10)''') 89 fut = self.loop.create_task(stmt.fetch()) 90 91 await asyncio.sleep(0.2) 92 93 self.assertFalse(self.con.is_closed()) 94 await self.con.close() 95 self.assertTrue(self.con.is_closed()) 96 97 with self.assertRaises(asyncpg.QueryCanceledError): 98 await fut 99 100 # Test that it's OK to call close again 101 await self.con.close() 102 103 async def test_prepare_07_interrupted_terminate(self): 104 stmt = await self.con.prepare('''SELECT pg_sleep(10)''') 105 fut = self.loop.create_task(stmt.fetchval()) 106 107 await asyncio.sleep(0.2) 108 109 self.assertFalse(self.con.is_closed()) 110 self.con.terminate() 111 self.assertTrue(self.con.is_closed()) 112 113 with self.assertRaisesRegex(asyncpg.ConnectionDoesNotExistError, 114 'closed in the middle'): 115 await fut 116 117 # Test that it's OK to call terminate again 118 self.con.terminate() 119 120 async def test_prepare_08_big_result(self): 121 stmt = await self.con.prepare('select generate_series(0,10000)') 122 result = await stmt.fetch() 123 124 self.assertEqual(len(result), 10001) 125 self.assertEqual( 126 [r[0] for r in result], 127 list(range(10001))) 128 129 async def test_prepare_09_raise_error(self): 130 # Stress test ReadBuffer.read_cstr() 131 msg = '0' * 1024 * 100 132 query = """ 133 DO language plpgsql $$ 134 BEGIN 135 RAISE EXCEPTION '{}'; 136 END 137 $$;""".format(msg) 138 139 stmt = await self.con.prepare(query) 140 with self.assertRaisesRegex(asyncpg.RaiseError, msg): 141 with tb.silence_asyncio_long_exec_warning(): 142 await stmt.fetchval() 143 144 async def test_prepare_10_stmt_lru(self): 145 cache = self.con._stmt_cache 146 147 query = 'select {}' 148 cache_max = cache.get_max_size() 149 iter_max = cache_max * 2 + 11 150 151 # First, we have no cached statements. 152 self.assertEqual(len(cache), 0) 153 154 stmts = [] 155 for i in range(iter_max): 156 s = await self.con._prepare(query.format(i), use_cache=True) 157 self.assertEqual(await s.fetchval(), i) 158 stmts.append(s) 159 160 # At this point our cache should be full. 161 self.assertEqual(len(cache), cache_max) 162 self.assertTrue(all(not s.closed for s in cache.iter_statements())) 163 164 # Since there are references to the statements (`stmts` list), 165 # no statements are scheduled to be closed. 166 self.assertEqual(len(self.con._stmts_to_close), 0) 167 168 # Removing refs to statements and preparing a new statement 169 # will cause connection to cleanup any stale statements. 170 stmts.clear() 171 gc.collect() 172 173 # Now we have a bunch of statements that have no refs to them 174 # scheduled to be closed. 175 self.assertEqual(len(self.con._stmts_to_close), iter_max - cache_max) 176 self.assertTrue(all(s.closed for s in self.con._stmts_to_close)) 177 self.assertTrue(all(not s.closed for s in cache.iter_statements())) 178 179 zero = await self.con.prepare(query.format(0)) 180 # Hence, all stale statements should be closed now. 181 self.assertEqual(len(self.con._stmts_to_close), 0) 182 183 # The number of cached statements will stay the same though. 184 self.assertEqual(len(cache), cache_max) 185 self.assertTrue(all(not s.closed for s in cache.iter_statements())) 186 187 # After closing all statements will be closed. 188 await self.con.close() 189 self.assertEqual(len(self.con._stmts_to_close), 0) 190 self.assertEqual(len(cache), 0) 191 192 # An attempt to perform an operation on a closed statement 193 # will trigger an error. 194 with self.assertRaisesRegex(asyncpg.InterfaceError, 'is closed'): 195 await zero.fetchval() 196 197 async def test_prepare_11_stmt_gc(self): 198 # Test that prepared statements should stay in the cache after 199 # they are GCed. 200 201 cache = self.con._stmt_cache 202 203 # First, we have no cached statements. 204 self.assertEqual(len(cache), 0) 205 self.assertEqual(len(self.con._stmts_to_close), 0) 206 207 # The prepared statement that we'll create will be GCed 208 # right await. However, its state should be still in 209 # in the statements LRU cache. 210 await self.con._prepare('select 1', use_cache=True) 211 gc.collect() 212 213 self.assertEqual(len(cache), 1) 214 self.assertEqual(len(self.con._stmts_to_close), 0) 215 216 async def test_prepare_12_stmt_gc(self): 217 # Test that prepared statements are closed when there is no space 218 # for them in the LRU cache and there are no references to them. 219 220 cache = self.con._stmt_cache 221 cache_max = cache.get_max_size() 222 223 # First, we have no cached statements. 224 self.assertEqual(len(cache), 0) 225 self.assertEqual(len(self.con._stmts_to_close), 0) 226 227 stmt = await self.con._prepare('select 100000000', use_cache=True) 228 self.assertEqual(len(cache), 1) 229 self.assertEqual(len(self.con._stmts_to_close), 0) 230 231 for i in range(cache_max): 232 await self.con._prepare('select {}'.format(i), use_cache=True) 233 234 self.assertEqual(len(cache), cache_max) 235 self.assertEqual(len(self.con._stmts_to_close), 0) 236 237 del stmt 238 gc.collect() 239 240 self.assertEqual(len(cache), cache_max) 241 self.assertEqual(len(self.con._stmts_to_close), 1) 242 243 async def test_prepare_13_connect(self): 244 v = await self.con.fetchval( 245 'SELECT $1::smallint AS foo', 10, column='foo') 246 self.assertEqual(v, 10) 247 248 r = await self.con.fetchrow('SELECT $1::smallint * 2 AS test', 10) 249 self.assertEqual(r['test'], 20) 250 251 rows = await self.con.fetch('SELECT generate_series(0,$1::int)', 3) 252 self.assertEqual([r[0] for r in rows], [0, 1, 2, 3]) 253 254 async def test_prepare_14_explain(self): 255 # Test simple EXPLAIN. 256 stmt = await self.con.prepare('SELECT typname FROM pg_type') 257 plan = await stmt.explain() 258 self.assertEqual(plan[0]['Plan']['Relation Name'], 'pg_type') 259 260 # Test "EXPLAIN ANALYZE". 261 stmt = await self.con.prepare( 262 'SELECT typname, typlen FROM pg_type WHERE typlen > $1') 263 plan = await stmt.explain(2, analyze=True) 264 self.assertEqual(plan[0]['Plan']['Relation Name'], 'pg_type') 265 self.assertIn('Actual Total Time', plan[0]['Plan']) 266 267 # Test that 'EXPLAIN ANALYZE' is executed in a transaction 268 # that gets rollbacked. 269 tr = self.con.transaction() 270 await tr.start() 271 try: 272 await self.con.execute('CREATE TABLE mytab (a int)') 273 stmt = await self.con.prepare( 274 'INSERT INTO mytab (a) VALUES (1), (2)') 275 plan = await stmt.explain(analyze=True) 276 self.assertEqual(plan[0]['Plan']['Operation'], 'Insert') 277 278 # Check that no data was inserted 279 res = await self.con.fetch('SELECT * FROM mytab') 280 self.assertEqual(res, []) 281 finally: 282 await tr.rollback() 283 284 async def test_prepare_15_stmt_gc_cache_disabled(self): 285 # Test that even if the statements cache is off, we're still 286 # cleaning up GCed statements. 287 288 cache = self.con._stmt_cache 289 290 self.assertEqual(len(cache), 0) 291 self.assertEqual(len(self.con._stmts_to_close), 0) 292 293 # Disable cache 294 cache.set_max_size(0) 295 296 stmt = await self.con._prepare('select 100000000', use_cache=True) 297 self.assertEqual(len(cache), 0) 298 self.assertEqual(len(self.con._stmts_to_close), 0) 299 300 del stmt 301 gc.collect() 302 303 # After GC, _stmts_to_close should contain stmt's state 304 self.assertEqual(len(cache), 0) 305 self.assertEqual(len(self.con._stmts_to_close), 1) 306 307 # Next "prepare" call will trigger a cleanup 308 stmt = await self.con._prepare('select 1', use_cache=True) 309 self.assertEqual(len(cache), 0) 310 self.assertEqual(len(self.con._stmts_to_close), 0) 311 312 del stmt 313 314 async def test_prepare_16_command_result(self): 315 async def status(query): 316 stmt = await self.con.prepare(query) 317 await stmt.fetch() 318 return stmt.get_statusmsg() 319 320 try: 321 self.assertEqual( 322 await status('CREATE TABLE mytab (a int)'), 323 'CREATE TABLE') 324 325 self.assertEqual( 326 await status('INSERT INTO mytab (a) VALUES (1), (2)'), 327 'INSERT 0 2') 328 329 self.assertEqual( 330 await status('SELECT a FROM mytab'), 331 'SELECT 2') 332 333 self.assertEqual( 334 await status('UPDATE mytab SET a = 3 WHERE a = 1'), 335 'UPDATE 1') 336 finally: 337 self.assertEqual( 338 await status('DROP TABLE mytab'), 339 'DROP TABLE') 340 341 async def test_prepare_17_stmt_closed_lru(self): 342 st = await self.con.prepare('SELECT 1') 343 st._state.mark_closed() 344 with self.assertRaisesRegex(asyncpg.InterfaceError, 'is closed'): 345 await st.fetch() 346 347 st = await self.con.prepare('SELECT 1') 348 self.assertEqual(await st.fetchval(), 1) 349 350 async def test_prepare_18_empty_result(self): 351 # test EmptyQueryResponse protocol message 352 st = await self.con.prepare('') 353 self.assertEqual(await st.fetch(), []) 354 self.assertIsNone(await st.fetchval()) 355 self.assertIsNone(await st.fetchrow()) 356 357 self.assertEqual(await self.con.fetch(''), []) 358 self.assertIsNone(await self.con.fetchval('')) 359 self.assertIsNone(await self.con.fetchrow('')) 360 361 async def test_prepare_19_concurrent_calls(self): 362 st = self.loop.create_task(self.con.fetchval( 363 'SELECT ROW(pg_sleep(0.1), 1)')) 364 365 # Wait for some time to make sure the first query is fully 366 # prepared (!) and is now awaiting the results (!!). 367 await asyncio.sleep(0.01) 368 369 with self.assertRaisesRegex(asyncpg.InterfaceError, 370 'another operation'): 371 await self.con.execute('SELECT 2') 372 373 self.assertEqual(await st, (None, 1)) 374 375 async def test_prepare_20_concurrent_calls(self): 376 expected = ((None, 1),) 377 378 for methname, val in [('fetch', [expected]), 379 ('fetchval', expected[0]), 380 ('fetchrow', expected)]: 381 382 with self.subTest(meth=methname): 383 384 meth = getattr(self.con, methname) 385 386 vf = self.loop.create_task( 387 meth('SELECT ROW(pg_sleep(0.1), 1)')) 388 389 await asyncio.sleep(0.01) 390 391 with self.assertRaisesRegex(asyncpg.InterfaceError, 392 'another operation'): 393 await meth('SELECT 2') 394 395 self.assertEqual(await vf, val) 396 397 async def test_prepare_21_errors(self): 398 stmt = await self.con.prepare('SELECT 10 / $1::int') 399 400 with self.assertRaises(asyncpg.DivisionByZeroError): 401 await stmt.fetchval(0) 402 403 self.assertEqual(await stmt.fetchval(5), 2) 404 405 async def test_prepare_22_empty(self): 406 # Support for empty target list was added in PostgreSQL 9.4 407 if self.server_version < (9, 4): 408 raise unittest.SkipTest( 409 'PostgreSQL servers < 9.4 do not support empty target list.') 410 411 result = await self.con.fetchrow('SELECT') 412 self.assertEqual(result, ()) 413 self.assertEqual(repr(result), '<Record>') 414 415 async def test_prepare_statement_invalid(self): 416 await self.con.execute('CREATE TABLE tab1(a int, b int)') 417 418 try: 419 await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') 420 421 stmt = await self.con.prepare('SELECT * FROM tab1') 422 423 await self.con.execute( 424 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') 425 426 with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, 427 'cached statement plan is invalid'): 428 await stmt.fetchrow() 429 430 finally: 431 await self.con.execute('DROP TABLE tab1') 432 433 @tb.with_connection_options(statement_cache_size=0) 434 async def test_prepare_23_no_stmt_cache_seq(self): 435 self.assertEqual(self.con._stmt_cache.get_max_size(), 0) 436 437 async def check_simple(): 438 # Run a simple query a few times. 439 self.assertEqual(await self.con.fetchval('SELECT 1'), 1) 440 self.assertEqual(await self.con.fetchval('SELECT 2'), 2) 441 self.assertEqual(await self.con.fetchval('SELECT 1'), 1) 442 443 await check_simple() 444 445 # Run a query that timeouts. 446 with self.assertRaises(asyncio.TimeoutError): 447 await self.con.fetchrow('select pg_sleep(10)', timeout=0.02) 448 449 # Check that we can run new queries after a timeout. 450 await check_simple() 451 452 # Try a cursor/timeout combination. Cursors should always use 453 # named prepared statements. 454 async with self.con.transaction(): 455 with self.assertRaises(asyncio.TimeoutError): 456 async for _ in self.con.cursor( # NOQA 457 'select pg_sleep(10)', timeout=0.1): 458 pass 459 460 # Check that we can run queries after a failed cursor 461 # operation. 462 await check_simple() 463 464 @tb.with_connection_options(max_cached_statement_lifetime=142) 465 async def test_prepare_24_max_lifetime(self): 466 cache = self.con._stmt_cache 467 468 self.assertEqual(cache.get_max_lifetime(), 142) 469 cache.set_max_lifetime(1) 470 471 s = await self.con._prepare('SELECT 1', use_cache=True) 472 state = s._state 473 474 s = await self.con._prepare('SELECT 1', use_cache=True) 475 self.assertIs(s._state, state) 476 477 s = await self.con._prepare('SELECT 1', use_cache=True) 478 self.assertIs(s._state, state) 479 480 await asyncio.sleep(1) 481 482 s = await self.con._prepare('SELECT 1', use_cache=True) 483 self.assertIsNot(s._state, state) 484 485 @tb.with_connection_options(max_cached_statement_lifetime=0.5) 486 async def test_prepare_25_max_lifetime_reset(self): 487 cache = self.con._stmt_cache 488 489 s = await self.con._prepare('SELECT 1', use_cache=True) 490 state = s._state 491 492 # Disable max_lifetime 493 cache.set_max_lifetime(0) 494 495 await asyncio.sleep(1) 496 497 # The statement should still be cached (as we disabled the timeout). 498 s = await self.con._prepare('SELECT 1', use_cache=True) 499 self.assertIs(s._state, state) 500 501 @tb.with_connection_options(max_cached_statement_lifetime=0.5) 502 async def test_prepare_26_max_lifetime_max_size(self): 503 cache = self.con._stmt_cache 504 505 s = await self.con._prepare('SELECT 1', use_cache=True) 506 state = s._state 507 508 # Disable max_lifetime 509 cache.set_max_size(0) 510 511 s = await self.con._prepare('SELECT 1', use_cache=True) 512 self.assertIsNot(s._state, state) 513 514 # Check that nothing crashes after the initial timeout 515 await asyncio.sleep(1) 516 517 @tb.with_connection_options(max_cacheable_statement_size=50) 518 async def test_prepare_27_max_cacheable_statement_size(self): 519 cache = self.con._stmt_cache 520 521 await self.con._prepare('SELECT 1', use_cache=True) 522 self.assertEqual(len(cache), 1) 523 524 # Test that long and explicitly created prepared statements 525 # are not cached. 526 await self.con._prepare("SELECT \'" + "a" * 50 + "\'", use_cache=True) 527 self.assertEqual(len(cache), 1) 528 529 # Test that implicitly created long prepared statements 530 # are not cached. 531 await self.con.fetchval("SELECT \'" + "a" * 50 + "\'") 532 self.assertEqual(len(cache), 1) 533 534 # Test that short prepared statements can still be cached. 535 await self.con._prepare('SELECT 2', use_cache=True) 536 self.assertEqual(len(cache), 2) 537 538 async def test_prepare_28_max_args(self): 539 N = 32768 540 args = ','.join('${}'.format(i) for i in range(1, N + 1)) 541 query = 'SELECT ARRAY[{}]'.format(args) 542 543 with self.assertRaisesRegex( 544 exceptions.InterfaceError, 545 'the number of query arguments cannot exceed 32767'): 546 await self.con.fetchval(query, *range(1, N + 1)) 547 548 async def test_prepare_29_duplicates(self): 549 # In addition to test_record.py, let's have a full functional 550 # test for records with duplicate keys. 551 r = await self.con.fetchrow('SELECT 1 as a, 2 as b, 3 as a') 552 self.assertEqual(list(r.items()), [('a', 1), ('b', 2), ('a', 3)]) 553 self.assertEqual(list(r.keys()), ['a', 'b', 'a']) 554 self.assertEqual(list(r.values()), [1, 2, 3]) 555 self.assertEqual(r['a'], 3) 556 self.assertEqual(r['b'], 2) 557 self.assertEqual(r[0], 1) 558 self.assertEqual(r[1], 2) 559 self.assertEqual(r[2], 3) 560 561 async def test_prepare_30_invalid_arg_count(self): 562 with self.assertRaisesRegex( 563 exceptions.InterfaceError, 564 'the server expects 1 argument for this query, 0 were passed'): 565 await self.con.fetchval('SELECT $1::int') 566 567 with self.assertRaisesRegex( 568 exceptions.InterfaceError, 569 'the server expects 0 arguments for this query, 1 was passed'): 570 await self.con.fetchval('SELECT 1', 1) 571 572 async def test_prepare_31_pgbouncer_note(self): 573 try: 574 await self.con.execute(""" 575 DO $$ BEGIN 576 RAISE EXCEPTION 577 'duplicate statement' USING ERRCODE = '42P05'; 578 END; $$ LANGUAGE plpgsql; 579 """) 580 except asyncpg.DuplicatePreparedStatementError as e: 581 self.assertTrue('pgbouncer' in e.hint) 582 else: 583 self.fail('DuplicatePreparedStatementError not raised') 584 585 try: 586 await self.con.execute(""" 587 DO $$ BEGIN 588 RAISE EXCEPTION 589 'invalid statement' USING ERRCODE = '26000'; 590 END; $$ LANGUAGE plpgsql; 591 """) 592 except asyncpg.InvalidSQLStatementNameError as e: 593 self.assertTrue('pgbouncer' in e.hint) 594 else: 595 self.fail('InvalidSQLStatementNameError not raised') 596 597 async def test_prepare_does_not_use_cache(self): 598 cache = self.con._stmt_cache 599 600 # prepare with disabled cache 601 await self.con.prepare('select 1') 602 self.assertEqual(len(cache), 0) 603 604 async def test_prepare_explicitly_named(self): 605 ps = await self.con.prepare('select 1', name='foobar') 606 self.assertEqual(ps.get_name(), 'foobar') 607 self.assertEqual(await self.con.fetchval('EXECUTE foobar'), 1) 608 609 with self.assertRaisesRegex( 610 exceptions.DuplicatePreparedStatementError, 611 'prepared statement "foobar" already exists', 612 ): 613 await self.con.prepare('select 1', name='foobar') 614