1package psbt
2
3import (
4	"bytes"
5	"reflect"
6	"testing"
7
8	"github.com/btcsuite/btcd/chaincfg/chainhash"
9	"github.com/btcsuite/btcd/wire"
10)
11
12func TestSumUtxoInputValues(t *testing.T) {
13	// Expect sum to fail for packet with non-matching txIn and PInputs.
14	tx := wire.NewMsgTx(2)
15	badPacket, err := NewFromUnsignedTx(tx)
16	if err != nil {
17		t.Fatalf("could not create packet from TX: %v", err)
18	}
19	badPacket.Inputs = append(badPacket.Inputs, PInput{})
20
21	_, err = SumUtxoInputValues(badPacket)
22	if err == nil {
23		t.Fatalf("expected sum of bad packet to fail")
24	}
25
26	// Expect sum to fail if any inputs don't have UTXO information added.
27	op := []*wire.OutPoint{{}, {}}
28	noUtxoInfoPacket, err := New(op, nil, 2, 0, []uint32{0, 0})
29	if err != nil {
30		t.Fatalf("could not create new packet: %v", err)
31	}
32
33	_, err = SumUtxoInputValues(noUtxoInfoPacket)
34	if err == nil {
35		t.Fatalf("expected sum of missing UTXO info to fail")
36	}
37
38	// Create a packet that is OK and contains both witness and non-witness
39	// UTXO information.
40	okPacket, err := New(op, nil, 2, 0, []uint32{0, 0})
41	if err != nil {
42		t.Fatalf("could not create new packet: %v", err)
43	}
44	okPacket.Inputs[0].WitnessUtxo = &wire.TxOut{Value: 1234}
45	okPacket.Inputs[1].NonWitnessUtxo = &wire.MsgTx{
46		TxOut: []*wire.TxOut{{Value: 6543}},
47	}
48
49	sum, err := SumUtxoInputValues(okPacket)
50	if err != nil {
51		t.Fatalf("could not sum input: %v", err)
52	}
53	if sum != (1234 + 6543) {
54		t.Fatalf("unexpected sum, got %d wanted %d", sum, 1234+6543)
55	}
56}
57
58func TestTxOutsEqual(t *testing.T) {
59	testCases := []struct {
60		name        string
61		out1        *wire.TxOut
62		out2        *wire.TxOut
63		expectEqual bool
64	}{{
65		name:        "both nil",
66		out1:        nil,
67		out2:        nil,
68		expectEqual: true,
69	}, {
70		name:        "one nil",
71		out1:        nil,
72		out2:        &wire.TxOut{},
73		expectEqual: false,
74	}, {
75		name:        "both empty",
76		out1:        &wire.TxOut{},
77		out2:        &wire.TxOut{},
78		expectEqual: true,
79	}, {
80		name: "one pk script set",
81		out1: &wire.TxOut{},
82		out2: &wire.TxOut{
83			PkScript: []byte("foo"),
84		},
85		expectEqual: false,
86	}, {
87		name: "both fully set",
88		out1: &wire.TxOut{
89			Value:    1234,
90			PkScript: []byte("bar"),
91		},
92		out2: &wire.TxOut{
93			Value:    1234,
94			PkScript: []byte("bar"),
95		},
96		expectEqual: true,
97	}}
98
99	for _, tc := range testCases {
100		tc := tc
101		t.Run(tc.name, func(t *testing.T) {
102			result := TxOutsEqual(tc.out1, tc.out2)
103			if result != tc.expectEqual {
104				t.Fatalf("unexpected result, got %v wanted %v",
105					result, tc.expectEqual)
106			}
107		})
108	}
109}
110
111func TestVerifyOutputsEqual(t *testing.T) {
112	testCases := []struct {
113		name      string
114		outs1     []*wire.TxOut
115		outs2     []*wire.TxOut
116		expectErr bool
117	}{{
118		name:      "both nil",
119		outs1:     nil,
120		outs2:     nil,
121		expectErr: false,
122	}, {
123		name:      "one nil",
124		outs1:     nil,
125		outs2:     []*wire.TxOut{{}},
126		expectErr: true,
127	}, {
128		name:      "both empty",
129		outs1:     []*wire.TxOut{{}},
130		outs2:     []*wire.TxOut{{}},
131		expectErr: false,
132	}, {
133		name:  "one pk script set",
134		outs1: []*wire.TxOut{{}},
135		outs2: []*wire.TxOut{{
136			PkScript: []byte("foo"),
137		}},
138		expectErr: true,
139	}, {
140		name: "both fully set",
141		outs1: []*wire.TxOut{{
142			Value:    1234,
143			PkScript: []byte("bar"),
144		}, {}},
145		outs2: []*wire.TxOut{{
146			Value:    1234,
147			PkScript: []byte("bar"),
148		}, {}},
149		expectErr: false,
150	}}
151
152	for _, tc := range testCases {
153		tc := tc
154		t.Run(tc.name, func(t *testing.T) {
155			err := VerifyOutputsEqual(tc.outs1, tc.outs2)
156			if (tc.expectErr && err == nil) ||
157				(!tc.expectErr && err != nil) {
158
159				t.Fatalf("got error '%v' but wanted it to be "+
160					"nil: %v", err, tc.expectErr)
161			}
162		})
163	}
164}
165
166func TestVerifyInputPrevOutpointsEqual(t *testing.T) {
167	testCases := []struct {
168		name      string
169		ins1      []*wire.TxIn
170		ins2      []*wire.TxIn
171		expectErr bool
172	}{{
173		name:      "both nil",
174		ins1:      nil,
175		ins2:      nil,
176		expectErr: false,
177	}, {
178		name:      "one nil",
179		ins1:      nil,
180		ins2:      []*wire.TxIn{{}},
181		expectErr: true,
182	}, {
183		name:      "both empty",
184		ins1:      []*wire.TxIn{{}},
185		ins2:      []*wire.TxIn{{}},
186		expectErr: false,
187	}, {
188		name: "one previous output set",
189		ins1: []*wire.TxIn{{}},
190		ins2: []*wire.TxIn{{
191			PreviousOutPoint: wire.OutPoint{
192				Hash:  chainhash.Hash{11, 22, 33},
193				Index: 7,
194			},
195		}},
196		expectErr: true,
197	}, {
198		name: "both fully set",
199		ins1: []*wire.TxIn{{
200			PreviousOutPoint: wire.OutPoint{
201				Hash:  chainhash.Hash{11, 22, 33},
202				Index: 7,
203			},
204		}, {}},
205		ins2: []*wire.TxIn{{
206			PreviousOutPoint: wire.OutPoint{
207				Hash:  chainhash.Hash{11, 22, 33},
208				Index: 7,
209			},
210		}, {}},
211		expectErr: false,
212	}}
213
214	for _, tc := range testCases {
215		tc := tc
216		t.Run(tc.name, func(t *testing.T) {
217			err := VerifyInputPrevOutpointsEqual(tc.ins1, tc.ins2)
218			if (tc.expectErr && err == nil) ||
219				(!tc.expectErr && err != nil) {
220
221				t.Fatalf("got error '%v' but wanted it to be "+
222					"nil: %v", err, tc.expectErr)
223			}
224		})
225	}
226}
227
228func TestVerifyInputOutputLen(t *testing.T) {
229	testCases := []struct {
230		name        string
231		packet      *Packet
232		needInputs  bool
233		needOutputs bool
234		expectErr   bool
235	}{{
236		name:      "packet nil",
237		packet:    nil,
238		expectErr: true,
239	}, {
240		name:      "wire tx nil",
241		packet:    &Packet{},
242		expectErr: true,
243	}, {
244		name: "both empty don't need outputs",
245		packet: &Packet{
246			UnsignedTx: &wire.MsgTx{},
247		},
248		expectErr: false,
249	}, {
250		name: "both empty but need outputs",
251		packet: &Packet{
252			UnsignedTx: &wire.MsgTx{},
253		},
254		needOutputs: true,
255		expectErr:   true,
256	}, {
257		name: "both empty but need inputs",
258		packet: &Packet{
259			UnsignedTx: &wire.MsgTx{},
260		},
261		needInputs: true,
262		expectErr:  true,
263	}, {
264		name: "input len mismatch",
265		packet: &Packet{
266			UnsignedTx: &wire.MsgTx{
267				TxIn: []*wire.TxIn{{}},
268			},
269		},
270		needInputs: true,
271		expectErr:  true,
272	}, {
273		name: "output len mismatch",
274		packet: &Packet{
275			UnsignedTx: &wire.MsgTx{
276				TxOut: []*wire.TxOut{{}},
277			},
278		},
279		needOutputs: true,
280		expectErr:   true,
281	}, {
282		name: "all fully set",
283		packet: &Packet{
284			UnsignedTx: &wire.MsgTx{
285				TxIn:  []*wire.TxIn{{}},
286				TxOut: []*wire.TxOut{{}},
287			},
288			Inputs:  []PInput{{}},
289			Outputs: []POutput{{}},
290		},
291		needInputs:  true,
292		needOutputs: true,
293		expectErr:   false,
294	}}
295
296	for _, tc := range testCases {
297		tc := tc
298		t.Run(tc.name, func(t *testing.T) {
299			err := VerifyInputOutputLen(
300				tc.packet, tc.needInputs, tc.needOutputs,
301			)
302			if (tc.expectErr && err == nil) ||
303				(!tc.expectErr && err != nil) {
304
305				t.Fatalf("got error '%v' but wanted it to be "+
306					"nil: %v", err, tc.expectErr)
307			}
308		})
309	}
310}
311
312func TestNewFromSignedTx(t *testing.T) {
313	orig := &wire.MsgTx{
314		TxIn: []*wire.TxIn{{
315			PreviousOutPoint: wire.OutPoint{},
316			SignatureScript:  []byte("script"),
317			Witness:          [][]byte{[]byte("witness")},
318			Sequence:         1234,
319		}},
320		TxOut: []*wire.TxOut{{
321			PkScript: []byte{77, 88},
322			Value:    99,
323		}},
324	}
325
326	packet, scripts, witnesses, err := NewFromSignedTx(orig)
327	if err != nil {
328		t.Fatalf("could not create packet from signed TX: %v", err)
329	}
330
331	tx := packet.UnsignedTx
332	expectedTxIn := []*wire.TxIn{{
333		PreviousOutPoint: wire.OutPoint{},
334		Sequence:         1234,
335	}}
336	if !reflect.DeepEqual(tx.TxIn, expectedTxIn) {
337		t.Fatalf("unexpected txin, got %#v wanted %#v",
338			tx.TxIn, expectedTxIn)
339	}
340	if !reflect.DeepEqual(tx.TxOut, orig.TxOut) {
341		t.Fatalf("unexpected txout, got %#v wanted %#v",
342			tx.TxOut, orig.TxOut)
343	}
344	if len(scripts) != 1 || !bytes.Equal(scripts[0], []byte("script")) {
345		t.Fatalf("script not extracted correctly")
346	}
347	if len(witnesses) != 1 ||
348		!bytes.Equal(witnesses[0][0], []byte("witness")) {
349
350		t.Fatalf("witness not extracted correctly")
351	}
352}
353