1package psbt
2
3import (
4	"bytes"
5	"encoding/binary"
6	"io"
7	"sort"
8
9	"github.com/btcsuite/btcd/txscript"
10	"github.com/btcsuite/btcd/wire"
11)
12
13// PInput is a struct encapsulating all the data that can be attached to any
14// specific input of the PSBT.
15type PInput struct {
16	NonWitnessUtxo     *wire.MsgTx
17	WitnessUtxo        *wire.TxOut
18	PartialSigs        []*PartialSig
19	SighashType        txscript.SigHashType
20	RedeemScript       []byte
21	WitnessScript      []byte
22	Bip32Derivation    []*Bip32Derivation
23	FinalScriptSig     []byte
24	FinalScriptWitness []byte
25	Unknowns           []*Unknown
26}
27
28// NewPsbtInput creates an instance of PsbtInput given either a nonWitnessUtxo
29// or a witnessUtxo.
30//
31// NOTE: Only one of the two arguments should be specified, with the other
32// being `nil`; otherwise the created PsbtInput object will fail IsSane()
33// checks and will not be usable.
34func NewPsbtInput(nonWitnessUtxo *wire.MsgTx,
35	witnessUtxo *wire.TxOut) *PInput {
36
37	return &PInput{
38		NonWitnessUtxo:     nonWitnessUtxo,
39		WitnessUtxo:        witnessUtxo,
40		PartialSigs:        []*PartialSig{},
41		SighashType:        0,
42		RedeemScript:       nil,
43		WitnessScript:      nil,
44		Bip32Derivation:    []*Bip32Derivation{},
45		FinalScriptSig:     nil,
46		FinalScriptWitness: nil,
47		Unknowns:           nil,
48	}
49}
50
51// IsSane returns true only if there are no conflicting values in the Psbt
52// PInput. It checks that witness and non-witness utxo entries do not both
53// exist, and that witnessScript entries are only added to witness inputs.
54func (pi *PInput) IsSane() bool {
55
56	if pi.NonWitnessUtxo != nil && pi.WitnessUtxo != nil {
57		return false
58	}
59	if pi.WitnessUtxo == nil && pi.WitnessScript != nil {
60		return false
61	}
62	if pi.WitnessUtxo == nil && pi.FinalScriptWitness != nil {
63		return false
64	}
65
66	return true
67}
68
69// deserialize attempts to deserialize a new PInput from the passed io.Reader.
70func (pi *PInput) deserialize(r io.Reader) error {
71	for {
72		keyint, keydata, err := getKey(r)
73		if err != nil {
74			return err
75		}
76		if keyint == -1 {
77			// Reached separator byte
78			break
79		}
80		value, err := wire.ReadVarBytes(
81			r, 0, MaxPsbtValueLength, "PSBT value",
82		)
83		if err != nil {
84			return err
85		}
86
87		switch InputType(keyint) {
88
89		case NonWitnessUtxoType:
90			if pi.NonWitnessUtxo != nil {
91				return ErrDuplicateKey
92			}
93			if keydata != nil {
94				return ErrInvalidKeydata
95			}
96			tx := wire.NewMsgTx(2)
97
98			err := tx.Deserialize(bytes.NewReader(value))
99			if err != nil {
100				return err
101			}
102			pi.NonWitnessUtxo = tx
103
104		case WitnessUtxoType:
105			if pi.WitnessUtxo != nil {
106				return ErrDuplicateKey
107			}
108			if keydata != nil {
109				return ErrInvalidKeydata
110			}
111			txout, err := readTxOut(value)
112			if err != nil {
113				return err
114			}
115			pi.WitnessUtxo = txout
116
117		case PartialSigType:
118			newPartialSig := PartialSig{
119				PubKey:    keydata,
120				Signature: value,
121			}
122
123			if !newPartialSig.checkValid() {
124				return ErrInvalidPsbtFormat
125			}
126
127			// Duplicate keys are not allowed
128			for _, x := range pi.PartialSigs {
129				if bytes.Equal(x.PubKey, newPartialSig.PubKey) {
130					return ErrDuplicateKey
131				}
132			}
133
134			pi.PartialSigs = append(pi.PartialSigs, &newPartialSig)
135
136		case SighashType:
137			if pi.SighashType != 0 {
138				return ErrDuplicateKey
139			}
140			if keydata != nil {
141				return ErrInvalidKeydata
142			}
143
144			// Bounds check on value here since the sighash type must be a
145			// 32-bit unsigned integer.
146			if len(value) != 4 {
147				return ErrInvalidKeydata
148			}
149
150			shtype := txscript.SigHashType(
151				binary.LittleEndian.Uint32(value),
152			)
153			pi.SighashType = shtype
154
155		case RedeemScriptInputType:
156			if pi.RedeemScript != nil {
157				return ErrDuplicateKey
158			}
159			if keydata != nil {
160				return ErrInvalidKeydata
161			}
162			pi.RedeemScript = value
163
164		case WitnessScriptInputType:
165			if pi.WitnessScript != nil {
166				return ErrDuplicateKey
167			}
168			if keydata != nil {
169				return ErrInvalidKeydata
170			}
171			pi.WitnessScript = value
172
173		case Bip32DerivationInputType:
174			if !validatePubkey(keydata) {
175				return ErrInvalidPsbtFormat
176			}
177			master, derivationPath, err := readBip32Derivation(value)
178			if err != nil {
179				return err
180			}
181
182			// Duplicate keys are not allowed
183			for _, x := range pi.Bip32Derivation {
184				if bytes.Equal(x.PubKey, keydata) {
185					return ErrDuplicateKey
186				}
187			}
188
189			pi.Bip32Derivation = append(
190				pi.Bip32Derivation,
191				&Bip32Derivation{
192					PubKey:               keydata,
193					MasterKeyFingerprint: master,
194					Bip32Path:            derivationPath,
195				},
196			)
197
198		case FinalScriptSigType:
199			if pi.FinalScriptSig != nil {
200				return ErrDuplicateKey
201			}
202			if keydata != nil {
203				return ErrInvalidKeydata
204			}
205
206			pi.FinalScriptSig = value
207
208		case FinalScriptWitnessType:
209			if pi.FinalScriptWitness != nil {
210				return ErrDuplicateKey
211			}
212			if keydata != nil {
213				return ErrInvalidKeydata
214			}
215
216			pi.FinalScriptWitness = value
217
218		default:
219			// A fall through case for any proprietary types.
220			keyintanddata := []byte{byte(keyint)}
221			keyintanddata = append(keyintanddata, keydata...)
222			newUnknown := &Unknown{
223				Key:   keyintanddata,
224				Value: value,
225			}
226
227			// Duplicate key+keydata are not allowed
228			for _, x := range pi.Unknowns {
229				if bytes.Equal(x.Key, newUnknown.Key) &&
230					bytes.Equal(x.Value, newUnknown.Value) {
231					return ErrDuplicateKey
232				}
233			}
234
235			pi.Unknowns = append(pi.Unknowns, newUnknown)
236		}
237	}
238
239	return nil
240}
241
242// serialize attempts to serialize the target PInput into the passed io.Writer.
243func (pi *PInput) serialize(w io.Writer) error {
244
245	if !pi.IsSane() {
246		return ErrInvalidPsbtFormat
247	}
248
249	if pi.NonWitnessUtxo != nil {
250		var buf bytes.Buffer
251		err := pi.NonWitnessUtxo.Serialize(&buf)
252		if err != nil {
253			return err
254		}
255
256		err = serializeKVPairWithType(
257			w, uint8(NonWitnessUtxoType), nil, buf.Bytes(),
258		)
259		if err != nil {
260			return err
261		}
262	}
263	if pi.WitnessUtxo != nil {
264		var buf bytes.Buffer
265		err := wire.WriteTxOut(&buf, 0, 0, pi.WitnessUtxo)
266		if err != nil {
267			return err
268		}
269
270		err = serializeKVPairWithType(
271			w, uint8(WitnessUtxoType), nil, buf.Bytes(),
272		)
273		if err != nil {
274			return err
275		}
276	}
277
278	if pi.FinalScriptSig == nil && pi.FinalScriptWitness == nil {
279		sort.Sort(PartialSigSorter(pi.PartialSigs))
280		for _, ps := range pi.PartialSigs {
281			err := serializeKVPairWithType(
282				w, uint8(PartialSigType), ps.PubKey,
283				ps.Signature,
284			)
285			if err != nil {
286				return err
287			}
288		}
289
290		if pi.SighashType != 0 {
291			var shtBytes [4]byte
292			binary.LittleEndian.PutUint32(
293				shtBytes[:], uint32(pi.SighashType),
294			)
295
296			err := serializeKVPairWithType(
297				w, uint8(SighashType), nil, shtBytes[:],
298			)
299			if err != nil {
300				return err
301			}
302		}
303
304		if pi.RedeemScript != nil {
305			err := serializeKVPairWithType(
306				w, uint8(RedeemScriptInputType), nil,
307				pi.RedeemScript,
308			)
309			if err != nil {
310				return err
311			}
312		}
313
314		if pi.WitnessScript != nil {
315			err := serializeKVPairWithType(
316				w, uint8(WitnessScriptInputType), nil,
317				pi.WitnessScript,
318			)
319			if err != nil {
320				return err
321			}
322		}
323
324		sort.Sort(Bip32Sorter(pi.Bip32Derivation))
325		for _, kd := range pi.Bip32Derivation {
326			err := serializeKVPairWithType(
327				w,
328				uint8(Bip32DerivationInputType), kd.PubKey,
329				SerializeBIP32Derivation(
330					kd.MasterKeyFingerprint, kd.Bip32Path,
331				),
332			)
333			if err != nil {
334				return err
335			}
336		}
337	}
338
339	if pi.FinalScriptSig != nil {
340		err := serializeKVPairWithType(
341			w, uint8(FinalScriptSigType), nil, pi.FinalScriptSig,
342		)
343		if err != nil {
344			return err
345		}
346	}
347
348	if pi.FinalScriptWitness != nil {
349		err := serializeKVPairWithType(
350			w, uint8(FinalScriptWitnessType), nil, pi.FinalScriptWitness,
351		)
352		if err != nil {
353			return err
354		}
355	}
356
357	// Unknown is a special case; we don't have a key type, only a key and
358	// a value field
359	for _, kv := range pi.Unknowns {
360		err := serializeKVpair(w, kv.Key, kv.Value)
361		if err != nil {
362			return err
363		}
364	}
365
366	return nil
367}
368