1# Copyright (C) 2018 The Electrum developers
2# Distributed under the MIT software license, see the accompanying
3# file LICENCE or http://www.opensource.org/licenses/mit-license.php
4
5import hashlib
6from typing import List, Tuple, NamedTuple, Union, Iterable, Sequence, Optional
7
8from .util import bfh, bh2u, BitcoinException
9from . import constants
10from . import ecc
11from .crypto import hash_160, hmac_oneshot
12from .bitcoin import rev_hex, int_to_hex, EncodeBase58Check, DecodeBase58Check
13from .logging import get_logger
14
15
16_logger = get_logger(__name__)
17BIP32_PRIME = 0x80000000
18UINT32_MAX = (1 << 32) - 1
19
20
21def protect_against_invalid_ecpoint(func):
22    def func_wrapper(*args):
23        child_index = args[-1]
24        while True:
25            is_prime = child_index & BIP32_PRIME
26            try:
27                return func(*args[:-1], child_index=child_index)
28            except ecc.InvalidECPointException:
29                _logger.warning('bip32 protect_against_invalid_ecpoint: skipping index')
30                child_index += 1
31                is_prime2 = child_index & BIP32_PRIME
32                if is_prime != is_prime2: raise OverflowError()
33    return func_wrapper
34
35
36@protect_against_invalid_ecpoint
37def CKD_priv(parent_privkey: bytes, parent_chaincode: bytes, child_index: int) -> Tuple[bytes, bytes]:
38    """Child private key derivation function (from master private key)
39    If n is hardened (i.e. the 32nd bit is set), the resulting private key's
40    corresponding public key can NOT be determined without the master private key.
41    However, if n is not hardened, the resulting private key's corresponding
42    public key can be determined without the master private key.
43    """
44    if child_index < 0: raise ValueError('the bip32 index needs to be non-negative')
45    is_hardened_child = bool(child_index & BIP32_PRIME)
46    return _CKD_priv(parent_privkey=parent_privkey,
47                     parent_chaincode=parent_chaincode,
48                     child_index=bfh(rev_hex(int_to_hex(child_index, 4))),
49                     is_hardened_child=is_hardened_child)
50
51
52def _CKD_priv(parent_privkey: bytes, parent_chaincode: bytes,
53              child_index: bytes, is_hardened_child: bool) -> Tuple[bytes, bytes]:
54    try:
55        keypair = ecc.ECPrivkey(parent_privkey)
56    except ecc.InvalidECPointException as e:
57        raise BitcoinException('Impossible xprv (not within curve order)') from e
58    parent_pubkey = keypair.get_public_key_bytes(compressed=True)
59    if is_hardened_child:
60        data = bytes([0]) + parent_privkey + child_index
61    else:
62        data = parent_pubkey + child_index
63    I = hmac_oneshot(parent_chaincode, data, hashlib.sha512)
64    I_left = ecc.string_to_number(I[0:32])
65    child_privkey = (I_left + ecc.string_to_number(parent_privkey)) % ecc.CURVE_ORDER
66    if I_left >= ecc.CURVE_ORDER or child_privkey == 0:
67        raise ecc.InvalidECPointException()
68    child_privkey = int.to_bytes(child_privkey, length=32, byteorder='big', signed=False)
69    child_chaincode = I[32:]
70    return child_privkey, child_chaincode
71
72
73
74@protect_against_invalid_ecpoint
75def CKD_pub(parent_pubkey: bytes, parent_chaincode: bytes, child_index: int) -> Tuple[bytes, bytes]:
76    """Child public key derivation function (from public key only)
77    This function allows us to find the nth public key, as long as n is
78    not hardened. If n is hardened, we need the master private key to find it.
79    """
80    if child_index < 0: raise ValueError('the bip32 index needs to be non-negative')
81    if child_index & BIP32_PRIME: raise Exception('not possible to derive hardened child from parent pubkey')
82    return _CKD_pub(parent_pubkey=parent_pubkey,
83                    parent_chaincode=parent_chaincode,
84                    child_index=bfh(rev_hex(int_to_hex(child_index, 4))))
85
86
87# helper function, callable with arbitrary 'child_index' byte-string.
88# i.e.: 'child_index' does not need to fit into 32 bits here! (c.f. trustedcoin billing)
89def _CKD_pub(parent_pubkey: bytes, parent_chaincode: bytes, child_index: bytes) -> Tuple[bytes, bytes]:
90    I = hmac_oneshot(parent_chaincode, parent_pubkey + child_index, hashlib.sha512)
91    pubkey = ecc.ECPrivkey(I[0:32]) + ecc.ECPubkey(parent_pubkey)
92    if pubkey.is_at_infinity():
93        raise ecc.InvalidECPointException()
94    child_pubkey = pubkey.get_public_key_bytes(compressed=True)
95    child_chaincode = I[32:]
96    return child_pubkey, child_chaincode
97
98
99def xprv_header(xtype: str, *, net=None) -> bytes:
100    if net is None:
101        net = constants.net
102    return net.XPRV_HEADERS[xtype].to_bytes(length=4, byteorder="big")
103
104
105def xpub_header(xtype: str, *, net=None) -> bytes:
106    if net is None:
107        net = constants.net
108    return net.XPUB_HEADERS[xtype].to_bytes(length=4, byteorder="big")
109
110
111class InvalidMasterKeyVersionBytes(BitcoinException): pass
112
113
114class BIP32Node(NamedTuple):
115    xtype: str
116    eckey: Union[ecc.ECPubkey, ecc.ECPrivkey]
117    chaincode: bytes
118    depth: int = 0
119    fingerprint: bytes = b'\x00'*4  # as in serialized format, this is the *parent's* fingerprint
120    child_number: bytes = b'\x00'*4
121
122    @classmethod
123    def from_xkey(cls, xkey: str, *, net=None) -> 'BIP32Node':
124        if net is None:
125            net = constants.net
126        xkey = DecodeBase58Check(xkey)
127        if len(xkey) != 78:
128            raise BitcoinException('Invalid length for extended key: {}'
129                                   .format(len(xkey)))
130        depth = xkey[4]
131        fingerprint = xkey[5:9]
132        child_number = xkey[9:13]
133        chaincode = xkey[13:13 + 32]
134        header = int.from_bytes(xkey[0:4], byteorder='big')
135        if header in net.XPRV_HEADERS_INV:
136            headers_inv = net.XPRV_HEADERS_INV
137            is_private = True
138        elif header in net.XPUB_HEADERS_INV:
139            headers_inv = net.XPUB_HEADERS_INV
140            is_private = False
141        else:
142            raise InvalidMasterKeyVersionBytes(f'Invalid extended key format: {hex(header)}')
143        xtype = headers_inv[header]
144        if is_private:
145            eckey = ecc.ECPrivkey(xkey[13 + 33:])
146        else:
147            eckey = ecc.ECPubkey(xkey[13 + 32:])
148        return BIP32Node(xtype=xtype,
149                         eckey=eckey,
150                         chaincode=chaincode,
151                         depth=depth,
152                         fingerprint=fingerprint,
153                         child_number=child_number)
154
155    @classmethod
156    def from_rootseed(cls, seed: bytes, *, xtype: str) -> 'BIP32Node':
157        I = hmac_oneshot(b"Bitcoin seed", seed, hashlib.sha512)
158        master_k = I[0:32]
159        master_c = I[32:]
160        return BIP32Node(xtype=xtype,
161                         eckey=ecc.ECPrivkey(master_k),
162                         chaincode=master_c)
163
164    @classmethod
165    def from_bytes(cls, b: bytes) -> 'BIP32Node':
166        if len(b) != 78:
167            raise Exception(f"unexpected xkey raw bytes len {len(b)} != 78")
168        xkey = EncodeBase58Check(b)
169        return cls.from_xkey(xkey)
170
171    def to_xprv(self, *, net=None) -> str:
172        payload = self.to_xprv_bytes(net=net)
173        return EncodeBase58Check(payload)
174
175    def to_xprv_bytes(self, *, net=None) -> bytes:
176        if not self.is_private():
177            raise Exception("cannot serialize as xprv; private key missing")
178        payload = (xprv_header(self.xtype, net=net) +
179                   bytes([self.depth]) +
180                   self.fingerprint +
181                   self.child_number +
182                   self.chaincode +
183                   bytes([0]) +
184                   self.eckey.get_secret_bytes())
185        assert len(payload) == 78, f"unexpected xprv payload len {len(payload)}"
186        return payload
187
188    def to_xpub(self, *, net=None) -> str:
189        payload = self.to_xpub_bytes(net=net)
190        return EncodeBase58Check(payload)
191
192    def to_xpub_bytes(self, *, net=None) -> bytes:
193        payload = (xpub_header(self.xtype, net=net) +
194                   bytes([self.depth]) +
195                   self.fingerprint +
196                   self.child_number +
197                   self.chaincode +
198                   self.eckey.get_public_key_bytes(compressed=True))
199        assert len(payload) == 78, f"unexpected xpub payload len {len(payload)}"
200        return payload
201
202    def to_xkey(self, *, net=None) -> str:
203        if self.is_private():
204            return self.to_xprv(net=net)
205        else:
206            return self.to_xpub(net=net)
207
208    def to_bytes(self, *, net=None) -> bytes:
209        if self.is_private():
210            return self.to_xprv_bytes(net=net)
211        else:
212            return self.to_xpub_bytes(net=net)
213
214    def convert_to_public(self) -> 'BIP32Node':
215        if not self.is_private():
216            return self
217        pubkey = ecc.ECPubkey(self.eckey.get_public_key_bytes())
218        return self._replace(eckey=pubkey)
219
220    def is_private(self) -> bool:
221        return isinstance(self.eckey, ecc.ECPrivkey)
222
223    def subkey_at_private_derivation(self, path: Union[str, Iterable[int]]) -> 'BIP32Node':
224        if path is None:
225            raise Exception("derivation path must not be None")
226        if isinstance(path, str):
227            path = convert_bip32_path_to_list_of_uint32(path)
228        if not self.is_private():
229            raise Exception("cannot do bip32 private derivation; private key missing")
230        if not path:
231            return self
232        depth = self.depth
233        chaincode = self.chaincode
234        privkey = self.eckey.get_secret_bytes()
235        for child_index in path:
236            parent_privkey = privkey
237            privkey, chaincode = CKD_priv(privkey, chaincode, child_index)
238            depth += 1
239        parent_pubkey = ecc.ECPrivkey(parent_privkey).get_public_key_bytes(compressed=True)
240        fingerprint = hash_160(parent_pubkey)[0:4]
241        child_number = child_index.to_bytes(length=4, byteorder="big")
242        return BIP32Node(xtype=self.xtype,
243                         eckey=ecc.ECPrivkey(privkey),
244                         chaincode=chaincode,
245                         depth=depth,
246                         fingerprint=fingerprint,
247                         child_number=child_number)
248
249    def subkey_at_public_derivation(self, path: Union[str, Iterable[int]]) -> 'BIP32Node':
250        if path is None:
251            raise Exception("derivation path must not be None")
252        if isinstance(path, str):
253            path = convert_bip32_path_to_list_of_uint32(path)
254        if not path:
255            return self.convert_to_public()
256        depth = self.depth
257        chaincode = self.chaincode
258        pubkey = self.eckey.get_public_key_bytes(compressed=True)
259        for child_index in path:
260            parent_pubkey = pubkey
261            pubkey, chaincode = CKD_pub(pubkey, chaincode, child_index)
262            depth += 1
263        fingerprint = hash_160(parent_pubkey)[0:4]
264        child_number = child_index.to_bytes(length=4, byteorder="big")
265        return BIP32Node(xtype=self.xtype,
266                         eckey=ecc.ECPubkey(pubkey),
267                         chaincode=chaincode,
268                         depth=depth,
269                         fingerprint=fingerprint,
270                         child_number=child_number)
271
272    def calc_fingerprint_of_this_node(self) -> bytes:
273        """Returns the fingerprint of this node.
274        Note that self.fingerprint is of the *parent*.
275        """
276        # TODO cache this
277        return hash_160(self.eckey.get_public_key_bytes(compressed=True))[0:4]
278
279
280def xpub_type(x):
281    return BIP32Node.from_xkey(x).xtype
282
283
284def is_xpub(text):
285    try:
286        node = BIP32Node.from_xkey(text)
287        return not node.is_private()
288    except:
289        return False
290
291
292def is_xprv(text):
293    try:
294        node = BIP32Node.from_xkey(text)
295        return node.is_private()
296    except:
297        return False
298
299
300def xpub_from_xprv(xprv):
301    return BIP32Node.from_xkey(xprv).to_xpub()
302
303
304def convert_bip32_path_to_list_of_uint32(n: str) -> List[int]:
305    """Convert bip32 path to list of uint32 integers with prime flags
306    m/0/-1/1' -> [0, 0x80000001, 0x80000001]
307
308    based on code in trezorlib
309    """
310    if not n:
311        return []
312    if n.endswith("/"):
313        n = n[:-1]
314    n = n.split('/')
315    # cut leading "m" if present, but do not require it
316    if n[0] == "m":
317        n = n[1:]
318    path = []
319    for x in n:
320        if x == '':
321            # gracefully allow repeating "/" chars in path.
322            # makes concatenating paths easier
323            continue
324        prime = 0
325        if x.endswith("'") or x.endswith("h"):
326            x = x[:-1]
327            prime = BIP32_PRIME
328        if x.startswith('-'):
329            if prime:
330                raise ValueError(f"bip32 path child index is signalling hardened level in multiple ways")
331            prime = BIP32_PRIME
332        child_index = abs(int(x)) | prime
333        if child_index > UINT32_MAX:
334            raise ValueError(f"bip32 path child index too large: {child_index} > {UINT32_MAX}")
335        path.append(child_index)
336    return path
337
338
339def convert_bip32_intpath_to_strpath(path: Sequence[int]) -> str:
340    s = "m/"
341    for child_index in path:
342        if not isinstance(child_index, int):
343            raise TypeError(f"bip32 path child index must be int: {child_index}")
344        if not (0 <= child_index <= UINT32_MAX):
345            raise ValueError(f"bip32 path child index out of range: {child_index}")
346        prime = ""
347        if child_index & BIP32_PRIME:
348            prime = "'"
349            child_index = child_index ^ BIP32_PRIME
350        s += str(child_index) + prime + '/'
351    # cut trailing "/"
352    s = s[:-1]
353    return s
354
355
356def is_bip32_derivation(s: str) -> bool:
357    try:
358        if not (s == 'm' or s.startswith('m/')):
359            return False
360        convert_bip32_path_to_list_of_uint32(s)
361    except:
362        return False
363    else:
364        return True
365
366
367def normalize_bip32_derivation(s: Optional[str]) -> Optional[str]:
368    if s is None:
369        return None
370    if not is_bip32_derivation(s):
371        raise ValueError(f"invalid bip32 derivation: {s}")
372    ints = convert_bip32_path_to_list_of_uint32(s)
373    return convert_bip32_intpath_to_strpath(ints)
374
375
376def is_all_public_derivation(path: Union[str, Iterable[int]]) -> bool:
377    """Returns whether all levels in path use non-hardened derivation."""
378    if isinstance(path, str):
379        path = convert_bip32_path_to_list_of_uint32(path)
380    for child_index in path:
381        if child_index < 0:
382            raise ValueError('the bip32 index needs to be non-negative')
383        if child_index & BIP32_PRIME:
384            return False
385    return True
386
387
388def root_fp_and_der_prefix_from_xkey(xkey: str) -> Tuple[Optional[str], Optional[str]]:
389    """Returns the root bip32 fingerprint and the derivation path from the
390    root to the given xkey, if they can be determined. Otherwise (None, None).
391    """
392    node = BIP32Node.from_xkey(xkey)
393    derivation_prefix = None
394    root_fingerprint = None
395    assert node.depth >= 0, node.depth
396    if node.depth == 0:
397        derivation_prefix = 'm'
398        root_fingerprint = node.calc_fingerprint_of_this_node().hex().lower()
399    elif node.depth == 1:
400        child_number_int = int.from_bytes(node.child_number, 'big')
401        derivation_prefix = convert_bip32_intpath_to_strpath([child_number_int])
402        root_fingerprint = node.fingerprint.hex()
403    return root_fingerprint, derivation_prefix
404
405
406def is_xkey_consistent_with_key_origin_info(xkey: str, *,
407                                            derivation_prefix: str = None,
408                                            root_fingerprint: str = None) -> bool:
409    bip32node = BIP32Node.from_xkey(xkey)
410    int_path = None
411    if derivation_prefix is not None:
412        int_path = convert_bip32_path_to_list_of_uint32(derivation_prefix)
413    if int_path is not None and len(int_path) != bip32node.depth:
414        return False
415    if bip32node.depth == 0:
416        if bfh(root_fingerprint) != bip32node.calc_fingerprint_of_this_node():
417            return False
418        if bip32node.child_number != bytes(4):
419            return False
420    if int_path is not None and bip32node.depth > 0:
421        if int.from_bytes(bip32node.child_number, 'big') != int_path[-1]:
422            return False
423    if bip32node.depth == 1:
424        if bfh(root_fingerprint) != bip32node.fingerprint:
425            return False
426    return True
427