1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package basex
5
6import "io"
7
8// Much of this code is adopted from Go's encoding/base64
9
10// EncodeToString returns the baseX encoding of src.
11func (enc *Encoding) EncodeToString(src []byte) string {
12	buf := make([]byte, enc.EncodedLen(len(src)))
13	enc.Encode(buf, src)
14	return string(buf)
15}
16
17type encoder struct {
18	err  error
19	enc  *Encoding
20	w    io.Writer
21	buf  []byte // buffered data waiting to be encoded
22	nbuf int    // number of bytes in buf
23	out  []byte // output buffer
24}
25
26func (e *encoder) Write(p []byte) (n int, err error) {
27	if e.err != nil {
28		return 0, e.err
29	}
30
31	ibl := e.enc.base256BlockLen
32	obl := e.enc.baseXBlockLen
33
34	// Leading fringe.
35	if e.nbuf > 0 {
36		var i int
37		for i = 0; i < len(p) && e.nbuf < ibl; i++ {
38			e.buf[e.nbuf] = p[i]
39			e.nbuf++
40		}
41		n += i
42		p = p[i:]
43		if e.nbuf < ibl {
44			return
45		}
46		e.enc.Encode(e.out[:], e.buf[:])
47		if _, e.err = e.w.Write(e.out[:obl]); e.err != nil {
48			return n, e.err
49		}
50		e.nbuf = 0
51	}
52
53	// Large interior chunks.
54	for len(p) >= ibl {
55		nn := len(e.out) / obl * ibl
56		if nn > len(p) {
57			nn = len(p)
58			nn -= nn % ibl
59		}
60		e.enc.Encode(e.out[:], p[:nn])
61		if _, e.err = e.w.Write(e.out[0 : nn/ibl*obl]); e.err != nil {
62			return n, e.err
63		}
64		n += nn
65		p = p[nn:]
66	}
67
68	// Trailing fringe.
69	copy(e.buf[0:len(p)], p)
70	e.nbuf = len(p)
71	n += len(p)
72	return
73}
74
75// Close flushes any pending output from the encoder.
76// It is an error to call Write after calling Close.
77func (e *encoder) Close() error {
78	// If there's anything left in the buffer, flush it out
79	if e.err == nil && e.nbuf > 0 {
80		e.enc.Encode(e.out[:], e.buf[:e.nbuf])
81		_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
82		e.nbuf = 0
83	}
84	return e.err
85}
86
87// NewEncoder returns a new baseX stream encoder.  Data written to
88// the returned writer will be encoded using enc and then written to w.
89// Encodings operate in enc.baseXBlockLen-byte blocks; when finished
90// writing, the caller must Close the returned encoder to flush any
91// partially written blocks.
92func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
93	return &encoder{
94		enc: enc,
95		w:   w,
96		buf: make([]byte, enc.base256BlockLen),
97		out: make([]byte, 128*enc.baseXBlockLen),
98	}
99}
100
101// DecodeString returns the bytes represented by the baseX string s.
102// It uses the liberal decoding strategy, ignoring any non-baseX-characters
103func (enc *Encoding) DecodeString(s string) ([]byte, error) {
104	dbuf := make([]byte, enc.DecodedLen(len(s)))
105	n, err := enc.Decode(dbuf, []byte(s))
106	return dbuf[:n], err
107}
108
109type decoder struct {
110	err        error
111	enc        *Encoding
112	r          io.Reader
113	out        []byte // leftover decoded output
114	buf        []byte // leftover input
115	nbuf       int    // the begin pointer of buf above
116	scratchbuf []byte // a temporary scratch buf, for reuse
117}
118
119func (d *decoder) Read(p []byte) (int, error) {
120
121	if d.err != nil {
122		return 0, d.err
123	}
124
125	// Use leftover decoded output from last read.
126	if len(d.out) > 0 {
127		ret := copy(p, d.out)
128		d.out = d.out[ret:]
129		return ret, nil
130	}
131
132	ibl := d.enc.base256BlockLen
133	obl := d.enc.baseXBlockLen
134
135	nn := len(p) / ibl * obl
136	if nn < obl {
137		nn = obl
138	}
139	if nn > len(d.buf) {
140		nn = len(d.buf)
141	}
142
143	// Try to read up to the next full block.
144	for d.nbuf < obl && d.err == nil {
145		var n int
146		n, d.err = d.r.Read(d.buf[d.nbuf:nn])
147		d.nbuf += n
148	}
149
150	eof := false
151
152	if d.err == io.EOF {
153		if d.nbuf == 0 {
154			return 0, d.err
155		}
156		eof = true
157		d.err = nil
158	} else if d.err != nil {
159		return 0, d.err
160	}
161
162	// The num bytes to decode should be along obl-aligned boundaries, unless
163	// we're at the end of file.
164	numBytesToDecode := d.nbuf
165	if !eof {
166		numBytesToDecode = numBytesToDecode / obl * obl
167	}
168	numBytesToOutput := d.enc.DecodedLen(numBytesToDecode)
169
170	var ret int
171
172	// If we have too many bytes for the given buffer, we can buffer
173	// the rest internally
174	if numBytesToOutput > len(p) {
175		var n int
176		n, d.err = d.enc.Decode(d.scratchbuf[:], d.buf[:numBytesToDecode])
177		d.out = d.scratchbuf[:n]
178		ret = copy(p, d.out)
179		d.out = d.out[ret:]
180	} else {
181		ret, d.err = d.enc.Decode(p, d.buf[:numBytesToDecode])
182	}
183
184	// Shift the bytes in d.buf over from [numBytesToDecode:] to the start of the array
185	d.nbuf -= numBytesToDecode
186	copy(d.buf[0:d.nbuf], d.buf[numBytesToDecode:numBytesToDecode+d.nbuf])
187
188	if ret == 0 && d.err == nil && len(p) != 0 {
189		return 0, io.EOF
190	}
191
192	return ret, d.err
193}
194
195type filteringReader struct {
196	wrapped io.Reader
197	enc     *Encoding
198	nRead   int
199}
200
201func (r *filteringReader) Read(p []byte) (int, error) {
202	n, err := r.wrapped.Read(p)
203	for n > 0 {
204		offset := 0
205		for i, b := range p[:n] {
206			typ := r.enc.getByteType(b)
207			if typ == invalidByteType {
208				// TODO: Return n, i.e. partial results?
209				return 0, CorruptInputError(r.nRead)
210			}
211			r.nRead++
212			if typ == skipByteType {
213				continue
214			}
215
216			// We want this byte. We only need to rewrite it if
217			// offset is behind i (otherwise we'd just be writing the
218			// same byte again, over itself).
219			if i != offset {
220				p[offset] = b
221			}
222			offset++
223		}
224		if offset > 0 {
225			return offset, err
226		}
227		// Previous buffer entirely whitespace, read again
228		n, err = r.wrapped.Read(p)
229	}
230	return n, err
231}
232
233// NewDecoder constructs a new baseX stream decoder.
234func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
235	return newDecoder(enc, r)
236}
237
238func newDecoder(enc *Encoding, r io.Reader) io.Reader {
239	if enc.hasSkipBytes() {
240		r = &filteringReader{r, enc, 0}
241	}
242	return &decoder{
243		enc:        enc,
244		r:          r,
245		buf:        make([]byte, 8192*enc.base256BlockLen),
246		scratchbuf: make([]byte, 8192*enc.baseXBlockLen),
247	}
248}
249