1// Copyright (c) 2013-2016 The btcsuite developers
2// Use of this source code is governed by an ISC
3// license that can be found in the LICENSE file.
4
5package wire
6
7import (
8	"crypto/rand"
9	"encoding/binary"
10	"fmt"
11	"io"
12	"math"
13	"time"
14
15	"github.com/btcsuite/btcd/chaincfg/chainhash"
16)
17
18const (
19	// MaxVarIntPayload is the maximum payload size for a variable length integer.
20	MaxVarIntPayload = 9
21
22	// binaryFreeListMaxItems is the number of buffers to keep in the free
23	// list to use for binary serialization and deserialization.
24	binaryFreeListMaxItems = 1024
25)
26
27var (
28	// littleEndian is a convenience variable since binary.LittleEndian is
29	// quite long.
30	littleEndian = binary.LittleEndian
31
32	// bigEndian is a convenience variable since binary.BigEndian is quite
33	// long.
34	bigEndian = binary.BigEndian
35)
36
37// binaryFreeList defines a concurrent safe free list of byte slices (up to the
38// maximum number defined by the binaryFreeListMaxItems constant) that have a
39// cap of 8 (thus it supports up to a uint64).  It is used to provide temporary
40// buffers for serializing and deserializing primitive numbers to and from their
41// binary encoding in order to greatly reduce the number of allocations
42// required.
43//
44// For convenience, functions are provided for each of the primitive unsigned
45// integers that automatically obtain a buffer from the free list, perform the
46// necessary binary conversion, read from or write to the given io.Reader or
47// io.Writer, and return the buffer to the free list.
48type binaryFreeList chan []byte
49
50// Borrow returns a byte slice from the free list with a length of 8.  A new
51// buffer is allocated if there are not any available on the free list.
52func (l binaryFreeList) Borrow() []byte {
53	var buf []byte
54	select {
55	case buf = <-l:
56	default:
57		buf = make([]byte, 8)
58	}
59	return buf[:8]
60}
61
62// Return puts the provided byte slice back on the free list.  The buffer MUST
63// have been obtained via the Borrow function and therefore have a cap of 8.
64func (l binaryFreeList) Return(buf []byte) {
65	select {
66	case l <- buf:
67	default:
68		// Let it go to the garbage collector.
69	}
70}
71
72// Uint8 reads a single byte from the provided reader using a buffer from the
73// free list and returns it as a uint8.
74func (l binaryFreeList) Uint8(r io.Reader) (uint8, error) {
75	buf := l.Borrow()[:1]
76	if _, err := io.ReadFull(r, buf); err != nil {
77		l.Return(buf)
78		return 0, err
79	}
80	rv := buf[0]
81	l.Return(buf)
82	return rv, nil
83}
84
85// Uint16 reads two bytes from the provided reader using a buffer from the
86// free list, converts it to a number using the provided byte order, and returns
87// the resulting uint16.
88func (l binaryFreeList) Uint16(r io.Reader, byteOrder binary.ByteOrder) (uint16, error) {
89	buf := l.Borrow()[:2]
90	if _, err := io.ReadFull(r, buf); err != nil {
91		l.Return(buf)
92		return 0, err
93	}
94	rv := byteOrder.Uint16(buf)
95	l.Return(buf)
96	return rv, nil
97}
98
99// Uint32 reads four bytes from the provided reader using a buffer from the
100// free list, converts it to a number using the provided byte order, and returns
101// the resulting uint32.
102func (l binaryFreeList) Uint32(r io.Reader, byteOrder binary.ByteOrder) (uint32, error) {
103	buf := l.Borrow()[:4]
104	if _, err := io.ReadFull(r, buf); err != nil {
105		l.Return(buf)
106		return 0, err
107	}
108	rv := byteOrder.Uint32(buf)
109	l.Return(buf)
110	return rv, nil
111}
112
113// Uint64 reads eight bytes from the provided reader using a buffer from the
114// free list, converts it to a number using the provided byte order, and returns
115// the resulting uint64.
116func (l binaryFreeList) Uint64(r io.Reader, byteOrder binary.ByteOrder) (uint64, error) {
117	buf := l.Borrow()[:8]
118	if _, err := io.ReadFull(r, buf); err != nil {
119		l.Return(buf)
120		return 0, err
121	}
122	rv := byteOrder.Uint64(buf)
123	l.Return(buf)
124	return rv, nil
125}
126
127// PutUint8 copies the provided uint8 into a buffer from the free list and
128// writes the resulting byte to the given writer.
129func (l binaryFreeList) PutUint8(w io.Writer, val uint8) error {
130	buf := l.Borrow()[:1]
131	buf[0] = val
132	_, err := w.Write(buf)
133	l.Return(buf)
134	return err
135}
136
137// PutUint16 serializes the provided uint16 using the given byte order into a
138// buffer from the free list and writes the resulting two bytes to the given
139// writer.
140func (l binaryFreeList) PutUint16(w io.Writer, byteOrder binary.ByteOrder, val uint16) error {
141	buf := l.Borrow()[:2]
142	byteOrder.PutUint16(buf, val)
143	_, err := w.Write(buf)
144	l.Return(buf)
145	return err
146}
147
148// PutUint32 serializes the provided uint32 using the given byte order into a
149// buffer from the free list and writes the resulting four bytes to the given
150// writer.
151func (l binaryFreeList) PutUint32(w io.Writer, byteOrder binary.ByteOrder, val uint32) error {
152	buf := l.Borrow()[:4]
153	byteOrder.PutUint32(buf, val)
154	_, err := w.Write(buf)
155	l.Return(buf)
156	return err
157}
158
159// PutUint64 serializes the provided uint64 using the given byte order into a
160// buffer from the free list and writes the resulting eight bytes to the given
161// writer.
162func (l binaryFreeList) PutUint64(w io.Writer, byteOrder binary.ByteOrder, val uint64) error {
163	buf := l.Borrow()[:8]
164	byteOrder.PutUint64(buf, val)
165	_, err := w.Write(buf)
166	l.Return(buf)
167	return err
168}
169
170// binarySerializer provides a free list of buffers to use for serializing and
171// deserializing primitive integer values to and from io.Readers and io.Writers.
172var binarySerializer binaryFreeList = make(chan []byte, binaryFreeListMaxItems)
173
174// errNonCanonicalVarInt is the common format string used for non-canonically
175// encoded variable length integer errors.
176var errNonCanonicalVarInt = "non-canonical varint %x - discriminant %x must " +
177	"encode a value greater than %x"
178
179// uint32Time represents a unix timestamp encoded with a uint32.  It is used as
180// a way to signal the readElement function how to decode a timestamp into a Go
181// time.Time since it is otherwise ambiguous.
182type uint32Time time.Time
183
184// int64Time represents a unix timestamp encoded with an int64.  It is used as
185// a way to signal the readElement function how to decode a timestamp into a Go
186// time.Time since it is otherwise ambiguous.
187type int64Time time.Time
188
189// readElement reads the next sequence of bytes from r using little endian
190// depending on the concrete type of element pointed to.
191func readElement(r io.Reader, element interface{}) error {
192	// Attempt to read the element based on the concrete type via fast
193	// type assertions first.
194	switch e := element.(type) {
195	case *int32:
196		rv, err := binarySerializer.Uint32(r, littleEndian)
197		if err != nil {
198			return err
199		}
200		*e = int32(rv)
201		return nil
202
203	case *uint32:
204		rv, err := binarySerializer.Uint32(r, littleEndian)
205		if err != nil {
206			return err
207		}
208		*e = rv
209		return nil
210
211	case *int64:
212		rv, err := binarySerializer.Uint64(r, littleEndian)
213		if err != nil {
214			return err
215		}
216		*e = int64(rv)
217		return nil
218
219	case *uint64:
220		rv, err := binarySerializer.Uint64(r, littleEndian)
221		if err != nil {
222			return err
223		}
224		*e = rv
225		return nil
226
227	case *bool:
228		rv, err := binarySerializer.Uint8(r)
229		if err != nil {
230			return err
231		}
232		if rv == 0x00 {
233			*e = false
234		} else {
235			*e = true
236		}
237		return nil
238
239	// Unix timestamp encoded as a uint32.
240	case *uint32Time:
241		rv, err := binarySerializer.Uint32(r, binary.LittleEndian)
242		if err != nil {
243			return err
244		}
245		*e = uint32Time(time.Unix(int64(rv), 0))
246		return nil
247
248	// Unix timestamp encoded as an int64.
249	case *int64Time:
250		rv, err := binarySerializer.Uint64(r, binary.LittleEndian)
251		if err != nil {
252			return err
253		}
254		*e = int64Time(time.Unix(int64(rv), 0))
255		return nil
256
257	// Message header checksum.
258	case *[4]byte:
259		_, err := io.ReadFull(r, e[:])
260		if err != nil {
261			return err
262		}
263		return nil
264
265	// Message header command.
266	case *[CommandSize]uint8:
267		_, err := io.ReadFull(r, e[:])
268		if err != nil {
269			return err
270		}
271		return nil
272
273	// IP address.
274	case *[16]byte:
275		_, err := io.ReadFull(r, e[:])
276		if err != nil {
277			return err
278		}
279		return nil
280
281	case *chainhash.Hash:
282		_, err := io.ReadFull(r, e[:])
283		if err != nil {
284			return err
285		}
286		return nil
287
288	case *ServiceFlag:
289		rv, err := binarySerializer.Uint64(r, littleEndian)
290		if err != nil {
291			return err
292		}
293		*e = ServiceFlag(rv)
294		return nil
295
296	case *InvType:
297		rv, err := binarySerializer.Uint32(r, littleEndian)
298		if err != nil {
299			return err
300		}
301		*e = InvType(rv)
302		return nil
303
304	case *BitcoinNet:
305		rv, err := binarySerializer.Uint32(r, littleEndian)
306		if err != nil {
307			return err
308		}
309		*e = BitcoinNet(rv)
310		return nil
311
312	case *BloomUpdateType:
313		rv, err := binarySerializer.Uint8(r)
314		if err != nil {
315			return err
316		}
317		*e = BloomUpdateType(rv)
318		return nil
319
320	case *RejectCode:
321		rv, err := binarySerializer.Uint8(r)
322		if err != nil {
323			return err
324		}
325		*e = RejectCode(rv)
326		return nil
327	}
328
329	// Fall back to the slower binary.Read if a fast path was not available
330	// above.
331	return binary.Read(r, littleEndian, element)
332}
333
334// readElements reads multiple items from r.  It is equivalent to multiple
335// calls to readElement.
336func readElements(r io.Reader, elements ...interface{}) error {
337	for _, element := range elements {
338		err := readElement(r, element)
339		if err != nil {
340			return err
341		}
342	}
343	return nil
344}
345
346// writeElement writes the little endian representation of element to w.
347func writeElement(w io.Writer, element interface{}) error {
348	// Attempt to write the element based on the concrete type via fast
349	// type assertions first.
350	switch e := element.(type) {
351	case int32:
352		err := binarySerializer.PutUint32(w, littleEndian, uint32(e))
353		if err != nil {
354			return err
355		}
356		return nil
357
358	case uint32:
359		err := binarySerializer.PutUint32(w, littleEndian, e)
360		if err != nil {
361			return err
362		}
363		return nil
364
365	case int64:
366		err := binarySerializer.PutUint64(w, littleEndian, uint64(e))
367		if err != nil {
368			return err
369		}
370		return nil
371
372	case uint64:
373		err := binarySerializer.PutUint64(w, littleEndian, e)
374		if err != nil {
375			return err
376		}
377		return nil
378
379	case bool:
380		var err error
381		if e {
382			err = binarySerializer.PutUint8(w, 0x01)
383		} else {
384			err = binarySerializer.PutUint8(w, 0x00)
385		}
386		if err != nil {
387			return err
388		}
389		return nil
390
391	// Message header checksum.
392	case [4]byte:
393		_, err := w.Write(e[:])
394		if err != nil {
395			return err
396		}
397		return nil
398
399	// Message header command.
400	case [CommandSize]uint8:
401		_, err := w.Write(e[:])
402		if err != nil {
403			return err
404		}
405		return nil
406
407	// IP address.
408	case [16]byte:
409		_, err := w.Write(e[:])
410		if err != nil {
411			return err
412		}
413		return nil
414
415	case *chainhash.Hash:
416		_, err := w.Write(e[:])
417		if err != nil {
418			return err
419		}
420		return nil
421
422	case ServiceFlag:
423		err := binarySerializer.PutUint64(w, littleEndian, uint64(e))
424		if err != nil {
425			return err
426		}
427		return nil
428
429	case InvType:
430		err := binarySerializer.PutUint32(w, littleEndian, uint32(e))
431		if err != nil {
432			return err
433		}
434		return nil
435
436	case BitcoinNet:
437		err := binarySerializer.PutUint32(w, littleEndian, uint32(e))
438		if err != nil {
439			return err
440		}
441		return nil
442
443	case BloomUpdateType:
444		err := binarySerializer.PutUint8(w, uint8(e))
445		if err != nil {
446			return err
447		}
448		return nil
449
450	case RejectCode:
451		err := binarySerializer.PutUint8(w, uint8(e))
452		if err != nil {
453			return err
454		}
455		return nil
456	}
457
458	// Fall back to the slower binary.Write if a fast path was not available
459	// above.
460	return binary.Write(w, littleEndian, element)
461}
462
463// writeElements writes multiple items to w.  It is equivalent to multiple
464// calls to writeElement.
465func writeElements(w io.Writer, elements ...interface{}) error {
466	for _, element := range elements {
467		err := writeElement(w, element)
468		if err != nil {
469			return err
470		}
471	}
472	return nil
473}
474
475// ReadVarInt reads a variable length integer from r and returns it as a uint64.
476func ReadVarInt(r io.Reader, pver uint32) (uint64, error) {
477	discriminant, err := binarySerializer.Uint8(r)
478	if err != nil {
479		return 0, err
480	}
481
482	var rv uint64
483	switch discriminant {
484	case 0xff:
485		sv, err := binarySerializer.Uint64(r, littleEndian)
486		if err != nil {
487			return 0, err
488		}
489		rv = sv
490
491		// The encoding is not canonical if the value could have been
492		// encoded using fewer bytes.
493		min := uint64(0x100000000)
494		if rv < min {
495			return 0, messageError("ReadVarInt", fmt.Sprintf(
496				errNonCanonicalVarInt, rv, discriminant, min))
497		}
498
499	case 0xfe:
500		sv, err := binarySerializer.Uint32(r, littleEndian)
501		if err != nil {
502			return 0, err
503		}
504		rv = uint64(sv)
505
506		// The encoding is not canonical if the value could have been
507		// encoded using fewer bytes.
508		min := uint64(0x10000)
509		if rv < min {
510			return 0, messageError("ReadVarInt", fmt.Sprintf(
511				errNonCanonicalVarInt, rv, discriminant, min))
512		}
513
514	case 0xfd:
515		sv, err := binarySerializer.Uint16(r, littleEndian)
516		if err != nil {
517			return 0, err
518		}
519		rv = uint64(sv)
520
521		// The encoding is not canonical if the value could have been
522		// encoded using fewer bytes.
523		min := uint64(0xfd)
524		if rv < min {
525			return 0, messageError("ReadVarInt", fmt.Sprintf(
526				errNonCanonicalVarInt, rv, discriminant, min))
527		}
528
529	default:
530		rv = uint64(discriminant)
531	}
532
533	return rv, nil
534}
535
536// WriteVarInt serializes val to w using a variable number of bytes depending
537// on its value.
538func WriteVarInt(w io.Writer, pver uint32, val uint64) error {
539	if val < 0xfd {
540		return binarySerializer.PutUint8(w, uint8(val))
541	}
542
543	if val <= math.MaxUint16 {
544		err := binarySerializer.PutUint8(w, 0xfd)
545		if err != nil {
546			return err
547		}
548		return binarySerializer.PutUint16(w, littleEndian, uint16(val))
549	}
550
551	if val <= math.MaxUint32 {
552		err := binarySerializer.PutUint8(w, 0xfe)
553		if err != nil {
554			return err
555		}
556		return binarySerializer.PutUint32(w, littleEndian, uint32(val))
557	}
558
559	err := binarySerializer.PutUint8(w, 0xff)
560	if err != nil {
561		return err
562	}
563	return binarySerializer.PutUint64(w, littleEndian, val)
564}
565
566// VarIntSerializeSize returns the number of bytes it would take to serialize
567// val as a variable length integer.
568func VarIntSerializeSize(val uint64) int {
569	// The value is small enough to be represented by itself, so it's
570	// just 1 byte.
571	if val < 0xfd {
572		return 1
573	}
574
575	// Discriminant 1 byte plus 2 bytes for the uint16.
576	if val <= math.MaxUint16 {
577		return 3
578	}
579
580	// Discriminant 1 byte plus 4 bytes for the uint32.
581	if val <= math.MaxUint32 {
582		return 5
583	}
584
585	// Discriminant 1 byte plus 8 bytes for the uint64.
586	return 9
587}
588
589// ReadVarString reads a variable length string from r and returns it as a Go
590// string.  A variable length string is encoded as a variable length integer
591// containing the length of the string followed by the bytes that represent the
592// string itself.  An error is returned if the length is greater than the
593// maximum block payload size since it helps protect against memory exhaustion
594// attacks and forced panics through malformed messages.
595func ReadVarString(r io.Reader, pver uint32) (string, error) {
596	count, err := ReadVarInt(r, pver)
597	if err != nil {
598		return "", err
599	}
600
601	// Prevent variable length strings that are larger than the maximum
602	// message size.  It would be possible to cause memory exhaustion and
603	// panics without a sane upper bound on this count.
604	if count > MaxMessagePayload {
605		str := fmt.Sprintf("variable length string is too long "+
606			"[count %d, max %d]", count, MaxMessagePayload)
607		return "", messageError("ReadVarString", str)
608	}
609
610	buf := make([]byte, count)
611	_, err = io.ReadFull(r, buf)
612	if err != nil {
613		return "", err
614	}
615	return string(buf), nil
616}
617
618// WriteVarString serializes str to w as a variable length integer containing
619// the length of the string followed by the bytes that represent the string
620// itself.
621func WriteVarString(w io.Writer, pver uint32, str string) error {
622	err := WriteVarInt(w, pver, uint64(len(str)))
623	if err != nil {
624		return err
625	}
626	_, err = w.Write([]byte(str))
627	return err
628}
629
630// ReadVarBytes reads a variable length byte array.  A byte array is encoded
631// as a varInt containing the length of the array followed by the bytes
632// themselves.  An error is returned if the length is greater than the
633// passed maxAllowed parameter which helps protect against memory exhaustion
634// attacks and forced panics through malformed messages.  The fieldName
635// parameter is only used for the error message so it provides more context in
636// the error.
637func ReadVarBytes(r io.Reader, pver uint32, maxAllowed uint32,
638	fieldName string) ([]byte, error) {
639
640	count, err := ReadVarInt(r, pver)
641	if err != nil {
642		return nil, err
643	}
644
645	// Prevent byte array larger than the max message size.  It would
646	// be possible to cause memory exhaustion and panics without a sane
647	// upper bound on this count.
648	if count > uint64(maxAllowed) {
649		str := fmt.Sprintf("%s is larger than the max allowed size "+
650			"[count %d, max %d]", fieldName, count, maxAllowed)
651		return nil, messageError("ReadVarBytes", str)
652	}
653
654	b := make([]byte, count)
655	_, err = io.ReadFull(r, b)
656	if err != nil {
657		return nil, err
658	}
659	return b, nil
660}
661
662// WriteVarBytes serializes a variable length byte array to w as a varInt
663// containing the number of bytes, followed by the bytes themselves.
664func WriteVarBytes(w io.Writer, pver uint32, bytes []byte) error {
665	slen := uint64(len(bytes))
666	err := WriteVarInt(w, pver, slen)
667	if err != nil {
668		return err
669	}
670
671	_, err = w.Write(bytes)
672	return err
673}
674
675// randomUint64 returns a cryptographically random uint64 value.  This
676// unexported version takes a reader primarily to ensure the error paths
677// can be properly tested by passing a fake reader in the tests.
678func randomUint64(r io.Reader) (uint64, error) {
679	rv, err := binarySerializer.Uint64(r, bigEndian)
680	if err != nil {
681		return 0, err
682	}
683	return rv, nil
684}
685
686// RandomUint64 returns a cryptographically random uint64 value.
687func RandomUint64() (uint64, error) {
688	return randomUint64(rand.Reader)
689}
690