1#-*- coding: iso-8859-1 -*-
2# pysqlite2/test/transactions.py: tests transactions
3#
4# Copyright (C) 2005-2007 Gerhard H�ring <gh@ghaering.de>
5#
6# This file is part of pysqlite.
7#
8# This software is provided 'as-is', without any express or implied
9# warranty.  In no event will the authors be held liable for any damages
10# arising from the use of this software.
11#
12# Permission is granted to anyone to use this software for any purpose,
13# including commercial applications, and to alter it and redistribute it
14# freely, subject to the following restrictions:
15#
16# 1. The origin of this software must not be misrepresented; you must not
17#    claim that you wrote the original software. If you use this software
18#    in a product, an acknowledgment in the product documentation would be
19#    appreciated but is not required.
20# 2. Altered source versions must be plainly marked as such, and must not be
21#    misrepresented as being the original software.
22# 3. This notice may not be removed or altered from any source distribution.
23
24import os, unittest
25import sqlite3 as sqlite
26
27def get_db_path():
28    return "sqlite_testdb"
29
30class TransactionTests(unittest.TestCase):
31    def setUp(self):
32        try:
33            os.remove(get_db_path())
34        except OSError:
35            pass
36
37        self.con1 = sqlite.connect(get_db_path(), timeout=0.1)
38        self.cur1 = self.con1.cursor()
39
40        self.con2 = sqlite.connect(get_db_path(), timeout=0.1)
41        self.cur2 = self.con2.cursor()
42
43    def tearDown(self):
44        self.cur1.close()
45        self.con1.close()
46
47        self.cur2.close()
48        self.con2.close()
49
50        try:
51            os.unlink(get_db_path())
52        except OSError:
53            pass
54
55    def CheckDMLDoesNotAutoCommitBefore(self):
56        self.cur1.execute("create table test(i)")
57        self.cur1.execute("insert into test(i) values (5)")
58        self.cur1.execute("create table test2(j)")
59        self.cur2.execute("select i from test")
60        res = self.cur2.fetchall()
61        self.assertEqual(len(res), 0)
62
63    def CheckInsertStartsTransaction(self):
64        self.cur1.execute("create table test(i)")
65        self.cur1.execute("insert into test(i) values (5)")
66        self.cur2.execute("select i from test")
67        res = self.cur2.fetchall()
68        self.assertEqual(len(res), 0)
69
70    def CheckUpdateStartsTransaction(self):
71        self.cur1.execute("create table test(i)")
72        self.cur1.execute("insert into test(i) values (5)")
73        self.con1.commit()
74        self.cur1.execute("update test set i=6")
75        self.cur2.execute("select i from test")
76        res = self.cur2.fetchone()[0]
77        self.assertEqual(res, 5)
78
79    def CheckDeleteStartsTransaction(self):
80        self.cur1.execute("create table test(i)")
81        self.cur1.execute("insert into test(i) values (5)")
82        self.con1.commit()
83        self.cur1.execute("delete from test")
84        self.cur2.execute("select i from test")
85        res = self.cur2.fetchall()
86        self.assertEqual(len(res), 1)
87
88    def CheckReplaceStartsTransaction(self):
89        self.cur1.execute("create table test(i)")
90        self.cur1.execute("insert into test(i) values (5)")
91        self.con1.commit()
92        self.cur1.execute("replace into test(i) values (6)")
93        self.cur2.execute("select i from test")
94        res = self.cur2.fetchall()
95        self.assertEqual(len(res), 1)
96        self.assertEqual(res[0][0], 5)
97
98    def CheckToggleAutoCommit(self):
99        self.cur1.execute("create table test(i)")
100        self.cur1.execute("insert into test(i) values (5)")
101        self.con1.isolation_level = None
102        self.assertEqual(self.con1.isolation_level, None)
103        self.cur2.execute("select i from test")
104        res = self.cur2.fetchall()
105        self.assertEqual(len(res), 1)
106
107        self.con1.isolation_level = "DEFERRED"
108        self.assertEqual(self.con1.isolation_level , "DEFERRED")
109        self.cur1.execute("insert into test(i) values (5)")
110        self.cur2.execute("select i from test")
111        res = self.cur2.fetchall()
112        self.assertEqual(len(res), 1)
113
114    @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2),
115                     'test hangs on sqlite versions older than 3.2.2')
116    def CheckRaiseTimeout(self):
117        self.cur1.execute("create table test(i)")
118        self.cur1.execute("insert into test(i) values (5)")
119        with self.assertRaises(sqlite.OperationalError):
120            self.cur2.execute("insert into test(i) values (5)")
121
122    @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2),
123                     'test hangs on sqlite versions older than 3.2.2')
124    def CheckLocking(self):
125        """
126        This tests the improved concurrency with pysqlite 2.3.4. You needed
127        to roll back con2 before you could commit con1.
128        """
129        self.cur1.execute("create table test(i)")
130        self.cur1.execute("insert into test(i) values (5)")
131        with self.assertRaises(sqlite.OperationalError):
132            self.cur2.execute("insert into test(i) values (5)")
133        # NO self.con2.rollback() HERE!!!
134        self.con1.commit()
135
136    def CheckRollbackCursorConsistency(self):
137        """
138        Checks if cursors on the connection are set into a "reset" state
139        when a rollback is done on the connection.
140        """
141        con = sqlite.connect(":memory:")
142        cur = con.cursor()
143        cur.execute("create table test(x)")
144        cur.execute("insert into test(x) values (5)")
145        cur.execute("select 1 union select 2 union select 3")
146
147        con.rollback()
148        with self.assertRaises(sqlite.InterfaceError):
149            cur.fetchall()
150
151class SpecialCommandTests(unittest.TestCase):
152    def setUp(self):
153        self.con = sqlite.connect(":memory:")
154        self.cur = self.con.cursor()
155
156    def CheckDropTable(self):
157        self.cur.execute("create table test(i)")
158        self.cur.execute("insert into test(i) values (5)")
159        self.cur.execute("drop table test")
160
161    def CheckPragma(self):
162        self.cur.execute("create table test(i)")
163        self.cur.execute("insert into test(i) values (5)")
164        self.cur.execute("pragma count_changes=1")
165
166    def tearDown(self):
167        self.cur.close()
168        self.con.close()
169
170class TransactionalDDL(unittest.TestCase):
171    def setUp(self):
172        self.con = sqlite.connect(":memory:")
173
174    def CheckDdlDoesNotAutostartTransaction(self):
175        # For backwards compatibility reasons, DDL statements should not
176        # implicitly start a transaction.
177        self.con.execute("create table test(i)")
178        self.con.rollback()
179        result = self.con.execute("select * from test").fetchall()
180        self.assertEqual(result, [])
181
182    def CheckImmediateTransactionalDDL(self):
183        # You can achieve transactional DDL by issuing a BEGIN
184        # statement manually.
185        self.con.execute("begin immediate")
186        self.con.execute("create table test(i)")
187        self.con.rollback()
188        with self.assertRaises(sqlite.OperationalError):
189            self.con.execute("select * from test")
190
191    def CheckTransactionalDDL(self):
192        # You can achieve transactional DDL by issuing a BEGIN
193        # statement manually.
194        self.con.execute("begin")
195        self.con.execute("create table test(i)")
196        self.con.rollback()
197        with self.assertRaises(sqlite.OperationalError):
198            self.con.execute("select * from test")
199
200    def tearDown(self):
201        self.con.close()
202
203def suite():
204    default_suite = unittest.makeSuite(TransactionTests, "Check")
205    special_command_suite = unittest.makeSuite(SpecialCommandTests, "Check")
206    ddl_suite = unittest.makeSuite(TransactionalDDL, "Check")
207    return unittest.TestSuite((default_suite, special_command_suite, ddl_suite))
208
209def test():
210    runner = unittest.TextTestRunner()
211    runner.run(suite())
212
213if __name__ == "__main__":
214    test()
215