1import hashlib
2import binascii
3try:
4    import sha3
5except:
6    from warnings import warn
7    warn("sha3 is not working!")
8
9
10class MerkleTools(object):
11    def __init__(self, hash_type="sha256"):
12        hash_type = hash_type.lower()
13        if hash_type in ['sha256', 'md5', 'sha224', 'sha384', 'sha512',
14                         'sha3_256', 'sha3_224', 'sha3_384', 'sha3_512']:
15            self.hash_function = getattr(hashlib, hash_type)
16        else:
17            raise Exception('`hash_type` {} nor supported'.format(hash_type))
18
19        self.reset_tree()
20
21    def _to_hex(self, x):
22        try:  # python3
23            return x.hex()
24        except:  # python2
25            return binascii.hexlify(x)
26
27    def reset_tree(self):
28        self.leaves = list()
29        self.levels = None
30        self.is_ready = False
31
32    def add_leaf(self, values, do_hash=False):
33        self.is_ready = False
34        # check if single leaf
35        if not isinstance(values, tuple) and not isinstance(values, list):
36            values = [values]
37        for v in values:
38            if do_hash:
39                v = v.encode('utf-8')
40                v = self.hash_function(v).hexdigest()
41            v = bytearray.fromhex(v)
42            self.leaves.append(v)
43
44    def get_leaf(self, index):
45        return self._to_hex(self.leaves[index])
46
47    def get_leaf_count(self):
48        return len(self.leaves)
49
50    def get_tree_ready_state(self):
51        return self.is_ready
52
53    def _calculate_next_level(self):
54        solo_leave = None
55        N = len(self.levels[0])  # number of leaves on the level
56        if N % 2 == 1:  # if odd number of leaves on the level
57            solo_leave = self.levels[0][-1]
58            N -= 1
59
60        new_level = []
61        for l, r in zip(self.levels[0][0:N:2], self.levels[0][1:N:2]):
62            new_level.append(self.hash_function(l+r).digest())
63        if solo_leave is not None:
64            new_level.append(solo_leave)
65        self.levels = [new_level, ] + self.levels  # prepend new level
66
67    def make_tree(self):
68        self.is_ready = False
69        if self.get_leaf_count() > 0:
70            self.levels = [self.leaves, ]
71            while len(self.levels[0]) > 1:
72                self._calculate_next_level()
73        self.is_ready = True
74
75    def get_merkle_root(self):
76        if self.is_ready:
77            if self.levels is not None:
78                return self._to_hex(self.levels[0][0])
79            else:
80                return None
81        else:
82            return None
83
84    def get_proof(self, index):
85        if self.levels is None:
86            return None
87        elif not self.is_ready or index > len(self.leaves)-1 or index < 0:
88            return None
89        else:
90            proof = []
91            for x in range(len(self.levels) - 1, 0, -1):
92                level_len = len(self.levels[x])
93                if (index == level_len - 1) and (level_len % 2 == 1):  # skip if this is an odd end node
94                    index = int(index / 2.)
95                    continue
96                is_right_node = index % 2
97                sibling_index = index - 1 if is_right_node else index + 1
98                sibling_pos = "left" if is_right_node else "right"
99                sibling_value = self._to_hex(self.levels[x][sibling_index])
100                proof.append({sibling_pos: sibling_value})
101                index = int(index / 2.)
102            return proof
103
104    def validate_proof(self, proof, target_hash, merkle_root):
105        merkle_root = bytearray.fromhex(merkle_root)
106        target_hash = bytearray.fromhex(target_hash)
107        if len(proof) == 0:
108            return target_hash == merkle_root
109        else:
110            proof_hash = target_hash
111            for p in proof:
112                try:
113                    # the sibling is a left node
114                    sibling = bytearray.fromhex(p['left'])
115                    proof_hash = self.hash_function(sibling + proof_hash).digest()
116                except:
117                    # the sibling is a right node
118                    sibling = bytearray.fromhex(p['right'])
119                    proof_hash = self.hash_function(proof_hash + sibling).digest()
120            return proof_hash == merkle_root
121