1import binascii 2import json 3import unittest 4from util import * 5 6FLAG_GRIND_R = 0x4 7 8class PSBTTests(unittest.TestCase): 9 10 def test_serialization(self): 11 """Testing serialization and deserialization""" 12 with open(root_dir + 'src/data/psbt.json', 'r') as f: 13 d = json.load(f) 14 invalids = d['invalid'] 15 valids = d['valid'] 16 creators = d['creator'] 17 signers = d['signer'] 18 inval_signers = d['inval_signer'] 19 combiners = d['combiner'] 20 finalizers = d['finalizer'] 21 extractors = d['extractor'] 22 23 for invalid in invalids: 24 self.assertEqual(WALLY_EINVAL, wally_psbt_from_base64(invalid.encode('utf-8'), pointer(wally_psbt()))) 25 26 for valid in valids: 27 psbt = pointer(wally_psbt()) 28 self.assertEqual(WALLY_OK, wally_psbt_from_base64(valid['psbt'].encode('utf-8'), psbt)) 29 ret, reser = wally_psbt_to_base64(psbt, 0) 30 self.assertEqual(WALLY_OK, ret) 31 self.assertEqual(valid['psbt'], reser) 32 ret, length = wally_psbt_get_length(psbt, 0) 33 self.assertEqual(WALLY_OK, ret) 34 self.assertEqual(length, valid['len']) 35 36 for creator in creators: 37 psbt = pointer(wally_psbt()) 38 self.assertEqual(WALLY_OK, wally_psbt_init_alloc(0, 2, 2, 0, psbt)) 39 40 tx = pointer(wally_tx()) 41 self.assertEqual(WALLY_OK, wally_tx_init_alloc(2, 0, 2, 2, tx)) 42 for txin in creator['inputs']: 43 tx_in = pointer(wally_tx_input()) 44 txid = binascii.unhexlify(txin['txid'])[::-1] 45 self.assertEqual(WALLY_OK, wally_tx_input_init_alloc(txid, len(txid), txin['vout'], 0xffffffff, None, 0, None, tx_in)) 46 self.assertEqual(WALLY_OK, wally_tx_add_input(tx, tx_in)) 47 for txout in creator['outputs']: 48 addr = txout['addr'] 49 amt = txout['amt'] 50 spk, spk_len = make_cbuffer('00' * (32 + 2)) 51 ret, written = wally_addr_segwit_to_bytes(addr.encode('utf-8'), 'bcrt'.encode('utf-8'), 0, spk, spk_len) 52 self.assertEqual(WALLY_OK, ret) 53 output = pointer(wally_tx_output()) 54 self.assertEqual(WALLY_OK, wally_tx_output_init_alloc(amt, spk, written, output)) 55 self.assertEqual(WALLY_OK, wally_tx_add_output(tx, output)) 56 57 self.assertEqual(WALLY_OK, wally_psbt_set_global_tx(psbt, tx)) 58 ret, ser = wally_psbt_to_base64(psbt, 0) 59 self.assertEqual(WALLY_OK, ret) 60 self.assertEqual(creator['result'], ser) 61 62 for combiner in combiners: 63 psbt = pointer(wally_psbt()) 64 self.assertEqual(WALLY_OK, wally_psbt_from_base64(combiner['combine'][0].encode('utf-8'), psbt)) 65 for src_b64 in combiner['combine'][1:]: 66 src = pointer(wally_psbt()) 67 self.assertEqual(WALLY_OK, wally_psbt_from_base64(src_b64.encode('utf-8'), src)) 68 self.assertEqual(WALLY_OK, wally_psbt_combine(psbt, src)) 69 self.assertEqual(WALLY_OK, wally_psbt_free(src)) 70 ret, psbt_b64 = wally_psbt_to_base64(psbt, 0) 71 self.assertEqual(combiner['result'], psbt_b64) 72 73 for signer in signers: 74 psbt = pointer(wally_psbt()) 75 self.assertEqual(WALLY_OK, wally_psbt_from_base64(signer['psbt'].encode('utf-8'), psbt)) 76 for priv in signer['privkeys']: 77 buf, buf_len = make_cbuffer('00'*32) 78 self.assertEqual(WALLY_OK, wally_wif_to_bytes(priv.encode('utf-8'), 0xEF, 0, buf, buf_len)) 79 self.assertEqual(WALLY_OK, wally_psbt_sign(psbt, buf, buf_len, FLAG_GRIND_R)) 80 81 ret, reser = wally_psbt_to_base64(psbt, 0) 82 self.assertEqual(WALLY_OK, ret) 83 # Check that we can *demarshal* the signed PSBT (some bugs only appear here) 84 self.assertEqual(WALLY_OK, wally_psbt_from_base64(reser, psbt)) 85 self.assertEqual(signer['result'], reser) 86 87 for inval_signer in inval_signers: 88 psbt = pointer(wally_psbt()) 89 self.assertEqual(WALLY_OK, wally_psbt_from_base64(inval_signer['psbt'].encode('utf-8'), psbt)) 90 91 for priv in inval_signer['privkeys']: 92 buf, buf_len = make_cbuffer('00'*32) 93 self.assertEqual(WALLY_OK, wally_wif_to_bytes(priv.encode('utf-8'), 0xEF, 0, buf, buf_len)) 94 self.assertEqual(WALLY_EINVAL, wally_psbt_sign(psbt, buf, buf_len, FLAG_GRIND_R)) 95 96 for finalizer in finalizers: 97 psbt = pointer(wally_psbt()) 98 self.assertEqual(WALLY_OK, wally_psbt_from_base64(finalizer['finalize'].encode('utf-8'), psbt)) 99 self.assertEqual(WALLY_OK, wally_psbt_finalize(psbt)) 100 ret, is_finalized = wally_psbt_is_finalized(psbt) 101 self.assertEqual(WALLY_OK, ret) 102 self.assertEqual(1, is_finalized) 103 ret, reser = wally_psbt_to_base64(psbt, 0) 104 self.assertEqual(WALLY_OK, ret) 105 self.assertEqual(finalizer['result'], reser) 106 107 for extractor in extractors: 108 psbt = pointer(wally_psbt()) 109 tx = pointer(wally_tx()) 110 self.assertEqual(WALLY_OK, wally_psbt_from_base64(extractor['extract'].encode('utf-8'), psbt)) 111 self.assertEqual(WALLY_OK, wally_psbt_extract(psbt, tx)) 112 ret, reser = wally_tx_to_hex(tx, 1) 113 self.assertEqual(WALLY_OK, ret) 114 self.assertEqual(extractor['result'], reser) 115 116 def test_map(self): 117 """Test PSBT map helper functions""" 118 m = pointer(wally_map()) 119 # Test keys. Once sorted we expect order k3, k2, k1 120 key1, key1_len = make_cbuffer('505050') 121 key2, key2_len = make_cbuffer('40404040') 122 key3, key3_len = make_cbuffer('404040') 123 val, val_len = make_cbuffer('ffffffff') 124 125 # Check invalid args 126 self.assertEqual(wally_map_init_alloc(0, None), WALLY_EINVAL) 127 self.assertEqual(wally_map_init_alloc(0, m), WALLY_OK) 128 129 for args in [(None, key1, key1_len, val, val_len), # Null map 130 (m, None, key1_len, val, val_len), # Null key 131 (m, key1, 0, val, val_len), # 0 length key 132 (m, key1, key1_len, None, val_len), # Null value 133 (m, key1, key1_len, val, 0)]: # 0 length value 134 self.assertEqual(wally_map_add(*args), WALLY_EINVAL) 135 # TODO: wally_map_add_keypath_item 136 137 for args in [(None, key1, key1_len), # Null map 138 (m, None, key1_len), # Null key 139 (m, key1, 0)]: # 0 length key 140 self.assertEqual(wally_map_find(*args), (WALLY_EINVAL, 0)) 141 142 self.assertEqual(wally_map_sort(None, 0), WALLY_EINVAL) # Null map 143 self.assertEqual(wally_map_sort(m, 1), WALLY_EINVAL) # Invalid flags 144 145 self.assertEqual(wally_map_free(None), WALLY_OK) # Null is OK 146 147 # Add and find each key 148 for k, l, i in [(key1, key1_len, 1), 149 (key2, key2_len, 2), 150 (key3, key3_len, 3)]: 151 self.assertEqual(wally_map_add(m, k, l, val, val_len), WALLY_OK) 152 self.assertEqual(wally_map_find(m, k, l), (WALLY_OK, i)) 153 154 # Sort 155 self.assertEqual(wally_map_sort(m, 0), WALLY_OK) 156 157 # Verify sort order 158 for k, l, i in [(key1, key1_len, 3), 159 (key2, key2_len, 2), 160 (key3, key3_len, 1)]: 161 self.assertEqual(wally_map_find(m, k, l), (WALLY_OK, i)) 162 163 self.assertEqual(wally_map_free(m), WALLY_OK) 164 165 def test_v20dot1_changes(self): 166 """See https://github.com/ElementsProject/libwally-core/issues/213 167 Verify that core v20.1 changes to address the segwit fee attack now work""" 168 b64 = "cHNidP8BAJoCAAAAAvezqpNxOIDkwNFhfZVLYvuhQxqmqNPJwlyXbhc8cuLPAQAAAAD9////krlOMdd9VVzPWn5+oadTb4C3NnUFWA3tF6cb1RiI4JAAAAAAAP3///8CESYAAAAAAAAWABQn/PFABd2EW5RsCUvJitAYNshf9BAnAAAAAAAAFgAUFpodxCngMIyYnbJ1mhpDwQykN4cAAAAAAAEAiQIAAAABfRJscM0GWu793LYoAX15Mnj+dVr0G7yvRMBeWSmvPpQAAAAAFxYAFESkW2FnrJlkwmQZjTXL1IVM95lW/f///wK76QAAAAAAABYAFB33sq8WtoOlpvUpCvoWbxJJl5rhECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcAAAAAAQEgECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcBBBYAFIsieXd6AAeP8TXHKZ329Z0nuSeZIgYD/ajyzV90ghQ+0zIO2mVSd3fGYhvwYjakGCY4WNYxoeYEiyJ5dwABAHICAAAAAfezqpNxOIDkwNFhfZVLYvuhQxqmqNPJwlyXbhc8cuLPAAAAAAD9////AhAnAAAAAAAAF6kUXJfUn/nNbND+a+QhqHnyCSy9oPmHHcIAAAAAAAAWABSUD3a8pIYaaLvKdZxoEPFfo8vlDwAAAAABASAQJwAAAAAAABepFFyX1J/5zWzQ/mvkIah58gksvaD5hwEEFgAUyRIBhZwlI4RLT6NDHluovlrN3iAiBgIs+YA2N8B5O6nF4SgVEG765xfHZFKrLiKbjZuo8/9vPATJEgGFACICAq8h+ABETC5Tczuts3xhCtXAzIEUHM5iMugvwFMrtCc4EBK06cYAAACAAQAAgMMAAIAAAA==" 169 psbt = pointer(wally_psbt()) 170 self.assertEqual(wally_psbt_from_base64(b64.encode('utf-8'), psbt), WALLY_OK) 171 buf, buf_len = make_cbuffer('00'*32) 172 for priv in ['cTatuMdjH4YA4F1pAm11QdbCt88T8t2TTMoAvVGzAxWAWmQZtkBZ', 173 'cR5yyo2g1SzzwCw2QAREzF7XhYuXZS9SzTTf8A9qerri9EXZcRYS']: 174 self.assertEqual(wally_wif_to_bytes(priv.encode('utf-8'), 0xEF, 0, buf, buf_len), WALLY_OK) 175 self.assertEqual(wally_psbt_sign(psbt, buf, buf_len, FLAG_GRIND_R), WALLY_OK) 176 self.assertEqual(wally_psbt_finalize(psbt), WALLY_OK) 177 ret, new64 = wally_psbt_to_base64(psbt, 0) 178 self.assertEqual(ret, WALLY_OK) 179 expected_b64 = "cHNidP8BAJoCAAAAAvezqpNxOIDkwNFhfZVLYvuhQxqmqNPJwlyXbhc8cuLPAQAAAAD9////krlOMdd9VVzPWn5+oadTb4C3NnUFWA3tF6cb1RiI4JAAAAAAAP3///8CESYAAAAAAAAWABQn/PFABd2EW5RsCUvJitAYNshf9BAnAAAAAAAAFgAUFpodxCngMIyYnbJ1mhpDwQykN4cAAAAAAAEAiQIAAAABfRJscM0GWu793LYoAX15Mnj+dVr0G7yvRMBeWSmvPpQAAAAAFxYAFESkW2FnrJlkwmQZjTXL1IVM95lW/f///wK76QAAAAAAABYAFB33sq8WtoOlpvUpCvoWbxJJl5rhECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcAAAAAAQEgECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcBBxcWABSLInl3egAHj/E1xymd9vWdJ7knmQEIawJHMEQCIAkPXe9sdpRjSDTjJ0gIrpwGGIWJby9xSd1rS9hPe1f0AiAJgqR7PL3G/MXyUu4KZdS1Z2O14fjxstF43k634u+4GAEhA/2o8s1fdIIUPtMyDtplUnd3xmIb8GI2pBgmOFjWMaHmAAEAcgIAAAAB97Oqk3E4gOTA0WF9lUti+6FDGqao08nCXJduFzxy4s8AAAAAAP3///8CECcAAAAAAAAXqRRcl9Sf+c1s0P5r5CGoefIJLL2g+YcdwgAAAAAAABYAFJQPdrykhhpou8p1nGgQ8V+jy+UPAAAAAAEBIBAnAAAAAAAAF6kUXJfUn/nNbND+a+QhqHnyCSy9oPmHAQcXFgAUyRIBhZwlI4RLT6NDHluovlrN3iABCGsCRzBEAiAOzRsNZ+2Et+VGCY/nXWO7WxGI3u39kpi025cUaJXQJgIgL6KtMqPfAwXGktQFWr9SNnOrHF2xjvKQI2VdeuQbxt0BIQIs+YA2N8B5O6nF4SgVEG765xfHZFKrLiKbjZuo8/9vPAAiAgKvIfgAREwuU3M7rbN8YQrVwMyBFBzOYjLoL8BTK7QnOBAStOnGAAAAgAEAAIDDAACAAAA=" 180 self.assertEqual(new64.encode('utf-8'), expected_b64.encode('utf-8')) 181 182 183if __name__ == '__main__': 184 unittest.main() 185