1// Copyright (c) 2018 The btcsuite developers
2// Use of this source code is governed by an ISC
3// license that can be found in the LICENSE file.
4
5package psbt
6
7import (
8	"bytes"
9	"encoding/binary"
10	"errors"
11	"fmt"
12	"io"
13	"sort"
14
15	"github.com/btcsuite/btcd/txscript"
16	"github.com/btcsuite/btcd/wire"
17)
18
19// WriteTxWitness is a utility function due to non-exported witness
20// serialization (writeTxWitness encodes the bitcoin protocol encoding for a
21// transaction input's witness into w).
22func WriteTxWitness(w io.Writer, wit [][]byte) error {
23	if err := wire.WriteVarInt(w, 0, uint64(len(wit))); err != nil {
24		return err
25	}
26
27	for _, item := range wit {
28		err := wire.WriteVarBytes(w, 0, item)
29		if err != nil {
30			return err
31		}
32	}
33	return nil
34}
35
36// writePKHWitness writes a witness for a p2wkh spending input
37func writePKHWitness(sig []byte, pub []byte) ([]byte, error) {
38	var (
39		buf          bytes.Buffer
40		witnessItems = [][]byte{sig, pub}
41	)
42
43	if err := WriteTxWitness(&buf, witnessItems); err != nil {
44		return nil, err
45	}
46
47	return buf.Bytes(), nil
48}
49
50// checkIsMultisigScript is a utility function to check whether a given
51// redeemscript fits the standard multisig template used in all P2SH based
52// multisig, given a set of pubkeys for redemption.
53func checkIsMultiSigScript(pubKeys [][]byte, sigs [][]byte,
54	script []byte) bool {
55
56	// First insist that the script type is multisig.
57	if txscript.GetScriptClass(script) != txscript.MultiSigTy {
58		return false
59	}
60
61	// Inspect the script to ensure that the number of sigs and pubkeys is
62	// correct
63	_, numSigs, err := txscript.CalcMultiSigStats(script)
64	if err != nil {
65		return false
66	}
67
68	// If the number of sigs provided, doesn't match the number of required
69	// pubkeys, then we can't proceed as we're not yet final.
70	if numSigs != len(pubKeys) || numSigs != len(sigs) {
71		return false
72	}
73
74	return true
75}
76
77// extractKeyOrderFromScript is a utility function to extract an ordered list
78// of signatures, given a serialized script (redeemscript or witness script), a
79// list of pubkeys and the signatures corresponding to those pubkeys. This
80// function is used to ensure that the signatures will be embedded in the final
81// scriptSig or scriptWitness in the correct order.
82func extractKeyOrderFromScript(script []byte, expectedPubkeys [][]byte,
83	sigs [][]byte) ([][]byte, error) {
84
85	// If this isn't a proper finalized multi-sig script, then we can't
86	// proceed.
87	if !checkIsMultiSigScript(expectedPubkeys, sigs, script) {
88		return nil, ErrUnsupportedScriptType
89	}
90
91	// Arrange the pubkeys and sigs into a slice of format:
92	//   * [[pub,sig], [pub,sig],..]
93	type sigWithPub struct {
94		pubKey []byte
95		sig    []byte
96	}
97	var pubsSigs []sigWithPub
98	for i, pub := range expectedPubkeys {
99		pubsSigs = append(pubsSigs, sigWithPub{
100			pubKey: pub,
101			sig:    sigs[i],
102		})
103	}
104
105	// Now that we have the set of (pubkey, sig) pairs, we'll construct a
106	// position map that we can use to swap the order in the slice above to
107	// match how things are laid out in the script.
108	type positionEntry struct {
109		index int
110		value sigWithPub
111	}
112	var positionMap []positionEntry
113
114	// For each pubkey in our pubsSigs slice, we'll now construct a proper
115	// positionMap entry, based on _where_ in the script the pubkey first
116	// appears.
117	for _, p := range pubsSigs {
118		pos := bytes.Index(script, p.pubKey)
119		if pos < 0 {
120			return nil, errors.New("script does not contain pubkeys")
121		}
122
123		positionMap = append(positionMap, positionEntry{
124			index: pos,
125			value: p,
126		})
127	}
128
129	// Now that we have the position map full populated, we'll use the
130	// index data to properly sort the entries in the map based on where
131	// they appear in the script.
132	sort.Slice(positionMap, func(i, j int) bool {
133		return positionMap[i].index < positionMap[j].index
134	})
135
136	// Finally, we can simply iterate through the position map in order to
137	// extract the proper signature ordering.
138	sortedSigs := make([][]byte, 0, len(positionMap))
139	for _, x := range positionMap {
140		sortedSigs = append(sortedSigs, x.value.sig)
141	}
142
143	return sortedSigs, nil
144}
145
146// getMultisigScriptWitness creates a full psbt serialized Witness field for
147// the transaction, given the public keys and signatures to be appended. This
148// function will only accept witnessScripts of the type M of N multisig. This
149// is used for both p2wsh and nested p2wsh multisig cases.
150func getMultisigScriptWitness(witnessScript []byte, pubKeys [][]byte,
151	sigs [][]byte) ([]byte, error) {
152
153	// First using the script as a guide, we'll properly order the sigs
154	// according to how their corresponding pubkeys appear in the
155	// witnessScript.
156	orderedSigs, err := extractKeyOrderFromScript(
157		witnessScript, pubKeys, sigs,
158	)
159	if err != nil {
160		return nil, err
161	}
162
163	// Now that we know the proper order, we'll append each of the
164	// signatures into a new witness stack, then top it off with the
165	// witness script at the end, prepending the nil as we need the extra
166	// pop..
167	witnessElements := make(wire.TxWitness, 0, len(sigs)+2)
168	witnessElements = append(witnessElements, nil)
169	for _, os := range orderedSigs {
170		witnessElements = append(witnessElements, os)
171	}
172	witnessElements = append(witnessElements, witnessScript)
173
174	// Now that we have the full witness stack, we'll serialize it in the
175	// expected format, and return the final bytes.
176	var buf bytes.Buffer
177	if err = WriteTxWitness(&buf, witnessElements); err != nil {
178		return nil, err
179	}
180	return buf.Bytes(), nil
181}
182
183// checkSigHashFlags compares the sighash flag byte on a signature with the
184// value expected according to any PsbtInSighashType field in this section of
185// the PSBT, and returns true if they match, false otherwise.
186// If no SighashType field exists, it is assumed to be SIGHASH_ALL.
187//
188// TODO(waxwing): sighash type not restricted to one byte in future?
189func checkSigHashFlags(sig []byte, input *PInput) bool {
190	expectedSighashType := txscript.SigHashAll
191	if input.SighashType != 0 {
192		expectedSighashType = input.SighashType
193	}
194
195	return expectedSighashType == txscript.SigHashType(sig[len(sig)-1])
196}
197
198// serializeKVpair writes out a kv pair using a varbyte prefix for each.
199func serializeKVpair(w io.Writer, key []byte, value []byte) error {
200	if err := wire.WriteVarBytes(w, 0, key); err != nil {
201		return err
202	}
203
204	return wire.WriteVarBytes(w, 0, value)
205}
206
207// serializeKVPairWithType writes out to the passed writer a type coupled with
208// a key.
209func serializeKVPairWithType(w io.Writer, kt uint8, keydata []byte,
210	value []byte) error {
211
212	// If the key has no data, then we write a blank slice.
213	if keydata == nil {
214		keydata = []byte{}
215	}
216
217	// The final key to be written is: {type} || {keyData}
218	serializedKey := append([]byte{kt}, keydata...)
219	return serializeKVpair(w, serializedKey, value)
220}
221
222// getKey retrieves a single key - both the key type and the keydata (if
223// present) from the stream and returns the key type as an integer, or -1 if
224// the key was of zero length. This integer is is used to indicate the presence
225// of a separator byte which indicates the end of a given key-value pair list,
226// and the keydata as a byte slice or nil if none is present.
227func getKey(r io.Reader) (int, []byte, error) {
228
229	// For the key, we read the varint separately, instead of using the
230	// available ReadVarBytes, because we have a specific treatment of 0x00
231	// here:
232	count, err := wire.ReadVarInt(r, 0)
233	if err != nil {
234		return -1, nil, ErrInvalidPsbtFormat
235	}
236	if count == 0 {
237		// A separator indicates end of key-value pair list.
238		return -1, nil, nil
239	}
240
241	// Check that we don't attempt to decode a dangerously large key.
242	if count > MaxPsbtKeyLength {
243		return -1, nil, ErrInvalidKeydata
244	}
245
246	// Next, we ready out the designated number of bytes, which may include
247	// a type, key, and optional data.
248	keyTypeAndData := make([]byte, count)
249	if _, err := io.ReadFull(r, keyTypeAndData[:]); err != nil {
250		return -1, nil, err
251	}
252
253	keyType := int(string(keyTypeAndData)[0])
254
255	// Note that the second return value will usually be empty, since most
256	// keys contain no more than the key type byte.
257	if len(keyTypeAndData) == 1 {
258		return keyType, nil, nil
259	}
260
261	// Otherwise, we return the key, along with any data that it may
262	// contain.
263	return keyType, keyTypeAndData[1:], nil
264
265}
266
267// readTxOut is a limited version of wire.ReadTxOut, because the latter is not
268// exported.
269func readTxOut(txout []byte) (*wire.TxOut, error) {
270	if len(txout) < 10 {
271		return nil, ErrInvalidPsbtFormat
272	}
273
274	valueSer := binary.LittleEndian.Uint64(txout[:8])
275	scriptPubKey := txout[9:]
276
277	return wire.NewTxOut(int64(valueSer), scriptPubKey), nil
278}
279
280// SumUtxoInputValues tries to extract the sum of all inputs specified in the
281// UTXO fields of the PSBT. An error is returned if an input is specified that
282// does not contain any UTXO information.
283func SumUtxoInputValues(packet *Packet) (int64, error) {
284	// We take the TX ins of the unsigned TX as the truth for how many
285	// inputs there should be, as the fields in the extra data part of the
286	// PSBT can be empty.
287	if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) {
288		return 0, fmt.Errorf("TX input length doesn't match PSBT " +
289			"input length")
290	}
291
292	inputSum := int64(0)
293	for idx, in := range packet.Inputs {
294		switch {
295		case in.WitnessUtxo != nil:
296			// Witness UTXOs only need to reference the TxOut.
297			inputSum += in.WitnessUtxo.Value
298
299		case in.NonWitnessUtxo != nil:
300			// Non-witness UTXOs reference to the whole transaction
301			// the UTXO resides in.
302			utxOuts := in.NonWitnessUtxo.TxOut
303			txIn := packet.UnsignedTx.TxIn[idx]
304			inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value
305
306		default:
307			return 0, fmt.Errorf("input %d has no UTXO information",
308				idx)
309		}
310	}
311	return inputSum, nil
312}
313
314// TxOutsEqual returns true if two transaction outputs are equal.
315func TxOutsEqual(out1, out2 *wire.TxOut) bool {
316	if out1 == nil || out2 == nil {
317		return out1 == out2
318	}
319	return out1.Value == out2.Value &&
320		bytes.Equal(out1.PkScript, out2.PkScript)
321}
322
323// VerifyOutputsEqual verifies that the two slices of transaction outputs are
324// deep equal to each other. We do the length check and manual loop to provide
325// better error messages to the user than just returning "not equal".
326func VerifyOutputsEqual(outs1, outs2 []*wire.TxOut) error {
327	if len(outs1) != len(outs2) {
328		return fmt.Errorf("number of outputs are different")
329	}
330	for idx, out := range outs1 {
331		// There is a byte slice in the output so we can't use the
332		// equality operator.
333		if !TxOutsEqual(out, outs2[idx]) {
334			return fmt.Errorf("output %d is different", idx)
335		}
336	}
337	return nil
338}
339
340// VerifyInputPrevOutpointsEqual verifies that the previous outpoints of the
341// two slices of transaction inputs are deep equal to each other. We do the
342// length check and manual loop to provide better error messages to the user
343// than just returning "not equal".
344func VerifyInputPrevOutpointsEqual(ins1, ins2 []*wire.TxIn) error {
345	if len(ins1) != len(ins2) {
346		return fmt.Errorf("number of inputs are different")
347	}
348	for idx, in := range ins1 {
349		if in.PreviousOutPoint != ins2[idx].PreviousOutPoint {
350			return fmt.Errorf("previous outpoint of input %d is "+
351				"different", idx)
352		}
353	}
354	return nil
355}
356
357// VerifyInputOutputLen makes sure a packet is non-nil, contains a non-nil wire
358// transaction and that the wire input/output lengths match the partial input/
359// output lengths. A caller also can specify if they expect any inputs and/or
360// outputs to be contained in the packet.
361func VerifyInputOutputLen(packet *Packet, needInputs, needOutputs bool) error {
362	if packet == nil || packet.UnsignedTx == nil {
363		return fmt.Errorf("PSBT packet cannot be nil")
364	}
365
366	if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) {
367		return fmt.Errorf("invalid PSBT, wire inputs don't match " +
368			"partial inputs")
369	}
370	if len(packet.UnsignedTx.TxOut) != len(packet.Outputs) {
371		return fmt.Errorf("invalid PSBT, wire outputs don't match " +
372			"partial outputs")
373	}
374
375	if needInputs && len(packet.UnsignedTx.TxIn) == 0 {
376		return fmt.Errorf("PSBT packet must contain at least one " +
377			"input")
378	}
379	if needOutputs && len(packet.UnsignedTx.TxOut) == 0 {
380		return fmt.Errorf("PSBT packet must contain at least one " +
381			"output")
382	}
383
384	return nil
385}
386
387// NewFromSignedTx is a utility function to create a packet from an
388// already-signed transaction. Returned are: an unsigned transaction
389// serialization, a list of scriptSigs, one per input, and a list of witnesses,
390// one per input.
391func NewFromSignedTx(tx *wire.MsgTx) (*Packet, [][]byte,
392	[]wire.TxWitness, error) {
393
394	scriptSigs := make([][]byte, 0, len(tx.TxIn))
395	witnesses := make([]wire.TxWitness, 0, len(tx.TxIn))
396	tx2 := tx.Copy()
397
398	// Blank out signature info in inputs
399	for i, tin := range tx2.TxIn {
400		tin.SignatureScript = nil
401		scriptSigs = append(scriptSigs, tx.TxIn[i].SignatureScript)
402		tin.Witness = nil
403		witnesses = append(witnesses, tx.TxIn[i].Witness)
404	}
405
406	// Outputs always contain: (value, scriptPubkey) so don't need
407	// amending.  Now tx2 is tx with all signing data stripped out
408	unsignedPsbt, err := NewFromUnsignedTx(tx2)
409	if err != nil {
410		return nil, nil, nil, err
411	}
412	return unsignedPsbt, scriptSigs, witnesses, nil
413}
414