1import unittest
2from util import *
3from binascii import unhexlify
4
5PBKDF2_HMAC_SHA256_LEN, PBKDF2_HMAC_SHA512_LEN = 32, 64
6
7class PBKDF2Case(object):
8    def __init__(self, items):
9        # Format: HMAC_SHA_TYPE, PASSWORD, SALT, COST, EXPECTED
10        self.typ = int(items[0])
11        assert self.typ in [256, 512]
12        self.passwd = unhexlify(items[1])
13        self.salt = items[2]
14        self.cost = int(items[3])
15        self.expected, self.expected_len = make_cbuffer(items[4])
16
17
18class PBKDF2Tests(unittest.TestCase):
19
20    def setUp(self):
21        if not hasattr(self, 'wally_pbkdf2_hmac_sha256'):
22            self.cases = []
23            with open(root_dir + 'src/data/pbkdf2_hmac_sha_vectors.txt', 'r') as f:
24                for l in f.readlines():
25                    l = l.strip()
26                    if len(l) == 0 or l.startswith('#'):
27                        continue
28                    self.cases.append(PBKDF2Case(l.split(',')))
29
30
31    def test_pbkdf2_hmac_sha(self):
32
33        # Some test vectors are nuts (e.g. 2097152 cost), so only run the
34        # first few. set these to -1 to run the whole suite (only needed
35        # when refactoring the impl)
36        num_crazy_256, num_crazy_512 = 8, 8
37
38        for case in self.cases:
39
40            if case.typ == 256:
41                fn = wally_pbkdf2_hmac_sha256
42                mult = PBKDF2_HMAC_SHA256_LEN
43                if case.cost > 100:
44                    if num_crazy_256 == 0:
45                         continue
46                    num_crazy_256 -= 1
47            else:
48                fn = wally_pbkdf2_hmac_sha512
49                mult = PBKDF2_HMAC_SHA512_LEN
50                if case.cost > 100:
51                    if num_crazy_512 == 0:
52                        continue
53                    num_crazy_512 -= 1
54
55            out_buf, out_len = make_cbuffer('00' * case.expected_len)
56            if case.expected_len % mult != 0:
57                # We only support output multiples of the hmac length
58                continue
59
60            salt, salt_len = make_cbuffer(case.salt)
61            ret = fn(case.passwd, len(case.passwd), salt, salt_len,
62                     0, case.cost, out_buf, out_len)
63
64            self.assertEqual(ret, 0)
65            self.assertEqual(h(out_buf), h(case.expected))
66
67
68    def _pbkdf2_hmac_sha_malloc_fail(self, fn, len):
69        fake_buf, fake_len = make_cbuffer('aabbccdd')
70        out_buf, out_len = make_cbuffer('00' * len)
71        ret = fn(fake_buf, fake_len, fake_buf, fake_len, 0, 1, out_buf, out_len)
72        self.assertEqual(ret, WALLY_ENOMEM)
73
74
75    @malloc_fail([1])
76    def test_pbkdf2_hmac_sha256_malloc(self):
77        self._pbkdf2_hmac_sha_malloc_fail(wally_pbkdf2_hmac_sha256, PBKDF2_HMAC_SHA256_LEN)
78
79
80    @malloc_fail([1])
81    def test_pbkdf2_hmac_sha512_malloc(self):
82        self._pbkdf2_hmac_sha_malloc_fail(wally_pbkdf2_hmac_sha512, PBKDF2_HMAC_SHA512_LEN)
83
84
85if __name__ == '__main__':
86    unittest.main()
87