1# pysqlite2/test/dbapi.py: tests for DB-API compliance
2#
3# Copyright (C) 2004-2010 Gerhard Häring <gh@ghaering.de>
4#
5# This file is part of pysqlite.
6#
7# This software is provided 'as-is', without any express or implied
8# warranty.  In no event will the authors be held liable for any damages
9# arising from the use of this software.
10#
11# Permission is granted to anyone to use this software for any purpose,
12# including commercial applications, and to alter it and redistribute it
13# freely, subject to the following restrictions:
14#
15# 1. The origin of this software must not be misrepresented; you must not
16#    claim that you wrote the original software. If you use this software
17#    in a product, an acknowledgment in the product documentation would be
18#    appreciated but is not required.
19# 2. Altered source versions must be plainly marked as such, and must not be
20#    misrepresented as being the original software.
21# 3. This notice may not be removed or altered from any source distribution.
22
23import contextlib
24import sqlite3 as sqlite
25import subprocess
26import sys
27import threading
28import unittest
29
30from test.support import (
31    SHORT_TIMEOUT,
32    check_disallow_instantiation,
33    threading_helper,
34)
35from test.support.os_helper import TESTFN, unlink, temp_dir
36
37
38# Helper for tests using TESTFN
39@contextlib.contextmanager
40def managed_connect(*args, in_mem=False, **kwargs):
41    cx = sqlite.connect(*args, **kwargs)
42    try:
43        yield cx
44    finally:
45        cx.close()
46        if not in_mem:
47            unlink(TESTFN)
48
49
50# Helper for temporary memory databases
51def memory_database(*args, **kwargs):
52    cx = sqlite.connect(":memory:", *args, **kwargs)
53    return contextlib.closing(cx)
54
55
56# Temporarily limit a database connection parameter
57@contextlib.contextmanager
58def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128):
59    try:
60        _prev = cx.setlimit(category, limit)
61        yield limit
62    finally:
63        cx.setlimit(category, _prev)
64
65
66class ModuleTests(unittest.TestCase):
67    def test_api_level(self):
68        self.assertEqual(sqlite.apilevel, "2.0",
69                         "apilevel is %s, should be 2.0" % sqlite.apilevel)
70
71    def test_thread_safety(self):
72        self.assertIn(sqlite.threadsafety, {0, 1, 3},
73                      "threadsafety is %d, should be 0, 1 or 3" %
74                      sqlite.threadsafety)
75
76    def test_param_style(self):
77        self.assertEqual(sqlite.paramstyle, "qmark",
78                         "paramstyle is '%s', should be 'qmark'" %
79                         sqlite.paramstyle)
80
81    def test_warning(self):
82        self.assertTrue(issubclass(sqlite.Warning, Exception),
83                     "Warning is not a subclass of Exception")
84
85    def test_error(self):
86        self.assertTrue(issubclass(sqlite.Error, Exception),
87                        "Error is not a subclass of Exception")
88
89    def test_interface_error(self):
90        self.assertTrue(issubclass(sqlite.InterfaceError, sqlite.Error),
91                        "InterfaceError is not a subclass of Error")
92
93    def test_database_error(self):
94        self.assertTrue(issubclass(sqlite.DatabaseError, sqlite.Error),
95                        "DatabaseError is not a subclass of Error")
96
97    def test_data_error(self):
98        self.assertTrue(issubclass(sqlite.DataError, sqlite.DatabaseError),
99                        "DataError is not a subclass of DatabaseError")
100
101    def test_operational_error(self):
102        self.assertTrue(issubclass(sqlite.OperationalError, sqlite.DatabaseError),
103                        "OperationalError is not a subclass of DatabaseError")
104
105    def test_integrity_error(self):
106        self.assertTrue(issubclass(sqlite.IntegrityError, sqlite.DatabaseError),
107                        "IntegrityError is not a subclass of DatabaseError")
108
109    def test_internal_error(self):
110        self.assertTrue(issubclass(sqlite.InternalError, sqlite.DatabaseError),
111                        "InternalError is not a subclass of DatabaseError")
112
113    def test_programming_error(self):
114        self.assertTrue(issubclass(sqlite.ProgrammingError, sqlite.DatabaseError),
115                        "ProgrammingError is not a subclass of DatabaseError")
116
117    def test_not_supported_error(self):
118        self.assertTrue(issubclass(sqlite.NotSupportedError,
119                                   sqlite.DatabaseError),
120                        "NotSupportedError is not a subclass of DatabaseError")
121
122    def test_module_constants(self):
123        consts = [
124            "SQLITE_ABORT",
125            "SQLITE_ALTER_TABLE",
126            "SQLITE_ANALYZE",
127            "SQLITE_ATTACH",
128            "SQLITE_AUTH",
129            "SQLITE_BUSY",
130            "SQLITE_CANTOPEN",
131            "SQLITE_CONSTRAINT",
132            "SQLITE_CORRUPT",
133            "SQLITE_CREATE_INDEX",
134            "SQLITE_CREATE_TABLE",
135            "SQLITE_CREATE_TEMP_INDEX",
136            "SQLITE_CREATE_TEMP_TABLE",
137            "SQLITE_CREATE_TEMP_TRIGGER",
138            "SQLITE_CREATE_TEMP_VIEW",
139            "SQLITE_CREATE_TRIGGER",
140            "SQLITE_CREATE_VIEW",
141            "SQLITE_CREATE_VTABLE",
142            "SQLITE_DELETE",
143            "SQLITE_DENY",
144            "SQLITE_DETACH",
145            "SQLITE_DONE",
146            "SQLITE_DROP_INDEX",
147            "SQLITE_DROP_TABLE",
148            "SQLITE_DROP_TEMP_INDEX",
149            "SQLITE_DROP_TEMP_TABLE",
150            "SQLITE_DROP_TEMP_TRIGGER",
151            "SQLITE_DROP_TEMP_VIEW",
152            "SQLITE_DROP_TRIGGER",
153            "SQLITE_DROP_VIEW",
154            "SQLITE_DROP_VTABLE",
155            "SQLITE_EMPTY",
156            "SQLITE_ERROR",
157            "SQLITE_FORMAT",
158            "SQLITE_FULL",
159            "SQLITE_FUNCTION",
160            "SQLITE_IGNORE",
161            "SQLITE_INSERT",
162            "SQLITE_INTERNAL",
163            "SQLITE_INTERRUPT",
164            "SQLITE_IOERR",
165            "SQLITE_LOCKED",
166            "SQLITE_MISMATCH",
167            "SQLITE_MISUSE",
168            "SQLITE_NOLFS",
169            "SQLITE_NOMEM",
170            "SQLITE_NOTADB",
171            "SQLITE_NOTFOUND",
172            "SQLITE_OK",
173            "SQLITE_PERM",
174            "SQLITE_PRAGMA",
175            "SQLITE_PROTOCOL",
176            "SQLITE_RANGE",
177            "SQLITE_READ",
178            "SQLITE_READONLY",
179            "SQLITE_REINDEX",
180            "SQLITE_ROW",
181            "SQLITE_SAVEPOINT",
182            "SQLITE_SCHEMA",
183            "SQLITE_SELECT",
184            "SQLITE_TOOBIG",
185            "SQLITE_TRANSACTION",
186            "SQLITE_UPDATE",
187            # Run-time limit categories
188            "SQLITE_LIMIT_LENGTH",
189            "SQLITE_LIMIT_SQL_LENGTH",
190            "SQLITE_LIMIT_COLUMN",
191            "SQLITE_LIMIT_EXPR_DEPTH",
192            "SQLITE_LIMIT_COMPOUND_SELECT",
193            "SQLITE_LIMIT_VDBE_OP",
194            "SQLITE_LIMIT_FUNCTION_ARG",
195            "SQLITE_LIMIT_ATTACHED",
196            "SQLITE_LIMIT_LIKE_PATTERN_LENGTH",
197            "SQLITE_LIMIT_VARIABLE_NUMBER",
198            "SQLITE_LIMIT_TRIGGER_DEPTH",
199        ]
200        if sqlite.sqlite_version_info >= (3, 7, 17):
201            consts += ["SQLITE_NOTICE", "SQLITE_WARNING"]
202        if sqlite.sqlite_version_info >= (3, 8, 3):
203            consts.append("SQLITE_RECURSIVE")
204        if sqlite.sqlite_version_info >= (3, 8, 7):
205            consts.append("SQLITE_LIMIT_WORKER_THREADS")
206        consts += ["PARSE_DECLTYPES", "PARSE_COLNAMES"]
207        # Extended result codes
208        consts += [
209            "SQLITE_ABORT_ROLLBACK",
210            "SQLITE_BUSY_RECOVERY",
211            "SQLITE_CANTOPEN_FULLPATH",
212            "SQLITE_CANTOPEN_ISDIR",
213            "SQLITE_CANTOPEN_NOTEMPDIR",
214            "SQLITE_CORRUPT_VTAB",
215            "SQLITE_IOERR_ACCESS",
216            "SQLITE_IOERR_BLOCKED",
217            "SQLITE_IOERR_CHECKRESERVEDLOCK",
218            "SQLITE_IOERR_CLOSE",
219            "SQLITE_IOERR_DELETE",
220            "SQLITE_IOERR_DELETE_NOENT",
221            "SQLITE_IOERR_DIR_CLOSE",
222            "SQLITE_IOERR_DIR_FSYNC",
223            "SQLITE_IOERR_FSTAT",
224            "SQLITE_IOERR_FSYNC",
225            "SQLITE_IOERR_LOCK",
226            "SQLITE_IOERR_NOMEM",
227            "SQLITE_IOERR_RDLOCK",
228            "SQLITE_IOERR_READ",
229            "SQLITE_IOERR_SEEK",
230            "SQLITE_IOERR_SHMLOCK",
231            "SQLITE_IOERR_SHMMAP",
232            "SQLITE_IOERR_SHMOPEN",
233            "SQLITE_IOERR_SHMSIZE",
234            "SQLITE_IOERR_SHORT_READ",
235            "SQLITE_IOERR_TRUNCATE",
236            "SQLITE_IOERR_UNLOCK",
237            "SQLITE_IOERR_WRITE",
238            "SQLITE_LOCKED_SHAREDCACHE",
239            "SQLITE_READONLY_CANTLOCK",
240            "SQLITE_READONLY_RECOVERY",
241        ]
242        if sqlite.version_info >= (3, 7, 16):
243            consts += [
244                "SQLITE_CONSTRAINT_CHECK",
245                "SQLITE_CONSTRAINT_COMMITHOOK",
246                "SQLITE_CONSTRAINT_FOREIGNKEY",
247                "SQLITE_CONSTRAINT_FUNCTION",
248                "SQLITE_CONSTRAINT_NOTNULL",
249                "SQLITE_CONSTRAINT_PRIMARYKEY",
250                "SQLITE_CONSTRAINT_TRIGGER",
251                "SQLITE_CONSTRAINT_UNIQUE",
252                "SQLITE_CONSTRAINT_VTAB",
253                "SQLITE_READONLY_ROLLBACK",
254            ]
255        if sqlite.version_info >= (3, 7, 17):
256            consts += [
257                "SQLITE_IOERR_MMAP",
258                "SQLITE_NOTICE_RECOVER_ROLLBACK",
259                "SQLITE_NOTICE_RECOVER_WAL",
260            ]
261        if sqlite.version_info >= (3, 8, 0):
262            consts += [
263                "SQLITE_BUSY_SNAPSHOT",
264                "SQLITE_IOERR_GETTEMPPATH",
265                "SQLITE_WARNING_AUTOINDEX",
266            ]
267        if sqlite.version_info >= (3, 8, 1):
268            consts += ["SQLITE_CANTOPEN_CONVPATH", "SQLITE_IOERR_CONVPATH"]
269        if sqlite.version_info >= (3, 8, 2):
270            consts.append("SQLITE_CONSTRAINT_ROWID")
271        if sqlite.version_info >= (3, 8, 3):
272            consts.append("SQLITE_READONLY_DBMOVED")
273        if sqlite.version_info >= (3, 8, 7):
274            consts.append("SQLITE_AUTH_USER")
275        if sqlite.version_info >= (3, 9, 0):
276            consts.append("SQLITE_IOERR_VNODE")
277        if sqlite.version_info >= (3, 10, 0):
278            consts.append("SQLITE_IOERR_AUTH")
279        if sqlite.version_info >= (3, 14, 1):
280            consts.append("SQLITE_OK_LOAD_PERMANENTLY")
281        if sqlite.version_info >= (3, 21, 0):
282            consts += [
283                "SQLITE_IOERR_BEGIN_ATOMIC",
284                "SQLITE_IOERR_COMMIT_ATOMIC",
285                "SQLITE_IOERR_ROLLBACK_ATOMIC",
286            ]
287        if sqlite.version_info >= (3, 22, 0):
288            consts += [
289                "SQLITE_ERROR_MISSING_COLLSEQ",
290                "SQLITE_ERROR_RETRY",
291                "SQLITE_READONLY_CANTINIT",
292                "SQLITE_READONLY_DIRECTORY",
293            ]
294        if sqlite.version_info >= (3, 24, 0):
295            consts += ["SQLITE_CORRUPT_SEQUENCE", "SQLITE_LOCKED_VTAB"]
296        if sqlite.version_info >= (3, 25, 0):
297            consts += ["SQLITE_CANTOPEN_DIRTYWAL", "SQLITE_ERROR_SNAPSHOT"]
298        if sqlite.version_info >= (3, 31, 0):
299            consts += [
300                "SQLITE_CANTOPEN_SYMLINK",
301                "SQLITE_CONSTRAINT_PINNED",
302                "SQLITE_OK_SYMLINK",
303            ]
304        if sqlite.version_info >= (3, 32, 0):
305            consts += [
306                "SQLITE_BUSY_TIMEOUT",
307                "SQLITE_CORRUPT_INDEX",
308                "SQLITE_IOERR_DATA",
309            ]
310        if sqlite.version_info >= (3, 34, 0):
311            const.append("SQLITE_IOERR_CORRUPTFS")
312        for const in consts:
313            with self.subTest(const=const):
314                self.assertTrue(hasattr(sqlite, const))
315
316    def test_error_code_on_exception(self):
317        err_msg = "unable to open database file"
318        if sys.platform.startswith("win"):
319            err_code = sqlite.SQLITE_CANTOPEN_ISDIR
320        else:
321            err_code = sqlite.SQLITE_CANTOPEN
322
323        with temp_dir() as db:
324            with self.assertRaisesRegex(sqlite.Error, err_msg) as cm:
325                sqlite.connect(db)
326            e = cm.exception
327            self.assertEqual(e.sqlite_errorcode, err_code)
328            self.assertTrue(e.sqlite_errorname.startswith("SQLITE_CANTOPEN"))
329
330    @unittest.skipIf(sqlite.sqlite_version_info <= (3, 7, 16),
331                     "Requires SQLite 3.7.16 or newer")
332    def test_extended_error_code_on_exception(self):
333        with managed_connect(":memory:", in_mem=True) as con:
334            with con:
335                con.execute("create table t(t integer check(t > 0))")
336            errmsg = "constraint failed"
337            with self.assertRaisesRegex(sqlite.IntegrityError, errmsg) as cm:
338                con.execute("insert into t values(-1)")
339            exc = cm.exception
340            self.assertEqual(exc.sqlite_errorcode,
341                             sqlite.SQLITE_CONSTRAINT_CHECK)
342            self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK")
343
344    # sqlite3_enable_shared_cache() is deprecated on macOS and calling it may raise
345    # OperationalError on some buildbots.
346    @unittest.skipIf(sys.platform == "darwin", "shared cache is deprecated on macOS")
347    def test_shared_cache_deprecated(self):
348        for enable in (True, False):
349            with self.assertWarns(DeprecationWarning) as cm:
350                sqlite.enable_shared_cache(enable)
351            self.assertIn("dbapi.py", cm.filename)
352
353    def test_disallow_instantiation(self):
354        cx = sqlite.connect(":memory:")
355        check_disallow_instantiation(self, type(cx("select 1")))
356
357    def test_complete_statement(self):
358        self.assertFalse(sqlite.complete_statement("select t"))
359        self.assertTrue(sqlite.complete_statement("create table t(t);"))
360
361
362class ConnectionTests(unittest.TestCase):
363
364    def setUp(self):
365        self.cx = sqlite.connect(":memory:")
366        cu = self.cx.cursor()
367        cu.execute("create table test(id integer primary key, name text)")
368        cu.execute("insert into test(name) values (?)", ("foo",))
369
370    def tearDown(self):
371        self.cx.close()
372
373    def test_commit(self):
374        self.cx.commit()
375
376    def test_commit_after_no_changes(self):
377        """
378        A commit should also work when no changes were made to the database.
379        """
380        self.cx.commit()
381        self.cx.commit()
382
383    def test_rollback(self):
384        self.cx.rollback()
385
386    def test_rollback_after_no_changes(self):
387        """
388        A rollback should also work when no changes were made to the database.
389        """
390        self.cx.rollback()
391        self.cx.rollback()
392
393    def test_cursor(self):
394        cu = self.cx.cursor()
395
396    def test_failed_open(self):
397        YOU_CANNOT_OPEN_THIS = "/foo/bar/bla/23534/mydb.db"
398        with self.assertRaises(sqlite.OperationalError):
399            con = sqlite.connect(YOU_CANNOT_OPEN_THIS)
400
401    def test_close(self):
402        self.cx.close()
403
404    def test_use_after_close(self):
405        sql = "select 1"
406        cu = self.cx.cursor()
407        res = cu.execute(sql)
408        self.cx.close()
409        self.assertRaises(sqlite.ProgrammingError, res.fetchall)
410        self.assertRaises(sqlite.ProgrammingError, cu.execute, sql)
411        self.assertRaises(sqlite.ProgrammingError, cu.executemany, sql, [])
412        self.assertRaises(sqlite.ProgrammingError, cu.executescript, sql)
413        self.assertRaises(sqlite.ProgrammingError, self.cx.execute, sql)
414        self.assertRaises(sqlite.ProgrammingError,
415                          self.cx.executemany, sql, [])
416        self.assertRaises(sqlite.ProgrammingError, self.cx.executescript, sql)
417        self.assertRaises(sqlite.ProgrammingError,
418                          self.cx.create_function, "t", 1, lambda x: x)
419        self.assertRaises(sqlite.ProgrammingError, self.cx.cursor)
420        with self.assertRaises(sqlite.ProgrammingError):
421            with self.cx:
422                pass
423
424    def test_exceptions(self):
425        # Optional DB-API extension.
426        self.assertEqual(self.cx.Warning, sqlite.Warning)
427        self.assertEqual(self.cx.Error, sqlite.Error)
428        self.assertEqual(self.cx.InterfaceError, sqlite.InterfaceError)
429        self.assertEqual(self.cx.DatabaseError, sqlite.DatabaseError)
430        self.assertEqual(self.cx.DataError, sqlite.DataError)
431        self.assertEqual(self.cx.OperationalError, sqlite.OperationalError)
432        self.assertEqual(self.cx.IntegrityError, sqlite.IntegrityError)
433        self.assertEqual(self.cx.InternalError, sqlite.InternalError)
434        self.assertEqual(self.cx.ProgrammingError, sqlite.ProgrammingError)
435        self.assertEqual(self.cx.NotSupportedError, sqlite.NotSupportedError)
436
437    def test_in_transaction(self):
438        # Can't use db from setUp because we want to test initial state.
439        cx = sqlite.connect(":memory:")
440        cu = cx.cursor()
441        self.assertEqual(cx.in_transaction, False)
442        cu.execute("create table transactiontest(id integer primary key, name text)")
443        self.assertEqual(cx.in_transaction, False)
444        cu.execute("insert into transactiontest(name) values (?)", ("foo",))
445        self.assertEqual(cx.in_transaction, True)
446        cu.execute("select name from transactiontest where name=?", ["foo"])
447        row = cu.fetchone()
448        self.assertEqual(cx.in_transaction, True)
449        cx.commit()
450        self.assertEqual(cx.in_transaction, False)
451        cu.execute("select name from transactiontest where name=?", ["foo"])
452        row = cu.fetchone()
453        self.assertEqual(cx.in_transaction, False)
454
455    def test_in_transaction_ro(self):
456        with self.assertRaises(AttributeError):
457            self.cx.in_transaction = True
458
459    def test_connection_exceptions(self):
460        exceptions = [
461            "DataError",
462            "DatabaseError",
463            "Error",
464            "IntegrityError",
465            "InterfaceError",
466            "NotSupportedError",
467            "OperationalError",
468            "ProgrammingError",
469            "Warning",
470        ]
471        for exc in exceptions:
472            with self.subTest(exc=exc):
473                self.assertTrue(hasattr(self.cx, exc))
474                self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
475
476    def test_interrupt_on_closed_db(self):
477        cx = sqlite.connect(":memory:")
478        cx.close()
479        with self.assertRaises(sqlite.ProgrammingError):
480            cx.interrupt()
481
482    def test_interrupt(self):
483        self.assertIsNone(self.cx.interrupt())
484
485    def test_drop_unused_refs(self):
486        for n in range(500):
487            cu = self.cx.execute(f"select {n}")
488            self.assertEqual(cu.fetchone()[0], n)
489
490    def test_connection_limits(self):
491        category = sqlite.SQLITE_LIMIT_SQL_LENGTH
492        saved_limit = self.cx.getlimit(category)
493        try:
494            new_limit = 10
495            prev_limit = self.cx.setlimit(category, new_limit)
496            self.assertEqual(saved_limit, prev_limit)
497            self.assertEqual(self.cx.getlimit(category), new_limit)
498            msg = "query string is too large"
499            self.assertRaisesRegex(sqlite.DataError, msg,
500                                   self.cx.execute, "select 1 as '16'")
501        finally:  # restore saved limit
502            self.cx.setlimit(category, saved_limit)
503
504    def test_connection_bad_limit_category(self):
505        msg = "'category' is out of bounds"
506        cat = 1111
507        self.assertRaisesRegex(sqlite.ProgrammingError, msg,
508                               self.cx.getlimit, cat)
509        self.assertRaisesRegex(sqlite.ProgrammingError, msg,
510                               self.cx.setlimit, cat, 0)
511
512    def test_connection_init_bad_isolation_level(self):
513        msg = (
514            "isolation_level string must be '', 'DEFERRED', 'IMMEDIATE', or "
515            "'EXCLUSIVE'"
516        )
517        levels = (
518            "BOGUS",
519            " ",
520            "DEFERRE",
521            "IMMEDIAT",
522            "EXCLUSIV",
523            "DEFERREDS",
524            "IMMEDIATES",
525            "EXCLUSIVES",
526        )
527        for level in levels:
528            with self.subTest(level=level):
529                with self.assertRaisesRegex(ValueError, msg):
530                    memory_database(isolation_level=level)
531                with memory_database() as cx:
532                    with self.assertRaisesRegex(ValueError, msg):
533                        cx.isolation_level = level
534                    # Check that the default level is not changed
535                    self.assertEqual(cx.isolation_level, "")
536
537    def test_connection_init_good_isolation_levels(self):
538        for level in ("", "DEFERRED", "IMMEDIATE", "EXCLUSIVE", None):
539            with self.subTest(level=level):
540                with memory_database(isolation_level=level) as cx:
541                    self.assertEqual(cx.isolation_level, level)
542                with memory_database() as cx:
543                    self.assertEqual(cx.isolation_level, "")
544                    cx.isolation_level = level
545                    self.assertEqual(cx.isolation_level, level)
546
547    def test_connection_reinit(self):
548        db = ":memory:"
549        cx = sqlite.connect(db)
550        cx.text_factory = bytes
551        cx.row_factory = sqlite.Row
552        cu = cx.cursor()
553        cu.execute("create table foo (bar)")
554        cu.executemany("insert into foo (bar) values (?)",
555                       ((str(v),) for v in range(4)))
556        cu.execute("select bar from foo")
557
558        rows = [r for r in cu.fetchmany(2)]
559        self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
560        self.assertEqual([r[0] for r in rows], [b"0", b"1"])
561
562        cx.__init__(db)
563        cx.execute("create table foo (bar)")
564        cx.executemany("insert into foo (bar) values (?)",
565                       ((v,) for v in ("a", "b", "c", "d")))
566
567        # This uses the old database, old row factory, but new text factory
568        rows = [r for r in cu.fetchall()]
569        self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
570        self.assertEqual([r[0] for r in rows], ["2", "3"])
571
572    def test_connection_bad_reinit(self):
573        cx = sqlite.connect(":memory:")
574        with cx:
575            cx.execute("create table t(t)")
576        with temp_dir() as db:
577            self.assertRaisesRegex(sqlite.OperationalError,
578                                   "unable to open database file",
579                                   cx.__init__, db)
580            self.assertRaisesRegex(sqlite.ProgrammingError,
581                                   "Base Connection.__init__ not called",
582                                   cx.executemany, "insert into t values(?)",
583                                   ((v,) for v in range(3)))
584
585
586class UninitialisedConnectionTests(unittest.TestCase):
587    def setUp(self):
588        self.cx = sqlite.Connection.__new__(sqlite.Connection)
589
590    def test_uninit_operations(self):
591        funcs = (
592            lambda: self.cx.isolation_level,
593            lambda: self.cx.total_changes,
594            lambda: self.cx.in_transaction,
595            lambda: self.cx.iterdump(),
596            lambda: self.cx.cursor(),
597            lambda: self.cx.close(),
598        )
599        for func in funcs:
600            with self.subTest(func=func):
601                self.assertRaisesRegex(sqlite.ProgrammingError,
602                                       "Base Connection.__init__ not called",
603                                       func)
604
605
606class OpenTests(unittest.TestCase):
607    _sql = "create table test(id integer)"
608
609    def test_open_with_path_like_object(self):
610        """ Checks that we can successfully connect to a database using an object that
611            is PathLike, i.e. has __fspath__(). """
612        class Path:
613            def __fspath__(self):
614                return TESTFN
615        path = Path()
616        with managed_connect(path) as cx:
617            cx.execute(self._sql)
618
619    def test_open_uri(self):
620        with managed_connect(TESTFN) as cx:
621            cx.execute(self._sql)
622        with managed_connect(f"file:{TESTFN}", uri=True) as cx:
623            cx.execute(self._sql)
624        with self.assertRaises(sqlite.OperationalError):
625            with managed_connect(f"file:{TESTFN}?mode=ro", uri=True) as cx:
626                cx.execute(self._sql)
627
628    def test_database_keyword(self):
629        with sqlite.connect(database=":memory:") as cx:
630            self.assertEqual(type(cx), sqlite.Connection)
631
632
633class CursorTests(unittest.TestCase):
634    def setUp(self):
635        self.cx = sqlite.connect(":memory:")
636        self.cu = self.cx.cursor()
637        self.cu.execute(
638            "create table test(id integer primary key, name text, "
639            "income number, unique_test text unique)"
640        )
641        self.cu.execute("insert into test(name) values (?)", ("foo",))
642
643    def tearDown(self):
644        self.cu.close()
645        self.cx.close()
646
647    def test_execute_no_args(self):
648        self.cu.execute("delete from test")
649
650    def test_execute_illegal_sql(self):
651        with self.assertRaises(sqlite.OperationalError):
652            self.cu.execute("select asdf")
653
654    def test_execute_too_much_sql(self):
655        with self.assertRaises(sqlite.Warning):
656            self.cu.execute("select 5+4; select 4+5")
657
658    def test_execute_too_much_sql2(self):
659        self.cu.execute("select 5+4; -- foo bar")
660
661    def test_execute_too_much_sql3(self):
662        self.cu.execute("""
663            select 5+4;
664
665            /*
666            foo
667            */
668            """)
669
670    def test_execute_wrong_sql_arg(self):
671        with self.assertRaises(TypeError):
672            self.cu.execute(42)
673
674    def test_execute_arg_int(self):
675        self.cu.execute("insert into test(id) values (?)", (42,))
676
677    def test_execute_arg_float(self):
678        self.cu.execute("insert into test(income) values (?)", (2500.32,))
679
680    def test_execute_arg_string(self):
681        self.cu.execute("insert into test(name) values (?)", ("Hugo",))
682
683    def test_execute_arg_string_with_zero_byte(self):
684        self.cu.execute("insert into test(name) values (?)", ("Hu\x00go",))
685
686        self.cu.execute("select name from test where id=?", (self.cu.lastrowid,))
687        row = self.cu.fetchone()
688        self.assertEqual(row[0], "Hu\x00go")
689
690    def test_execute_non_iterable(self):
691        with self.assertRaises(ValueError) as cm:
692            self.cu.execute("insert into test(id) values (?)", 42)
693        self.assertEqual(str(cm.exception), 'parameters are of unsupported type')
694
695    def test_execute_wrong_no_of_args1(self):
696        # too many parameters
697        with self.assertRaises(sqlite.ProgrammingError):
698            self.cu.execute("insert into test(id) values (?)", (17, "Egon"))
699
700    def test_execute_wrong_no_of_args2(self):
701        # too little parameters
702        with self.assertRaises(sqlite.ProgrammingError):
703            self.cu.execute("insert into test(id) values (?)")
704
705    def test_execute_wrong_no_of_args3(self):
706        # no parameters, parameters are needed
707        with self.assertRaises(sqlite.ProgrammingError):
708            self.cu.execute("insert into test(id) values (?)")
709
710    def test_execute_param_list(self):
711        self.cu.execute("insert into test(name) values ('foo')")
712        self.cu.execute("select name from test where name=?", ["foo"])
713        row = self.cu.fetchone()
714        self.assertEqual(row[0], "foo")
715
716    def test_execute_param_sequence(self):
717        class L:
718            def __len__(self):
719                return 1
720            def __getitem__(self, x):
721                assert x == 0
722                return "foo"
723
724        self.cu.execute("insert into test(name) values ('foo')")
725        self.cu.execute("select name from test where name=?", L())
726        row = self.cu.fetchone()
727        self.assertEqual(row[0], "foo")
728
729    def test_execute_param_sequence_bad_len(self):
730        # Issue41662: Error in __len__() was overridden with ProgrammingError.
731        class L:
732            def __len__(self):
733                1/0
734            def __getitem__(slf, x):
735                raise AssertionError
736
737        self.cu.execute("insert into test(name) values ('foo')")
738        with self.assertRaises(ZeroDivisionError):
739            self.cu.execute("select name from test where name=?", L())
740
741    def test_execute_too_many_params(self):
742        category = sqlite.SQLITE_LIMIT_VARIABLE_NUMBER
743        msg = "too many SQL variables"
744        with cx_limit(self.cx, category=category, limit=1):
745            self.cu.execute("select * from test where id=?", (1,))
746            with self.assertRaisesRegex(sqlite.OperationalError, msg):
747                self.cu.execute("select * from test where id!=? and id!=?",
748                                (1, 2))
749
750    def test_execute_dict_mapping(self):
751        self.cu.execute("insert into test(name) values ('foo')")
752        self.cu.execute("select name from test where name=:name", {"name": "foo"})
753        row = self.cu.fetchone()
754        self.assertEqual(row[0], "foo")
755
756    def test_execute_dict_mapping_mapping(self):
757        class D(dict):
758            def __missing__(self, key):
759                return "foo"
760
761        self.cu.execute("insert into test(name) values ('foo')")
762        self.cu.execute("select name from test where name=:name", D())
763        row = self.cu.fetchone()
764        self.assertEqual(row[0], "foo")
765
766    def test_execute_dict_mapping_too_little_args(self):
767        self.cu.execute("insert into test(name) values ('foo')")
768        with self.assertRaises(sqlite.ProgrammingError):
769            self.cu.execute("select name from test where name=:name and id=:id", {"name": "foo"})
770
771    def test_execute_dict_mapping_no_args(self):
772        self.cu.execute("insert into test(name) values ('foo')")
773        with self.assertRaises(sqlite.ProgrammingError):
774            self.cu.execute("select name from test where name=:name")
775
776    def test_execute_dict_mapping_unnamed(self):
777        self.cu.execute("insert into test(name) values ('foo')")
778        with self.assertRaises(sqlite.ProgrammingError):
779            self.cu.execute("select name from test where name=?", {"name": "foo"})
780
781    def test_close(self):
782        self.cu.close()
783
784    def test_rowcount_execute(self):
785        self.cu.execute("delete from test")
786        self.cu.execute("insert into test(name) values ('foo')")
787        self.cu.execute("insert into test(name) values ('foo')")
788        self.cu.execute("update test set name='bar'")
789        self.assertEqual(self.cu.rowcount, 2)
790
791    def test_rowcount_select(self):
792        """
793        pysqlite does not know the rowcount of SELECT statements, because we
794        don't fetch all rows after executing the select statement. The rowcount
795        has thus to be -1.
796        """
797        self.cu.execute("select 5 union select 6")
798        self.assertEqual(self.cu.rowcount, -1)
799
800    def test_rowcount_executemany(self):
801        self.cu.execute("delete from test")
802        self.cu.executemany("insert into test(name) values (?)", [(1,), (2,), (3,)])
803        self.assertEqual(self.cu.rowcount, 3)
804
805    def test_total_changes(self):
806        self.cu.execute("insert into test(name) values ('foo')")
807        self.cu.execute("insert into test(name) values ('foo')")
808        self.assertLess(2, self.cx.total_changes, msg='total changes reported wrong value')
809
810    # Checks for executemany:
811    # Sequences are required by the DB-API, iterators
812    # enhancements in pysqlite.
813
814    def test_execute_many_sequence(self):
815        self.cu.executemany("insert into test(income) values (?)", [(x,) for x in range(100, 110)])
816
817    def test_execute_many_iterator(self):
818        class MyIter:
819            def __init__(self):
820                self.value = 5
821
822            def __iter__(self):
823                return self
824
825            def __next__(self):
826                if self.value == 10:
827                    raise StopIteration
828                else:
829                    self.value += 1
830                    return (self.value,)
831
832        self.cu.executemany("insert into test(income) values (?)", MyIter())
833
834    def test_execute_many_generator(self):
835        def mygen():
836            for i in range(5):
837                yield (i,)
838
839        self.cu.executemany("insert into test(income) values (?)", mygen())
840
841    def test_execute_many_wrong_sql_arg(self):
842        with self.assertRaises(TypeError):
843            self.cu.executemany(42, [(3,)])
844
845    def test_execute_many_select(self):
846        with self.assertRaises(sqlite.ProgrammingError):
847            self.cu.executemany("select ?", [(3,)])
848
849    def test_execute_many_not_iterable(self):
850        with self.assertRaises(TypeError):
851            self.cu.executemany("insert into test(income) values (?)", 42)
852
853    def test_fetch_iter(self):
854        # Optional DB-API extension.
855        self.cu.execute("delete from test")
856        self.cu.execute("insert into test(id) values (?)", (5,))
857        self.cu.execute("insert into test(id) values (?)", (6,))
858        self.cu.execute("select id from test order by id")
859        lst = []
860        for row in self.cu:
861            lst.append(row[0])
862        self.assertEqual(lst[0], 5)
863        self.assertEqual(lst[1], 6)
864
865    def test_fetchone(self):
866        self.cu.execute("select name from test")
867        row = self.cu.fetchone()
868        self.assertEqual(row[0], "foo")
869        row = self.cu.fetchone()
870        self.assertEqual(row, None)
871
872    def test_fetchone_no_statement(self):
873        cur = self.cx.cursor()
874        row = cur.fetchone()
875        self.assertEqual(row, None)
876
877    def test_array_size(self):
878        # must default to 1
879        self.assertEqual(self.cu.arraysize, 1)
880
881        # now set to 2
882        self.cu.arraysize = 2
883
884        # now make the query return 3 rows
885        self.cu.execute("delete from test")
886        self.cu.execute("insert into test(name) values ('A')")
887        self.cu.execute("insert into test(name) values ('B')")
888        self.cu.execute("insert into test(name) values ('C')")
889        self.cu.execute("select name from test")
890        res = self.cu.fetchmany()
891
892        self.assertEqual(len(res), 2)
893
894    def test_fetchmany(self):
895        self.cu.execute("select name from test")
896        res = self.cu.fetchmany(100)
897        self.assertEqual(len(res), 1)
898        res = self.cu.fetchmany(100)
899        self.assertEqual(res, [])
900
901    def test_fetchmany_kw_arg(self):
902        """Checks if fetchmany works with keyword arguments"""
903        self.cu.execute("select name from test")
904        res = self.cu.fetchmany(size=100)
905        self.assertEqual(len(res), 1)
906
907    def test_fetchall(self):
908        self.cu.execute("select name from test")
909        res = self.cu.fetchall()
910        self.assertEqual(len(res), 1)
911        res = self.cu.fetchall()
912        self.assertEqual(res, [])
913
914    def test_setinputsizes(self):
915        self.cu.setinputsizes([3, 4, 5])
916
917    def test_setoutputsize(self):
918        self.cu.setoutputsize(5, 0)
919
920    def test_setoutputsize_no_column(self):
921        self.cu.setoutputsize(42)
922
923    def test_cursor_connection(self):
924        # Optional DB-API extension.
925        self.assertEqual(self.cu.connection, self.cx)
926
927    def test_wrong_cursor_callable(self):
928        with self.assertRaises(TypeError):
929            def f(): pass
930            cur = self.cx.cursor(f)
931
932    def test_cursor_wrong_class(self):
933        class Foo: pass
934        foo = Foo()
935        with self.assertRaises(TypeError):
936            cur = sqlite.Cursor(foo)
937
938    def test_last_row_id_on_replace(self):
939        """
940        INSERT OR REPLACE and REPLACE INTO should produce the same behavior.
941        """
942        sql = '{} INTO test(id, unique_test) VALUES (?, ?)'
943        for statement in ('INSERT OR REPLACE', 'REPLACE'):
944            with self.subTest(statement=statement):
945                self.cu.execute(sql.format(statement), (1, 'foo'))
946                self.assertEqual(self.cu.lastrowid, 1)
947
948    def test_last_row_id_on_ignore(self):
949        self.cu.execute(
950            "insert or ignore into test(unique_test) values (?)",
951            ('test',))
952        self.assertEqual(self.cu.lastrowid, 2)
953        self.cu.execute(
954            "insert or ignore into test(unique_test) values (?)",
955            ('test',))
956        self.assertEqual(self.cu.lastrowid, 2)
957
958    def test_last_row_id_insert_o_r(self):
959        results = []
960        for statement in ('FAIL', 'ABORT', 'ROLLBACK'):
961            sql = 'INSERT OR {} INTO test(unique_test) VALUES (?)'
962            with self.subTest(statement='INSERT OR {}'.format(statement)):
963                self.cu.execute(sql.format(statement), (statement,))
964                results.append((statement, self.cu.lastrowid))
965                with self.assertRaises(sqlite.IntegrityError):
966                    self.cu.execute(sql.format(statement), (statement,))
967                results.append((statement, self.cu.lastrowid))
968        expected = [
969            ('FAIL', 2), ('FAIL', 2),
970            ('ABORT', 3), ('ABORT', 3),
971            ('ROLLBACK', 4), ('ROLLBACK', 4),
972        ]
973        self.assertEqual(results, expected)
974
975    def test_column_count(self):
976        # Check that column count is updated correctly for cached statements
977        select = "select * from test"
978        res = self.cu.execute(select)
979        old_count = len(res.description)
980        # Add a new column and execute the cached select query again
981        self.cu.execute("alter table test add newcol")
982        res = self.cu.execute(select)
983        new_count = len(res.description)
984        self.assertEqual(new_count - old_count, 1)
985
986    def test_same_query_in_multiple_cursors(self):
987        cursors = [self.cx.execute("select 1") for _ in range(3)]
988        for cu in cursors:
989            self.assertEqual(cu.fetchall(), [(1,)])
990
991
992class ThreadTests(unittest.TestCase):
993    def setUp(self):
994        self.con = sqlite.connect(":memory:")
995        self.cur = self.con.cursor()
996        self.cur.execute("create table test(name text)")
997
998    def tearDown(self):
999        self.cur.close()
1000        self.con.close()
1001
1002    @threading_helper.reap_threads
1003    def _run_test(self, fn, *args, **kwds):
1004        def run(err):
1005            try:
1006                fn(*args, **kwds)
1007                err.append("did not raise ProgrammingError")
1008            except sqlite.ProgrammingError:
1009                pass
1010            except:
1011                err.append("raised wrong exception")
1012
1013        err = []
1014        t = threading.Thread(target=run, kwargs={"err": err})
1015        t.start()
1016        t.join()
1017        if err:
1018            self.fail("\n".join(err))
1019
1020    def test_check_connection_thread(self):
1021        fns = [
1022            lambda: self.con.cursor(),
1023            lambda: self.con.commit(),
1024            lambda: self.con.rollback(),
1025            lambda: self.con.close(),
1026            lambda: self.con.set_trace_callback(None),
1027            lambda: self.con.set_authorizer(None),
1028            lambda: self.con.create_collation("foo", None),
1029            lambda: self.con.setlimit(sqlite.SQLITE_LIMIT_LENGTH, -1),
1030            lambda: self.con.getlimit(sqlite.SQLITE_LIMIT_LENGTH),
1031        ]
1032        for fn in fns:
1033            with self.subTest(fn=fn):
1034                self._run_test(fn)
1035
1036    def test_check_cursor_thread(self):
1037        fns = [
1038            lambda: self.cur.execute("insert into test(name) values('a')"),
1039            lambda: self.cur.close(),
1040            lambda: self.cur.execute("select name from test"),
1041            lambda: self.cur.fetchone(),
1042        ]
1043        for fn in fns:
1044            with self.subTest(fn=fn):
1045                self._run_test(fn)
1046
1047
1048    @threading_helper.reap_threads
1049    def test_dont_check_same_thread(self):
1050        def run(con, err):
1051            try:
1052                con.execute("select 1")
1053            except sqlite.Error:
1054                err.append("multi-threading not allowed")
1055
1056        con = sqlite.connect(":memory:", check_same_thread=False)
1057        err = []
1058        t = threading.Thread(target=run, kwargs={"con": con, "err": err})
1059        t.start()
1060        t.join()
1061        self.assertEqual(len(err), 0, "\n".join(err))
1062
1063
1064class ConstructorTests(unittest.TestCase):
1065    def test_date(self):
1066        d = sqlite.Date(2004, 10, 28)
1067
1068    def test_time(self):
1069        t = sqlite.Time(12, 39, 35)
1070
1071    def test_timestamp(self):
1072        ts = sqlite.Timestamp(2004, 10, 28, 12, 39, 35)
1073
1074    def test_date_from_ticks(self):
1075        d = sqlite.DateFromTicks(42)
1076
1077    def test_time_from_ticks(self):
1078        t = sqlite.TimeFromTicks(42)
1079
1080    def test_timestamp_from_ticks(self):
1081        ts = sqlite.TimestampFromTicks(42)
1082
1083    def test_binary(self):
1084        b = sqlite.Binary(b"\0'")
1085
1086class ExtensionTests(unittest.TestCase):
1087    def test_script_string_sql(self):
1088        con = sqlite.connect(":memory:")
1089        cur = con.cursor()
1090        cur.executescript("""
1091            -- bla bla
1092            /* a stupid comment */
1093            create table a(i);
1094            insert into a(i) values (5);
1095            """)
1096        cur.execute("select i from a")
1097        res = cur.fetchone()[0]
1098        self.assertEqual(res, 5)
1099
1100    def test_script_syntax_error(self):
1101        con = sqlite.connect(":memory:")
1102        cur = con.cursor()
1103        with self.assertRaises(sqlite.OperationalError):
1104            cur.executescript("create table test(x); asdf; create table test2(x)")
1105
1106    def test_script_error_normal(self):
1107        con = sqlite.connect(":memory:")
1108        cur = con.cursor()
1109        with self.assertRaises(sqlite.OperationalError):
1110            cur.executescript("create table test(sadfsadfdsa); select foo from hurz;")
1111
1112    def test_cursor_executescript_as_bytes(self):
1113        con = sqlite.connect(":memory:")
1114        cur = con.cursor()
1115        with self.assertRaises(TypeError):
1116            cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
1117
1118    def test_cursor_executescript_with_null_characters(self):
1119        con = sqlite.connect(":memory:")
1120        cur = con.cursor()
1121        with self.assertRaises(ValueError):
1122            cur.executescript("""
1123                create table a(i);\0
1124                insert into a(i) values (5);
1125                """)
1126
1127    def test_cursor_executescript_with_surrogates(self):
1128        con = sqlite.connect(":memory:")
1129        cur = con.cursor()
1130        with self.assertRaises(UnicodeEncodeError):
1131            cur.executescript("""
1132                create table a(s);
1133                insert into a(s) values ('\ud8ff');
1134                """)
1135
1136    def test_cursor_executescript_too_large_script(self):
1137        msg = "query string is too large"
1138        with memory_database() as cx, cx_limit(cx) as lim:
1139            cx.executescript("select 'almost too large'".ljust(lim))
1140            with self.assertRaisesRegex(sqlite.DataError, msg):
1141                cx.executescript("select 'too large'".ljust(lim+1))
1142
1143    def test_cursor_executescript_tx_control(self):
1144        con = sqlite.connect(":memory:")
1145        con.execute("begin")
1146        self.assertTrue(con.in_transaction)
1147        con.executescript("select 1")
1148        self.assertFalse(con.in_transaction)
1149
1150    def test_connection_execute(self):
1151        con = sqlite.connect(":memory:")
1152        result = con.execute("select 5").fetchone()[0]
1153        self.assertEqual(result, 5, "Basic test of Connection.execute")
1154
1155    def test_connection_executemany(self):
1156        con = sqlite.connect(":memory:")
1157        con.execute("create table test(foo)")
1158        con.executemany("insert into test(foo) values (?)", [(3,), (4,)])
1159        result = con.execute("select foo from test order by foo").fetchall()
1160        self.assertEqual(result[0][0], 3, "Basic test of Connection.executemany")
1161        self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")
1162
1163    def test_connection_executescript(self):
1164        con = sqlite.connect(":memory:")
1165        con.executescript("create table test(foo); insert into test(foo) values (5);")
1166        result = con.execute("select foo from test").fetchone()[0]
1167        self.assertEqual(result, 5, "Basic test of Connection.executescript")
1168
1169class ClosedConTests(unittest.TestCase):
1170    def test_closed_con_cursor(self):
1171        con = sqlite.connect(":memory:")
1172        con.close()
1173        with self.assertRaises(sqlite.ProgrammingError):
1174            cur = con.cursor()
1175
1176    def test_closed_con_commit(self):
1177        con = sqlite.connect(":memory:")
1178        con.close()
1179        with self.assertRaises(sqlite.ProgrammingError):
1180            con.commit()
1181
1182    def test_closed_con_rollback(self):
1183        con = sqlite.connect(":memory:")
1184        con.close()
1185        with self.assertRaises(sqlite.ProgrammingError):
1186            con.rollback()
1187
1188    def test_closed_cur_execute(self):
1189        con = sqlite.connect(":memory:")
1190        cur = con.cursor()
1191        con.close()
1192        with self.assertRaises(sqlite.ProgrammingError):
1193            cur.execute("select 4")
1194
1195    def test_closed_create_function(self):
1196        con = sqlite.connect(":memory:")
1197        con.close()
1198        def f(x): return 17
1199        with self.assertRaises(sqlite.ProgrammingError):
1200            con.create_function("foo", 1, f)
1201
1202    def test_closed_create_aggregate(self):
1203        con = sqlite.connect(":memory:")
1204        con.close()
1205        class Agg:
1206            def __init__(self):
1207                pass
1208            def step(self, x):
1209                pass
1210            def finalize(self):
1211                return 17
1212        with self.assertRaises(sqlite.ProgrammingError):
1213            con.create_aggregate("foo", 1, Agg)
1214
1215    def test_closed_set_authorizer(self):
1216        con = sqlite.connect(":memory:")
1217        con.close()
1218        def authorizer(*args):
1219            return sqlite.DENY
1220        with self.assertRaises(sqlite.ProgrammingError):
1221            con.set_authorizer(authorizer)
1222
1223    def test_closed_set_progress_callback(self):
1224        con = sqlite.connect(":memory:")
1225        con.close()
1226        def progress(): pass
1227        with self.assertRaises(sqlite.ProgrammingError):
1228            con.set_progress_handler(progress, 100)
1229
1230    def test_closed_call(self):
1231        con = sqlite.connect(":memory:")
1232        con.close()
1233        with self.assertRaises(sqlite.ProgrammingError):
1234            con()
1235
1236class ClosedCurTests(unittest.TestCase):
1237    def test_closed(self):
1238        con = sqlite.connect(":memory:")
1239        cur = con.cursor()
1240        cur.close()
1241
1242        for method_name in ("execute", "executemany", "executescript", "fetchall", "fetchmany", "fetchone"):
1243            if method_name in ("execute", "executescript"):
1244                params = ("select 4 union select 5",)
1245            elif method_name == "executemany":
1246                params = ("insert into foo(bar) values (?)", [(3,), (4,)])
1247            else:
1248                params = []
1249
1250            with self.assertRaises(sqlite.ProgrammingError):
1251                method = getattr(cur, method_name)
1252                method(*params)
1253
1254
1255class SqliteOnConflictTests(unittest.TestCase):
1256    """
1257    Tests for SQLite's "insert on conflict" feature.
1258
1259    See https://www.sqlite.org/lang_conflict.html for details.
1260    """
1261
1262    def setUp(self):
1263        self.cx = sqlite.connect(":memory:")
1264        self.cu = self.cx.cursor()
1265        self.cu.execute("""
1266          CREATE TABLE test(
1267            id INTEGER PRIMARY KEY, name TEXT, unique_name TEXT UNIQUE
1268          );
1269        """)
1270
1271    def tearDown(self):
1272        self.cu.close()
1273        self.cx.close()
1274
1275    def test_on_conflict_rollback_with_explicit_transaction(self):
1276        self.cx.isolation_level = None  # autocommit mode
1277        self.cu = self.cx.cursor()
1278        # Start an explicit transaction.
1279        self.cu.execute("BEGIN")
1280        self.cu.execute("INSERT INTO test(name) VALUES ('abort_test')")
1281        self.cu.execute("INSERT OR ROLLBACK INTO test(unique_name) VALUES ('foo')")
1282        with self.assertRaises(sqlite.IntegrityError):
1283            self.cu.execute("INSERT OR ROLLBACK INTO test(unique_name) VALUES ('foo')")
1284        # Use connection to commit.
1285        self.cx.commit()
1286        self.cu.execute("SELECT name, unique_name from test")
1287        # Transaction should have rolled back and nothing should be in table.
1288        self.assertEqual(self.cu.fetchall(), [])
1289
1290    def test_on_conflict_abort_raises_with_explicit_transactions(self):
1291        # Abort cancels the current sql statement but doesn't change anything
1292        # about the current transaction.
1293        self.cx.isolation_level = None  # autocommit mode
1294        self.cu = self.cx.cursor()
1295        # Start an explicit transaction.
1296        self.cu.execute("BEGIN")
1297        self.cu.execute("INSERT INTO test(name) VALUES ('abort_test')")
1298        self.cu.execute("INSERT OR ABORT INTO test(unique_name) VALUES ('foo')")
1299        with self.assertRaises(sqlite.IntegrityError):
1300            self.cu.execute("INSERT OR ABORT INTO test(unique_name) VALUES ('foo')")
1301        self.cx.commit()
1302        self.cu.execute("SELECT name, unique_name FROM test")
1303        # Expect the first two inserts to work, third to do nothing.
1304        self.assertEqual(self.cu.fetchall(), [('abort_test', None), (None, 'foo',)])
1305
1306    def test_on_conflict_rollback_without_transaction(self):
1307        # Start of implicit transaction
1308        self.cu.execute("INSERT INTO test(name) VALUES ('abort_test')")
1309        self.cu.execute("INSERT OR ROLLBACK INTO test(unique_name) VALUES ('foo')")
1310        with self.assertRaises(sqlite.IntegrityError):
1311            self.cu.execute("INSERT OR ROLLBACK INTO test(unique_name) VALUES ('foo')")
1312        self.cu.execute("SELECT name, unique_name FROM test")
1313        # Implicit transaction is rolled back on error.
1314        self.assertEqual(self.cu.fetchall(), [])
1315
1316    def test_on_conflict_abort_raises_without_transactions(self):
1317        # Abort cancels the current sql statement but doesn't change anything
1318        # about the current transaction.
1319        self.cu.execute("INSERT INTO test(name) VALUES ('abort_test')")
1320        self.cu.execute("INSERT OR ABORT INTO test(unique_name) VALUES ('foo')")
1321        with self.assertRaises(sqlite.IntegrityError):
1322            self.cu.execute("INSERT OR ABORT INTO test(unique_name) VALUES ('foo')")
1323        # Make sure all other values were inserted.
1324        self.cu.execute("SELECT name, unique_name FROM test")
1325        self.assertEqual(self.cu.fetchall(), [('abort_test', None), (None, 'foo',)])
1326
1327    def test_on_conflict_fail(self):
1328        self.cu.execute("INSERT OR FAIL INTO test(unique_name) VALUES ('foo')")
1329        with self.assertRaises(sqlite.IntegrityError):
1330            self.cu.execute("INSERT OR FAIL INTO test(unique_name) VALUES ('foo')")
1331        self.assertEqual(self.cu.fetchall(), [])
1332
1333    def test_on_conflict_ignore(self):
1334        self.cu.execute("INSERT OR IGNORE INTO test(unique_name) VALUES ('foo')")
1335        # Nothing should happen.
1336        self.cu.execute("INSERT OR IGNORE INTO test(unique_name) VALUES ('foo')")
1337        self.cu.execute("SELECT unique_name FROM test")
1338        self.assertEqual(self.cu.fetchall(), [('foo',)])
1339
1340    def test_on_conflict_replace(self):
1341        self.cu.execute("INSERT OR REPLACE INTO test(name, unique_name) VALUES ('Data!', 'foo')")
1342        # There shouldn't be an IntegrityError exception.
1343        self.cu.execute("INSERT OR REPLACE INTO test(name, unique_name) VALUES ('Very different data!', 'foo')")
1344        self.cu.execute("SELECT name, unique_name FROM test")
1345        self.assertEqual(self.cu.fetchall(), [('Very different data!', 'foo')])
1346
1347
1348class MultiprocessTests(unittest.TestCase):
1349    CONNECTION_TIMEOUT = SHORT_TIMEOUT / 1000.  # Defaults to 30 ms
1350
1351    def tearDown(self):
1352        unlink(TESTFN)
1353
1354    def test_ctx_mgr_rollback_if_commit_failed(self):
1355        # bpo-27334: ctx manager does not rollback if commit fails
1356        SCRIPT = f"""if 1:
1357            import sqlite3
1358            def wait():
1359                print("started")
1360                assert "database is locked" in input()
1361
1362            cx = sqlite3.connect("{TESTFN}", timeout={self.CONNECTION_TIMEOUT})
1363            cx.create_function("wait", 0, wait)
1364            with cx:
1365                cx.execute("create table t(t)")
1366            try:
1367                # execute two transactions; both will try to lock the db
1368                cx.executescript('''
1369                    -- start a transaction and wait for parent
1370                    begin transaction;
1371                    select * from t;
1372                    select wait();
1373                    rollback;
1374
1375                    -- start a new transaction; would fail if parent holds lock
1376                    begin transaction;
1377                    select * from t;
1378                    rollback;
1379                ''')
1380            finally:
1381                cx.close()
1382        """
1383
1384        # spawn child process
1385        proc = subprocess.Popen(
1386            [sys.executable, "-c", SCRIPT],
1387            encoding="utf-8",
1388            bufsize=0,
1389            stdin=subprocess.PIPE,
1390            stdout=subprocess.PIPE,
1391        )
1392        self.addCleanup(proc.communicate)
1393
1394        # wait for child process to start
1395        self.assertEqual("started", proc.stdout.readline().strip())
1396
1397        cx = sqlite.connect(TESTFN, timeout=self.CONNECTION_TIMEOUT)
1398        try:  # context manager should correctly release the db lock
1399            with cx:
1400                cx.execute("insert into t values('test')")
1401        except sqlite.OperationalError as exc:
1402            proc.stdin.write(str(exc))
1403        else:
1404            proc.stdin.write("no error")
1405        finally:
1406            cx.close()
1407
1408        # terminate child process
1409        self.assertIsNone(proc.returncode)
1410        try:
1411            proc.communicate(input="end", timeout=SHORT_TIMEOUT)
1412        except subprocess.TimeoutExpired:
1413            proc.kill()
1414            proc.communicate()
1415            raise
1416        self.assertEqual(proc.returncode, 0)
1417
1418
1419if __name__ == "__main__":
1420    unittest.main()
1421