1package fwd
2
3import "io"
4
5const (
6	// DefaultWriterSize is the
7	// default write buffer size.
8	DefaultWriterSize = 2048
9
10	minWriterSize = minReaderSize
11)
12
13// Writer is a buffered writer
14type Writer struct {
15	w   io.Writer // writer
16	buf []byte    // 0:len(buf) is bufered data
17}
18
19// NewWriter returns a new writer
20// that writes to 'w' and has a buffer
21// that is `DefaultWriterSize` bytes.
22func NewWriter(w io.Writer) *Writer {
23	if wr, ok := w.(*Writer); ok {
24		return wr
25	}
26	return &Writer{
27		w:   w,
28		buf: make([]byte, 0, DefaultWriterSize),
29	}
30}
31
32// NewWriterSize returns a new writer
33// that writes to 'w' and has a buffer
34// that is 'size' bytes.
35func NewWriterSize(w io.Writer, size int) *Writer {
36	if wr, ok := w.(*Writer); ok && cap(wr.buf) >= size {
37		return wr
38	}
39	return &Writer{
40		w:   w,
41		buf: make([]byte, 0, max(size, minWriterSize)),
42	}
43}
44
45// Buffered returns the number of buffered bytes
46// in the reader.
47func (w *Writer) Buffered() int { return len(w.buf) }
48
49// BufferSize returns the maximum size of the buffer.
50func (w *Writer) BufferSize() int { return cap(w.buf) }
51
52// Flush flushes any buffered bytes
53// to the underlying writer.
54func (w *Writer) Flush() error {
55	l := len(w.buf)
56	if l > 0 {
57		n, err := w.w.Write(w.buf)
58
59		// if we didn't write the whole
60		// thing, copy the unwritten
61		// bytes to the beginnning of the
62		// buffer.
63		if n < l && n > 0 {
64			w.pushback(n)
65			if err == nil {
66				err = io.ErrShortWrite
67			}
68		}
69		if err != nil {
70			return err
71		}
72		w.buf = w.buf[:0]
73		return nil
74	}
75	return nil
76}
77
78// Write implements `io.Writer`
79func (w *Writer) Write(p []byte) (int, error) {
80	c, l, ln := cap(w.buf), len(w.buf), len(p)
81	avail := c - l
82
83	// requires flush
84	if avail < ln {
85		if err := w.Flush(); err != nil {
86			return 0, err
87		}
88		l = len(w.buf)
89	}
90	// too big to fit in buffer;
91	// write directly to w.w
92	if c < ln {
93		return w.w.Write(p)
94	}
95
96	// grow buf slice; copy; return
97	w.buf = w.buf[:l+ln]
98	return copy(w.buf[l:], p), nil
99}
100
101// WriteString is analogous to Write, but it takes a string.
102func (w *Writer) WriteString(s string) (int, error) {
103	c, l, ln := cap(w.buf), len(w.buf), len(s)
104	avail := c - l
105
106	// requires flush
107	if avail < ln {
108		if err := w.Flush(); err != nil {
109			return 0, err
110		}
111		l = len(w.buf)
112	}
113	// too big to fit in buffer;
114	// write directly to w.w
115	//
116	// yes, this is unsafe. *but*
117	// io.Writer is not allowed
118	// to mutate its input or
119	// maintain a reference to it,
120	// per the spec in package io.
121	//
122	// plus, if the string is really
123	// too big to fit in the buffer, then
124	// creating a copy to write it is
125	// expensive (and, strictly speaking,
126	// unnecessary)
127	if c < ln {
128		return w.w.Write(unsafestr(s))
129	}
130
131	// grow buf slice; copy; return
132	w.buf = w.buf[:l+ln]
133	return copy(w.buf[l:], s), nil
134}
135
136// WriteByte implements `io.ByteWriter`
137func (w *Writer) WriteByte(b byte) error {
138	if len(w.buf) == cap(w.buf) {
139		if err := w.Flush(); err != nil {
140			return err
141		}
142	}
143	w.buf = append(w.buf, b)
144	return nil
145}
146
147// Next returns the next 'n' free bytes
148// in the write buffer, flushing the writer
149// as necessary. Next will return `io.ErrShortBuffer`
150// if 'n' is greater than the size of the write buffer.
151// Calls to 'next' increment the write position by
152// the size of the returned buffer.
153func (w *Writer) Next(n int) ([]byte, error) {
154	c, l := cap(w.buf), len(w.buf)
155	if n > c {
156		return nil, io.ErrShortBuffer
157	}
158	avail := c - l
159	if avail < n {
160		if err := w.Flush(); err != nil {
161			return nil, err
162		}
163		l = len(w.buf)
164	}
165	w.buf = w.buf[:l+n]
166	return w.buf[l:], nil
167}
168
169// take the bytes from w.buf[n:len(w.buf)]
170// and put them at the beginning of w.buf,
171// and resize to the length of the copied segment.
172func (w *Writer) pushback(n int) {
173	w.buf = w.buf[:copy(w.buf, w.buf[n:])]
174}
175
176// ReadFrom implements `io.ReaderFrom`
177func (w *Writer) ReadFrom(r io.Reader) (int64, error) {
178	// anticipatory flush
179	if err := w.Flush(); err != nil {
180		return 0, err
181	}
182
183	w.buf = w.buf[0:cap(w.buf)] // expand buffer
184
185	var nn int64  // written
186	var err error // error
187	var x int     // read
188
189	// 1:1 reads and writes
190	for err == nil {
191		x, err = r.Read(w.buf)
192		if x > 0 {
193			n, werr := w.w.Write(w.buf[:x])
194			nn += int64(n)
195
196			if err != nil {
197				if n < x && n > 0 {
198					w.pushback(n - x)
199				}
200				return nn, werr
201			}
202			if n < x {
203				w.pushback(n - x)
204				return nn, io.ErrShortWrite
205			}
206		} else if err == nil {
207			err = io.ErrNoProgress
208			break
209		}
210	}
211	if err != io.EOF {
212		return nn, err
213	}
214
215	// we only clear here
216	// because we are sure
217	// the writes have
218	// succeeded. otherwise,
219	// we retain the data in case
220	// future writes succeed.
221	w.buf = w.buf[0:0]
222
223	return nn, nil
224}
225