1 #include "config.h"
2 #include <bitcoin/psbt.h>
3 #include <bitcoin/script.h>
4 #include <ccan/asort/asort.h>
5 #include <ccan/ccan/endian/endian.h>
6 #include <ccan/ccan/mem/mem.h>
7 #include <common/psbt_open.h>
8 #include <common/pseudorand.h>
9 #include <common/utils.h>
10 
psbt_get_serial_id(const struct wally_map * map,u64 * serial_id)11 bool psbt_get_serial_id(const struct wally_map *map, u64 *serial_id)
12 {
13 	size_t value_len;
14 	beint64_t bev;
15 	void *result = psbt_get_lightning(map, PSBT_TYPE_SERIAL_ID, &value_len);
16 	if (!result)
17 		return false;
18 
19 	if (value_len != sizeof(bev))
20 		return false;
21 
22 	memcpy(&bev, result, value_len);
23 	*serial_id = be64_to_cpu(bev);
24 	return true;
25 }
26 
compare_serials(const struct wally_map * map_a,const struct wally_map * map_b)27 static int compare_serials(const struct wally_map *map_a,
28 			   const struct wally_map *map_b)
29 {
30 	u64 serial_left, serial_right;
31 	bool ok;
32 
33 	ok = psbt_get_serial_id(map_a, &serial_left);
34 	assert(ok);
35 	ok = psbt_get_serial_id(map_b, &serial_right);
36 	assert(ok);
37 	if (serial_left > serial_right)
38 		return 1;
39 	if (serial_left < serial_right)
40 		return -1;
41 	return 0;
42 }
43 
compare_inputs_at(const struct input_set * a,const struct input_set * b,void * unused UNUSED)44 static int compare_inputs_at(const struct input_set *a,
45 			     const struct input_set *b,
46 			     void *unused UNUSED)
47 {
48 	return compare_serials(&a->input.unknowns,
49 			       &b->input.unknowns);
50 }
51 
compare_outputs_at(const struct output_set * a,const struct output_set * b,void * unused UNUSED)52 static int compare_outputs_at(const struct output_set *a,
53 			      const struct output_set *b,
54 			      void *unused UNUSED)
55 {
56 	return compare_serials(&a->output.unknowns,
57 			       &b->output.unknowns);
58 }
59 
linearize_input(const tal_t * ctx,const struct wally_psbt_input * in,const struct wally_tx_input * tx_in)60 static const u8 *linearize_input(const tal_t *ctx,
61 				 const struct wally_psbt_input *in,
62 				 const struct wally_tx_input *tx_in)
63 {
64 	struct wally_psbt *psbt = create_psbt(NULL, 1, 0, 0);
65 	size_t byte_len;
66 
67 	tal_wally_start();
68 	if (wally_tx_add_input(psbt->tx, tx_in) != WALLY_OK)
69 		abort();
70 	tal_wally_end(psbt->tx);
71 
72 	psbt->inputs[0] = *in;
73 	psbt->num_inputs++;
74 
75 
76 	/* Sort the inputs, so serializing them is ok */
77 	wally_map_sort(&psbt->inputs[0].unknowns, 0);
78 
79 	/* signatures, keypaths, etc - we dont care if they change */
80 	psbt->inputs[0].final_witness = NULL;
81 	psbt->inputs[0].final_scriptsig_len = 0;
82 	psbt->inputs[0].witness_script = NULL;
83 	psbt->inputs[0].witness_script_len = 0;
84 	psbt->inputs[0].redeem_script_len = 0;
85 	psbt->inputs[0].keypaths.num_items = 0;
86 	psbt->inputs[0].signatures.num_items = 0;
87 
88 	const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len);
89 
90 	/* Hide the inputs we added, so it doesn't get freed */
91 	psbt->num_inputs--;
92 	tal_free(psbt);
93 	return bytes;
94 }
95 
linearize_output(const tal_t * ctx,const struct wally_psbt_output * out,const struct wally_tx_output * tx_out)96 static const u8 *linearize_output(const tal_t *ctx,
97 				  const struct wally_psbt_output *out,
98 				  const struct wally_tx_output *tx_out)
99 {
100 	struct wally_psbt *psbt = create_psbt(NULL, 1, 1, 0);
101 	size_t byte_len;
102 	struct bitcoin_outpoint outpoint;
103 
104 	/* Add a 'fake' input so this will linearize the tx */
105 	memset(&outpoint, 0, sizeof(outpoint));
106 	psbt_append_input(psbt, &outpoint, 0, NULL, NULL, NULL);
107 
108 	tal_wally_start();
109 	if (wally_tx_add_output(psbt->tx, tx_out) != WALLY_OK)
110 		abort();
111 	tal_wally_end(psbt->tx);
112 
113 	psbt->outputs[0] = *out;
114 	psbt->num_outputs++;
115 	/* Sort the outputs, so serializing them is ok */
116 	wally_map_sort(&psbt->outputs[0].unknowns, 0);
117 
118 	/* We don't care if the keypaths change */
119 	psbt->outputs[0].keypaths.num_items = 0;
120 	/* And you can add scripts, no problem */
121 	psbt->outputs[0].witness_script_len = 0;
122 	psbt->outputs[0].redeem_script_len = 0;
123 
124 	const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len);
125 
126 	/* Hide the outputs we added, so it doesn't get freed */
127 	psbt->num_outputs--;
128 	tal_free(psbt);
129 	return bytes;
130 }
131 
input_identical(const struct wally_psbt * a,size_t a_index,const struct wally_psbt * b,size_t b_index)132 static bool input_identical(const struct wally_psbt *a,
133 			    size_t a_index,
134 			    const struct wally_psbt *b,
135 			    size_t b_index)
136 {
137 	const u8 *a_in = linearize_input(tmpctx,
138 					 &a->inputs[a_index],
139 					 &a->tx->inputs[a_index]);
140 	const u8 *b_in = linearize_input(tmpctx,
141 					 &b->inputs[b_index],
142 					 &b->tx->inputs[b_index]);
143 
144 	return memeq(a_in, tal_bytelen(a_in),
145 		     b_in, tal_bytelen(b_in));
146 }
147 
output_identical(const struct wally_psbt * a,size_t a_index,const struct wally_psbt * b,size_t b_index)148 static bool output_identical(const struct wally_psbt *a,
149 			     size_t a_index,
150 			     const struct wally_psbt *b,
151 			     size_t b_index)
152 {
153 	const u8 *a_out = linearize_output(tmpctx,
154 					   &a->outputs[a_index],
155 					   &a->tx->outputs[a_index]);
156 	const u8 *b_out = linearize_output(tmpctx,
157 					   &b->outputs[b_index],
158 					   &b->tx->outputs[b_index]);
159 	return memeq(a_out, tal_bytelen(a_out),
160 		     b_out, tal_bytelen(b_out));
161 }
162 
sort_inputs(struct wally_psbt * psbt)163 static void sort_inputs(struct wally_psbt *psbt)
164 {
165 	/* Build an input map */
166 	struct input_set *set = tal_arr(NULL,
167 					struct input_set,
168 					psbt->num_inputs);
169 
170 	for (size_t i = 0; i < tal_count(set); i++) {
171 		set[i].tx_input = psbt->tx->inputs[i];
172 		set[i].input = psbt->inputs[i];
173 	}
174 
175 	asort(set, tal_count(set),
176 	      compare_inputs_at, NULL);
177 
178 	/* Put PSBT parts into place */
179 	for (size_t i = 0; i < tal_count(set); i++) {
180 		psbt->inputs[i] = set[i].input;
181 		psbt->tx->inputs[i] = set[i].tx_input;
182 	}
183 
184 	tal_free(set);
185 }
186 
sort_outputs(struct wally_psbt * psbt)187 static void sort_outputs(struct wally_psbt *psbt)
188 {
189 	/* Build an output map */
190 	struct output_set *set = tal_arr(NULL,
191 					 struct output_set,
192 					 psbt->num_outputs);
193 	for (size_t i = 0; i < tal_count(set); i++) {
194 		set[i].tx_output = psbt->tx->outputs[i];
195 		set[i].output = psbt->outputs[i];
196 	}
197 
198 	asort(set, tal_count(set),
199 	      compare_outputs_at, NULL);
200 
201 	/* Put PSBT parts into place */
202 	for (size_t i = 0; i < tal_count(set); i++) {
203 		psbt->outputs[i] = set[i].output;
204 		psbt->tx->outputs[i] = set[i].tx_output;
205 	}
206 
207 	tal_free(set);
208 }
209 
psbt_sort_by_serial_id(struct wally_psbt * psbt)210 void psbt_sort_by_serial_id(struct wally_psbt *psbt)
211 {
212 	sort_inputs(psbt);
213 	sort_outputs(psbt);
214 }
215 
216 #define ADD(type, add_to, from, index)				\
217 	do {							\
218 		struct type##_set a;				\
219 		a.type = from->type##s[index];			\
220 		a.tx_##type = from->tx->type##s[index]; 	\
221 		a.idx = index;					\
222 		tal_arr_expand(&add_to, a);			\
223 	} while (0)
224 
new_changeset(const tal_t * ctx)225 static struct psbt_changeset *new_changeset(const tal_t *ctx)
226 {
227 	struct psbt_changeset *set = tal(ctx, struct psbt_changeset);
228 
229 	set->added_ins = tal_arr(set, struct input_set, 0);
230 	set->rm_ins = tal_arr(set, struct input_set, 0);
231 	set->added_outs = tal_arr(set, struct output_set, 0);
232 	set->rm_outs = tal_arr(set, struct output_set, 0);
233 
234 	return set;
235 }
236 
237 /* this requires having a serial_id entry on everything */
238 /* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */
psbt_get_changeset(const tal_t * ctx,struct wally_psbt * orig,struct wally_psbt * new)239 struct psbt_changeset *psbt_get_changeset(const tal_t *ctx,
240 					  struct wally_psbt *orig,
241 					  struct wally_psbt *new)
242 {
243 	int result;
244 	size_t i = 0, j = 0;
245 	struct psbt_changeset *set;
246 
247 	psbt_sort_by_serial_id(orig);
248 	psbt_sort_by_serial_id(new);
249 
250 	set = new_changeset(ctx);
251 
252 	/* Find the input diff */
253 	while (i < orig->num_inputs || j < new->num_inputs) {
254 		if (i >= orig->num_inputs) {
255 			ADD(input, set->added_ins, new, j);
256 			j++;
257 			continue;
258 		}
259 		if (j >= new->num_inputs) {
260 			ADD(input, set->rm_ins, orig, i);
261 			i++;
262 			continue;
263 		}
264 
265 		result = compare_serials(&orig->inputs[i].unknowns,
266 					 &new->inputs[j].unknowns);
267 		if (result == -1) {
268 			ADD(input, set->rm_ins, orig, i);
269 			i++;
270 			continue;
271 		}
272 		if (result == 1) {
273 			ADD(input, set->added_ins, new, j);
274 			j++;
275 			continue;
276 		}
277 
278 		if (!input_identical(orig, i, new, j)) {
279 			ADD(input, set->rm_ins, orig, i);
280 			ADD(input, set->added_ins, new, j);
281 		}
282 		i++;
283 		j++;
284 	}
285 	/* Find the output diff */
286 	i = 0;
287 	j = 0;
288 	while (i < orig->num_outputs || j < new->num_outputs) {
289 		if (i >= orig->num_outputs) {
290 			ADD(output, set->added_outs, new, j);
291 			j++;
292 			continue;
293 		}
294 		if (j >= new->num_outputs) {
295 			ADD(output, set->rm_outs, orig, i);
296 			i++;
297 			continue;
298 		}
299 
300 		result = compare_serials(&orig->outputs[i].unknowns,
301 					 &new->outputs[j].unknowns);
302 		if (result == -1) {
303 			ADD(output, set->rm_outs, orig, i);
304 			i++;
305 			continue;
306 		}
307 		if (result == 1) {
308 			ADD(output, set->added_outs, new, j);
309 			j++;
310 			continue;
311 		}
312 		if (!output_identical(orig, i, new, j)) {
313 			ADD(output, set->rm_outs, orig, i);
314 			ADD(output, set->added_outs, new, j);
315 		}
316 		i++;
317 		j++;
318 	}
319 
320 	return set;
321 }
322 
323 
psbt_input_set_serial_id(const tal_t * ctx,struct wally_psbt_input * input,u64 serial_id)324 void psbt_input_set_serial_id(const tal_t *ctx,
325 			      struct wally_psbt_input *input,
326 			      u64 serial_id)
327 {
328 	u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL);
329 	beint64_t bev = cpu_to_be64(serial_id);
330 
331 	psbt_input_set_unknown(ctx, input, key, &bev, sizeof(bev));
332 }
333 
334 
psbt_output_set_serial_id(const tal_t * ctx,struct wally_psbt_output * output,u64 serial_id)335 void psbt_output_set_serial_id(const tal_t *ctx,
336 			       struct wally_psbt_output *output,
337 			       u64 serial_id)
338 {
339 	u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL);
340 	beint64_t bev = cpu_to_be64(serial_id);
341 	psbt_output_set_unknown(ctx, output, key, &bev, sizeof(bev));
342 }
343 
psbt_find_serial_input(struct wally_psbt * psbt,u64 serial_id)344 int psbt_find_serial_input(struct wally_psbt *psbt, u64 serial_id)
345 {
346 	for (size_t i = 0; i < psbt->num_inputs; i++) {
347 		u64 in_serial;
348 		if (!psbt_get_serial_id(&psbt->inputs[i].unknowns, &in_serial))
349 			continue;
350 		if (in_serial == serial_id)
351 			return i;
352 	}
353 	return -1;
354 }
355 
psbt_find_serial_output(struct wally_psbt * psbt,u64 serial_id)356 int psbt_find_serial_output(struct wally_psbt *psbt, u64 serial_id)
357 {
358 	for (size_t i = 0; i < psbt->num_outputs; i++) {
359 		u64 out_serial;
360 		if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &out_serial))
361 			continue;
362 		if (out_serial == serial_id)
363 			return i;
364 	}
365 	return -1;
366 }
367 
get_random_serial(enum tx_role role)368 static u64 get_random_serial(enum tx_role role)
369 {
370 	return pseudorand_u64() << 1 | role;
371 }
372 
psbt_new_input_serial(struct wally_psbt * psbt,enum tx_role role)373 u64 psbt_new_input_serial(struct wally_psbt *psbt, enum tx_role role)
374 {
375 	u64 serial_id;
376 
377 	while ((serial_id = get_random_serial(role)) &&
378 		psbt_find_serial_input(psbt, serial_id) != -1) {
379 		/* keep going; */
380 	}
381 
382 	return serial_id;
383 }
384 
psbt_new_output_serial(struct wally_psbt * psbt,enum tx_role role)385 u64 psbt_new_output_serial(struct wally_psbt *psbt, enum tx_role role)
386 {
387 	u64 serial_id;
388 
389 	while ((serial_id = get_random_serial(role)) &&
390 		psbt_find_serial_output(psbt, serial_id) != -1) {
391 		/* keep going; */
392 	}
393 
394 	return serial_id;
395 }
396 
psbt_has_required_fields(struct wally_psbt * psbt)397 bool psbt_has_required_fields(struct wally_psbt *psbt)
398 {
399 	u64 serial_id;
400 	for (size_t i = 0; i < psbt->num_inputs; i++) {
401 		struct wally_psbt_input *input = &psbt->inputs[i];
402 
403 		if (!psbt_get_serial_id(&input->unknowns, &serial_id))
404 			return false;
405 
406 		/* Required because we send the full tx over the wire now */
407 		if (!input->utxo)
408 			return false;
409 
410 		/* If is P2SH, redeemscript must be present */
411 		assert(psbt->tx->inputs[i].index < input->utxo->num_outputs);
412 		const u8 *outscript =
413 			wally_tx_output_get_script(tmpctx,
414 				&input->utxo->outputs[psbt->tx->inputs[i].index]);
415 		if (is_p2sh(outscript, NULL) && input->redeem_script_len == 0)
416 			return false;
417 
418 	}
419 
420 	for (size_t i = 0; i < psbt->num_outputs; i++) {
421 		if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id))
422 			return false;
423 	}
424 
425 	return true;
426 }
427 
psbt_side_finalized(const struct wally_psbt * psbt,enum tx_role role)428 bool psbt_side_finalized(const struct wally_psbt *psbt, enum tx_role role)
429 {
430 	u64 serial_id;
431 	for (size_t i = 0; i < psbt->num_inputs; i++) {
432 		if (!psbt_get_serial_id(&psbt->inputs[i].unknowns,
433 					&serial_id)) {
434 			return false;
435 		}
436 		if (serial_id % 2 == role) {
437 			if (!psbt->inputs[i].final_witness ||
438 					psbt->inputs[i].final_witness->num_items == 0)
439 				return false;
440 		}
441 	}
442 	return true;
443 }
444 
445 /* Adds serials to inputs + outputs that don't have one yet */
psbt_add_serials(struct wally_psbt * psbt,enum tx_role role)446 void psbt_add_serials(struct wally_psbt *psbt, enum tx_role role)
447 {
448 	u64 serial_id;
449 	for (size_t i = 0; i < psbt->num_inputs; i++) {
450 		/* Skip ones that already have a serial id */
451 		if (psbt_get_serial_id(&psbt->inputs[i].unknowns, &serial_id))
452 			continue;
453 
454 		serial_id = psbt_new_input_serial(psbt, role);
455 		psbt_input_set_serial_id(psbt, &psbt->inputs[i], serial_id);
456 	}
457 	for (size_t i = 0; i < psbt->num_outputs; i++) {
458 		/* Skip ones that already have a serial id */
459 		if (psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id))
460 			continue;
461 
462 		serial_id = psbt_new_output_serial(psbt, role);
463 		psbt_output_set_serial_id(psbt, &psbt->outputs[i], serial_id);
464 	}
465 }
466 
psbt_input_mark_ours(const tal_t * ctx,struct wally_psbt_input * input)467 void psbt_input_mark_ours(const tal_t *ctx,
468 			  struct wally_psbt_input *input)
469 {
470 	u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_INPUT_MARKER, NULL);
471 	beint16_t bev = cpu_to_be16(1);
472 
473 	psbt_input_set_unknown(ctx, input, key, &bev, sizeof(bev));
474 }
475 
psbt_input_is_ours(const struct wally_psbt_input * input)476 bool psbt_input_is_ours(const struct wally_psbt_input *input)
477 {
478 	size_t unused;
479 	void *result = psbt_get_lightning(&input->unknowns,
480 					  PSBT_TYPE_INPUT_MARKER, &unused);
481 	return !(!result);
482 }
483 
psbt_has_our_input(const struct wally_psbt * psbt)484 bool psbt_has_our_input(const struct wally_psbt *psbt)
485 {
486 	for (size_t i = 0; i < psbt->num_inputs; i++) {
487 		if (psbt_input_is_ours(&psbt->inputs[i]))
488 			return true;
489 	}
490 
491 	return false;
492 }
493 
psbt_contribs_changed(struct wally_psbt * orig,struct wally_psbt * new)494 bool psbt_contribs_changed(struct wally_psbt *orig,
495 			   struct wally_psbt *new)
496 {
497 	struct psbt_changeset *cs;
498 	bool ok;
499 	cs = psbt_get_changeset(NULL, orig, new);
500 
501 	ok = tal_count(cs->added_ins) > 0 ||
502 	    tal_count(cs->rm_ins) > 0 ||
503 	    tal_count(cs->added_outs) > 0 ||
504 	    tal_count(cs->rm_outs) > 0;
505 
506 	tal_free(cs);
507 	return ok;
508 }
509