1// Copyright (c) 2013-2016 The btcsuite developers
2// Use of this source code is governed by an ISC
3// license that can be found in the LICENSE file.
4
5package wire
6
7import (
8	"bytes"
9	"io"
10)
11
12// fixedWriter implements the io.Writer interface and intentially allows
13// testing of error paths by forcing short writes.
14type fixedWriter struct {
15	b   []byte
16	pos int
17}
18
19// Write writes the contents of p to w.  When the contents of p would cause
20// the writer to exceed the maximum allowed size of the fixed writer,
21// io.ErrShortWrite is returned and the writer is left unchanged.
22//
23// This satisfies the io.Writer interface.
24func (w *fixedWriter) Write(p []byte) (n int, err error) {
25	lenp := len(p)
26	if w.pos+lenp > cap(w.b) {
27		return 0, io.ErrShortWrite
28	}
29	n = lenp
30	w.pos += copy(w.b[w.pos:], p)
31	return
32}
33
34// Bytes returns the bytes already written to the fixed writer.
35func (w *fixedWriter) Bytes() []byte {
36	return w.b
37}
38
39// newFixedWriter returns a new io.Writer that will error once more bytes than
40// the specified max have been written.
41func newFixedWriter(max int) io.Writer {
42	b := make([]byte, max)
43	fw := fixedWriter{b, 0}
44	return &fw
45}
46
47// fixedReader implements the io.Reader interface and intentially allows
48// testing of error paths by forcing short reads.
49type fixedReader struct {
50	buf   []byte
51	pos   int
52	iobuf *bytes.Buffer
53}
54
55// Read reads the next len(p) bytes from the fixed reader.  When the number of
56// bytes read would exceed the maximum number of allowed bytes to be read from
57// the fixed writer, an error is returned.
58//
59// This satisfies the io.Reader interface.
60func (fr *fixedReader) Read(p []byte) (n int, err error) {
61	n, err = fr.iobuf.Read(p)
62	fr.pos += n
63	return
64}
65
66// newFixedReader returns a new io.Reader that will error once more bytes than
67// the specified max have been read.
68func newFixedReader(max int, buf []byte) io.Reader {
69	b := make([]byte, max)
70	if buf != nil {
71		copy(b[:], buf)
72	}
73
74	iobuf := bytes.NewBuffer(b)
75	fr := fixedReader{b, 0, iobuf}
76	return &fr
77}
78