1#-*- coding: iso-8859-1 -*-
2# pysqlite2/test/types.py: tests for type conversion and detection
3#
4# Copyright (C) 2005 Gerhard H�ring <gh@ghaering.de>
5#
6# This file is part of pysqlite.
7#
8# This software is provided 'as-is', without any express or implied
9# warranty.  In no event will the authors be held liable for any damages
10# arising from the use of this software.
11#
12# Permission is granted to anyone to use this software for any purpose,
13# including commercial applications, and to alter it and redistribute it
14# freely, subject to the following restrictions:
15#
16# 1. The origin of this software must not be misrepresented; you must not
17#    claim that you wrote the original software. If you use this software
18#    in a product, an acknowledgment in the product documentation would be
19#    appreciated but is not required.
20# 2. Altered source versions must be plainly marked as such, and must not be
21#    misrepresented as being the original software.
22# 3. This notice may not be removed or altered from any source distribution.
23
24import datetime
25import unittest
26import sqlite3 as sqlite
27try:
28    import zlib
29except ImportError:
30    zlib = None
31
32
33class SqliteTypeTests(unittest.TestCase):
34    def setUp(self):
35        self.con = sqlite.connect(":memory:")
36        self.cur = self.con.cursor()
37        self.cur.execute("create table test(i integer, s varchar, f number, b blob)")
38
39    def tearDown(self):
40        self.cur.close()
41        self.con.close()
42
43    def CheckString(self):
44        self.cur.execute("insert into test(s) values (?)", ("�sterreich",))
45        self.cur.execute("select s from test")
46        row = self.cur.fetchone()
47        self.assertEqual(row[0], "�sterreich")
48
49    def CheckSmallInt(self):
50        self.cur.execute("insert into test(i) values (?)", (42,))
51        self.cur.execute("select i from test")
52        row = self.cur.fetchone()
53        self.assertEqual(row[0], 42)
54
55    def CheckLargeInt(self):
56        num = 2**40
57        self.cur.execute("insert into test(i) values (?)", (num,))
58        self.cur.execute("select i from test")
59        row = self.cur.fetchone()
60        self.assertEqual(row[0], num)
61
62    def CheckFloat(self):
63        val = 3.14
64        self.cur.execute("insert into test(f) values (?)", (val,))
65        self.cur.execute("select f from test")
66        row = self.cur.fetchone()
67        self.assertEqual(row[0], val)
68
69    def CheckBlob(self):
70        sample = b"Guglhupf"
71        val = memoryview(sample)
72        self.cur.execute("insert into test(b) values (?)", (val,))
73        self.cur.execute("select b from test")
74        row = self.cur.fetchone()
75        self.assertEqual(row[0], sample)
76
77    def CheckUnicodeExecute(self):
78        self.cur.execute("select '�sterreich'")
79        row = self.cur.fetchone()
80        self.assertEqual(row[0], "�sterreich")
81
82class DeclTypesTests(unittest.TestCase):
83    class Foo:
84        def __init__(self, _val):
85            if isinstance(_val, bytes):
86                # sqlite3 always calls __init__ with a bytes created from a
87                # UTF-8 string when __conform__ was used to store the object.
88                _val = _val.decode('utf-8')
89            self.val = _val
90
91        def __eq__(self, other):
92            if not isinstance(other, DeclTypesTests.Foo):
93                return NotImplemented
94            return self.val == other.val
95
96        def __conform__(self, protocol):
97            if protocol is sqlite.PrepareProtocol:
98                return self.val
99            else:
100                return None
101
102        def __str__(self):
103            return "<%s>" % self.val
104
105    class BadConform:
106        def __init__(self, exc):
107            self.exc = exc
108        def __conform__(self, protocol):
109            raise self.exc
110
111    def setUp(self):
112        self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
113        self.cur = self.con.cursor()
114        self.cur.execute("create table test(i int, s str, f float, b bool, u unicode, foo foo, bin blob, n1 number, n2 number(5), bad bad)")
115
116        # override float, make them always return the same number
117        sqlite.converters["FLOAT"] = lambda x: 47.2
118
119        # and implement two custom ones
120        sqlite.converters["BOOL"] = lambda x: bool(int(x))
121        sqlite.converters["FOO"] = DeclTypesTests.Foo
122        sqlite.converters["BAD"] = DeclTypesTests.BadConform
123        sqlite.converters["WRONG"] = lambda x: "WRONG"
124        sqlite.converters["NUMBER"] = float
125
126    def tearDown(self):
127        del sqlite.converters["FLOAT"]
128        del sqlite.converters["BOOL"]
129        del sqlite.converters["FOO"]
130        del sqlite.converters["BAD"]
131        del sqlite.converters["WRONG"]
132        del sqlite.converters["NUMBER"]
133        self.cur.close()
134        self.con.close()
135
136    def CheckString(self):
137        # default
138        self.cur.execute("insert into test(s) values (?)", ("foo",))
139        self.cur.execute('select s as "s [WRONG]" from test')
140        row = self.cur.fetchone()
141        self.assertEqual(row[0], "foo")
142
143    def CheckSmallInt(self):
144        # default
145        self.cur.execute("insert into test(i) values (?)", (42,))
146        self.cur.execute("select i from test")
147        row = self.cur.fetchone()
148        self.assertEqual(row[0], 42)
149
150    def CheckLargeInt(self):
151        # default
152        num = 2**40
153        self.cur.execute("insert into test(i) values (?)", (num,))
154        self.cur.execute("select i from test")
155        row = self.cur.fetchone()
156        self.assertEqual(row[0], num)
157
158    def CheckFloat(self):
159        # custom
160        val = 3.14
161        self.cur.execute("insert into test(f) values (?)", (val,))
162        self.cur.execute("select f from test")
163        row = self.cur.fetchone()
164        self.assertEqual(row[0], 47.2)
165
166    def CheckBool(self):
167        # custom
168        self.cur.execute("insert into test(b) values (?)", (False,))
169        self.cur.execute("select b from test")
170        row = self.cur.fetchone()
171        self.assertIs(row[0], False)
172
173        self.cur.execute("delete from test")
174        self.cur.execute("insert into test(b) values (?)", (True,))
175        self.cur.execute("select b from test")
176        row = self.cur.fetchone()
177        self.assertIs(row[0], True)
178
179    def CheckUnicode(self):
180        # default
181        val = "\xd6sterreich"
182        self.cur.execute("insert into test(u) values (?)", (val,))
183        self.cur.execute("select u from test")
184        row = self.cur.fetchone()
185        self.assertEqual(row[0], val)
186
187    def CheckFoo(self):
188        val = DeclTypesTests.Foo("bla")
189        self.cur.execute("insert into test(foo) values (?)", (val,))
190        self.cur.execute("select foo from test")
191        row = self.cur.fetchone()
192        self.assertEqual(row[0], val)
193
194    def CheckErrorInConform(self):
195        val = DeclTypesTests.BadConform(TypeError)
196        with self.assertRaises(sqlite.InterfaceError):
197            self.cur.execute("insert into test(bad) values (?)", (val,))
198        with self.assertRaises(sqlite.InterfaceError):
199            self.cur.execute("insert into test(bad) values (:val)", {"val": val})
200
201        val = DeclTypesTests.BadConform(KeyboardInterrupt)
202        with self.assertRaises(KeyboardInterrupt):
203            self.cur.execute("insert into test(bad) values (?)", (val,))
204        with self.assertRaises(KeyboardInterrupt):
205            self.cur.execute("insert into test(bad) values (:val)", {"val": val})
206
207    def CheckUnsupportedSeq(self):
208        class Bar: pass
209        val = Bar()
210        with self.assertRaises(sqlite.InterfaceError):
211            self.cur.execute("insert into test(f) values (?)", (val,))
212
213    def CheckUnsupportedDict(self):
214        class Bar: pass
215        val = Bar()
216        with self.assertRaises(sqlite.InterfaceError):
217            self.cur.execute("insert into test(f) values (:val)", {"val": val})
218
219    def CheckBlob(self):
220        # default
221        sample = b"Guglhupf"
222        val = memoryview(sample)
223        self.cur.execute("insert into test(bin) values (?)", (val,))
224        self.cur.execute("select bin from test")
225        row = self.cur.fetchone()
226        self.assertEqual(row[0], sample)
227
228    def CheckNumber1(self):
229        self.cur.execute("insert into test(n1) values (5)")
230        value = self.cur.execute("select n1 from test").fetchone()[0]
231        # if the converter is not used, it's an int instead of a float
232        self.assertEqual(type(value), float)
233
234    def CheckNumber2(self):
235        """Checks whether converter names are cut off at '(' characters"""
236        self.cur.execute("insert into test(n2) values (5)")
237        value = self.cur.execute("select n2 from test").fetchone()[0]
238        # if the converter is not used, it's an int instead of a float
239        self.assertEqual(type(value), float)
240
241class ColNamesTests(unittest.TestCase):
242    def setUp(self):
243        self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
244        self.cur = self.con.cursor()
245        self.cur.execute("create table test(x foo)")
246
247        sqlite.converters["FOO"] = lambda x: "[%s]" % x.decode("ascii")
248        sqlite.converters["BAR"] = lambda x: "<%s>" % x.decode("ascii")
249        sqlite.converters["EXC"] = lambda x: 5/0
250        sqlite.converters["B1B1"] = lambda x: "MARKER"
251
252    def tearDown(self):
253        del sqlite.converters["FOO"]
254        del sqlite.converters["BAR"]
255        del sqlite.converters["EXC"]
256        del sqlite.converters["B1B1"]
257        self.cur.close()
258        self.con.close()
259
260    def CheckDeclTypeNotUsed(self):
261        """
262        Assures that the declared type is not used when PARSE_DECLTYPES
263        is not set.
264        """
265        self.cur.execute("insert into test(x) values (?)", ("xxx",))
266        self.cur.execute("select x from test")
267        val = self.cur.fetchone()[0]
268        self.assertEqual(val, "xxx")
269
270    def CheckNone(self):
271        self.cur.execute("insert into test(x) values (?)", (None,))
272        self.cur.execute("select x from test")
273        val = self.cur.fetchone()[0]
274        self.assertEqual(val, None)
275
276    def CheckColName(self):
277        self.cur.execute("insert into test(x) values (?)", ("xxx",))
278        self.cur.execute('select x as "x y [bar]" from test')
279        val = self.cur.fetchone()[0]
280        self.assertEqual(val, "<xxx>")
281
282        # Check if the stripping of colnames works. Everything after the first
283        # '[' (and the preceeding space) should be stripped.
284        self.assertEqual(self.cur.description[0][0], "x y")
285
286    def CheckCaseInConverterName(self):
287        self.cur.execute("select 'other' as \"x [b1b1]\"")
288        val = self.cur.fetchone()[0]
289        self.assertEqual(val, "MARKER")
290
291    def CheckCursorDescriptionNoRow(self):
292        """
293        cursor.description should at least provide the column name(s), even if
294        no row returned.
295        """
296        self.cur.execute("select * from test where 0 = 1")
297        self.assertEqual(self.cur.description[0][0], "x")
298
299    def CheckCursorDescriptionInsert(self):
300        self.cur.execute("insert into test values (1)")
301        self.assertIsNone(self.cur.description)
302
303
304@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "CTEs not supported")
305class CommonTableExpressionTests(unittest.TestCase):
306
307    def setUp(self):
308        self.con = sqlite.connect(":memory:")
309        self.cur = self.con.cursor()
310        self.cur.execute("create table test(x foo)")
311
312    def tearDown(self):
313        self.cur.close()
314        self.con.close()
315
316    def CheckCursorDescriptionCTESimple(self):
317        self.cur.execute("with one as (select 1) select * from one")
318        self.assertIsNotNone(self.cur.description)
319        self.assertEqual(self.cur.description[0][0], "1")
320
321    def CheckCursorDescriptionCTESMultipleColumns(self):
322        self.cur.execute("insert into test values(1)")
323        self.cur.execute("insert into test values(2)")
324        self.cur.execute("with testCTE as (select * from test) select * from testCTE")
325        self.assertIsNotNone(self.cur.description)
326        self.assertEqual(self.cur.description[0][0], "x")
327
328    def CheckCursorDescriptionCTE(self):
329        self.cur.execute("insert into test values (1)")
330        self.cur.execute("with bar as (select * from test) select * from test where x = 1")
331        self.assertIsNotNone(self.cur.description)
332        self.assertEqual(self.cur.description[0][0], "x")
333        self.cur.execute("with bar as (select * from test) select * from test where x = 2")
334        self.assertIsNotNone(self.cur.description)
335        self.assertEqual(self.cur.description[0][0], "x")
336
337
338class ObjectAdaptationTests(unittest.TestCase):
339    def cast(obj):
340        return float(obj)
341    cast = staticmethod(cast)
342
343    def setUp(self):
344        self.con = sqlite.connect(":memory:")
345        try:
346            del sqlite.adapters[int]
347        except:
348            pass
349        sqlite.register_adapter(int, ObjectAdaptationTests.cast)
350        self.cur = self.con.cursor()
351
352    def tearDown(self):
353        del sqlite.adapters[(int, sqlite.PrepareProtocol)]
354        self.cur.close()
355        self.con.close()
356
357    def CheckCasterIsUsed(self):
358        self.cur.execute("select ?", (4,))
359        val = self.cur.fetchone()[0]
360        self.assertEqual(type(val), float)
361
362@unittest.skipUnless(zlib, "requires zlib")
363class BinaryConverterTests(unittest.TestCase):
364    def convert(s):
365        return zlib.decompress(s)
366    convert = staticmethod(convert)
367
368    def setUp(self):
369        self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
370        sqlite.register_converter("bin", BinaryConverterTests.convert)
371
372    def tearDown(self):
373        self.con.close()
374
375    def CheckBinaryInputForConverter(self):
376        testdata = b"abcdefg" * 10
377        result = self.con.execute('select ? as "x [bin]"', (memoryview(zlib.compress(testdata)),)).fetchone()[0]
378        self.assertEqual(testdata, result)
379
380class DateTimeTests(unittest.TestCase):
381    def setUp(self):
382        self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
383        self.cur = self.con.cursor()
384        self.cur.execute("create table test(d date, ts timestamp)")
385
386    def tearDown(self):
387        self.cur.close()
388        self.con.close()
389
390    def CheckSqliteDate(self):
391        d = sqlite.Date(2004, 2, 14)
392        self.cur.execute("insert into test(d) values (?)", (d,))
393        self.cur.execute("select d from test")
394        d2 = self.cur.fetchone()[0]
395        self.assertEqual(d, d2)
396
397    def CheckSqliteTimestamp(self):
398        ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0)
399        self.cur.execute("insert into test(ts) values (?)", (ts,))
400        self.cur.execute("select ts from test")
401        ts2 = self.cur.fetchone()[0]
402        self.assertEqual(ts, ts2)
403
404    @unittest.skipIf(sqlite.sqlite_version_info < (3, 1),
405                     'the date functions are available on 3.1 or later')
406    def CheckSqlTimestamp(self):
407        now = datetime.datetime.utcnow()
408        self.cur.execute("insert into test(ts) values (current_timestamp)")
409        self.cur.execute("select ts from test")
410        ts = self.cur.fetchone()[0]
411        self.assertEqual(type(ts), datetime.datetime)
412        self.assertEqual(ts.year, now.year)
413
414    def CheckDateTimeSubSeconds(self):
415        ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 500000)
416        self.cur.execute("insert into test(ts) values (?)", (ts,))
417        self.cur.execute("select ts from test")
418        ts2 = self.cur.fetchone()[0]
419        self.assertEqual(ts, ts2)
420
421    def CheckDateTimeSubSecondsFloatingPoint(self):
422        ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 510241)
423        self.cur.execute("insert into test(ts) values (?)", (ts,))
424        self.cur.execute("select ts from test")
425        ts2 = self.cur.fetchone()[0]
426        self.assertEqual(ts, ts2)
427
428def suite():
429    sqlite_type_suite = unittest.makeSuite(SqliteTypeTests, "Check")
430    decltypes_type_suite = unittest.makeSuite(DeclTypesTests, "Check")
431    colnames_type_suite = unittest.makeSuite(ColNamesTests, "Check")
432    adaptation_suite = unittest.makeSuite(ObjectAdaptationTests, "Check")
433    bin_suite = unittest.makeSuite(BinaryConverterTests, "Check")
434    date_suite = unittest.makeSuite(DateTimeTests, "Check")
435    cte_suite = unittest.makeSuite(CommonTableExpressionTests, "Check")
436    return unittest.TestSuite((sqlite_type_suite, decltypes_type_suite, colnames_type_suite, adaptation_suite, bin_suite, date_suite, cte_suite))
437
438def test():
439    runner = unittest.TextTestRunner()
440    runner.run(suite())
441
442if __name__ == "__main__":
443    test()
444