1import hashlib 2import binascii 3try: 4 import sha3 5except: 6 from warnings import warn 7 warn("sha3 is not working!") 8 9 10class MerkleTools(object): 11 def __init__(self, hash_type="sha256"): 12 hash_type = hash_type.lower() 13 if hash_type in ['sha256', 'md5', 'sha224', 'sha384', 'sha512', 14 'sha3_256', 'sha3_224', 'sha3_384', 'sha3_512']: 15 self.hash_function = getattr(hashlib, hash_type) 16 else: 17 raise Exception('`hash_type` {} nor supported'.format(hash_type)) 18 19 self.reset_tree() 20 21 def _to_hex(self, x): 22 try: # python3 23 return x.hex() 24 except: # python2 25 return binascii.hexlify(x) 26 27 def reset_tree(self): 28 self.leaves = list() 29 self.levels = None 30 self.is_ready = False 31 32 def add_leaf(self, values, do_hash=False): 33 self.is_ready = False 34 # check if single leaf 35 if not isinstance(values, tuple) and not isinstance(values, list): 36 values = [values] 37 for v in values: 38 if do_hash: 39 v = v.encode('utf-8') 40 v = self.hash_function(v).hexdigest() 41 v = bytearray.fromhex(v) 42 self.leaves.append(v) 43 44 def get_leaf(self, index): 45 return self._to_hex(self.leaves[index]) 46 47 def get_leaf_count(self): 48 return len(self.leaves) 49 50 def get_tree_ready_state(self): 51 return self.is_ready 52 53 def _calculate_next_level(self): 54 solo_leave = None 55 N = len(self.levels[0]) # number of leaves on the level 56 if N % 2 == 1: # if odd number of leaves on the level 57 solo_leave = self.levels[0][-1] 58 N -= 1 59 60 new_level = [] 61 for l, r in zip(self.levels[0][0:N:2], self.levels[0][1:N:2]): 62 new_level.append(self.hash_function(l+r).digest()) 63 if solo_leave is not None: 64 new_level.append(solo_leave) 65 self.levels = [new_level, ] + self.levels # prepend new level 66 67 def make_tree(self): 68 self.is_ready = False 69 if self.get_leaf_count() > 0: 70 self.levels = [self.leaves, ] 71 while len(self.levels[0]) > 1: 72 self._calculate_next_level() 73 self.is_ready = True 74 75 def get_merkle_root(self): 76 if self.is_ready: 77 if self.levels is not None: 78 return self._to_hex(self.levels[0][0]) 79 else: 80 return None 81 else: 82 return None 83 84 def get_proof(self, index): 85 if self.levels is None: 86 return None 87 elif not self.is_ready or index > len(self.leaves)-1 or index < 0: 88 return None 89 else: 90 proof = [] 91 for x in range(len(self.levels) - 1, 0, -1): 92 level_len = len(self.levels[x]) 93 if (index == level_len - 1) and (level_len % 2 == 1): # skip if this is an odd end node 94 index = int(index / 2.) 95 continue 96 is_right_node = index % 2 97 sibling_index = index - 1 if is_right_node else index + 1 98 sibling_pos = "left" if is_right_node else "right" 99 sibling_value = self._to_hex(self.levels[x][sibling_index]) 100 proof.append({sibling_pos: sibling_value}) 101 index = int(index / 2.) 102 return proof 103 104 def validate_proof(self, proof, target_hash, merkle_root): 105 merkle_root = bytearray.fromhex(merkle_root) 106 target_hash = bytearray.fromhex(target_hash) 107 if len(proof) == 0: 108 return target_hash == merkle_root 109 else: 110 proof_hash = target_hash 111 for p in proof: 112 try: 113 # the sibling is a left node 114 sibling = bytearray.fromhex(p['left']) 115 proof_hash = self.hash_function(sibling + proof_hash).digest() 116 except: 117 # the sibling is a right node 118 sibling = bytearray.fromhex(p['right']) 119 proof_hash = self.hash_function(proof_hash + sibling).digest() 120 return proof_hash == merkle_root 121