1// Copyright (c) 2013-2016 The btcsuite developers
2// Use of this source code is governed by an ISC
3// license that can be found in the LICENSE file.
4
5package wire
6
7import (
8	"bytes"
9	"encoding/binary"
10	"io"
11	"net"
12	"reflect"
13	"testing"
14	"time"
15
16	"github.com/btcsuite/btcd/chaincfg/chainhash"
17	"github.com/davecgh/go-spew/spew"
18)
19
20// makeHeader is a convenience function to make a message header in the form of
21// a byte slice.  It is used to force errors when reading messages.
22func makeHeader(btcnet BitcoinNet, command string,
23	payloadLen uint32, checksum uint32) []byte {
24
25	// The length of a bitcoin message header is 24 bytes.
26	// 4 byte magic number of the bitcoin network + 12 byte command + 4 byte
27	// payload length + 4 byte checksum.
28	buf := make([]byte, 24)
29	binary.LittleEndian.PutUint32(buf, uint32(btcnet))
30	copy(buf[4:], []byte(command))
31	binary.LittleEndian.PutUint32(buf[16:], payloadLen)
32	binary.LittleEndian.PutUint32(buf[20:], checksum)
33	return buf
34}
35
36// TestMessage tests the Read/WriteMessage and Read/WriteMessageN API.
37func TestMessage(t *testing.T) {
38	pver := ProtocolVersion
39
40	// Create the various types of messages to test.
41
42	// MsgVersion.
43	addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
44	you := NewNetAddress(addrYou, SFNodeNetwork)
45	you.Timestamp = time.Time{} // Version message has zero value timestamp.
46	addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
47	me := NewNetAddress(addrMe, SFNodeNetwork)
48	me.Timestamp = time.Time{} // Version message has zero value timestamp.
49	msgVersion := NewMsgVersion(me, you, 123123, 0)
50
51	msgVerack := NewMsgVerAck()
52	msgGetAddr := NewMsgGetAddr()
53	msgAddr := NewMsgAddr()
54	msgGetBlocks := NewMsgGetBlocks(&chainhash.Hash{})
55	msgBlock := &blockOne
56	msgInv := NewMsgInv()
57	msgGetData := NewMsgGetData()
58	msgNotFound := NewMsgNotFound()
59	msgTx := NewMsgTx(1)
60	msgPing := NewMsgPing(123123)
61	msgPong := NewMsgPong(123123)
62	msgGetHeaders := NewMsgGetHeaders()
63	msgHeaders := NewMsgHeaders()
64	msgAlert := NewMsgAlert([]byte("payload"), []byte("signature"))
65	msgMemPool := NewMsgMemPool()
66	msgFilterAdd := NewMsgFilterAdd([]byte{0x01})
67	msgFilterClear := NewMsgFilterClear()
68	msgFilterLoad := NewMsgFilterLoad([]byte{0x01}, 10, 0, BloomUpdateNone)
69	bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0)
70	msgMerkleBlock := NewMsgMerkleBlock(bh)
71	msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block")
72	msgGetCFilters := NewMsgGetCFilters(GCSFilterRegular, 0, &chainhash.Hash{})
73	msgGetCFHeaders := NewMsgGetCFHeaders(GCSFilterRegular, 0, &chainhash.Hash{})
74	msgGetCFCheckpt := NewMsgGetCFCheckpt(GCSFilterRegular, &chainhash.Hash{})
75	msgCFilter := NewMsgCFilter(GCSFilterRegular, &chainhash.Hash{},
76		[]byte("payload"))
77	msgCFHeaders := NewMsgCFHeaders()
78	msgCFCheckpt := NewMsgCFCheckpt(GCSFilterRegular, &chainhash.Hash{}, 0)
79
80	tests := []struct {
81		in     Message    // Value to encode
82		out    Message    // Expected decoded value
83		pver   uint32     // Protocol version for wire encoding
84		btcnet BitcoinNet // Network to use for wire encoding
85		bytes  int        // Expected num bytes read/written
86	}{
87		{msgVersion, msgVersion, pver, MainNet, 125},
88		{msgVerack, msgVerack, pver, MainNet, 24},
89		{msgGetAddr, msgGetAddr, pver, MainNet, 24},
90		{msgAddr, msgAddr, pver, MainNet, 25},
91		{msgGetBlocks, msgGetBlocks, pver, MainNet, 61},
92		{msgBlock, msgBlock, pver, MainNet, 239},
93		{msgInv, msgInv, pver, MainNet, 25},
94		{msgGetData, msgGetData, pver, MainNet, 25},
95		{msgNotFound, msgNotFound, pver, MainNet, 25},
96		{msgTx, msgTx, pver, MainNet, 34},
97		{msgPing, msgPing, pver, MainNet, 32},
98		{msgPong, msgPong, pver, MainNet, 32},
99		{msgGetHeaders, msgGetHeaders, pver, MainNet, 61},
100		{msgHeaders, msgHeaders, pver, MainNet, 25},
101		{msgAlert, msgAlert, pver, MainNet, 42},
102		{msgMemPool, msgMemPool, pver, MainNet, 24},
103		{msgFilterAdd, msgFilterAdd, pver, MainNet, 26},
104		{msgFilterClear, msgFilterClear, pver, MainNet, 24},
105		{msgFilterLoad, msgFilterLoad, pver, MainNet, 35},
106		{msgMerkleBlock, msgMerkleBlock, pver, MainNet, 110},
107		{msgReject, msgReject, pver, MainNet, 79},
108		{msgGetCFilters, msgGetCFilters, pver, MainNet, 61},
109		{msgGetCFHeaders, msgGetCFHeaders, pver, MainNet, 61},
110		{msgGetCFCheckpt, msgGetCFCheckpt, pver, MainNet, 57},
111		{msgCFilter, msgCFilter, pver, MainNet, 65},
112		{msgCFHeaders, msgCFHeaders, pver, MainNet, 90},
113		{msgCFCheckpt, msgCFCheckpt, pver, MainNet, 58},
114	}
115
116	t.Logf("Running %d tests", len(tests))
117	for i, test := range tests {
118		// Encode to wire format.
119		var buf bytes.Buffer
120		nw, err := WriteMessageN(&buf, test.in, test.pver, test.btcnet)
121		if err != nil {
122			t.Errorf("WriteMessage #%d error %v", i, err)
123			continue
124		}
125
126		// Ensure the number of bytes written match the expected value.
127		if nw != test.bytes {
128			t.Errorf("WriteMessage #%d unexpected num bytes "+
129				"written - got %d, want %d", i, nw, test.bytes)
130		}
131
132		// Decode from wire format.
133		rbuf := bytes.NewReader(buf.Bytes())
134		nr, msg, _, err := ReadMessageN(rbuf, test.pver, test.btcnet)
135		if err != nil {
136			t.Errorf("ReadMessage #%d error %v, msg %v", i, err,
137				spew.Sdump(msg))
138			continue
139		}
140		if !reflect.DeepEqual(msg, test.out) {
141			t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
142				spew.Sdump(msg), spew.Sdump(test.out))
143			continue
144		}
145
146		// Ensure the number of bytes read match the expected value.
147		if nr != test.bytes {
148			t.Errorf("ReadMessage #%d unexpected num bytes read - "+
149				"got %d, want %d", i, nr, test.bytes)
150		}
151	}
152
153	// Do the same thing for Read/WriteMessage, but ignore the bytes since
154	// they don't return them.
155	t.Logf("Running %d tests", len(tests))
156	for i, test := range tests {
157		// Encode to wire format.
158		var buf bytes.Buffer
159		err := WriteMessage(&buf, test.in, test.pver, test.btcnet)
160		if err != nil {
161			t.Errorf("WriteMessage #%d error %v", i, err)
162			continue
163		}
164
165		// Decode from wire format.
166		rbuf := bytes.NewReader(buf.Bytes())
167		msg, _, err := ReadMessage(rbuf, test.pver, test.btcnet)
168		if err != nil {
169			t.Errorf("ReadMessage #%d error %v, msg %v", i, err,
170				spew.Sdump(msg))
171			continue
172		}
173		if !reflect.DeepEqual(msg, test.out) {
174			t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
175				spew.Sdump(msg), spew.Sdump(test.out))
176			continue
177		}
178	}
179}
180
181// TestReadMessageWireErrors performs negative tests against wire decoding into
182// concrete messages to confirm error paths work correctly.
183func TestReadMessageWireErrors(t *testing.T) {
184	pver := ProtocolVersion
185	btcnet := MainNet
186
187	// Ensure message errors are as expected with no function specified.
188	wantErr := "something bad happened"
189	testErr := MessageError{Description: wantErr}
190	if testErr.Error() != wantErr {
191		t.Errorf("MessageError: wrong error - got %v, want %v",
192			testErr.Error(), wantErr)
193	}
194
195	// Ensure message errors are as expected with a function specified.
196	wantFunc := "foo"
197	testErr = MessageError{Func: wantFunc, Description: wantErr}
198	if testErr.Error() != wantFunc+": "+wantErr {
199		t.Errorf("MessageError: wrong error - got %v, want %v",
200			testErr.Error(), wantErr)
201	}
202
203	// Wire encoded bytes for main and testnet3 networks magic identifiers.
204	testNet3Bytes := makeHeader(TestNet3, "", 0, 0)
205
206	// Wire encoded bytes for a message that exceeds max overall message
207	// length.
208	mpl := uint32(MaxMessagePayload)
209	exceedMaxPayloadBytes := makeHeader(btcnet, "getaddr", mpl+1, 0)
210
211	// Wire encoded bytes for a command which is invalid utf-8.
212	badCommandBytes := makeHeader(btcnet, "bogus", 0, 0)
213	badCommandBytes[4] = 0x81
214
215	// Wire encoded bytes for a command which is valid, but not supported.
216	unsupportedCommandBytes := makeHeader(btcnet, "bogus", 0, 0)
217
218	// Wire encoded bytes for a message which exceeds the max payload for
219	// a specific message type.
220	exceedTypePayloadBytes := makeHeader(btcnet, "getaddr", 1, 0)
221
222	// Wire encoded bytes for a message which does not deliver the full
223	// payload according to the header length.
224	shortPayloadBytes := makeHeader(btcnet, "version", 115, 0)
225
226	// Wire encoded bytes for a message with a bad checksum.
227	badChecksumBytes := makeHeader(btcnet, "version", 2, 0xbeef)
228	badChecksumBytes = append(badChecksumBytes, []byte{0x0, 0x0}...)
229
230	// Wire encoded bytes for a message which has a valid header, but is
231	// the wrong format.  An addr starts with a varint of the number of
232	// contained in the message.  Claim there is two, but don't provide
233	// them.  At the same time, forge the header fields so the message is
234	// otherwise accurate.
235	badMessageBytes := makeHeader(btcnet, "addr", 1, 0xeaadc31c)
236	badMessageBytes = append(badMessageBytes, 0x2)
237
238	// Wire encoded bytes for a message which the header claims has 15k
239	// bytes of data to discard.
240	discardBytes := makeHeader(btcnet, "bogus", 15*1024, 0)
241
242	tests := []struct {
243		buf     []byte     // Wire encoding
244		pver    uint32     // Protocol version for wire encoding
245		btcnet  BitcoinNet // Bitcoin network for wire encoding
246		max     int        // Max size of fixed buffer to induce errors
247		readErr error      // Expected read error
248		bytes   int        // Expected num bytes read
249	}{
250		// Latest protocol version with intentional read errors.
251
252		// Short header.
253		{
254			[]byte{},
255			pver,
256			btcnet,
257			0,
258			io.EOF,
259			0,
260		},
261
262		// Wrong network.  Want MainNet, but giving TestNet3.
263		{
264			testNet3Bytes,
265			pver,
266			btcnet,
267			len(testNet3Bytes),
268			&MessageError{},
269			24,
270		},
271
272		// Exceed max overall message payload length.
273		{
274			exceedMaxPayloadBytes,
275			pver,
276			btcnet,
277			len(exceedMaxPayloadBytes),
278			&MessageError{},
279			24,
280		},
281
282		// Invalid UTF-8 command.
283		{
284			badCommandBytes,
285			pver,
286			btcnet,
287			len(badCommandBytes),
288			&MessageError{},
289			24,
290		},
291
292		// Valid, but unsupported command.
293		{
294			unsupportedCommandBytes,
295			pver,
296			btcnet,
297			len(unsupportedCommandBytes),
298			&MessageError{},
299			24,
300		},
301
302		// Exceed max allowed payload for a message of a specific type.
303		{
304			exceedTypePayloadBytes,
305			pver,
306			btcnet,
307			len(exceedTypePayloadBytes),
308			&MessageError{},
309			24,
310		},
311
312		// Message with a payload shorter than the header indicates.
313		{
314			shortPayloadBytes,
315			pver,
316			btcnet,
317			len(shortPayloadBytes),
318			io.EOF,
319			24,
320		},
321
322		// Message with a bad checksum.
323		{
324			badChecksumBytes,
325			pver,
326			btcnet,
327			len(badChecksumBytes),
328			&MessageError{},
329			26,
330		},
331
332		// Message with a valid header, but wrong format.
333		{
334			badMessageBytes,
335			pver,
336			btcnet,
337			len(badMessageBytes),
338			io.EOF,
339			25,
340		},
341
342		// 15k bytes of data to discard.
343		{
344			discardBytes,
345			pver,
346			btcnet,
347			len(discardBytes),
348			&MessageError{},
349			24,
350		},
351	}
352
353	t.Logf("Running %d tests", len(tests))
354	for i, test := range tests {
355		// Decode from wire format.
356		r := newFixedReader(test.max, test.buf)
357		nr, _, _, err := ReadMessageN(r, test.pver, test.btcnet)
358		if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) {
359			t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
360				"want: %T", i, err, err, test.readErr)
361			continue
362		}
363
364		// Ensure the number of bytes written match the expected value.
365		if nr != test.bytes {
366			t.Errorf("ReadMessage #%d unexpected num bytes read - "+
367				"got %d, want %d", i, nr, test.bytes)
368		}
369
370		// For errors which are not of type MessageError, check them for
371		// equality.
372		if _, ok := err.(*MessageError); !ok {
373			if err != test.readErr {
374				t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
375					"want: %v <%T>", i, err, err,
376					test.readErr, test.readErr)
377				continue
378			}
379		}
380	}
381}
382
383// TestWriteMessageWireErrors performs negative tests against wire encoding from
384// concrete messages to confirm error paths work correctly.
385func TestWriteMessageWireErrors(t *testing.T) {
386	pver := ProtocolVersion
387	btcnet := MainNet
388	wireErr := &MessageError{}
389
390	// Fake message with a command that is too long.
391	badCommandMsg := &fakeMessage{command: "somethingtoolong"}
392
393	// Fake message with a problem during encoding
394	encodeErrMsg := &fakeMessage{forceEncodeErr: true}
395
396	// Fake message that has payload which exceeds max overall message size.
397	exceedOverallPayload := make([]byte, MaxMessagePayload+1)
398	exceedOverallPayloadErrMsg := &fakeMessage{payload: exceedOverallPayload}
399
400	// Fake message that has payload which exceeds max allowed per message.
401	exceedPayload := make([]byte, 1)
402	exceedPayloadErrMsg := &fakeMessage{payload: exceedPayload, forceLenErr: true}
403
404	// Fake message that is used to force errors in the header and payload
405	// writes.
406	bogusPayload := []byte{0x01, 0x02, 0x03, 0x04}
407	bogusMsg := &fakeMessage{command: "bogus", payload: bogusPayload}
408
409	tests := []struct {
410		msg    Message    // Message to encode
411		pver   uint32     // Protocol version for wire encoding
412		btcnet BitcoinNet // Bitcoin network for wire encoding
413		max    int        // Max size of fixed buffer to induce errors
414		err    error      // Expected error
415		bytes  int        // Expected num bytes written
416	}{
417		// Command too long.
418		{badCommandMsg, pver, btcnet, 0, wireErr, 0},
419		// Force error in payload encode.
420		{encodeErrMsg, pver, btcnet, 0, wireErr, 0},
421		// Force error due to exceeding max overall message payload size.
422		{exceedOverallPayloadErrMsg, pver, btcnet, 0, wireErr, 0},
423		// Force error due to exceeding max payload for message type.
424		{exceedPayloadErrMsg, pver, btcnet, 0, wireErr, 0},
425		// Force error in header write.
426		{bogusMsg, pver, btcnet, 0, io.ErrShortWrite, 0},
427		// Force error in payload write.
428		{bogusMsg, pver, btcnet, 24, io.ErrShortWrite, 24},
429	}
430
431	t.Logf("Running %d tests", len(tests))
432	for i, test := range tests {
433		// Encode wire format.
434		w := newFixedWriter(test.max)
435		nw, err := WriteMessageN(w, test.msg, test.pver, test.btcnet)
436		if reflect.TypeOf(err) != reflect.TypeOf(test.err) {
437			t.Errorf("WriteMessage #%d wrong error got: %v <%T>, "+
438				"want: %T", i, err, err, test.err)
439			continue
440		}
441
442		// Ensure the number of bytes written match the expected value.
443		if nw != test.bytes {
444			t.Errorf("WriteMessage #%d unexpected num bytes "+
445				"written - got %d, want %d", i, nw, test.bytes)
446		}
447
448		// For errors which are not of type MessageError, check them for
449		// equality.
450		if _, ok := err.(*MessageError); !ok {
451			if err != test.err {
452				t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
453					"want: %v <%T>", i, err, err,
454					test.err, test.err)
455				continue
456			}
457		}
458	}
459}
460