1import binascii
2import json
3import unittest
4from util import *
5
6FLAG_GRIND_R = 0x4
7
8class PSBTTests(unittest.TestCase):
9
10    def test_serialization(self):
11        """Testing serialization and deserialization"""
12        with open(root_dir + 'src/data/psbt.json', 'r') as f:
13            d = json.load(f)
14            invalids = d['invalid']
15            valids = d['valid']
16            creators = d['creator']
17            signers = d['signer']
18            inval_signers = d['inval_signer']
19            combiners = d['combiner']
20            finalizers = d['finalizer']
21            extractors = d['extractor']
22
23        for invalid in invalids:
24            self.assertEqual(WALLY_EINVAL, wally_psbt_from_base64(invalid.encode('utf-8'), pointer(wally_psbt())))
25
26        for valid in valids:
27            psbt = pointer(wally_psbt())
28            self.assertEqual(WALLY_OK, wally_psbt_from_base64(valid['psbt'].encode('utf-8'), psbt))
29            ret, reser = wally_psbt_to_base64(psbt, 0)
30            self.assertEqual(WALLY_OK, ret)
31            self.assertEqual(valid['psbt'], reser)
32            ret, length = wally_psbt_get_length(psbt, 0)
33            self.assertEqual(WALLY_OK, ret)
34            self.assertEqual(length, valid['len'])
35
36        for creator in creators:
37            psbt = pointer(wally_psbt())
38            self.assertEqual(WALLY_OK, wally_psbt_init_alloc(0, 2, 2, 0, psbt))
39
40            tx = pointer(wally_tx())
41            self.assertEqual(WALLY_OK, wally_tx_init_alloc(2, 0, 2, 2, tx))
42            for txin in creator['inputs']:
43                tx_in = pointer(wally_tx_input())
44                txid = binascii.unhexlify(txin['txid'])[::-1]
45                self.assertEqual(WALLY_OK, wally_tx_input_init_alloc(txid, len(txid), txin['vout'], 0xffffffff, None, 0, None, tx_in))
46                self.assertEqual(WALLY_OK, wally_tx_add_input(tx, tx_in))
47            for txout in creator['outputs']:
48                addr = txout['addr']
49                amt = txout['amt']
50                spk, spk_len = make_cbuffer('00' * (32 + 2))
51                ret, written = wally_addr_segwit_to_bytes(addr.encode('utf-8'), 'bcrt'.encode('utf-8'), 0, spk, spk_len)
52                self.assertEqual(WALLY_OK, ret)
53                output = pointer(wally_tx_output())
54                self.assertEqual(WALLY_OK, wally_tx_output_init_alloc(amt, spk, written, output))
55                self.assertEqual(WALLY_OK, wally_tx_add_output(tx, output))
56
57            self.assertEqual(WALLY_OK, wally_psbt_set_global_tx(psbt, tx))
58            ret, ser = wally_psbt_to_base64(psbt, 0)
59            self.assertEqual(WALLY_OK, ret)
60            self.assertEqual(creator['result'], ser)
61
62        for combiner in combiners:
63            psbt = pointer(wally_psbt())
64            self.assertEqual(WALLY_OK, wally_psbt_from_base64(combiner['combine'][0].encode('utf-8'), psbt))
65            for src_b64 in combiner['combine'][1:]:
66                src = pointer(wally_psbt())
67                self.assertEqual(WALLY_OK, wally_psbt_from_base64(src_b64.encode('utf-8'), src))
68                self.assertEqual(WALLY_OK, wally_psbt_combine(psbt, src))
69                self.assertEqual(WALLY_OK, wally_psbt_free(src))
70            ret, psbt_b64 = wally_psbt_to_base64(psbt, 0)
71            self.assertEqual(combiner['result'], psbt_b64)
72
73        for signer in signers:
74            psbt = pointer(wally_psbt())
75            self.assertEqual(WALLY_OK, wally_psbt_from_base64(signer['psbt'].encode('utf-8'), psbt))
76            for priv in signer['privkeys']:
77                buf, buf_len = make_cbuffer('00'*32)
78                self.assertEqual(WALLY_OK, wally_wif_to_bytes(priv.encode('utf-8'), 0xEF, 0, buf, buf_len))
79                self.assertEqual(WALLY_OK, wally_psbt_sign(psbt, buf, buf_len, FLAG_GRIND_R))
80
81            ret, reser = wally_psbt_to_base64(psbt, 0)
82            self.assertEqual(WALLY_OK, ret)
83            # Check that we can *demarshal* the signed PSBT (some bugs only appear here)
84            self.assertEqual(WALLY_OK, wally_psbt_from_base64(reser, psbt))
85            self.assertEqual(signer['result'], reser)
86
87        for inval_signer in inval_signers:
88            psbt = pointer(wally_psbt())
89            self.assertEqual(WALLY_OK, wally_psbt_from_base64(inval_signer['psbt'].encode('utf-8'), psbt))
90
91            for priv in inval_signer['privkeys']:
92                buf, buf_len = make_cbuffer('00'*32)
93                self.assertEqual(WALLY_OK, wally_wif_to_bytes(priv.encode('utf-8'), 0xEF, 0, buf, buf_len))
94                self.assertEqual(WALLY_EINVAL, wally_psbt_sign(psbt, buf, buf_len, FLAG_GRIND_R))
95
96        for finalizer in finalizers:
97            psbt = pointer(wally_psbt())
98            self.assertEqual(WALLY_OK, wally_psbt_from_base64(finalizer['finalize'].encode('utf-8'), psbt))
99            self.assertEqual(WALLY_OK, wally_psbt_finalize(psbt))
100            ret, is_finalized = wally_psbt_is_finalized(psbt)
101            self.assertEqual(WALLY_OK, ret)
102            self.assertEqual(1, is_finalized)
103            ret, reser = wally_psbt_to_base64(psbt, 0)
104            self.assertEqual(WALLY_OK, ret)
105            self.assertEqual(finalizer['result'], reser)
106
107        for extractor in extractors:
108            psbt = pointer(wally_psbt())
109            tx = pointer(wally_tx())
110            self.assertEqual(WALLY_OK, wally_psbt_from_base64(extractor['extract'].encode('utf-8'), psbt))
111            self.assertEqual(WALLY_OK, wally_psbt_extract(psbt, tx))
112            ret, reser = wally_tx_to_hex(tx, 1)
113            self.assertEqual(WALLY_OK, ret)
114            self.assertEqual(extractor['result'], reser)
115
116    def test_map(self):
117        """Test PSBT map helper functions"""
118        m = pointer(wally_map())
119        # Test keys. Once sorted we expect order k3, k2, k1
120        key1, key1_len = make_cbuffer('505050')
121        key2, key2_len = make_cbuffer('40404040')
122        key3, key3_len = make_cbuffer('404040')
123        val, val_len = make_cbuffer('ffffffff')
124
125        # Check invalid args
126        self.assertEqual(wally_map_init_alloc(0, None), WALLY_EINVAL)
127        self.assertEqual(wally_map_init_alloc(0, m), WALLY_OK)
128
129        for args in [(None, key1, key1_len, val,  val_len), # Null map
130                     (m,    None, key1_len, val,  val_len), # Null key
131                     (m,    key1, 0,        val,  val_len), # 0 length key
132                     (m,    key1, key1_len, None, val_len), # Null value
133                     (m,    key1, key1_len, val,  0)]:      # 0 length value
134            self.assertEqual(wally_map_add(*args), WALLY_EINVAL)
135            # TODO: wally_map_add_keypath_item
136
137        for args in [(None, key1, key1_len), # Null map
138                     (m,    None, key1_len), # Null key
139                     (m,    key1, 0)]:       # 0 length key
140            self.assertEqual(wally_map_find(*args), (WALLY_EINVAL, 0))
141
142        self.assertEqual(wally_map_sort(None, 0), WALLY_EINVAL) # Null map
143        self.assertEqual(wally_map_sort(m, 1),    WALLY_EINVAL) # Invalid flags
144
145        self.assertEqual(wally_map_free(None), WALLY_OK) # Null is OK
146
147        # Add and find each key
148        for k, l, i in [(key1, key1_len, 1),
149                        (key2, key2_len, 2),
150                        (key3, key3_len, 3)]:
151            self.assertEqual(wally_map_add(m, k, l, val, val_len), WALLY_OK)
152            self.assertEqual(wally_map_find(m, k, l), (WALLY_OK, i))
153
154        # Sort
155        self.assertEqual(wally_map_sort(m, 0), WALLY_OK)
156
157        # Verify sort order
158        for k, l, i in [(key1, key1_len, 3),
159                        (key2, key2_len, 2),
160                        (key3, key3_len, 1)]:
161            self.assertEqual(wally_map_find(m, k, l), (WALLY_OK, i))
162
163        self.assertEqual(wally_map_free(m), WALLY_OK)
164
165    def test_v20dot1_changes(self):
166        """See https://github.com/ElementsProject/libwally-core/issues/213
167           Verify that core v20.1 changes to address the segwit fee attack now work"""
168        b64 = "cHNidP8BAJoCAAAAAvezqpNxOIDkwNFhfZVLYvuhQxqmqNPJwlyXbhc8cuLPAQAAAAD9////krlOMdd9VVzPWn5+oadTb4C3NnUFWA3tF6cb1RiI4JAAAAAAAP3///8CESYAAAAAAAAWABQn/PFABd2EW5RsCUvJitAYNshf9BAnAAAAAAAAFgAUFpodxCngMIyYnbJ1mhpDwQykN4cAAAAAAAEAiQIAAAABfRJscM0GWu793LYoAX15Mnj+dVr0G7yvRMBeWSmvPpQAAAAAFxYAFESkW2FnrJlkwmQZjTXL1IVM95lW/f///wK76QAAAAAAABYAFB33sq8WtoOlpvUpCvoWbxJJl5rhECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcAAAAAAQEgECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcBBBYAFIsieXd6AAeP8TXHKZ329Z0nuSeZIgYD/ajyzV90ghQ+0zIO2mVSd3fGYhvwYjakGCY4WNYxoeYEiyJ5dwABAHICAAAAAfezqpNxOIDkwNFhfZVLYvuhQxqmqNPJwlyXbhc8cuLPAAAAAAD9////AhAnAAAAAAAAF6kUXJfUn/nNbND+a+QhqHnyCSy9oPmHHcIAAAAAAAAWABSUD3a8pIYaaLvKdZxoEPFfo8vlDwAAAAABASAQJwAAAAAAABepFFyX1J/5zWzQ/mvkIah58gksvaD5hwEEFgAUyRIBhZwlI4RLT6NDHluovlrN3iAiBgIs+YA2N8B5O6nF4SgVEG765xfHZFKrLiKbjZuo8/9vPATJEgGFACICAq8h+ABETC5Tczuts3xhCtXAzIEUHM5iMugvwFMrtCc4EBK06cYAAACAAQAAgMMAAIAAAA=="
169        psbt = pointer(wally_psbt())
170        self.assertEqual(wally_psbt_from_base64(b64.encode('utf-8'), psbt), WALLY_OK)
171        buf, buf_len = make_cbuffer('00'*32)
172        for priv in ['cTatuMdjH4YA4F1pAm11QdbCt88T8t2TTMoAvVGzAxWAWmQZtkBZ',
173                     'cR5yyo2g1SzzwCw2QAREzF7XhYuXZS9SzTTf8A9qerri9EXZcRYS']:
174            self.assertEqual(wally_wif_to_bytes(priv.encode('utf-8'), 0xEF, 0, buf, buf_len), WALLY_OK)
175            self.assertEqual(wally_psbt_sign(psbt, buf, buf_len, FLAG_GRIND_R), WALLY_OK)
176        self.assertEqual(wally_psbt_finalize(psbt), WALLY_OK)
177        ret, new64 = wally_psbt_to_base64(psbt, 0)
178        self.assertEqual(ret, WALLY_OK)
179        expected_b64 = "cHNidP8BAJoCAAAAAvezqpNxOIDkwNFhfZVLYvuhQxqmqNPJwlyXbhc8cuLPAQAAAAD9////krlOMdd9VVzPWn5+oadTb4C3NnUFWA3tF6cb1RiI4JAAAAAAAP3///8CESYAAAAAAAAWABQn/PFABd2EW5RsCUvJitAYNshf9BAnAAAAAAAAFgAUFpodxCngMIyYnbJ1mhpDwQykN4cAAAAAAAEAiQIAAAABfRJscM0GWu793LYoAX15Mnj+dVr0G7yvRMBeWSmvPpQAAAAAFxYAFESkW2FnrJlkwmQZjTXL1IVM95lW/f///wK76QAAAAAAABYAFB33sq8WtoOlpvUpCvoWbxJJl5rhECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcAAAAAAQEgECcAAAAAAAAXqRTFhAlcZBMRkG4iAustDT6iSw6wkIcBBxcWABSLInl3egAHj/E1xymd9vWdJ7knmQEIawJHMEQCIAkPXe9sdpRjSDTjJ0gIrpwGGIWJby9xSd1rS9hPe1f0AiAJgqR7PL3G/MXyUu4KZdS1Z2O14fjxstF43k634u+4GAEhA/2o8s1fdIIUPtMyDtplUnd3xmIb8GI2pBgmOFjWMaHmAAEAcgIAAAAB97Oqk3E4gOTA0WF9lUti+6FDGqao08nCXJduFzxy4s8AAAAAAP3///8CECcAAAAAAAAXqRRcl9Sf+c1s0P5r5CGoefIJLL2g+YcdwgAAAAAAABYAFJQPdrykhhpou8p1nGgQ8V+jy+UPAAAAAAEBIBAnAAAAAAAAF6kUXJfUn/nNbND+a+QhqHnyCSy9oPmHAQcXFgAUyRIBhZwlI4RLT6NDHluovlrN3iABCGsCRzBEAiAOzRsNZ+2Et+VGCY/nXWO7WxGI3u39kpi025cUaJXQJgIgL6KtMqPfAwXGktQFWr9SNnOrHF2xjvKQI2VdeuQbxt0BIQIs+YA2N8B5O6nF4SgVEG765xfHZFKrLiKbjZuo8/9vPAAiAgKvIfgAREwuU3M7rbN8YQrVwMyBFBzOYjLoL8BTK7QnOBAStOnGAAAAgAEAAIDDAACAAAA="
180        self.assertEqual(new64.encode('utf-8'), expected_b64.encode('utf-8'))
181
182
183if __name__ == '__main__':
184    unittest.main()
185