1import unittest
2import copy
3from util import *
4
5VER_MAIN_PUBLIC = 0x0488B21E
6VER_MAIN_PRIVATE = 0x0488ADE4
7VER_TEST_PUBLIC = 0x043587CF
8VER_TEST_PRIVATE = 0x04358394
9
10FLAG_KEY_PRIVATE, FLAG_KEY_PUBLIC, FLAG_SKIP_HASH, = 0x0, 0x1, 0x2
11ALL_DEFINED_FLAGS = FLAG_KEY_PRIVATE | FLAG_KEY_PUBLIC | FLAG_SKIP_HASH
12BIP32_SERIALIZED_LEN = 78
13BIP32_FLAG_SKIP_HASH = 0x2
14
15# These vectors are expressed in binary rather than base 58. The spec base 58
16# representation just obfuscates the data we are validating. For example, the
17# chain codes in pub/priv results can be seen as equal in the hex data only.
18#
19# The vector results are the serialized resulting extended key using either the
20# contained public or private key. This is not to be confused with private or
21# public derivation - these vectors only derive privately.
22vec_1 = {
23    'seed':               '000102030405060708090a0b0c0d0e0f',
24
25    'm': {
26        FLAG_KEY_PUBLIC:  '0488B21E000000000000000000873DFF'
27                          '81C02F525623FD1FE5167EAC3A55A049'
28                          'DE3D314BB42EE227FFED37D5080339A3'
29                          '6013301597DAEF41FBE593A02CC513D0'
30                          'B55527EC2DF1050E2E8FF49C85C2AB473B21',
31
32        FLAG_KEY_PRIVATE: '0488ADE4000000000000000000873DFF'
33                          '81C02F525623FD1FE5167EAC3A55A049'
34                          'DE3D314BB42EE227FFED37D50800E8F3'
35                          '2E723DECF4051AEFAC8E2C93C9C5B214'
36                          '313817CDB01A1494B917C8436B35E77E9D71'
37    },
38
39    'm/0H': {
40        FLAG_KEY_PUBLIC:  '0488B21E013442193E8000000047FDAC'
41                          'BD0F1097043B78C63C20C34EF4ED9A11'
42                          '1D980047AD16282C7AE6236141035A78'
43                          '4662A4A20A65BF6AAB9AE98A6C068A81'
44                          'C52E4B032C0FB5400C706CFCCC56B8B9C580',
45
46        FLAG_KEY_PRIVATE: '0488ADE4013442193E8000000047FDAC'
47                          'BD0F1097043B78C63C20C34EF4ED9A11'
48                          '1D980047AD16282C7AE623614100EDB2'
49                          'E14F9EE77D26DD93B4ECEDE8D16ED408'
50                          'CE149B6CD80B0715A2D911A0AFEA0A794DEC'
51    },
52
53    'm/0H/1': {
54        FLAG_KEY_PUBLIC:  '0488B21E025C1BD648000000012A7857'
55                          '631386BA23DACAC34180DD1983734E44'
56                          '4FDBF774041578E9B6ADB37C1903501E'
57                          '454BF00751F24B1B489AA925215D66AF'
58                          '2234E3891C3B21A52BEDB3CD711C6F6E2AF7',
59
60        FLAG_KEY_PRIVATE: '0488ADE4025C1BD648000000012A7857'
61                          '631386BA23DACAC34180DD1983734E44'
62                          '4FDBF774041578E9B6ADB37C19003C6C'
63                          'B8D0F6A264C91EA8B5030FADAA8E538B'
64                          '020F0A387421A12DE9319DC93368B34BC442'
65    },
66
67    'm/0H/1/2H': {
68        FLAG_KEY_PUBLIC:  '0488B21E03BEF5A2F98000000204466B'
69                          '9CC8E161E966409CA52986C584F07E9D'
70                          'C81F735DB683C3FF6EC7B1503F0357BF'
71                          'E1E341D01C69FE5654309956CBEA5168'
72                          '22FBA8A601743A012A7896EE8DC2A5162AFA',
73
74        FLAG_KEY_PRIVATE: '0488ADE403BEF5A2F98000000204466B'
75                          '9CC8E161E966409CA52986C584F07E9D'
76                          'C81F735DB683C3FF6EC7B1503F00CBCE'
77                          '0D719ECF7431D88E6A89FA1483E02E35'
78                          '092AF60C042B1DF2FF59FA424DCA25814A3A'
79    }
80}
81
82vec_3 = {
83    'seed':               '4B381541583BE4423346C643850DA4B3'
84                          '20E46A87AE3D2A4E6DA11EBA819CD4AC'
85                          'BA45D239319AC14F863B8D5AB5A0D0C6'
86                          '4D2E8A1E7D1457DF2E5A3C51C73235BE',
87
88    'm': {
89        FLAG_KEY_PUBLIC:  '0488B21E00000000000000000001D28A'
90                          '3E53CFFA419EC122C968B3259E16B650'
91                          '76495494D97CAE10BBFEC3C36F03683A'
92                          'F1BA5743BDFC798CF814EFEEAB2735EC'
93                          '52D95ECED528E692B8E34C4E56696541E136',
94
95        FLAG_KEY_PRIVATE: '0488ADE400000000000000000001D28A'
96                          '3E53CFFA419EC122C968B3259E16B650'
97                          '76495494D97CAE10BBFEC3C36F0000DD'
98                          'B80B067E0D4993197FE10F2657A844A3'
99                          '84589847602D56F0C629C81AAE3233C0C6BF'
100    },
101
102    'm/0H': {
103        FLAG_KEY_PUBLIC:  '0488B21E0141D63B5080000000E5FEA1'
104                          '2A97B927FC9DC3D2CB0D1EA1CF50AA5A'
105                          '1FDC1F933E8906BB38DF3377BD026557'
106                          'FDDA1D5D43D79611F784780471F086D5'
107                          '8E8126B8C40ACB82272A7712E7F20158D8FD',
108
109        FLAG_KEY_PRIVATE: '0488ADE40141D63B5080000000E5FEA1'
110                          '2A97B927FC9DC3D2CB0D1EA1CF50AA5A'
111                          '1FDC1F933E8906BB38DF3377BD00491F'
112                          '7A2EEBC7B57028E0D3FAA0ACDA02E75C'
113                          '33B03C48FB288C41E2EA44E1DAEF7332BB35'
114    }
115}
116
117class BIP32Tests(unittest.TestCase):
118
119    NULL_HASH160 = '00' * 20
120    SERIALIZED_LEN = 4 + 1 + 4 + 4 + 32 + 33
121
122    def unserialize_key(self, buf, buf_len):
123        key_out = ext_key()
124        ret = bip32_key_unserialize(buf, buf_len, byref(key_out))
125        return ret, key_out
126
127    def get_test_master_key(self, vec):
128        seed, seed_len = make_cbuffer(vec['seed'])
129        master = ext_key()
130        ret = bip32_key_from_seed(seed, seed_len,
131                                  VER_MAIN_PRIVATE, 0, byref(master))
132        self.assertEqual(ret, WALLY_OK)
133        return master
134
135    def get_test_key(self, vec, path, flags):
136        buf, buf_len = make_cbuffer(vec[path][flags])
137        ret, key_out = self.unserialize_key(buf, self.SERIALIZED_LEN)
138        self.assertEqual(ret, WALLY_OK)
139        return key_out
140
141    def derive_key(self, parent, child_num, flags):
142        key_out = ext_key()
143        ret = bip32_key_from_parent(byref(parent), child_num,
144                                    flags, byref(key_out))
145        self.assertEqual(ret, WALLY_OK)
146
147        # Verify that path derivation matches also
148        p_key_out = self.derive_key_by_path(parent, [child_num], flags)
149        self.compare_keys(p_key_out, key_out, flags)
150        return key_out
151
152    def path_to_c(self, path):
153        c_path = (c_uint * len(path))()
154        for i, n in enumerate(path):
155            c_path[i] = n
156        return c_path
157
158    def derive_key_by_path(self, parent, path, flags, expected=WALLY_OK):
159        key_out = ext_key()
160        c_path = self.path_to_c(path)
161        ret = bip32_key_from_parent_path(byref(parent), c_path, len(path),
162                                         flags, byref(key_out))
163        self.assertEqual(ret, expected)
164        return key_out
165
166    def compare_keys(self, key, expected, flags):
167        self.assertEqual(h(key.chain_code), h(expected.chain_code))
168        key_name = 'pub_key' if (flags & FLAG_KEY_PUBLIC) else 'priv_key'
169        expected_cmp = getattr(expected, key_name)
170        key_cmp = getattr(key, key_name)
171        self.assertEqual(h(key_cmp), h(expected_cmp))
172        self.assertEqual(key.depth, expected.depth)
173        self.assertEqual(key.child_num, expected.child_num)
174        self.assertEqual(h(key.chain_code), h(expected.chain_code))
175        # These would be more useful tests if there were any public
176        # derivation test vectors
177        # We can only compare the first 4 bytes of the parent fingerprint
178        # Since that is all thats serialized.
179        # FIXME: Implement bip32_key_set_parent and test it here
180        b32 = lambda k: h(k)[0:8]
181        if flags & FLAG_SKIP_HASH:
182            self.assertEqual(h(key.hash160), utf8(self.NULL_HASH160))
183            self.assertEqual(b32(key.parent160), utf8(self.NULL_HASH160[0:8]))
184        else:
185            self.assertEqual(h(key.hash160), h(expected.hash160))
186            self.assertEqual(b32(key.parent160), b32(expected.parent160))
187
188
189    def test_serialization(self):
190
191        # Try short, correct, long lengths. Trimming 8 chars is the correct
192        # length because the vector value contains 4 check bytes at the end.
193        for trim, expected in [(0, WALLY_EINVAL), (8, WALLY_OK), (16, WALLY_EINVAL)]:
194            serialized_hex = vec_1['m'][FLAG_KEY_PRIVATE][0:-trim]
195            buf, buf_len = make_cbuffer(serialized_hex)
196            ret, key_out = self.unserialize_key(buf, buf_len)
197            self.assertEqual(ret, expected)
198            if ret == WALLY_OK:
199                # Check this key serializes back to the same representation
200                buf, buf_len = make_cbuffer('0' * len(serialized_hex))
201                ret = bip32_key_serialize(key_out, FLAG_KEY_PRIVATE,
202                                          buf, buf_len)
203                self.assertEqual(ret, WALLY_OK)
204                self.assertEqual(h(buf).upper(), utf8(serialized_hex))
205
206        # Check correct and incorrect version numbers as well
207        # as mismatched key types and versions
208        ver_cases = [(VER_MAIN_PUBLIC,  FLAG_KEY_PUBLIC,  WALLY_OK),
209                     (VER_MAIN_PUBLIC,  FLAG_KEY_PRIVATE, WALLY_EINVAL),
210                     (VER_MAIN_PRIVATE, FLAG_KEY_PUBLIC,  WALLY_EINVAL),
211                     (VER_MAIN_PRIVATE, FLAG_KEY_PRIVATE, WALLY_OK),
212                     (VER_TEST_PUBLIC,  FLAG_KEY_PUBLIC,  WALLY_OK),
213                     (VER_TEST_PUBLIC , FLAG_KEY_PRIVATE, WALLY_EINVAL),
214                     (VER_TEST_PRIVATE, FLAG_KEY_PUBLIC,  WALLY_EINVAL),
215                     (VER_TEST_PRIVATE, FLAG_KEY_PRIVATE, WALLY_OK),
216                     (0x01111111,       FLAG_KEY_PUBLIC,  WALLY_EINVAL),
217                     (0x01111111,       FLAG_KEY_PRIVATE, WALLY_EINVAL)]
218
219        for ver, flags, expected in ver_cases:
220            no_ver = vec_1['m'][flags][8:-8]
221            v_str = '0' + hex(ver)[2:]
222            buf, buf_len = make_cbuffer(v_str + no_ver)
223            ret, _ = self.unserialize_key(buf, buf_len)
224            self.assertEqual(ret, expected)
225
226        # Check invalid arguments fail
227        master = self.get_test_master_key(vec_1)
228        pub = self.derive_key(master, 1, FLAG_KEY_PUBLIC)
229        key_out = ext_key()
230        cases = [
231            [~ALL_DEFINED_FLAGS, BIP32_SERIALIZED_LEN],
232            [FLAG_KEY_PRIVATE, BIP32_SERIALIZED_LEN],
233            [FLAG_KEY_PUBLIC, BIP32_SERIALIZED_LEN + 1],
234        ]
235        for (flags, len_out) in cases:
236            ret = bip32_key_serialize(byref(pub), flags, byref(key_out), len_out)
237            self.assertEqual(WALLY_EINVAL, ret)
238
239    def test_key_from_seed(self):
240
241        seed, seed_len = make_cbuffer(vec_1['seed'])
242        key_out = ext_key()
243
244        # Only private key versions can be used
245        ver_cases = [(VER_MAIN_PUBLIC,   0,               WALLY_EINVAL),
246                     (VER_MAIN_PRIVATE,  0,               WALLY_OK),
247                     (VER_TEST_PUBLIC,   0,               WALLY_EINVAL),
248                     (VER_TEST_PRIVATE,  0,               WALLY_OK),
249                     (VER_TEST_PRIVATE,  FLAG_KEY_PUBLIC, WALLY_EINVAL),
250                     (VER_TEST_PRIVATE,  FLAG_SKIP_HASH,  WALLY_OK)]
251        for ver, flags, expected in ver_cases:
252            ret = bip32_key_from_seed(seed, seed_len, ver, flags, byref(key_out))
253            self.assertEqual(ret, expected)
254
255    def test_key_init(self):
256        # Note we test bip32_key_init_alloc: it calls bip32_key_init internally
257        _, _, priv = self.create_master_pub_priv()
258
259        ver, depth, num = priv.version, priv.depth, priv.child_num
260        cc, cc_len = make_cbuffer(h(priv.chain_code))
261        pub_key, pub_key_len = make_cbuffer(h(priv.pub_key))
262        priv_key, priv_key_len = make_cbuffer(h(priv.priv_key)[2:])
263        h160, h160_len = make_cbuffer(h(priv.hash160))
264        p160, p160_len = make_cbuffer(h(priv.parent160))
265        key_out = POINTER(ext_key)()
266        valid_args = [ver, depth, num, cc, cc_len, pub_key, pub_key_len,
267                      priv_key, priv_key_len, h160, h160_len, p160, p160_len]
268
269        # Test cases
270        arg_diffs = [
271            (True,  12, p160_len),  # No change
272            (True,  12, 4),         # 4 byte fingerprint only
273            (False, 1,  256),       # Depth > 255
274            (False, 3,  None),      # Null chaincode, valid length
275            (False, 4,  15),        # Invalid chaincode length
276            (False, 5,  None),      # Null pub key, valid length
277            (False, 6,  15),        # Invalid pub key length
278            (False, 7,  None),      # Null priv key, valid length
279            (False, 8,  15),        # Invalid priv key length
280            (False, 9,  None),      # Null hash160, valid length
281            (False, 10, 15),        # Invalid hash160 length
282            (False, 11, None),      # Null parent160, valid length
283            (False, 12, 15),        # Invalid parent160 length
284        ]
285        for ok, idx, new_val in arg_diffs:
286            call_args = copy.deepcopy(valid_args) + [byref(key_out)]
287            call_args[idx] = new_val
288            ret = bip32_key_init_alloc(*call_args)
289            self.assertEqual(ret, WALLY_OK if ok else WALLY_EINVAL)
290
291    def test_bip32_vectors(self):
292        self.do_test_vector(vec_1)
293        self.do_test_vector(vec_3)
294
295    def do_test_vector(self, vec):
296
297        # BIP32 Test vector 1
298        master = self.get_test_master_key(vec)
299
300        # Chain m:
301        for flags in [FLAG_KEY_PUBLIC, FLAG_KEY_PRIVATE]:
302            expected = self.get_test_key(vec, 'm', flags)
303            self.compare_keys(master, expected, flags)
304
305        derived = master
306        for path, i in [('m/0H', 0x80000000),
307                        ('m/0H/1', 1),
308                        ('m/0H/1/2H', 0x80000002)]:
309
310            if path not in vec:
311                continue
312
313            # Derive a public and private child. Verify that the private child
314            # contains the public and private published vectors. Verify that
315            # the public child matches the public vector and has no private
316            # key. Finally, check that the child holds the correct parent hash.
317            parent160 = derived.hash160
318            derived_pub = self.derive_key(derived, i, FLAG_KEY_PUBLIC)
319            derived = self.derive_key(derived, i, FLAG_KEY_PRIVATE)
320            for flags in [FLAG_KEY_PUBLIC, FLAG_KEY_PRIVATE]:
321                expected = self.get_test_key(vec, path, flags)
322                self.compare_keys(derived, expected, flags)
323                if flags & FLAG_KEY_PUBLIC:
324                    self.compare_keys(derived_pub, expected, flags)
325                    # A neutered private key is indicated by
326                    # BIP32_FLAG_KEY_PUBLIC (0x1) as its first byte.
327                    self.assertEqual(h(derived_pub.priv_key), utf8('01' + '00' * 32))
328                self.assertEqual(h(derived.parent160), h(parent160))
329
330    def create_master_pub_priv(self):
331
332        # Start with BIP32 Test vector 1
333        master = self.get_test_master_key(vec_1)
334        # Derive the same child public and private keys from master
335        priv = self.derive_key(master, 1, FLAG_KEY_PRIVATE)
336        pub = self.derive_key(master, 1, FLAG_KEY_PUBLIC)
337        return master, pub, priv
338
339    def test_public_derivation_identities(self):
340
341        master, pub, priv = self.create_master_pub_priv()
342        # From the private child we can derive public and private keys
343        priv_pub = self.derive_key(priv, 1, FLAG_KEY_PUBLIC)
344        priv_priv = self.derive_key(priv, 1, FLAG_KEY_PRIVATE)
345        # From the public child we can only derive a public key
346        pub_pub = self.derive_key(pub, 1, FLAG_KEY_PUBLIC)
347
348        # Verify that trying to derive a private key doesn't work
349        key_out = ext_key()
350        ret = bip32_key_from_parent(byref(pub), 1,
351                                    FLAG_KEY_PRIVATE, byref(key_out))
352        self.assertEqual(ret, WALLY_EINVAL)
353
354        # Now our identities:
355        # The children share the same public key
356        self.assertEqual(h(pub.pub_key), h(priv.pub_key))
357        # The grand-children share the same public key
358        self.assertEqual(h(priv_pub.pub_key), h(priv_priv.pub_key))
359        self.assertEqual(h(priv_pub.pub_key), h(pub_pub.pub_key))
360        # The children and grand-children do not share the same public key
361        self.assertNotEqual(h(pub.pub_key), h(priv_pub.pub_key))
362
363        # Test path derivation with multiple child elements
364        for flags, expected in [(FLAG_KEY_PUBLIC,                   pub_pub),
365                                (FLAG_KEY_PRIVATE,                  priv_priv),
366                                (FLAG_KEY_PUBLIC  | FLAG_SKIP_HASH, pub_pub),
367                                (FLAG_KEY_PRIVATE | FLAG_SKIP_HASH, priv_priv)]:
368            path_derived = self.derive_key_by_path(master, [1, 1], flags)
369            self.compare_keys(path_derived, expected, flags)
370
371    def test_key_from_parent_invalid(self):
372        master, pub, priv = self.create_master_pub_priv()
373        key_out = byref(ext_key())
374
375        cases = [[None,        FLAG_KEY_PRIVATE,   key_out],  # Null parent
376                 [byref(priv), FLAG_KEY_PRIVATE,   None],     # Null output key
377                 [byref(pub),  ~ALL_DEFINED_FLAGS, key_out],  # Invalid flags (pub)
378                 [byref(priv), ~ALL_DEFINED_FLAGS, key_out]]  # Invalid flags (priv)
379
380        for key, flags, key_out in cases:
381            ret = bip32_key_from_parent(key, 1, flags, key_out)
382            self.assertEqual(ret, WALLY_EINVAL)
383
384        m = byref(master)
385        c_path = self.path_to_c([1, 1])
386        cases = [(None, len(c_path), FLAG_KEY_PRIVATE,   key_out), # Null parent
387                 (m,    len(c_path), FLAG_KEY_PRIVATE,   None),    # Null output key
388                 (m,    len(c_path), ~ALL_DEFINED_FLAGS, key_out), # Invalid flags
389                 (m,     0,          FLAG_KEY_PRIVATE,   key_out)] # Bad path length
390
391        for key, plen, flags, key_out in cases:
392            ret = bip32_key_from_parent_path(key, c_path, plen, flags, key_out)
393            self.assertEqual(ret, WALLY_EINVAL)
394
395        master.depth = 0xff # Cant derive from a parent of depth 255
396        ret = bip32_key_from_parent(m, 5, FLAG_KEY_PUBLIC, key_out)
397        self.assertEqual(ret, WALLY_EINVAL)
398        ret = bip32_key_from_parent_path(m, c_path, len(c_path), FLAG_KEY_PUBLIC, key_out)
399        self.assertEqual(ret, WALLY_EINVAL)
400
401    def test_free_invalid(self):
402        self.assertEqual(WALLY_EINVAL, bip32_key_free(None))
403
404    def test_base58(self):
405        key = self.create_master_pub_priv()[2]
406        buf, buf_len = make_cbuffer('00' * 78)
407
408        for flag in [FLAG_KEY_PRIVATE, FLAG_KEY_PUBLIC]:
409            self.assertEqual(bip32_key_serialize(key, flag, buf, buf_len), WALLY_OK)
410            exp_hex = h(buf).upper()
411
412            ret, out = bip32_key_to_base58(key, flag)
413            self.assertEqual(ret, WALLY_OK)
414
415            key_out = POINTER(ext_key)()
416            self.assertEqual(bip32_key_from_base58_alloc(utf8(out), byref(key_out)), WALLY_OK)
417            self.assertEqual(bip32_key_serialize(key_out, flag, buf, buf_len), WALLY_OK)
418            self.assertEqual(h(buf).upper(), exp_hex)
419
420    def test_strip_private_key(self):
421        self.assertEqual(bip32_key_strip_private_key(None), WALLY_EINVAL)
422
423        _, pub, priv = self.create_master_pub_priv()
424
425        self.assertEqual(priv.priv_key[0], FLAG_KEY_PRIVATE)
426        self.assertEqual(bip32_key_strip_private_key(priv), WALLY_OK)
427        self.assertEqual(priv.priv_key[0], FLAG_KEY_PUBLIC)
428        self.assertEqual(priv.priv_key[1:], [0] * 32)
429
430        self.assertEqual(bip32_key_strip_private_key(pub), WALLY_OK)
431        self.assertEqual(pub.priv_key[0], FLAG_KEY_PUBLIC)
432        self.assertEqual(pub.priv_key[1:], [0] * 32)
433
434    def test_get_fingerprint(self):
435        key = self.create_master_pub_priv()[2]
436        buf, buf_len = make_cbuffer('00' * 4)
437
438        self.assertEqual(bip32_key_get_fingerprint(key, buf, buf_len), WALLY_OK)
439        self.assertEqual(h(buf), b'bbe06d6a')
440
441        # As a sanity check, derive a child and ask for its parent fingerprint
442        child = self.derive_key(key, 0, FLAG_KEY_PUBLIC)
443        b32 = lambda k: h(k)[0:8]
444        self.assertEqual(b32(child.parent160), b'bbe06d6a')
445
446        # Check fingerprint when hash calculation was skipped during derivation
447        child = self.derive_key(key, 0, FLAG_KEY_PUBLIC | BIP32_FLAG_SKIP_HASH)
448        self.assertEqual(bip32_key_get_fingerprint(child, buf, buf_len), WALLY_OK)
449        self.assertEqual(h(buf), b'f09cb160')
450
451
452if __name__ == '__main__':
453    unittest.main()
454