1# -*- coding: utf-8 -*-
2#
3#  SelfTest/Hash/common.py: Common code for Crypto.SelfTest.Hash
4#
5# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
6#
7# ===================================================================
8# The contents of this file are dedicated to the public domain.  To
9# the extent that dedication to the public domain is not available,
10# everyone is granted a worldwide, perpetual, royalty-free,
11# non-exclusive license to exercise all rights associated with the
12# contents of this file for any purpose whatsoever.
13# No rights are reserved.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
19# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
20# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22# SOFTWARE.
23# ===================================================================
24
25"""Self-testing for PyCrypto hash modules"""
26
27import re
28import sys
29import unittest
30import binascii
31import Crypto.Hash
32from binascii import hexlify, unhexlify
33from Crypto.Util.py3compat import b, tobytes
34from Crypto.Util.strxor import strxor_c
35
36def t2b(hex_string):
37    shorter = re.sub(br'\s+', b'', tobytes(hex_string))
38    return unhexlify(shorter)
39
40
41class HashDigestSizeSelfTest(unittest.TestCase):
42
43    def __init__(self, hashmod, description, expected, extra_params):
44        unittest.TestCase.__init__(self)
45        self.hashmod = hashmod
46        self.expected = expected
47        self.description = description
48        self.extra_params = extra_params
49
50    def shortDescription(self):
51        return self.description
52
53    def runTest(self):
54        if "truncate" not in self.extra_params:
55            self.failUnless(hasattr(self.hashmod, "digest_size"))
56            self.assertEquals(self.hashmod.digest_size, self.expected)
57        h = self.hashmod.new(**self.extra_params)
58        self.failUnless(hasattr(h, "digest_size"))
59        self.assertEquals(h.digest_size, self.expected)
60
61
62class HashSelfTest(unittest.TestCase):
63
64    def __init__(self, hashmod, description, expected, input, extra_params):
65        unittest.TestCase.__init__(self)
66        self.hashmod = hashmod
67        self.expected = expected.lower()
68        self.input = input
69        self.description = description
70        self.extra_params = extra_params
71
72    def shortDescription(self):
73        return self.description
74
75    def runTest(self):
76        h = self.hashmod.new(**self.extra_params)
77        h.update(self.input)
78
79        out1 = binascii.b2a_hex(h.digest())
80        out2 = h.hexdigest()
81
82        h = self.hashmod.new(self.input, **self.extra_params)
83
84        out3 = h.hexdigest()
85        out4 = binascii.b2a_hex(h.digest())
86
87        # PY3K: hexdigest() should return str(), and digest() bytes
88        self.assertEqual(self.expected, out1)   # h = .new(); h.update(data); h.digest()
89        if sys.version_info[0] == 2:
90            self.assertEqual(self.expected, out2)   # h = .new(); h.update(data); h.hexdigest()
91            self.assertEqual(self.expected, out3)   # h = .new(data); h.hexdigest()
92        else:
93            self.assertEqual(self.expected.decode(), out2)   # h = .new(); h.update(data); h.hexdigest()
94            self.assertEqual(self.expected.decode(), out3)   # h = .new(data); h.hexdigest()
95        self.assertEqual(self.expected, out4)   # h = .new(data); h.digest()
96
97        # Verify that the .new() method produces a fresh hash object, except
98        # for MD5 and SHA1, which are hashlib objects.  (But test any .new()
99        # method that does exist.)
100        if self.hashmod.__name__ not in ('Crypto.Hash.MD5', 'Crypto.Hash.SHA1') or hasattr(h, 'new'):
101            h2 = h.new()
102            h2.update(self.input)
103            out5 = binascii.b2a_hex(h2.digest())
104            self.assertEqual(self.expected, out5)
105
106
107class HashTestOID(unittest.TestCase):
108    def __init__(self, hashmod, oid, extra_params):
109        unittest.TestCase.__init__(self)
110        self.hashmod = hashmod
111        self.oid = oid
112        self.extra_params = extra_params
113
114    def runTest(self):
115        h = self.hashmod.new(**self.extra_params)
116        self.assertEqual(h.oid, self.oid)
117
118
119class ByteArrayTest(unittest.TestCase):
120
121    def __init__(self, module, extra_params):
122        unittest.TestCase.__init__(self)
123        self.module = module
124        self.extra_params = extra_params
125
126    def runTest(self):
127        data = b("\x00\x01\x02")
128
129        # Data can be a bytearray (during initialization)
130        ba = bytearray(data)
131
132        h1 = self.module.new(data, **self.extra_params)
133        h2 = self.module.new(ba, **self.extra_params)
134        ba[:1] = b'\xFF'
135        self.assertEqual(h1.digest(), h2.digest())
136
137        # Data can be a bytearray (during operation)
138        ba = bytearray(data)
139
140        h1 = self.module.new(**self.extra_params)
141        h2 = self.module.new(**self.extra_params)
142
143        h1.update(data)
144        h2.update(ba)
145
146        ba[:1] = b'\xFF'
147        self.assertEqual(h1.digest(), h2.digest())
148
149
150class MemoryViewTest(unittest.TestCase):
151
152    def __init__(self, module, extra_params):
153        unittest.TestCase.__init__(self)
154        self.module = module
155        self.extra_params = extra_params
156
157    def runTest(self):
158
159        data = b"\x00\x01\x02"
160
161        def get_mv_ro(data):
162            return memoryview(data)
163
164        def get_mv_rw(data):
165            return memoryview(bytearray(data))
166
167        for get_mv in get_mv_ro, get_mv_rw:
168
169            # Data can be a memoryview (during initialization)
170            mv = get_mv(data)
171
172            h1 = self.module.new(data, **self.extra_params)
173            h2 = self.module.new(mv, **self.extra_params)
174            if not mv.readonly:
175                mv[:1] = b'\xFF'
176            self.assertEqual(h1.digest(), h2.digest())
177
178            # Data can be a memoryview (during operation)
179            mv = get_mv(data)
180
181            h1 = self.module.new(**self.extra_params)
182            h2 = self.module.new(**self.extra_params)
183            h1.update(data)
184            h2.update(mv)
185            if not mv.readonly:
186                mv[:1] = b'\xFF'
187            self.assertEqual(h1.digest(), h2.digest())
188
189
190class MACSelfTest(unittest.TestCase):
191
192    def __init__(self, module, description, result, data, key, params):
193        unittest.TestCase.__init__(self)
194        self.module = module
195        self.result = t2b(result)
196        self.data = t2b(data)
197        self.key = t2b(key)
198        self.params = params
199        self.description = description
200
201    def shortDescription(self):
202        return self.description
203
204    def runTest(self):
205
206        result_hex = hexlify(self.result)
207
208        # Verify result
209        h = self.module.new(self.key, **self.params)
210        h.update(self.data)
211        self.assertEqual(self.result, h.digest())
212        self.assertEqual(hexlify(self.result).decode('ascii'), h.hexdigest())
213
214        # Verify that correct MAC does not raise any exception
215        h.verify(self.result)
216        h.hexverify(result_hex)
217
218        # Verify that incorrect MAC does raise ValueError exception
219        wrong_mac = strxor_c(self.result, 255)
220        self.assertRaises(ValueError, h.verify, wrong_mac)
221        self.assertRaises(ValueError, h.hexverify, "4556")
222
223        # Verify again, with data passed to new()
224        h = self.module.new(self.key, self.data, **self.params)
225        self.assertEqual(self.result, h.digest())
226        self.assertEqual(hexlify(self.result).decode('ascii'), h.hexdigest())
227
228        # Test .copy()
229        try:
230            h = self.module.new(self.key, self.data, **self.params)
231            h2 = h.copy()
232            h3 = h.copy()
233
234            # Verify that changing the copy does not change the original
235            h2.update(b"bla")
236            self.assertEqual(h3.digest(), self.result)
237
238            # Verify that both can reach the same state
239            h.update(b"bla")
240            self.assertEqual(h.digest(), h2.digest())
241        except NotImplementedError:
242            pass
243
244        # PY3K: Check that hexdigest() returns str and digest() returns bytes
245        self.assertTrue(isinstance(h.digest(), type(b"")))
246        self.assertTrue(isinstance(h.hexdigest(), type("")))
247
248        # PY3K: Check that .hexverify() accepts bytes or str
249        h.hexverify(h.hexdigest())
250        h.hexverify(h.hexdigest().encode('ascii'))
251
252
253def make_hash_tests(module, module_name, test_data, digest_size, oid=None,
254                    extra_params={}):
255    tests = []
256    for i in range(len(test_data)):
257        row = test_data[i]
258        (expected, input) = map(tobytes,row[0:2])
259        if len(row) < 3:
260            description = repr(input)
261        else:
262            description = row[2]
263        name = "%s #%d: %s" % (module_name, i+1, description)
264        tests.append(HashSelfTest(module, name, expected, input, extra_params))
265
266    name = "%s #%d: digest_size" % (module_name, len(test_data) + 1)
267    tests.append(HashDigestSizeSelfTest(module, name, digest_size, extra_params))
268
269    if oid is not None:
270        tests.append(HashTestOID(module, oid, extra_params))
271
272    tests.append(ByteArrayTest(module, extra_params))
273
274    tests.append(MemoryViewTest(module, extra_params))
275
276    return tests
277
278
279def make_mac_tests(module, module_name, test_data):
280    tests = []
281    for i, row in enumerate(test_data):
282        if len(row) == 4:
283            (key, data, results, description, params) = list(row) + [ {} ]
284        else:
285            (key, data, results, description, params) = row
286        name = "%s #%d: %s" % (module_name, i+1, description)
287        tests.append(MACSelfTest(module, name, results, data, key, params))
288    return tests
289
290# vim:set ts=4 sw=4 sts=4 expandtab:
291