1package snappystream
2
3import (
4	"bufio"
5	"errors"
6	"fmt"
7	"hash/crc32"
8	"io"
9
10	"github.com/mreiferson/go-snappystream/snappy-go"
11)
12
13var errClosed = fmt.Errorf("closed")
14
15// BufferedWriter is an io.WriteCloser with behavior similar to writers
16// returned by NewWriter but it buffers written data, maximizing block size (to
17// improve the output compression ratio) at the cost of speed. Benefits over
18// NewWriter are most noticible when individual writes are small and when
19// streams are long.
20//
21// Failure to call a BufferedWriter's Close or Flush methods after it is done
22// being written to will likely result in missing data frames which will be
23// undetectable in the decoding process.
24//
25// NOTE: BufferedWriter cannot be instantiated via struct literal and must
26// use NewBufferedWriter (i.e. its zero value is not usable).
27type BufferedWriter struct {
28	err error
29	w   *writer
30	bw  *bufio.Writer
31}
32
33// NewBufferedWriter allocates and returns a BufferedWriter with an internal
34// buffer of MaxBlockSize bytes.  If an error occurs writing a block to w, all
35// future writes will fail with the same error.  After all data has been
36// written, the client should call the Flush method to guarantee all data has
37// been forwarded to the underlying io.Writer.
38func NewBufferedWriter(w io.Writer) *BufferedWriter {
39	_w := NewWriter(w).(*writer)
40	return &BufferedWriter{
41		w:  _w,
42		bw: bufio.NewWriterSize(_w, MaxBlockSize),
43	}
44}
45
46// ReadFrom implements the io.ReaderFrom interface used by io.Copy. It encodes
47// data read from r as a snappy framed stream that is written to the underlying
48// writer.  ReadFrom returns the number number of bytes read, along with any
49// error encountered (other than io.EOF).
50func (w *BufferedWriter) ReadFrom(r io.Reader) (int64, error) {
51	if w.err != nil {
52		return 0, w.err
53	}
54
55	var n int64
56	n, w.err = w.bw.ReadFrom(r)
57	return n, w.err
58}
59
60// Write buffers p internally, encoding and writing a block to the underlying
61// buffer if the buffer grows beyond MaxBlockSize bytes.  The returned int
62// will be 0 if there was an error and len(p) otherwise.
63func (w *BufferedWriter) Write(p []byte) (int, error) {
64	if w.err != nil {
65		return 0, w.err
66	}
67
68	_, w.err = w.bw.Write(p)
69	if w.err != nil {
70		return 0, w.err
71	}
72
73	return len(p), nil
74}
75
76// Flush encodes and writes a block with the contents of w's internal buffer to
77// the underlying writer even if the buffer does not contain a full block of
78// data (MaxBlockSize bytes).
79func (w *BufferedWriter) Flush() error {
80	if w.err == nil {
81		w.err = w.bw.Flush()
82	}
83
84	return w.err
85}
86
87// Close flushes w's internal buffer and tears down internal data structures.
88// After a successful call to Close method calls on w return an error.  Close
89// makes no attempt to close the underlying writer.
90func (w *BufferedWriter) Close() error {
91	if w.err != nil {
92		return w.err
93	}
94
95	w.err = w.bw.Flush()
96	w.w = nil
97	w.bw = nil
98
99	if w.err != nil {
100		return w.err
101	}
102
103	w.err = errClosed
104	return nil
105}
106
107type writer struct {
108	writer io.Writer
109	err    error
110
111	hdr []byte
112	dst []byte
113
114	sentStreamID bool
115}
116
117// NewWriter returns an io.Writer that writes its input to an underlying
118// io.Writer encoded as a snappy framed stream.  A stream identifier block is
119// written to w preceding the first data block.  The returned writer will never
120// emit a block with length in bytes greater than MaxBlockSize+4 nor one
121// containing more than MaxBlockSize bytes of (uncompressed) data.
122//
123// For each Write, the returned length will only ever be len(p) or 0,
124// regardless of the length of *compressed* bytes written to the wrapped
125// io.Writer.  If the returned length is 0 then error will be non-nil.  If
126// len(p) exceeds 65536, the slice will be automatically chunked into smaller
127// blocks which are all emitted before the call returns.
128func NewWriter(w io.Writer) io.Writer {
129	return &writer{
130		writer: w,
131
132		hdr: make([]byte, 8),
133		dst: make([]byte, 4096),
134	}
135}
136
137func (w *writer) Write(p []byte) (int, error) {
138	if w.err != nil {
139		return 0, w.err
140	}
141
142	total := 0
143	sz := MaxBlockSize
144	var n int
145	for i := 0; i < len(p); i += n {
146		if i+sz > len(p) {
147			sz = len(p) - i
148		}
149
150		n, w.err = w.write(p[i : i+sz])
151		if w.err != nil {
152			return 0, w.err
153		}
154		total += n
155	}
156	return total, nil
157}
158
159// write attempts to encode p as a block and write it to the underlying writer.
160// The returned int may not equal p's length if compression below
161// MaxBlockSize-4 could not be achieved.
162func (w *writer) write(p []byte) (int, error) {
163	var err error
164
165	if len(p) > MaxBlockSize {
166		return 0, errors.New(fmt.Sprintf("block too large %d > %d", len(p), MaxBlockSize))
167	}
168
169	w.dst = w.dst[:cap(w.dst)] // Encode does dumb resize w/o context. reslice avoids alloc.
170	w.dst, err = snappy.Encode(w.dst, p)
171	if err != nil {
172		return 0, err
173	}
174	block := w.dst
175	n := len(p)
176	compressed := true
177
178	// check for data which is better left uncompressed.  this is determined if
179	// the encoded content is longer than the source.
180	if len(w.dst) >= len(p) {
181		compressed = false
182		block = p[:n]
183	}
184
185	if !w.sentStreamID {
186		_, err := w.writer.Write(streamID)
187		if err != nil {
188			return 0, err
189		}
190		w.sentStreamID = true
191	}
192
193	// set the block type
194	if compressed {
195		writeHeader(w.hdr, blockCompressed, block, p[:n])
196	} else {
197		writeHeader(w.hdr, blockUncompressed, block, p[:n])
198	}
199
200	_, err = w.writer.Write(w.hdr)
201	if err != nil {
202		return 0, err
203	}
204
205	_, err = w.writer.Write(block)
206	if err != nil {
207		return 0, err
208	}
209
210	return n, nil
211}
212
213// writeHeader panics if len(hdr) is less than 8.
214func writeHeader(hdr []byte, btype byte, enc, dec []byte) {
215	hdr[0] = btype
216
217	// 3 byte little endian length of encoded content
218	length := uint32(len(enc)) + 4 // +4 for checksum
219	hdr[1] = byte(length)
220	hdr[2] = byte(length >> 8)
221	hdr[3] = byte(length >> 16)
222
223	// 4 byte little endian CRC32 checksum of decoded content
224	checksum := maskChecksum(crc32.Checksum(dec, crcTable))
225	hdr[4] = byte(checksum)
226	hdr[5] = byte(checksum >> 8)
227	hdr[6] = byte(checksum >> 16)
228	hdr[7] = byte(checksum >> 24)
229}
230