1 #include "config.h"
2 #include <bitcoin/psbt.h>
3 #include <bitcoin/script.h>
4 #include <ccan/asort/asort.h>
5 #include <ccan/ccan/endian/endian.h>
6 #include <ccan/ccan/mem/mem.h>
7 #include <common/psbt_open.h>
8 #include <common/pseudorand.h>
9 #include <common/utils.h>
10
psbt_get_serial_id(const struct wally_map * map,u64 * serial_id)11 bool psbt_get_serial_id(const struct wally_map *map, u64 *serial_id)
12 {
13 size_t value_len;
14 beint64_t bev;
15 void *result = psbt_get_lightning(map, PSBT_TYPE_SERIAL_ID, &value_len);
16 if (!result)
17 return false;
18
19 if (value_len != sizeof(bev))
20 return false;
21
22 memcpy(&bev, result, value_len);
23 *serial_id = be64_to_cpu(bev);
24 return true;
25 }
26
compare_serials(const struct wally_map * map_a,const struct wally_map * map_b)27 static int compare_serials(const struct wally_map *map_a,
28 const struct wally_map *map_b)
29 {
30 u64 serial_left, serial_right;
31 bool ok;
32
33 ok = psbt_get_serial_id(map_a, &serial_left);
34 assert(ok);
35 ok = psbt_get_serial_id(map_b, &serial_right);
36 assert(ok);
37 if (serial_left > serial_right)
38 return 1;
39 if (serial_left < serial_right)
40 return -1;
41 return 0;
42 }
43
compare_inputs_at(const struct input_set * a,const struct input_set * b,void * unused UNUSED)44 static int compare_inputs_at(const struct input_set *a,
45 const struct input_set *b,
46 void *unused UNUSED)
47 {
48 return compare_serials(&a->input.unknowns,
49 &b->input.unknowns);
50 }
51
compare_outputs_at(const struct output_set * a,const struct output_set * b,void * unused UNUSED)52 static int compare_outputs_at(const struct output_set *a,
53 const struct output_set *b,
54 void *unused UNUSED)
55 {
56 return compare_serials(&a->output.unknowns,
57 &b->output.unknowns);
58 }
59
linearize_input(const tal_t * ctx,const struct wally_psbt_input * in,const struct wally_tx_input * tx_in)60 static const u8 *linearize_input(const tal_t *ctx,
61 const struct wally_psbt_input *in,
62 const struct wally_tx_input *tx_in)
63 {
64 struct wally_psbt *psbt = create_psbt(NULL, 1, 0, 0);
65 size_t byte_len;
66
67 tal_wally_start();
68 if (wally_tx_add_input(psbt->tx, tx_in) != WALLY_OK)
69 abort();
70 tal_wally_end(psbt->tx);
71
72 psbt->inputs[0] = *in;
73 psbt->num_inputs++;
74
75
76 /* Sort the inputs, so serializing them is ok */
77 wally_map_sort(&psbt->inputs[0].unknowns, 0);
78
79 /* signatures, keypaths, etc - we dont care if they change */
80 psbt->inputs[0].final_witness = NULL;
81 psbt->inputs[0].final_scriptsig_len = 0;
82 psbt->inputs[0].witness_script = NULL;
83 psbt->inputs[0].witness_script_len = 0;
84 psbt->inputs[0].redeem_script_len = 0;
85 psbt->inputs[0].keypaths.num_items = 0;
86 psbt->inputs[0].signatures.num_items = 0;
87
88 const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len);
89
90 /* Hide the inputs we added, so it doesn't get freed */
91 psbt->num_inputs--;
92 tal_free(psbt);
93 return bytes;
94 }
95
linearize_output(const tal_t * ctx,const struct wally_psbt_output * out,const struct wally_tx_output * tx_out)96 static const u8 *linearize_output(const tal_t *ctx,
97 const struct wally_psbt_output *out,
98 const struct wally_tx_output *tx_out)
99 {
100 struct wally_psbt *psbt = create_psbt(NULL, 1, 1, 0);
101 size_t byte_len;
102 struct bitcoin_outpoint outpoint;
103
104 /* Add a 'fake' input so this will linearize the tx */
105 memset(&outpoint, 0, sizeof(outpoint));
106 psbt_append_input(psbt, &outpoint, 0, NULL, NULL, NULL);
107
108 tal_wally_start();
109 if (wally_tx_add_output(psbt->tx, tx_out) != WALLY_OK)
110 abort();
111 tal_wally_end(psbt->tx);
112
113 psbt->outputs[0] = *out;
114 psbt->num_outputs++;
115 /* Sort the outputs, so serializing them is ok */
116 wally_map_sort(&psbt->outputs[0].unknowns, 0);
117
118 /* We don't care if the keypaths change */
119 psbt->outputs[0].keypaths.num_items = 0;
120 /* And you can add scripts, no problem */
121 psbt->outputs[0].witness_script_len = 0;
122 psbt->outputs[0].redeem_script_len = 0;
123
124 const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len);
125
126 /* Hide the outputs we added, so it doesn't get freed */
127 psbt->num_outputs--;
128 tal_free(psbt);
129 return bytes;
130 }
131
input_identical(const struct wally_psbt * a,size_t a_index,const struct wally_psbt * b,size_t b_index)132 static bool input_identical(const struct wally_psbt *a,
133 size_t a_index,
134 const struct wally_psbt *b,
135 size_t b_index)
136 {
137 const u8 *a_in = linearize_input(tmpctx,
138 &a->inputs[a_index],
139 &a->tx->inputs[a_index]);
140 const u8 *b_in = linearize_input(tmpctx,
141 &b->inputs[b_index],
142 &b->tx->inputs[b_index]);
143
144 return memeq(a_in, tal_bytelen(a_in),
145 b_in, tal_bytelen(b_in));
146 }
147
output_identical(const struct wally_psbt * a,size_t a_index,const struct wally_psbt * b,size_t b_index)148 static bool output_identical(const struct wally_psbt *a,
149 size_t a_index,
150 const struct wally_psbt *b,
151 size_t b_index)
152 {
153 const u8 *a_out = linearize_output(tmpctx,
154 &a->outputs[a_index],
155 &a->tx->outputs[a_index]);
156 const u8 *b_out = linearize_output(tmpctx,
157 &b->outputs[b_index],
158 &b->tx->outputs[b_index]);
159 return memeq(a_out, tal_bytelen(a_out),
160 b_out, tal_bytelen(b_out));
161 }
162
sort_inputs(struct wally_psbt * psbt)163 static void sort_inputs(struct wally_psbt *psbt)
164 {
165 /* Build an input map */
166 struct input_set *set = tal_arr(NULL,
167 struct input_set,
168 psbt->num_inputs);
169
170 for (size_t i = 0; i < tal_count(set); i++) {
171 set[i].tx_input = psbt->tx->inputs[i];
172 set[i].input = psbt->inputs[i];
173 }
174
175 asort(set, tal_count(set),
176 compare_inputs_at, NULL);
177
178 /* Put PSBT parts into place */
179 for (size_t i = 0; i < tal_count(set); i++) {
180 psbt->inputs[i] = set[i].input;
181 psbt->tx->inputs[i] = set[i].tx_input;
182 }
183
184 tal_free(set);
185 }
186
sort_outputs(struct wally_psbt * psbt)187 static void sort_outputs(struct wally_psbt *psbt)
188 {
189 /* Build an output map */
190 struct output_set *set = tal_arr(NULL,
191 struct output_set,
192 psbt->num_outputs);
193 for (size_t i = 0; i < tal_count(set); i++) {
194 set[i].tx_output = psbt->tx->outputs[i];
195 set[i].output = psbt->outputs[i];
196 }
197
198 asort(set, tal_count(set),
199 compare_outputs_at, NULL);
200
201 /* Put PSBT parts into place */
202 for (size_t i = 0; i < tal_count(set); i++) {
203 psbt->outputs[i] = set[i].output;
204 psbt->tx->outputs[i] = set[i].tx_output;
205 }
206
207 tal_free(set);
208 }
209
psbt_sort_by_serial_id(struct wally_psbt * psbt)210 void psbt_sort_by_serial_id(struct wally_psbt *psbt)
211 {
212 sort_inputs(psbt);
213 sort_outputs(psbt);
214 }
215
216 #define ADD(type, add_to, from, index) \
217 do { \
218 struct type##_set a; \
219 a.type = from->type##s[index]; \
220 a.tx_##type = from->tx->type##s[index]; \
221 a.idx = index; \
222 tal_arr_expand(&add_to, a); \
223 } while (0)
224
new_changeset(const tal_t * ctx)225 static struct psbt_changeset *new_changeset(const tal_t *ctx)
226 {
227 struct psbt_changeset *set = tal(ctx, struct psbt_changeset);
228
229 set->added_ins = tal_arr(set, struct input_set, 0);
230 set->rm_ins = tal_arr(set, struct input_set, 0);
231 set->added_outs = tal_arr(set, struct output_set, 0);
232 set->rm_outs = tal_arr(set, struct output_set, 0);
233
234 return set;
235 }
236
237 /* this requires having a serial_id entry on everything */
238 /* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */
psbt_get_changeset(const tal_t * ctx,struct wally_psbt * orig,struct wally_psbt * new)239 struct psbt_changeset *psbt_get_changeset(const tal_t *ctx,
240 struct wally_psbt *orig,
241 struct wally_psbt *new)
242 {
243 int result;
244 size_t i = 0, j = 0;
245 struct psbt_changeset *set;
246
247 psbt_sort_by_serial_id(orig);
248 psbt_sort_by_serial_id(new);
249
250 set = new_changeset(ctx);
251
252 /* Find the input diff */
253 while (i < orig->num_inputs || j < new->num_inputs) {
254 if (i >= orig->num_inputs) {
255 ADD(input, set->added_ins, new, j);
256 j++;
257 continue;
258 }
259 if (j >= new->num_inputs) {
260 ADD(input, set->rm_ins, orig, i);
261 i++;
262 continue;
263 }
264
265 result = compare_serials(&orig->inputs[i].unknowns,
266 &new->inputs[j].unknowns);
267 if (result == -1) {
268 ADD(input, set->rm_ins, orig, i);
269 i++;
270 continue;
271 }
272 if (result == 1) {
273 ADD(input, set->added_ins, new, j);
274 j++;
275 continue;
276 }
277
278 if (!input_identical(orig, i, new, j)) {
279 ADD(input, set->rm_ins, orig, i);
280 ADD(input, set->added_ins, new, j);
281 }
282 i++;
283 j++;
284 }
285 /* Find the output diff */
286 i = 0;
287 j = 0;
288 while (i < orig->num_outputs || j < new->num_outputs) {
289 if (i >= orig->num_outputs) {
290 ADD(output, set->added_outs, new, j);
291 j++;
292 continue;
293 }
294 if (j >= new->num_outputs) {
295 ADD(output, set->rm_outs, orig, i);
296 i++;
297 continue;
298 }
299
300 result = compare_serials(&orig->outputs[i].unknowns,
301 &new->outputs[j].unknowns);
302 if (result == -1) {
303 ADD(output, set->rm_outs, orig, i);
304 i++;
305 continue;
306 }
307 if (result == 1) {
308 ADD(output, set->added_outs, new, j);
309 j++;
310 continue;
311 }
312 if (!output_identical(orig, i, new, j)) {
313 ADD(output, set->rm_outs, orig, i);
314 ADD(output, set->added_outs, new, j);
315 }
316 i++;
317 j++;
318 }
319
320 return set;
321 }
322
323
psbt_input_set_serial_id(const tal_t * ctx,struct wally_psbt_input * input,u64 serial_id)324 void psbt_input_set_serial_id(const tal_t *ctx,
325 struct wally_psbt_input *input,
326 u64 serial_id)
327 {
328 u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL);
329 beint64_t bev = cpu_to_be64(serial_id);
330
331 psbt_input_set_unknown(ctx, input, key, &bev, sizeof(bev));
332 }
333
334
psbt_output_set_serial_id(const tal_t * ctx,struct wally_psbt_output * output,u64 serial_id)335 void psbt_output_set_serial_id(const tal_t *ctx,
336 struct wally_psbt_output *output,
337 u64 serial_id)
338 {
339 u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL);
340 beint64_t bev = cpu_to_be64(serial_id);
341 psbt_output_set_unknown(ctx, output, key, &bev, sizeof(bev));
342 }
343
psbt_find_serial_input(struct wally_psbt * psbt,u64 serial_id)344 int psbt_find_serial_input(struct wally_psbt *psbt, u64 serial_id)
345 {
346 for (size_t i = 0; i < psbt->num_inputs; i++) {
347 u64 in_serial;
348 if (!psbt_get_serial_id(&psbt->inputs[i].unknowns, &in_serial))
349 continue;
350 if (in_serial == serial_id)
351 return i;
352 }
353 return -1;
354 }
355
psbt_find_serial_output(struct wally_psbt * psbt,u64 serial_id)356 int psbt_find_serial_output(struct wally_psbt *psbt, u64 serial_id)
357 {
358 for (size_t i = 0; i < psbt->num_outputs; i++) {
359 u64 out_serial;
360 if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &out_serial))
361 continue;
362 if (out_serial == serial_id)
363 return i;
364 }
365 return -1;
366 }
367
get_random_serial(enum tx_role role)368 static u64 get_random_serial(enum tx_role role)
369 {
370 return pseudorand_u64() << 1 | role;
371 }
372
psbt_new_input_serial(struct wally_psbt * psbt,enum tx_role role)373 u64 psbt_new_input_serial(struct wally_psbt *psbt, enum tx_role role)
374 {
375 u64 serial_id;
376
377 while ((serial_id = get_random_serial(role)) &&
378 psbt_find_serial_input(psbt, serial_id) != -1) {
379 /* keep going; */
380 }
381
382 return serial_id;
383 }
384
psbt_new_output_serial(struct wally_psbt * psbt,enum tx_role role)385 u64 psbt_new_output_serial(struct wally_psbt *psbt, enum tx_role role)
386 {
387 u64 serial_id;
388
389 while ((serial_id = get_random_serial(role)) &&
390 psbt_find_serial_output(psbt, serial_id) != -1) {
391 /* keep going; */
392 }
393
394 return serial_id;
395 }
396
psbt_has_required_fields(struct wally_psbt * psbt)397 bool psbt_has_required_fields(struct wally_psbt *psbt)
398 {
399 u64 serial_id;
400 for (size_t i = 0; i < psbt->num_inputs; i++) {
401 struct wally_psbt_input *input = &psbt->inputs[i];
402
403 if (!psbt_get_serial_id(&input->unknowns, &serial_id))
404 return false;
405
406 /* Required because we send the full tx over the wire now */
407 if (!input->utxo)
408 return false;
409
410 /* If is P2SH, redeemscript must be present */
411 assert(psbt->tx->inputs[i].index < input->utxo->num_outputs);
412 const u8 *outscript =
413 wally_tx_output_get_script(tmpctx,
414 &input->utxo->outputs[psbt->tx->inputs[i].index]);
415 if (is_p2sh(outscript, NULL) && input->redeem_script_len == 0)
416 return false;
417
418 }
419
420 for (size_t i = 0; i < psbt->num_outputs; i++) {
421 if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id))
422 return false;
423 }
424
425 return true;
426 }
427
psbt_side_finalized(const struct wally_psbt * psbt,enum tx_role role)428 bool psbt_side_finalized(const struct wally_psbt *psbt, enum tx_role role)
429 {
430 u64 serial_id;
431 for (size_t i = 0; i < psbt->num_inputs; i++) {
432 if (!psbt_get_serial_id(&psbt->inputs[i].unknowns,
433 &serial_id)) {
434 return false;
435 }
436 if (serial_id % 2 == role) {
437 if (!psbt->inputs[i].final_witness ||
438 psbt->inputs[i].final_witness->num_items == 0)
439 return false;
440 }
441 }
442 return true;
443 }
444
445 /* Adds serials to inputs + outputs that don't have one yet */
psbt_add_serials(struct wally_psbt * psbt,enum tx_role role)446 void psbt_add_serials(struct wally_psbt *psbt, enum tx_role role)
447 {
448 u64 serial_id;
449 for (size_t i = 0; i < psbt->num_inputs; i++) {
450 /* Skip ones that already have a serial id */
451 if (psbt_get_serial_id(&psbt->inputs[i].unknowns, &serial_id))
452 continue;
453
454 serial_id = psbt_new_input_serial(psbt, role);
455 psbt_input_set_serial_id(psbt, &psbt->inputs[i], serial_id);
456 }
457 for (size_t i = 0; i < psbt->num_outputs; i++) {
458 /* Skip ones that already have a serial id */
459 if (psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id))
460 continue;
461
462 serial_id = psbt_new_output_serial(psbt, role);
463 psbt_output_set_serial_id(psbt, &psbt->outputs[i], serial_id);
464 }
465 }
466
psbt_input_mark_ours(const tal_t * ctx,struct wally_psbt_input * input)467 void psbt_input_mark_ours(const tal_t *ctx,
468 struct wally_psbt_input *input)
469 {
470 u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_INPUT_MARKER, NULL);
471 beint16_t bev = cpu_to_be16(1);
472
473 psbt_input_set_unknown(ctx, input, key, &bev, sizeof(bev));
474 }
475
psbt_input_is_ours(const struct wally_psbt_input * input)476 bool psbt_input_is_ours(const struct wally_psbt_input *input)
477 {
478 size_t unused;
479 void *result = psbt_get_lightning(&input->unknowns,
480 PSBT_TYPE_INPUT_MARKER, &unused);
481 return !(!result);
482 }
483
psbt_has_our_input(const struct wally_psbt * psbt)484 bool psbt_has_our_input(const struct wally_psbt *psbt)
485 {
486 for (size_t i = 0; i < psbt->num_inputs; i++) {
487 if (psbt_input_is_ours(&psbt->inputs[i]))
488 return true;
489 }
490
491 return false;
492 }
493
psbt_contribs_changed(struct wally_psbt * orig,struct wally_psbt * new)494 bool psbt_contribs_changed(struct wally_psbt *orig,
495 struct wally_psbt *new)
496 {
497 struct psbt_changeset *cs;
498 bool ok;
499 cs = psbt_get_changeset(NULL, orig, new);
500
501 ok = tal_count(cs->added_ins) > 0 ||
502 tal_count(cs->rm_ins) > 0 ||
503 tal_count(cs->added_outs) > 0 ||
504 tal_count(cs->rm_outs) > 0;
505
506 tal_free(cs);
507 return ok;
508 }
509