1"""
2TestCases for python DB duplicate and Btree key comparison function.
3"""
4
5import sys, os, re
6import test_all
7from cStringIO import StringIO
8
9import unittest
10
11from test_all import db, dbshelve, test_support, \
12        get_new_environment_path, get_new_database_path
13
14
15# Needed for python 3. "cmp" vanished in 3.0.1
16def cmp(a, b) :
17    if a==b : return 0
18    if a<b : return -1
19    return 1
20
21lexical_cmp = cmp
22
23def lowercase_cmp(left, right) :
24    return cmp(left.lower(), right.lower())
25
26def make_reverse_comparator(cmp) :
27    def reverse(left, right, delegate=cmp) :
28        return - delegate(left, right)
29    return reverse
30
31_expected_lexical_test_data = ['', 'CCCP', 'a', 'aaa', 'b', 'c', 'cccce', 'ccccf']
32_expected_lowercase_test_data = ['', 'a', 'aaa', 'b', 'c', 'CC', 'cccce', 'ccccf', 'CCCP']
33
34class ComparatorTests(unittest.TestCase) :
35    def comparator_test_helper(self, comparator, expected_data) :
36        data = expected_data[:]
37
38        import sys
39        if sys.version_info < (2, 6) :
40            data.sort(cmp=comparator)
41        else :  # Insertion Sort. Please, improve
42            data2 = []
43            for i in data :
44                for j, k in enumerate(data2) :
45                    r = comparator(k, i)
46                    if r == 1 :
47                        data2.insert(j, i)
48                        break
49                else :
50                    data2.append(i)
51            data = data2
52
53        self.assertEqual(data, expected_data,
54                         "comparator `%s' is not right: %s vs. %s"
55                         % (comparator, expected_data, data))
56    def test_lexical_comparator(self) :
57        self.comparator_test_helper(lexical_cmp, _expected_lexical_test_data)
58    def test_reverse_lexical_comparator(self) :
59        rev = _expected_lexical_test_data[:]
60        rev.reverse()
61        self.comparator_test_helper(make_reverse_comparator(lexical_cmp),
62                                     rev)
63    def test_lowercase_comparator(self) :
64        self.comparator_test_helper(lowercase_cmp,
65                                     _expected_lowercase_test_data)
66
67class AbstractBtreeKeyCompareTestCase(unittest.TestCase) :
68    env = None
69    db = None
70
71    if (sys.version_info < (2, 7)) or ((sys.version_info >= (3,0)) and
72            (sys.version_info < (3, 2))) :
73        def assertLess(self, a, b, msg=None) :
74            return self.assertTrue(a<b, msg=msg)
75
76    def setUp(self) :
77        self.filename = self.__class__.__name__ + '.db'
78        self.homeDir = get_new_environment_path()
79        env = db.DBEnv()
80        env.open(self.homeDir,
81                  db.DB_CREATE | db.DB_INIT_MPOOL
82                  | db.DB_INIT_LOCK | db.DB_THREAD)
83        self.env = env
84
85    def tearDown(self) :
86        self.closeDB()
87        if self.env is not None:
88            self.env.close()
89            self.env = None
90        test_support.rmtree(self.homeDir)
91
92    def addDataToDB(self, data) :
93        i = 0
94        for item in data:
95            self.db.put(item, str(i))
96            i = i + 1
97
98    def createDB(self, key_comparator) :
99        self.db = db.DB(self.env)
100        self.setupDB(key_comparator)
101        self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE)
102
103    def setupDB(self, key_comparator) :
104        self.db.set_bt_compare(key_comparator)
105
106    def closeDB(self) :
107        if self.db is not None:
108            self.db.close()
109            self.db = None
110
111    def startTest(self) :
112        pass
113
114    def finishTest(self, expected = None) :
115        if expected is not None:
116            self.check_results(expected)
117        self.closeDB()
118
119    def check_results(self, expected) :
120        curs = self.db.cursor()
121        try:
122            index = 0
123            rec = curs.first()
124            while rec:
125                key, ignore = rec
126                self.assertLess(index, len(expected),
127                                 "to many values returned from cursor")
128                self.assertEqual(expected[index], key,
129                                 "expected value `%s' at %d but got `%s'"
130                                 % (expected[index], index, key))
131                index = index + 1
132                rec = curs.next()
133            self.assertEqual(index, len(expected),
134                             "not enough values returned from cursor")
135        finally:
136            curs.close()
137
138class BtreeKeyCompareTestCase(AbstractBtreeKeyCompareTestCase) :
139    def runCompareTest(self, comparator, data) :
140        self.startTest()
141        self.createDB(comparator)
142        self.addDataToDB(data)
143        self.finishTest(data)
144
145    def test_lexical_ordering(self) :
146        self.runCompareTest(lexical_cmp, _expected_lexical_test_data)
147
148    def test_reverse_lexical_ordering(self) :
149        expected_rev_data = _expected_lexical_test_data[:]
150        expected_rev_data.reverse()
151        self.runCompareTest(make_reverse_comparator(lexical_cmp),
152                             expected_rev_data)
153
154    def test_compare_function_useless(self) :
155        self.startTest()
156        def socialist_comparator(l, r) :
157            return 0
158        self.createDB(socialist_comparator)
159        self.addDataToDB(['b', 'a', 'd'])
160        # all things being equal the first key will be the only key
161        # in the database...  (with the last key's value fwiw)
162        self.finishTest(['b'])
163
164
165class BtreeExceptionsTestCase(AbstractBtreeKeyCompareTestCase) :
166    def test_raises_non_callable(self) :
167        self.startTest()
168        self.assertRaises(TypeError, self.createDB, 'abc')
169        self.assertRaises(TypeError, self.createDB, None)
170        self.finishTest()
171
172    def test_set_bt_compare_with_function(self) :
173        self.startTest()
174        self.createDB(lexical_cmp)
175        self.finishTest()
176
177    def check_results(self, results) :
178        pass
179
180    def test_compare_function_incorrect(self) :
181        self.startTest()
182        def bad_comparator(l, r) :
183            return 1
184        # verify that set_bt_compare checks that comparator('', '') == 0
185        self.assertRaises(TypeError, self.createDB, bad_comparator)
186        self.finishTest()
187
188    def verifyStderr(self, method, successRe) :
189        """
190        Call method() while capturing sys.stderr output internally and
191        call self.fail() if successRe.search() does not match the stderr
192        output.  This is used to test for uncatchable exceptions.
193        """
194        stdErr = sys.stderr
195        sys.stderr = StringIO()
196        try:
197            method()
198        finally:
199            temp = sys.stderr
200            sys.stderr = stdErr
201            errorOut = temp.getvalue()
202            if not successRe.search(errorOut) :
203                self.fail("unexpected stderr output:\n"+errorOut)
204        if sys.version_info < (3, 0) :  # XXX: How to do this in Py3k ???
205            sys.exc_traceback = sys.last_traceback = None
206
207    def _test_compare_function_exception(self) :
208        self.startTest()
209        def bad_comparator(l, r) :
210            if l == r:
211                # pass the set_bt_compare test
212                return 0
213            raise RuntimeError, "i'm a naughty comparison function"
214        self.createDB(bad_comparator)
215        #print "\n*** test should print 2 uncatchable tracebacks ***"
216        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
217        self.finishTest()
218
219    def test_compare_function_exception(self) :
220        self.verifyStderr(
221                self._test_compare_function_exception,
222                re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S)
223        )
224
225    def _test_compare_function_bad_return(self) :
226        self.startTest()
227        def bad_comparator(l, r) :
228            if l == r:
229                # pass the set_bt_compare test
230                return 0
231            return l
232        self.createDB(bad_comparator)
233        #print "\n*** test should print 2 errors about returning an int ***"
234        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
235        self.finishTest()
236
237    def test_compare_function_bad_return(self) :
238        self.verifyStderr(
239                self._test_compare_function_bad_return,
240                re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S)
241        )
242
243
244    def test_cannot_assign_twice(self) :
245
246        def my_compare(a, b) :
247            return 0
248
249        self.startTest()
250        self.createDB(my_compare)
251        self.assertRaises(RuntimeError, self.db.set_bt_compare, my_compare)
252
253class AbstractDuplicateCompareTestCase(unittest.TestCase) :
254    env = None
255    db = None
256
257    if (sys.version_info < (2, 7)) or ((sys.version_info >= (3,0)) and
258            (sys.version_info < (3, 2))) :
259        def assertLess(self, a, b, msg=None) :
260            return self.assertTrue(a<b, msg=msg)
261
262    def setUp(self) :
263        self.filename = self.__class__.__name__ + '.db'
264        self.homeDir = get_new_environment_path()
265        env = db.DBEnv()
266        env.open(self.homeDir,
267                  db.DB_CREATE | db.DB_INIT_MPOOL
268                  | db.DB_INIT_LOCK | db.DB_THREAD)
269        self.env = env
270
271    def tearDown(self) :
272        self.closeDB()
273        if self.env is not None:
274            self.env.close()
275            self.env = None
276        test_support.rmtree(self.homeDir)
277
278    def addDataToDB(self, data) :
279        for item in data:
280            self.db.put("key", item)
281
282    def createDB(self, dup_comparator) :
283        self.db = db.DB(self.env)
284        self.setupDB(dup_comparator)
285        self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE)
286
287    def setupDB(self, dup_comparator) :
288        self.db.set_flags(db.DB_DUPSORT)
289        self.db.set_dup_compare(dup_comparator)
290
291    def closeDB(self) :
292        if self.db is not None:
293            self.db.close()
294            self.db = None
295
296    def startTest(self) :
297        pass
298
299    def finishTest(self, expected = None) :
300        if expected is not None:
301            self.check_results(expected)
302        self.closeDB()
303
304    def check_results(self, expected) :
305        curs = self.db.cursor()
306        try:
307            index = 0
308            rec = curs.first()
309            while rec:
310                ignore, data = rec
311                self.assertLess(index, len(expected),
312                                 "to many values returned from cursor")
313                self.assertEqual(expected[index], data,
314                                 "expected value `%s' at %d but got `%s'"
315                                 % (expected[index], index, data))
316                index = index + 1
317                rec = curs.next()
318            self.assertEqual(index, len(expected),
319                             "not enough values returned from cursor")
320        finally:
321            curs.close()
322
323class DuplicateCompareTestCase(AbstractDuplicateCompareTestCase) :
324    def runCompareTest(self, comparator, data) :
325        self.startTest()
326        self.createDB(comparator)
327        self.addDataToDB(data)
328        self.finishTest(data)
329
330    def test_lexical_ordering(self) :
331        self.runCompareTest(lexical_cmp, _expected_lexical_test_data)
332
333    def test_reverse_lexical_ordering(self) :
334        expected_rev_data = _expected_lexical_test_data[:]
335        expected_rev_data.reverse()
336        self.runCompareTest(make_reverse_comparator(lexical_cmp),
337                             expected_rev_data)
338
339class DuplicateExceptionsTestCase(AbstractDuplicateCompareTestCase) :
340    def test_raises_non_callable(self) :
341        self.startTest()
342        self.assertRaises(TypeError, self.createDB, 'abc')
343        self.assertRaises(TypeError, self.createDB, None)
344        self.finishTest()
345
346    def test_set_dup_compare_with_function(self) :
347        self.startTest()
348        self.createDB(lexical_cmp)
349        self.finishTest()
350
351    def check_results(self, results) :
352        pass
353
354    def test_compare_function_incorrect(self) :
355        self.startTest()
356        def bad_comparator(l, r) :
357            return 1
358        # verify that set_dup_compare checks that comparator('', '') == 0
359        self.assertRaises(TypeError, self.createDB, bad_comparator)
360        self.finishTest()
361
362    def test_compare_function_useless(self) :
363        self.startTest()
364        def socialist_comparator(l, r) :
365            return 0
366        self.createDB(socialist_comparator)
367        # DUPSORT does not allow "duplicate duplicates"
368        self.assertRaises(db.DBKeyExistError, self.addDataToDB, ['b', 'a', 'd'])
369        self.finishTest()
370
371    def verifyStderr(self, method, successRe) :
372        """
373        Call method() while capturing sys.stderr output internally and
374        call self.fail() if successRe.search() does not match the stderr
375        output.  This is used to test for uncatchable exceptions.
376        """
377        stdErr = sys.stderr
378        sys.stderr = StringIO()
379        try:
380            method()
381        finally:
382            temp = sys.stderr
383            sys.stderr = stdErr
384            errorOut = temp.getvalue()
385            if not successRe.search(errorOut) :
386                self.fail("unexpected stderr output:\n"+errorOut)
387        if sys.version_info < (3, 0) :  # XXX: How to do this in Py3k ???
388            sys.exc_traceback = sys.last_traceback = None
389
390    def _test_compare_function_exception(self) :
391        self.startTest()
392        def bad_comparator(l, r) :
393            if l == r:
394                # pass the set_dup_compare test
395                return 0
396            raise RuntimeError, "i'm a naughty comparison function"
397        self.createDB(bad_comparator)
398        #print "\n*** test should print 2 uncatchable tracebacks ***"
399        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
400        self.finishTest()
401
402    def test_compare_function_exception(self) :
403        self.verifyStderr(
404                self._test_compare_function_exception,
405                re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S)
406        )
407
408    def _test_compare_function_bad_return(self) :
409        self.startTest()
410        def bad_comparator(l, r) :
411            if l == r:
412                # pass the set_dup_compare test
413                return 0
414            return l
415        self.createDB(bad_comparator)
416        #print "\n*** test should print 2 errors about returning an int ***"
417        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
418        self.finishTest()
419
420    def test_compare_function_bad_return(self) :
421        self.verifyStderr(
422                self._test_compare_function_bad_return,
423                re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S)
424        )
425
426
427    def test_cannot_assign_twice(self) :
428
429        def my_compare(a, b) :
430            return 0
431
432        self.startTest()
433        self.createDB(my_compare)
434        self.assertRaises(RuntimeError, self.db.set_dup_compare, my_compare)
435
436def test_suite() :
437    res = unittest.TestSuite()
438
439    res.addTest(unittest.makeSuite(ComparatorTests))
440    res.addTest(unittest.makeSuite(BtreeExceptionsTestCase))
441    res.addTest(unittest.makeSuite(BtreeKeyCompareTestCase))
442    res.addTest(unittest.makeSuite(DuplicateExceptionsTestCase))
443    res.addTest(unittest.makeSuite(DuplicateCompareTestCase))
444    return res
445
446if __name__ == '__main__':
447    unittest.main(defaultTest = 'suite')
448