1import os
2import sys
3import unittest
4
5
6try:
7    long
8except NameError:  # does not exist in Python 3
9    long = int
10
11
12def u(string, encoding='latin-1'):
13    """Surrogate for Unicode literals which are missing in Python 3.0-3.2"""
14    if isinstance(string, bytes):
15        string = string.decode(encoding)
16    return string
17
18
19try:
20    from unqlite import UnQLite
21except ImportError:
22    sys.stderr.write('Unable to import `unqlite`. Make sure it is properly '
23                     'installed.\n')
24    sys.stderr.flush()
25    raise
26
27
28class BaseTestCase(unittest.TestCase):
29    def setUp(self):
30        super(BaseTestCase, self).setUp()
31        self.db = UnQLite()
32        self._filename = 'test.db'
33        self.file_db = UnQLite(self._filename)
34
35    def tearDown(self):
36        if self.db.is_open:
37            self.db.close()
38        if self.file_db.is_open:
39            self.file_db.close()
40        if os.path.exists(self._filename):
41            os.unlink(self._filename)
42
43    def store_range(self, n, db=None):
44        if db is None:
45            db = self.db
46        for i in range(n):
47            db['k%s' % i] = str(i)
48
49
50class TestKeyValueStorage(BaseTestCase):
51    def test_basic_operations(self):
52        for db in (self.db, self.file_db):
53            db.store('k1', 'v1')
54            db.store('k2', 'v2')
55            self.assertEqual(db.fetch('k1'), 'v1')
56            self.assertEqual(db.fetch('k2'), 'v2')
57            self.assertRaises(KeyError, db.fetch, 'k3')
58
59            db.delete('k2')
60            self.assertRaises(KeyError, db.fetch, 'k2')
61
62            self.assertTrue(db.exists('k1'))
63            self.assertFalse(db.exists('k2'))
64
65    def test_dict_interface(self):
66        for db in (self.db, self.file_db):
67            db['k1'] = 'v1'
68            db['k2'] = 'v2'
69            self.assertEqual(db['k1'], 'v1')
70            self.assertEqual(db['k2'], 'v2')
71            self.assertRaises(KeyError, lambda: db['k3'])
72
73            del db['k2']
74            self.assertRaises(KeyError, lambda: db['k2'])
75
76            self.assertTrue('k1' in db)
77            self.assertFalse('k2' in db)
78
79    def test_append(self):
80        self.db['k1'] = 'v1'
81        self.db.append('k1', 'V1')
82        self.assertEqual(self.db['k1'], 'v1V1')
83
84        self.db.append('k2', 'V2')
85        self.assertEqual(self.db['k2'], 'V2')
86
87    def test_iteration(self):
88        self.store_range(4, self.db)
89        data = [item for item in self.db]
90        self.assertEqual(data, [
91            ('k0', '0'),
92            ('k1', '1'),
93            ('k2', '2'),
94            ('k3', '3'),
95        ])
96
97        del self.db['k2']
98        self.assertEqual([key for key, _ in self.db], ['k0', 'k1', 'k3'])
99
100    def test_file_iteration(self):
101        self.store_range(4, self.file_db)
102        data = [item for item in self.file_db]
103        self.assertEqual(data, [
104            ('k3', '3'),
105            ('k2', '2'),
106            ('k1', '1'),
107            ('k0', '0'),
108        ])
109
110        del self.file_db['k2']
111        self.assertEqual([key for key, _ in self.file_db], ['k3', 'k1', 'k0'])
112
113    def test_range(self):
114        self.store_range(10, self.db)
115        data = [item for item in self.db.range('k4', 'k6')]
116        self.assertEqual(data, [
117            ('k4', '4'),
118            ('k5', '5'),
119            ('k6', '6'),
120        ])
121
122        data = [item for item in self.db.range('k8', 'kX')]
123        self.assertEqual(data, [
124            ('k8', '8'),
125            ('k9', '9'),
126        ])
127
128        def invalid_start():
129            data = [item for item in self.db.range('kx', 'k2')]
130        self.assertRaises(KeyError, invalid_start)
131
132    def test_file_range(self):
133        self.store_range(10, self.file_db)
134        data = [item for item in self.file_db.range('k6', 'k4')]
135        self.assertEqual(data, [
136            ('k6', '6'),
137            ('k5', '5'),
138            ('k4', '4'),
139        ])
140
141        data = [item for item in self.file_db.range('k2', 'k0')]
142        self.assertEqual(data, [
143            ('k2', '2'),
144            ('k1', '1'),
145            ('k0', '0'),
146        ])
147
148        def invalid_start():
149            data = [item for item in self.file_db.range('kx', 'k2')]
150        self.assertRaises(KeyError, invalid_start)
151
152    def test_flush(self):
153        for db in (self.db, self.file_db):
154            self.store_range(10, db)
155            self.assertEqual(len(list(db)), 10)
156            db.flush()
157            self.assertEqual(list(db), [])
158
159    def test_len(self):
160        for db in (self.db, self.file_db):
161            self.store_range(10, db)
162            self.assertEqual(len(db), 10)
163            db.flush()
164            self.assertEqual(len(db), 0)
165            db['a'] = 'A'
166            db['b'] = 'B'
167            db['b'] = 'Bb'
168            self.assertEqual(len(db), 2)
169
170    def test_autocommit(self):
171        self.file_db['k1'] = 'v1'
172        self.file_db.close()
173        self.file_db.open()
174        self.assertEqual(self.file_db['k1'], 'v1')
175
176        self.file_db.disable_autocommit()
177        self.file_db['k2'] = 'v2'
178        self.file_db.close()
179        self.file_db.open()
180        self.assertRaises(KeyError, lambda: self.file_db['k2'])
181
182    def test_dict_methods(self):
183        for db in (self.db, self.file_db):
184            self.store_range(3, db)
185            self.assertEqual(sorted(db.keys()), ['k0', 'k1', 'k2'])
186            self.assertEqual(sorted(db.values()), ['0', '1', '2'])
187            self.assertEqual(sorted(db.items()), [
188                ('k0', '0'),
189                ('k1', '1'),
190                ('k2', '2')])
191
192            db.update({'foo': 'bar', 'baz': 'nug'})
193            self.assertEqual(db['foo'], 'bar')
194            self.assertEqual(db['baz'], 'nug')
195
196    def test_byte_strings(self):
197        byte_data = [
198            (b'k\xe4se', b'sp\xe4tzle'),
199            (b'kn\xf6dli', b'br\xf6tli'),
200            (b'w\xfcrstel', b's\xfclzli')]
201        for db in (self.db, self.file_db):
202            for k, v in byte_data:
203                db.store(k, v)
204            for k, v in byte_data:
205                w = db.fetch(k)
206                self.assertTrue(isinstance(w, bytes))
207                self.assertEqual(w, v)
208
209    def test_unicode_strings(self):
210        unicode_data = [
211            (u('k\xe4se'), u('sp\xe4tzle')),
212            (u('kn\xf6dli'), u('br\xf6tli')),
213            (u('w\xfcrstel'), u('s\xfclzli'))]
214        for db in (self.db, self.file_db):
215            for k, v in unicode_data:
216                db.store(k, v)
217            for k, v in unicode_data:
218                w = db.fetch(k)
219                self.assertTrue(isinstance(w, str))
220                if str is bytes:
221                    w = w.decode('utf-8')
222                self.assertEqual(w, v)
223
224
225class TestTransaction(BaseTestCase):
226    """
227    We must use a file-based database to test the transaction functions. See
228    http://unqlite.org/forum/trouble-with-transactions+1 for details.
229    """
230    def test_transaction(self):
231        @self.file_db.commit_on_success
232        def _test_success(key, value):
233            self.file_db[key] = value
234
235        @self.file_db.commit_on_success
236        def _test_failure(key, value):
237            self.file_db[key] = value
238            raise Exception('intentional exception raised')
239
240        _test_success('k1', 'v1')
241        self.assertEqual(self.file_db['k1'], 'v1')
242
243        self.assertRaises(Exception , lambda: _test_failure('k2', 'v2'))
244        self.assertRaises(KeyError, lambda: self.file_db['k2'])
245
246    def test_context_manager(self):
247        with self.file_db.transaction():
248            self.file_db['foo'] = 'bar'
249
250        self.assertEqual(self.file_db['foo'], 'bar')
251
252        with self.file_db.transaction():
253            self.file_db['baz'] = 'nug'
254            self.file_db.rollback()
255
256        self.assertRaises(KeyError, lambda: self.file_db['baz'])
257
258    def test_explicit_transaction(self):
259        self.file_db.close()
260        self.file_db.open()
261        self.file_db.begin()
262        self.file_db['k1'] = 'v1'
263        self.file_db.rollback()
264
265        self.assertRaises(KeyError, lambda: self.file_db['k1'])
266
267
268class TestCursor(BaseTestCase):
269    def setUp(self):
270        super(TestCursor, self).setUp()
271        for db in (self.db, self.file_db):
272            self.store_range(10, db)
273
274    def assertIndex(self, cursor, idx):
275        self.assertTrue(cursor.is_valid())
276        self.assertEqual(cursor.key(), 'k%d' % idx)
277        self.assertEqual(cursor.value(), str(idx))
278
279    def test_cursor_basic(self):
280        cursor = self.db.cursor()
281        self.assertIndex(cursor, 0)
282        cursor.next_entry()
283        self.assertIndex(cursor, 1)
284        cursor.last()
285        self.assertIndex(cursor, 9)
286        cursor.previous_entry()
287        self.assertIndex(cursor, 8)
288        cursor.first()
289        self.assertIndex(cursor, 0)
290        cursor.delete()
291        self.assertIndex(cursor, 1)
292        del cursor
293
294    def test_cursor_basic_file(self):
295        cursor = self.file_db.cursor()
296        cursor.first()
297        self.assertIndex(cursor, 9)
298        cursor.next_entry()
299        self.assertIndex(cursor, 8)
300        cursor.last()
301        self.assertIndex(cursor, 0)
302        cursor.previous_entry()
303        self.assertIndex(cursor, 1)
304        cursor.delete()
305        self.assertIndex(cursor, 0)
306        cursor.previous_entry()
307        self.assertIndex(cursor, 2)
308        cursor.next_entry()
309        self.assertRaises(StopIteration, cursor.next_entry)
310
311    def test_cursor_iteration(self):
312        with self.db.cursor() as cursor:
313            cursor.seek('k4')
314            cursor.delete()
315            cursor.reset()
316            results = [item for item in cursor]
317            self.assertEqual(results, [
318                ('k0', '0'),
319                ('k1', '1'),
320                ('k2', '2'),
321                ('k3', '3'),
322                ('k5', '5'),
323                ('k6', '6'),
324                ('k7', '7'),
325                ('k8', '8'),
326                ('k9', '9'),
327            ])
328
329            cursor.seek('k5')
330            self.assertEqual(cursor.value(), '5')
331            keys = [key for key, _ in cursor]
332            self.assertEqual(keys, ['k5', 'k6', 'k7', 'k8', 'k9'])
333
334        with self.db.cursor() as cursor:
335            self.assertRaises(Exception, cursor.seek, 'k4')
336            cursor.seek('k5')
337            keys = []
338            while True:
339                key = cursor.key()
340                keys.append(key)
341                if key == 'k7':
342                    break
343                else:
344                    cursor.next_entry()
345        self.assertEqual(keys, ['k5', 'k6', 'k7'])
346
347        # New items are appended to the end of the database.
348        del self.db['k5']
349        del self.db['k9']
350        del self.db['k7']
351        self.db['a0'] = 'x0'
352        self.db['k5'] = 'x5'
353
354        with self.db.cursor() as cursor:
355            self.assertEqual(cursor.key(), 'k0')
356            items = [k for k, _ in cursor]
357            self.assertEqual(
358                items,
359                ['k0', 'k1', 'k2', 'k3', 'k6', 'k8', 'a0', 'k5'])
360
361
362class TestJx9(BaseTestCase):
363    def test_simple_compilation(self):
364        script = """
365            $collection = 'users';
366            if (!db_exists($collection)) {
367                db_create($collection);
368            }
369            db_store($collection, {"username": "huey", "age": 3});
370            $huey_id = db_last_record_id($collection);
371            db_store($collection, {"username": "mickey", "age": 5});
372            $mickey_id = db_last_record_id($collection);
373            $something = 'hello world';
374            $users = db_fetch_all($collection);
375            $nested = {
376                "k1": {"foo": [1, 2, 3]},
377                "k2": ["v2", ["v3", "v4"]]};
378        """
379
380        with self.db.vm(script) as vm:
381            vm.execute()
382            self.assertEqual(vm['huey_id'], 0)
383            self.assertEqual(vm['mickey_id'], 1)
384            self.assertEqual(vm['something'], 'hello world')
385
386            users = vm['users']
387            self.assertEqual(users, [
388                {'__id': 0, 'age': 3, 'username': 'huey'},
389                {'__id': 1, 'age': 5, 'username': 'mickey'},
390            ])
391
392            nested = vm['nested']
393            self.assertEqual(nested, {
394                'k1': {'foo': [1, 2, 3]},
395                'k2': ['v2', ['v3', 'v4']]})
396
397    def test_setting_values(self):
398        script = """
399            $collection = 'users';
400            db_create($collection);
401            db_store($collection, $values);
402            $users = db_fetch_all($collection);
403        """
404        values = [
405            {'username': 'hubie', 'color': 'white'},
406            {'username': 'michael', 'color': 'black'},
407        ]
408
409        with self.db.vm(script) as vm:
410            vm['values'] = values
411            vm.execute()
412
413            users = vm['users']
414            self.assertEqual(users, [
415                {'username': 'hubie', 'color': 'white', '__id': 0},
416                {'username': 'michael', 'color': 'black', '__id': 1},
417            ])
418
419
420class TestUtils(BaseTestCase):
421    def test_random(self):
422        ri = self.db.random_int()
423        self.assertTrue(isinstance(ri, (int, long)))
424
425        rs = self.db.random_string(10)
426        self.assertEqual(len(rs), 10)
427
428
429class TestCollection(BaseTestCase):
430    def test_basic_crud_mem(self):
431        self._test_basic_crud(self.db)
432
433    def test_basic_crud_file(self):
434        self._test_basic_crud(self.file_db)
435
436    def _test_basic_crud(self, db):
437        users = db.collection('users')
438        users.create()
439
440        self.assertEqual(users.store({'username': 'huey'}), 0)
441        self.assertEqual(users.fetch(users.last_record_id()), {
442            '__id': 0,
443            'username': 'huey'})
444
445        self.assertEqual(users.store({'username': u('mickey')}), 1)
446        self.assertEqual(users.fetch(users.last_record_id()), {
447            '__id': 1,
448            'username': u('mickey')})
449
450        user_list = users.all()
451        self.assertEqual(user_list, [
452            {'__id': 0, 'username': 'huey'},
453            {'__id': 1, 'username': 'mickey'},
454        ])
455
456        users.delete(1)
457        self.assertEqual(users[0], {'__id': 0, 'username': 'huey'})
458        self.assertTrue(users[1] is None)
459
460        ret = users.update(0, {'color': 'white', 'name': 'hueybear'})
461        self.assertTrue(ret)
462        self.assertEqual(users[0], {
463            '__id': 0,
464            'color': 'white',
465            'name': 'hueybear',
466        })
467
468        ret = users.update(1, {'name': 'zaizee'})
469        self.assertFalse(ret)
470        self.assertTrue(users[1] is None)
471
472        self.assertEqual(users.all(), [
473            {'__id': 0, 'color': 'white', 'name': 'hueybear'},
474        ])
475
476    def test_basic_operations_mem(self):
477        self._test_basic_operations(self.db)
478
479    def test_basic_operations_file(self):
480        self._test_basic_operations(self.file_db)
481
482    def _test_basic_operations(self, db):
483        users = db.collection('users')
484        self.assertFalse(users.exists())
485        users.create()
486        self.assertTrue(users.exists())
487        self.assertEqual(len(users), 0)
488
489        user_data = [
490            {'name': 'charlie', 'activities': ['coding', 'reading']},
491            {'name': 'huey', 'activities': ['playing', 'sleeping']},
492            {'name': 'mickey', 'activities': ['sleeping', 'hunger']}]
493
494        users.store(user_data)
495        self.assertEqual(len(users), 3)
496
497        users_with_ids = [record.copy() for record in user_data]
498        for idx, record in enumerate(users_with_ids):
499            record['__id'] = idx
500
501        results = users.all()
502        self.assertEqual(results, users_with_ids)
503
504        users.store({'name': 'leslie', 'activities': ['reading', 'surgery']})
505        self.assertEqual(len(users), 4)
506
507        record = users.fetch_current()
508        self.assertEqual(record['name'], 'charlie')
509
510        self.assertEqual(users.fetch(3), {
511            'name': 'leslie',
512            'activities': ['reading', 'surgery'],
513            '__id': 3})
514
515        users.delete(0)
516        users.delete(2)
517        users.delete(3)
518        self.assertEqual(users.all(), [
519            {'name': 'huey', 'activities': ['playing', 'sleeping'], '__id': 1}
520        ])
521
522        self.assertTrue(users[99] is None)
523
524    def test_unicode_key(self):
525        users = self.db.collection('users')
526        users.create()
527        self.assertEqual(users.store({u('key'): u('value')}), 0)
528        self.assertEqual(users.fetch(users.last_record_id()), {
529            '__id': 0,
530            'key': 'value',
531        })
532
533    def test_filtering(self):
534        values = self.db.collection('values')
535        values.create()
536        value_data = [{'val': i} for i in range(20)]
537        values.store(value_data)
538        self.assertEqual(len(values), 20)
539
540        filtered = values.filter(lambda obj: obj['val'] in range(7, 12))
541        self.assertEqual(filtered, [
542            {'__id': 7, 'val': 7},
543            {'__id': 8, 'val': 8},
544            {'__id': 9, 'val': 9},
545            {'__id': 10, 'val': 10},
546            {'__id': 11, 'val': 11},
547        ])
548
549    def test_odd_values_mem(self):
550        self._test_odd_values(self.db)
551
552    def test_odd_values_file(self):
553        self._test_odd_values(self.file_db)
554
555    def _test_odd_values(self, db):
556        coll = db.collection('testing')
557        coll.create()
558        coll.store({1: 2})
559        res = coll.fetch(coll.last_record_id())
560        self.assertEqual(res, [2, 0])
561
562        coll.drop()
563
564        # Try storing in non-existent collection?
565        self.assertRaises(ValueError, lambda: coll.store({'f': 'f'}))
566
567    def test_data_type_integrity(self):
568        coll = self.db.collection('testing')
569        coll.create()
570
571        self.assertEqual(coll.store({
572            'a': 'A',
573            'b': 2,
574            'c': 3.1,
575            'd': True,
576            'e': False,
577            'f': 0}), 0)
578
579        res = coll.fetch(coll.last_record_id())
580        self.assertEqual(res, {
581            'a': 'A',
582            'b': 2,
583            'c': 3.1,
584            'd': True,
585            'e': False,
586            'f': 0,
587            '__id': 0})
588        self.assertTrue(isinstance(res['d'], bool))
589        self.assertTrue(isinstance(res['e'], bool))
590        self.assertTrue(isinstance(res['f'], int))
591
592
593if __name__ == '__main__':
594    unittest.main(argv=sys.argv)
595