1"""
2Copyright (c) 2008-2020, Jesus Cea Avion <jcea@jcea.es>
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions
7are met:
8
9    1. Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11
12    2. Redistributions in binary form must reproduce the above
13    copyright notice, this list of conditions and the following
14    disclaimer in the documentation and/or other materials provided
15    with the distribution.
16
17    3. Neither the name of Jesus Cea Avion nor the names of its
18    contributors may be used to endorse or promote products derived
19    from this software without specific prior written permission.
20
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
22    CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
23    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24    MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
26    BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
27    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
28            TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29            DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30    ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
31    TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
32    THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
33    SUCH DAMAGE.
34    """
35
36"""
37TestCases for python DB duplicate and Btree key comparison function.
38"""
39
40import sys, os, re
41import test_all
42from cStringIO import StringIO
43
44import unittest
45
46from test_all import db, dbshelve, test_support, \
47        get_new_environment_path, get_new_database_path
48
49
50# Needed for python 3. "cmp" vanished in 3.0.1
51def cmp(a, b) :
52    if a==b : return 0
53    if a<b : return -1
54    return 1
55
56lexical_cmp = cmp
57
58def lowercase_cmp(left, right) :
59    return cmp(left.lower(), right.lower())
60
61def make_reverse_comparator(cmp) :
62    def reverse(left, right, delegate=cmp) :
63        return - delegate(left, right)
64    return reverse
65
66_expected_lexical_test_data = ['', 'CCCP', 'a', 'aaa', 'b', 'c', 'cccce', 'ccccf']
67_expected_lowercase_test_data = ['', 'a', 'aaa', 'b', 'c', 'CC', 'cccce', 'ccccf', 'CCCP']
68
69class ComparatorTests(unittest.TestCase) :
70    def comparator_test_helper(self, comparator, expected_data) :
71        data = expected_data[:]
72
73        import sys
74        # Insertion Sort. Please, improve
75        data2 = []
76        for i in data :
77            for j, k in enumerate(data2) :
78                r = comparator(k, i)
79                if r == 1 :
80                    data2.insert(j, i)
81                    break
82            else :
83                data2.append(i)
84        data = data2
85
86        self.assertEqual(data, expected_data,
87                         "comparator `%s' is not right: %s vs. %s"
88                         % (comparator, expected_data, data))
89    def test_lexical_comparator(self) :
90        self.comparator_test_helper(lexical_cmp, _expected_lexical_test_data)
91    def test_reverse_lexical_comparator(self) :
92        rev = _expected_lexical_test_data[:]
93        rev.reverse()
94        self.comparator_test_helper(make_reverse_comparator(lexical_cmp),
95                                     rev)
96    def test_lowercase_comparator(self) :
97        self.comparator_test_helper(lowercase_cmp,
98                                     _expected_lowercase_test_data)
99
100class AbstractBtreeKeyCompareTestCase(unittest.TestCase) :
101    env = None
102    db = None
103
104    if sys.version_info < (2, 7) :
105        def assertLess(self, a, b, msg=None) :
106            return self.assertTrue(a<b, msg=msg)
107
108    def setUp(self) :
109        self.filename = self.__class__.__name__ + '.db'
110        self.homeDir = get_new_environment_path()
111        env = db.DBEnv()
112        env.open(self.homeDir,
113                  db.DB_CREATE | db.DB_INIT_MPOOL
114                  | db.DB_INIT_LOCK | db.DB_THREAD)
115        self.env = env
116
117    def tearDown(self) :
118        self.closeDB()
119        if self.env is not None:
120            self.env.close()
121            self.env = None
122        test_support.rmtree(self.homeDir)
123
124    def addDataToDB(self, data) :
125        i = 0
126        for item in data:
127            self.db.put(item, str(i))
128            i = i + 1
129
130    def createDB(self, key_comparator) :
131        self.db = db.DB(self.env)
132        self.setupDB(key_comparator)
133        self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE)
134
135    def setupDB(self, key_comparator) :
136        self.db.set_bt_compare(key_comparator)
137
138    def closeDB(self) :
139        if self.db is not None:
140            self.db.close()
141            self.db = None
142
143    def startTest(self) :
144        pass
145
146    def finishTest(self, expected = None) :
147        if expected is not None:
148            self.check_results(expected)
149        self.closeDB()
150
151    def check_results(self, expected) :
152        curs = self.db.cursor()
153        try:
154            index = 0
155            rec = curs.first()
156            while rec:
157                key, ignore = rec
158                self.assertLess(index, len(expected),
159                                 "to many values returned from cursor")
160                self.assertEqual(expected[index], key,
161                                 "expected value `%s' at %d but got `%s'"
162                                 % (expected[index], index, key))
163                index = index + 1
164                rec = curs.next()
165            self.assertEqual(index, len(expected),
166                             "not enough values returned from cursor")
167        finally:
168            curs.close()
169
170class BtreeKeyCompareTestCase(AbstractBtreeKeyCompareTestCase) :
171    def runCompareTest(self, comparator, data) :
172        self.startTest()
173        self.createDB(comparator)
174        self.addDataToDB(data)
175        self.finishTest(data)
176
177    def test_lexical_ordering(self) :
178        self.runCompareTest(lexical_cmp, _expected_lexical_test_data)
179
180    def test_reverse_lexical_ordering(self) :
181        expected_rev_data = _expected_lexical_test_data[:]
182        expected_rev_data.reverse()
183        self.runCompareTest(make_reverse_comparator(lexical_cmp),
184                             expected_rev_data)
185
186    def test_compare_function_useless(self) :
187        self.startTest()
188        def socialist_comparator(l, r) :
189            return 0
190        self.createDB(socialist_comparator)
191        self.addDataToDB(['b', 'a', 'd'])
192        # all things being equal the first key will be the only key
193        # in the database...  (with the last key's value fwiw)
194        self.finishTest(['b'])
195
196
197class BtreeExceptionsTestCase(AbstractBtreeKeyCompareTestCase) :
198    def test_raises_non_callable(self) :
199        self.startTest()
200        self.assertRaises(TypeError, self.createDB, 'abc')
201        self.assertRaises(TypeError, self.createDB, None)
202        self.finishTest()
203
204    def test_set_bt_compare_with_function(self) :
205        self.startTest()
206        self.createDB(lexical_cmp)
207        self.finishTest()
208
209    def check_results(self, results) :
210        pass
211
212    def test_compare_function_incorrect(self) :
213        self.startTest()
214        def bad_comparator(l, r) :
215            return 1
216        # verify that set_bt_compare checks that comparator('', '') == 0
217        self.assertRaises(TypeError, self.createDB, bad_comparator)
218        self.finishTest()
219
220    def verifyStderr(self, method, successRe) :
221        """
222        Call method() while capturing sys.stderr output internally and
223        call self.fail() if successRe.search() does not match the stderr
224        output.  This is used to test for uncatchable exceptions.
225        """
226        stdErr = sys.stderr
227        sys.stderr = StringIO()
228        try:
229            method()
230        finally:
231            temp = sys.stderr
232            sys.stderr = stdErr
233            errorOut = temp.getvalue()
234            if not successRe.search(errorOut) :
235                self.fail("unexpected stderr output:\n"+errorOut)
236        if sys.version_info < (3, 0) :  # XXX: How to do this in Py3k ???
237            sys.exc_traceback = sys.last_traceback = None
238
239    def _test_compare_function_exception(self) :
240        self.startTest()
241        def bad_comparator(l, r) :
242            if l == r:
243                # pass the set_bt_compare test
244                return 0
245            raise RuntimeError, "i'm a naughty comparison function"
246        self.createDB(bad_comparator)
247        #print "\n*** test should print 2 uncatchable tracebacks ***"
248        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
249        self.finishTest()
250
251    def test_compare_function_exception(self) :
252        self.verifyStderr(
253                self._test_compare_function_exception,
254                re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S)
255        )
256
257    def _test_compare_function_bad_return(self) :
258        self.startTest()
259        def bad_comparator(l, r) :
260            if l == r:
261                # pass the set_bt_compare test
262                return 0
263            return l
264        self.createDB(bad_comparator)
265        #print "\n*** test should print 2 errors about returning an int ***"
266        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
267        self.finishTest()
268
269    def test_compare_function_bad_return(self) :
270        self.verifyStderr(
271                self._test_compare_function_bad_return,
272                re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S)
273        )
274
275
276    def test_cannot_assign_twice(self) :
277
278        def my_compare(a, b) :
279            return 0
280
281        self.startTest()
282        self.createDB(my_compare)
283        self.assertRaises(RuntimeError, self.db.set_bt_compare, my_compare)
284
285class AbstractDuplicateCompareTestCase(unittest.TestCase) :
286    env = None
287    db = None
288
289    if sys.version_info < (2, 7) :
290        def assertLess(self, a, b, msg=None) :
291            return self.assertTrue(a<b, msg=msg)
292
293    def setUp(self) :
294        self.filename = self.__class__.__name__ + '.db'
295        self.homeDir = get_new_environment_path()
296        env = db.DBEnv()
297        env.open(self.homeDir,
298                  db.DB_CREATE | db.DB_INIT_MPOOL
299                  | db.DB_INIT_LOCK | db.DB_THREAD)
300        self.env = env
301
302    def tearDown(self) :
303        self.closeDB()
304        if self.env is not None:
305            self.env.close()
306            self.env = None
307        test_support.rmtree(self.homeDir)
308
309    def addDataToDB(self, data) :
310        for item in data:
311            self.db.put("key", item)
312
313    def createDB(self, dup_comparator) :
314        self.db = db.DB(self.env)
315        self.setupDB(dup_comparator)
316        self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE)
317
318    def setupDB(self, dup_comparator) :
319        self.db.set_flags(db.DB_DUPSORT)
320        self.db.set_dup_compare(dup_comparator)
321
322    def closeDB(self) :
323        if self.db is not None:
324            self.db.close()
325            self.db = None
326
327    def startTest(self) :
328        pass
329
330    def finishTest(self, expected = None) :
331        if expected is not None:
332            self.check_results(expected)
333        self.closeDB()
334
335    def check_results(self, expected) :
336        curs = self.db.cursor()
337        try:
338            index = 0
339            rec = curs.first()
340            while rec:
341                ignore, data = rec
342                self.assertLess(index, len(expected),
343                                 "to many values returned from cursor")
344                self.assertEqual(expected[index], data,
345                                 "expected value `%s' at %d but got `%s'"
346                                 % (expected[index], index, data))
347                index = index + 1
348                rec = curs.next()
349            self.assertEqual(index, len(expected),
350                             "not enough values returned from cursor")
351        finally:
352            curs.close()
353
354class DuplicateCompareTestCase(AbstractDuplicateCompareTestCase) :
355    def runCompareTest(self, comparator, data) :
356        self.startTest()
357        self.createDB(comparator)
358        self.addDataToDB(data)
359        self.finishTest(data)
360
361    def test_lexical_ordering(self) :
362        self.runCompareTest(lexical_cmp, _expected_lexical_test_data)
363
364    def test_reverse_lexical_ordering(self) :
365        expected_rev_data = _expected_lexical_test_data[:]
366        expected_rev_data.reverse()
367        self.runCompareTest(make_reverse_comparator(lexical_cmp),
368                             expected_rev_data)
369
370class DuplicateExceptionsTestCase(AbstractDuplicateCompareTestCase) :
371    def test_raises_non_callable(self) :
372        self.startTest()
373        self.assertRaises(TypeError, self.createDB, 'abc')
374        self.assertRaises(TypeError, self.createDB, None)
375        self.finishTest()
376
377    def test_set_dup_compare_with_function(self) :
378        self.startTest()
379        self.createDB(lexical_cmp)
380        self.finishTest()
381
382    def check_results(self, results) :
383        pass
384
385    def test_compare_function_incorrect(self) :
386        self.startTest()
387        def bad_comparator(l, r) :
388            return 1
389        # verify that set_dup_compare checks that comparator('', '') == 0
390        self.assertRaises(TypeError, self.createDB, bad_comparator)
391        self.finishTest()
392
393    def test_compare_function_useless(self) :
394        self.startTest()
395        def socialist_comparator(l, r) :
396            return 0
397        self.createDB(socialist_comparator)
398        # DUPSORT does not allow "duplicate duplicates"
399        self.assertRaises(db.DBKeyExistError, self.addDataToDB, ['b', 'a', 'd'])
400        self.finishTest()
401
402    def verifyStderr(self, method, successRe) :
403        """
404        Call method() while capturing sys.stderr output internally and
405        call self.fail() if successRe.search() does not match the stderr
406        output.  This is used to test for uncatchable exceptions.
407        """
408        stdErr = sys.stderr
409        sys.stderr = StringIO()
410        try:
411            method()
412        finally:
413            temp = sys.stderr
414            sys.stderr = stdErr
415            errorOut = temp.getvalue()
416            if not successRe.search(errorOut) :
417                self.fail("unexpected stderr output:\n"+errorOut)
418        if sys.version_info < (3, 0) :  # XXX: How to do this in Py3k ???
419            sys.exc_traceback = sys.last_traceback = None
420
421    def _test_compare_function_exception(self) :
422        self.startTest()
423        def bad_comparator(l, r) :
424            if l == r:
425                # pass the set_dup_compare test
426                return 0
427            raise RuntimeError, "i'm a naughty comparison function"
428        self.createDB(bad_comparator)
429        #print "\n*** test should print 2 uncatchable tracebacks ***"
430        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
431        self.finishTest()
432
433    def test_compare_function_exception(self) :
434        self.verifyStderr(
435                self._test_compare_function_exception,
436                re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S)
437        )
438
439    def _test_compare_function_bad_return(self) :
440        self.startTest()
441        def bad_comparator(l, r) :
442            if l == r:
443                # pass the set_dup_compare test
444                return 0
445            return l
446        self.createDB(bad_comparator)
447        #print "\n*** test should print 2 errors about returning an int ***"
448        self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
449        self.finishTest()
450
451    def test_compare_function_bad_return(self) :
452        self.verifyStderr(
453                self._test_compare_function_bad_return,
454                re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S)
455        )
456
457
458    def test_cannot_assign_twice(self) :
459
460        def my_compare(a, b) :
461            return 0
462
463        self.startTest()
464        self.createDB(my_compare)
465        self.assertRaises(RuntimeError, self.db.set_dup_compare, my_compare)
466
467def test_suite() :
468    res = unittest.TestSuite()
469
470    res.addTest(unittest.makeSuite(ComparatorTests))
471    res.addTest(unittest.makeSuite(BtreeExceptionsTestCase))
472    res.addTest(unittest.makeSuite(BtreeKeyCompareTestCase))
473    res.addTest(unittest.makeSuite(DuplicateExceptionsTestCase))
474    res.addTest(unittest.makeSuite(DuplicateCompareTestCase))
475    return res
476
477if __name__ == '__main__':
478    unittest.main(defaultTest = 'suite')
479