1// Copyright 2015 Keybase, Inc. All rights reserved. Use of 2// this source code is governed by the included BSD license. 3 4package basex 5 6import ( 7 "bytes" 8 "errors" 9 "fmt" 10 "math" 11 "math/big" 12) 13 14// Encoding is a radix X encoding/decoding scheme, defined by X-length 15// character alphabet. 16type Encoding struct { 17 encode []byte 18 decodeMap [256](*big.Int) 19 skipMap [256]bool 20 base256BlockLen int 21 baseXBlockLen int 22 base int 23 logOfBase float64 24 baseBig *big.Int 25 skipBytes string 26} 27 28// NewEncoding returns a new Encoding defined by the given alphabet, 29// which must a x-byte string. No padding options are currently allowed. 30// inBlock is the size of input blocks to consider. 31// 32// For base 58, we recommend 19-byte 33// input blocks, which encode to 26-byte output blocks with only .3 bits 34// wasted per block. The name of the game is to find a good rational 35// approximation of 8/log2(58), and 26/19 is pretty good! 36func NewEncoding(encoder string, base256BlockLen int, skipBytes string) *Encoding { 37 38 base := len(encoder) 39 40 logOfBase := math.Log2(float64(base)) 41 42 // If input blocks are base256BlockLen size, compute the corresponding 43 // output block length. We need to round up to fit the overflow. 44 baseXBlockLen := int(math.Ceil(float64(8*base256BlockLen) / logOfBase)) 45 46 // Code adapted from encoding/base64/base64.go in the standard 47 // Go libraries. 48 e := &Encoding{ 49 encode: make([]byte, base), 50 base: base, 51 base256BlockLen: base256BlockLen, 52 baseXBlockLen: baseXBlockLen, 53 logOfBase: logOfBase, 54 baseBig: big.NewInt(int64(base)), 55 skipBytes: skipBytes, 56 } 57 copy(e.encode[:], encoder) 58 59 for _, c := range skipBytes { 60 e.skipMap[c] = true 61 } 62 for i := 0; i < len(encoder); i++ { 63 e.decodeMap[encoder[i]] = big.NewInt(int64(i)) 64 } 65 return e 66} 67 68/* 69 * Encoder 70 */ 71 72// Encode encodes src using the encoding enc, writing 73// EncodedLen(len(src)) bytes to dst. 74// 75// The encoding aligns the input along base256BlockLen boundaries. 76// so Encode is not appropriate for use on individual blocks 77// of a large data stream. Use NewEncoder() instead. 78func (enc *Encoding) Encode(dst, src []byte) { 79 for sp, dp, sLim, dLim := 0, 0, 0, 0; sp < len(src); sp, dp = sLim, dLim { 80 sLim = sp + enc.base256BlockLen 81 dLim = dp + enc.baseXBlockLen 82 if sLim > len(src) { 83 sLim = len(src) 84 } 85 if dLim > len(dst) { 86 dLim = len(dst) 87 } 88 enc.encodeBlock(dst[dp:dLim], src[sp:sLim]) 89 } 90} 91 92type byteType int 93 94const ( 95 normalByteType byteType = 0 96 skipByteType byteType = 1 97 invalidByteType byteType = 2 98) 99 100func (enc *Encoding) getByteType(b byte) byteType { 101 if enc.decodeMap[b] != nil { 102 return normalByteType 103 } 104 if enc.skipMap[b] { 105 return skipByteType 106 } 107 return invalidByteType 108 109} 110 111func (enc *Encoding) hasSkipBytes() bool { 112 return len(enc.skipBytes) > 0 113} 114 115// IsValidByte returns true if the given byte is valid in this 116// decoding. Can be either from the main alphabet or the skip 117// alphabet to be considered valid. 118func (enc *Encoding) IsValidByte(b byte) bool { 119 return enc.decodeMap[b] != nil || enc.skipMap[b] 120} 121 122// encodeBlock fills the dst buffer with the encoding of src. 123// It is assumed the buffers are appropriately sized, and no 124// bounds checks are performed. In particular, the dst buffer will 125// be zero-padded from right to left in all remaining bytes. 126func (enc *Encoding) encodeBlock(dst, src []byte) { 127 128 // Interpret the block as a big-endian number (Go's default) 129 num := new(big.Int).SetBytes(src) 130 rem := new(big.Int) 131 quo := new(big.Int) 132 133 encodedLen := enc.EncodedLen(len(src)) 134 135 p := encodedLen - 1 136 137 for num.Sign() != 0 { 138 num, rem = quo.QuoRem(num, enc.baseBig, rem) 139 dst[p] = enc.encode[rem.Uint64()] 140 p-- 141 } 142 143 // Pad the remainder of the buffer with 0s 144 for p >= 0 { 145 dst[p] = enc.encode[0] 146 p-- 147 } 148} 149 150func (enc *Encoding) decode(dst []byte, src []byte) (n int, err error) { 151 dp, sp := 0, 0 152 for sp < len(src) { 153 di, si, err := enc.decodeBlock(dst[dp:], src[sp:], sp) 154 if err != nil { 155 return dp, err 156 } 157 sp += si 158 dp += di 159 } 160 return dp, nil 161} 162 163// Decode decodes src using the encoding enc. It writes at most 164// DecodedLen(len(src)) bytes to dst and returns the number of bytes 165// written. If src contains invalid baseX data, it will return the 166// number of bytes successfully written and CorruptInputError. It can 167// also return an ErrInvalidEncodingLength error if there is a non-standard 168// number of bytes in this encoding 169func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { 170 return enc.decode(dst, src) 171} 172 173// CorruptInputError is returned when Decode() finds a non-alphabet character 174type CorruptInputError int 175 176// Error fits the error interface 177func (e CorruptInputError) Error() string { 178 return fmt.Sprintf("illegal data at input byte %d", int(e)) 179} 180 181// ErrInvalidEncodingLength is returned when a non-minimal encoding length is found 182var ErrInvalidEncodingLength = errors.New("invalid encoding length; either truncated or has trailing garbage") 183 184func (enc *Encoding) decodeBlock(dst []byte, src []byte, baseOffset int) (int, int, error) { 185 si := 0 // source index 186 numGoodChars := 0 187 res := new(big.Int) 188 res.SetUint64(0) 189 190 for i, b := range src { 191 v := enc.decodeMap[b] 192 si++ 193 194 if v == nil { 195 if enc.skipMap[b] { 196 continue 197 } 198 return 0, 0, CorruptInputError(i + baseOffset) 199 } 200 201 numGoodChars++ 202 res.Mul(res, enc.baseBig) 203 res.Add(res, v) 204 205 if numGoodChars == enc.baseXBlockLen { 206 break 207 } 208 } 209 210 if !enc.IsValidEncodingLength(numGoodChars) { 211 return 0, 0, ErrInvalidEncodingLength 212 } 213 214 paddedLen := enc.DecodedLen(numGoodChars) 215 216 // Use big-endian representation (the default with Go's library) 217 raw := res.Bytes() 218 p := 0 219 if len(raw) < paddedLen { 220 p = paddedLen - len(raw) 221 copy(dst, bytes.Repeat([]byte{0}, p)) 222 } 223 copy(dst[p:paddedLen], raw) 224 return paddedLen, si, nil 225} 226 227// EncodedLen returns the length in bytes of the baseX encoding 228// of an input buffer of length n 229func (enc *Encoding) EncodedLen(n int) int { 230 231 // Fast path! 232 if n == enc.base256BlockLen { 233 return enc.baseXBlockLen 234 } 235 236 nblocks := n / enc.base256BlockLen 237 out := nblocks * enc.baseXBlockLen 238 rem := n % enc.base256BlockLen 239 if rem > 0 { 240 out += int(math.Ceil(float64(rem*8) / enc.logOfBase)) 241 } 242 return out 243} 244 245// DecodedLen returns the length in bytes of the baseX decoding 246// of an input buffer of length n 247func (enc *Encoding) DecodedLen(n int) int { 248 249 // Fast path! 250 if n == enc.baseXBlockLen { 251 return enc.base256BlockLen 252 } 253 254 nblocks := n / enc.baseXBlockLen 255 out := nblocks * enc.base256BlockLen 256 rem := n % enc.baseXBlockLen 257 if rem > 0 { 258 out += int(math.Floor(float64(rem) * enc.logOfBase / float64(8))) 259 } 260 return out 261} 262 263// IsValidEncodingLength returns true if this block has a valid encoding length. 264// An encoding length is invalid if a short encoding would have sufficed. 265func (enc *Encoding) IsValidEncodingLength(n int) bool { 266 // Fast path! 267 if n == enc.baseXBlockLen { 268 return true 269 } 270 f := func(n int) int { 271 return int(math.Floor(float64(n) * enc.logOfBase / float64(8))) 272 } 273 return f(n) != f(n-1) 274} 275