1import os
2import unittest
3from openid.server.trustroot import TrustRoot
4
5
6class _ParseTest(unittest.TestCase):
7    def __init__(self, sanity, desc, case):
8        unittest.TestCase.__init__(self)
9        self.desc = desc + ': ' + repr(case)
10        self.case = case
11        self.sanity = sanity
12
13    def shortDescription(self):
14        return self.desc
15
16    def runTest(self):
17        tr = TrustRoot.parse(self.case)
18        if self.sanity == 'sane':
19            assert tr.isSane(), self.case
20        elif self.sanity == 'insane':
21            assert not tr.isSane(), self.case
22        else:
23            assert tr is None, tr
24
25
26class _MatchTest(unittest.TestCase):
27    def __init__(self, match, desc, line):
28        unittest.TestCase.__init__(self)
29        tr, rt = line.split()
30        self.desc = desc + ': ' + repr(tr) + ' ' + repr(rt)
31        self.tr = tr
32        self.rt = rt
33        self.match = match
34
35    def shortDescription(self):
36        return self.desc
37
38    def runTest(self):
39        tr = TrustRoot.parse(self.tr)
40        self.assertFalse(tr is None, self.tr)
41
42        match = tr.validateURL(self.rt)
43        if self.match:
44            assert match
45        else:
46            assert not match
47
48
49def getTests(t, grps, head, dat):
50    tests = []
51    top = head.strip()
52    gdat = list(map(str.strip, dat.split('-' * 40 + '\n')))
53    assert not gdat[0]
54    assert len(gdat) == (len(grps) * 2 + 1), (gdat, grps)
55    i = 1
56    for x in grps:
57        n, desc = gdat[i].split(': ')
58        cases = gdat[i + 1].split('\n')
59        assert len(cases) == int(n)
60        for case in cases:
61            tests.append(t(x, top + ' - ' + desc, case))
62        i += 2
63    return tests
64
65
66def parseTests(data):
67    parts = list(map(str.strip, data.split('=' * 40 + '\n')))
68    assert not parts[0]
69    _, ph, pdat, mh, mdat = parts
70
71    tests = []
72    tests.extend(getTests(_ParseTest, ['bad', 'insane', 'sane'], ph, pdat))
73    tests.extend(getTests(_MatchTest, [1, 0], mh, mdat))
74    return tests
75
76
77def pyUnitTests():
78    here = os.path.dirname(os.path.abspath(__file__))
79    test_data_file_name = os.path.join(here, 'data', 'trustroot.txt')
80    test_data_file = open(test_data_file_name, encoding='utf-8')
81    test_data = test_data_file.read()
82    test_data_file.close()
83
84    tests = parseTests(test_data)
85    return unittest.TestSuite(tests)
86
87
88if __name__ == '__main__':
89    suite = pyUnitTests()
90    runner = unittest.TextTestRunner()
91    runner.run(suite)
92