1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package basex
5
6import (
7	"bytes"
8	"errors"
9	"fmt"
10	"math"
11	"math/big"
12)
13
14// Encoding is a radix X encoding/decoding scheme, defined by X-length
15// character alphabet.
16type Encoding struct {
17	encode          []byte
18	decodeMap       [256](*big.Int)
19	skipMap         [256]bool
20	base256BlockLen int
21	baseXBlockLen   int
22	base            int
23	logOfBase       float64
24	baseBig         *big.Int
25	skipBytes       string
26}
27
28// NewEncoding returns a new Encoding defined by the given alphabet,
29// which must a x-byte string. No padding options are currently allowed.
30// inBlock is the size of input blocks to consider.
31//
32// For base 58, we recommend 19-byte
33// input blocks, which encode to 26-byte output blocks with only .3 bits
34// wasted per block. The name of the game is to find a good rational
35// approximation of 8/log2(58), and 26/19 is pretty good!
36func NewEncoding(encoder string, base256BlockLen int, skipBytes string) *Encoding {
37
38	base := len(encoder)
39
40	logOfBase := math.Log2(float64(base))
41
42	// If input blocks are base256BlockLen size, compute the corresponding
43	// output block length.  We need to round up to fit the overflow.
44	baseXBlockLen := int(math.Ceil(float64(8*base256BlockLen) / logOfBase))
45
46	// Code adapted from encoding/base64/base64.go in the standard
47	// Go libraries.
48	e := &Encoding{
49		encode:          make([]byte, base),
50		base:            base,
51		base256BlockLen: base256BlockLen,
52		baseXBlockLen:   baseXBlockLen,
53		logOfBase:       logOfBase,
54		baseBig:         big.NewInt(int64(base)),
55		skipBytes:       skipBytes,
56	}
57	copy(e.encode[:], encoder)
58
59	for _, c := range skipBytes {
60		e.skipMap[c] = true
61	}
62	for i := 0; i < len(encoder); i++ {
63		e.decodeMap[encoder[i]] = big.NewInt(int64(i))
64	}
65	return e
66}
67
68/*
69 * Encoder
70 */
71
72// Encode encodes src using the encoding enc, writing
73// EncodedLen(len(src)) bytes to dst.
74//
75// The encoding aligns the input along base256BlockLen boundaries.
76// so Encode is not appropriate for use on individual blocks
77// of a large data stream.  Use NewEncoder() instead.
78func (enc *Encoding) Encode(dst, src []byte) {
79	for sp, dp, sLim, dLim := 0, 0, 0, 0; sp < len(src); sp, dp = sLim, dLim {
80		sLim = sp + enc.base256BlockLen
81		dLim = dp + enc.baseXBlockLen
82		if sLim > len(src) {
83			sLim = len(src)
84		}
85		if dLim > len(dst) {
86			dLim = len(dst)
87		}
88		enc.encodeBlock(dst[dp:dLim], src[sp:sLim])
89	}
90}
91
92type byteType int
93
94const (
95	normalByteType  byteType = 0
96	skipByteType    byteType = 1
97	invalidByteType byteType = 2
98)
99
100func (enc *Encoding) getByteType(b byte) byteType {
101	if enc.decodeMap[b] != nil {
102		return normalByteType
103	}
104	if enc.skipMap[b] {
105		return skipByteType
106	}
107	return invalidByteType
108
109}
110
111func (enc *Encoding) hasSkipBytes() bool {
112	return len(enc.skipBytes) > 0
113}
114
115// IsValidByte returns true if the given byte is valid in this
116// decoding. Can be either from the main alphabet or the skip
117// alphabet to be considered valid.
118func (enc *Encoding) IsValidByte(b byte) bool {
119	return enc.decodeMap[b] != nil || enc.skipMap[b]
120}
121
122// encodeBlock fills the dst buffer with the encoding of src.
123// It is assumed the buffers are appropriately sized, and no
124// bounds checks are performed.  In particular, the dst buffer will
125// be zero-padded from right to left in all remaining bytes.
126func (enc *Encoding) encodeBlock(dst, src []byte) {
127
128	// Interpret the block as a big-endian number (Go's default)
129	num := new(big.Int).SetBytes(src)
130	rem := new(big.Int)
131	quo := new(big.Int)
132
133	encodedLen := enc.EncodedLen(len(src))
134
135	p := encodedLen - 1
136
137	for num.Sign() != 0 {
138		num, rem = quo.QuoRem(num, enc.baseBig, rem)
139		dst[p] = enc.encode[rem.Uint64()]
140		p--
141	}
142
143	// Pad the remainder of the buffer with 0s
144	for p >= 0 {
145		dst[p] = enc.encode[0]
146		p--
147	}
148}
149
150func (enc *Encoding) decode(dst []byte, src []byte) (n int, err error) {
151	dp, sp := 0, 0
152	for sp < len(src) {
153		di, si, err := enc.decodeBlock(dst[dp:], src[sp:], sp)
154		if err != nil {
155			return dp, err
156		}
157		sp += si
158		dp += di
159	}
160	return dp, nil
161}
162
163// Decode decodes src using the encoding enc.  It writes at most
164// DecodedLen(len(src)) bytes to dst and returns the number of bytes
165// written.  If src contains invalid baseX data, it will return the
166// number of bytes successfully written and CorruptInputError.  It can
167// also return an ErrInvalidEncodingLength error if there is a non-standard
168// number of bytes in this encoding
169func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
170	return enc.decode(dst, src)
171}
172
173// CorruptInputError is returned when Decode() finds a non-alphabet character
174type CorruptInputError int
175
176// Error fits the error interface
177func (e CorruptInputError) Error() string {
178	return fmt.Sprintf("illegal data at input byte %d", int(e))
179}
180
181// ErrInvalidEncodingLength is returned when a non-minimal encoding length is found
182var ErrInvalidEncodingLength = errors.New("invalid encoding length; either truncated or has trailing garbage")
183
184func (enc *Encoding) decodeBlock(dst []byte, src []byte, baseOffset int) (int, int, error) {
185	si := 0 // source index
186	numGoodChars := 0
187	res := new(big.Int)
188	res.SetUint64(0)
189
190	for i, b := range src {
191		v := enc.decodeMap[b]
192		si++
193
194		if v == nil {
195			if enc.skipMap[b] {
196				continue
197			}
198			return 0, 0, CorruptInputError(i + baseOffset)
199		}
200
201		numGoodChars++
202		res.Mul(res, enc.baseBig)
203		res.Add(res, v)
204
205		if numGoodChars == enc.baseXBlockLen {
206			break
207		}
208	}
209
210	if !enc.IsValidEncodingLength(numGoodChars) {
211		return 0, 0, ErrInvalidEncodingLength
212	}
213
214	paddedLen := enc.DecodedLen(numGoodChars)
215
216	// Use big-endian representation (the default with Go's library)
217	raw := res.Bytes()
218	p := 0
219	if len(raw) < paddedLen {
220		p = paddedLen - len(raw)
221		copy(dst, bytes.Repeat([]byte{0}, p))
222	}
223	copy(dst[p:paddedLen], raw)
224	return paddedLen, si, nil
225}
226
227// EncodedLen returns the length in bytes of the baseX encoding
228// of an input buffer of length n
229func (enc *Encoding) EncodedLen(n int) int {
230
231	// Fast path!
232	if n == enc.base256BlockLen {
233		return enc.baseXBlockLen
234	}
235
236	nblocks := n / enc.base256BlockLen
237	out := nblocks * enc.baseXBlockLen
238	rem := n % enc.base256BlockLen
239	if rem > 0 {
240		out += int(math.Ceil(float64(rem*8) / enc.logOfBase))
241	}
242	return out
243}
244
245// DecodedLen returns the length in bytes of the baseX decoding
246// of an input buffer of length n
247func (enc *Encoding) DecodedLen(n int) int {
248
249	// Fast path!
250	if n == enc.baseXBlockLen {
251		return enc.base256BlockLen
252	}
253
254	nblocks := n / enc.baseXBlockLen
255	out := nblocks * enc.base256BlockLen
256	rem := n % enc.baseXBlockLen
257	if rem > 0 {
258		out += int(math.Floor(float64(rem) * enc.logOfBase / float64(8)))
259	}
260	return out
261}
262
263// IsValidEncodingLength returns true if this block has a valid encoding length.
264// An encoding length is invalid if a short encoding would have sufficed.
265func (enc *Encoding) IsValidEncodingLength(n int) bool {
266	// Fast path!
267	if n == enc.baseXBlockLen {
268		return true
269	}
270	f := func(n int) int {
271		return int(math.Floor(float64(n) * enc.logOfBase / float64(8)))
272	}
273	return f(n) != f(n-1)
274}
275