1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package base64 implements base64 encoding as specified by RFC 4648.
6package base64
7
8import (
9	"io"
10	"strconv"
11)
12
13/*
14 * Encodings
15 */
16
17// An Encoding is a radix 64 encoding/decoding scheme, defined by a
18// 64-character alphabet.  The most common encoding is the "base64"
19// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
20// (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
21// the standard encoding with - and _ substituted for + and /.
22type Encoding struct {
23	encode    string
24	decodeMap [256]byte
25}
26
27const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
28const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
29
30// NewEncoding returns a new Encoding defined by the given alphabet,
31// which must be a 64-byte string.
32func NewEncoding(encoder string) *Encoding {
33	e := new(Encoding)
34	e.encode = encoder
35	for i := 0; i < len(e.decodeMap); i++ {
36		e.decodeMap[i] = 0xFF
37	}
38	for i := 0; i < len(encoder); i++ {
39		e.decodeMap[encoder[i]] = byte(i)
40	}
41	return e
42}
43
44// StdEncoding is the standard base64 encoding, as defined in
45// RFC 4648.
46var StdEncoding = NewEncoding(encodeStd)
47
48// URLEncoding is the alternate base64 encoding defined in RFC 4648.
49// It is typically used in URLs and file names.
50var URLEncoding = NewEncoding(encodeURL)
51
52/*
53 * Encoder
54 */
55
56// Encode encodes src using the encoding enc, writing
57// EncodedLen(len(src)) bytes to dst.
58//
59// The encoding pads the output to a multiple of 4 bytes,
60// so Encode is not appropriate for use on individual blocks
61// of a large data stream.  Use NewEncoder() instead.
62func (enc *Encoding) Encode(dst, src []byte) {
63	if len(src) == 0 {
64		return
65	}
66
67	for len(src) > 0 {
68		dst[0] = 0
69		dst[1] = 0
70		dst[2] = 0
71		dst[3] = 0
72
73		// Unpack 4x 6-bit source blocks into a 4 byte
74		// destination quantum
75		switch len(src) {
76		default:
77			dst[3] |= src[2] & 0x3F
78			dst[2] |= src[2] >> 6
79			fallthrough
80		case 2:
81			dst[2] |= (src[1] << 2) & 0x3F
82			dst[1] |= src[1] >> 4
83			fallthrough
84		case 1:
85			dst[1] |= (src[0] << 4) & 0x3F
86			dst[0] |= src[0] >> 2
87		}
88
89		// Encode 6-bit blocks using the base64 alphabet
90		for j := 0; j < 4; j++ {
91			dst[j] = enc.encode[dst[j]]
92		}
93
94		// Pad the final quantum
95		if len(src) < 3 {
96			dst[3] = '='
97			if len(src) < 2 {
98				dst[2] = '='
99			}
100			break
101		}
102
103		src = src[3:]
104		dst = dst[4:]
105	}
106}
107
108// EncodeToString returns the base64 encoding of src.
109func (enc *Encoding) EncodeToString(src []byte) string {
110	buf := make([]byte, enc.EncodedLen(len(src)))
111	enc.Encode(buf, src)
112	return string(buf)
113}
114
115type encoder struct {
116	err  error
117	enc  *Encoding
118	w    io.Writer
119	buf  [3]byte    // buffered data waiting to be encoded
120	nbuf int        // number of bytes in buf
121	out  [1024]byte // output buffer
122}
123
124func (e *encoder) Write(p []byte) (n int, err error) {
125	if e.err != nil {
126		return 0, e.err
127	}
128
129	// Leading fringe.
130	if e.nbuf > 0 {
131		var i int
132		for i = 0; i < len(p) && e.nbuf < 3; i++ {
133			e.buf[e.nbuf] = p[i]
134			e.nbuf++
135		}
136		n += i
137		p = p[i:]
138		if e.nbuf < 3 {
139			return
140		}
141		e.enc.Encode(e.out[0:], e.buf[0:])
142		if _, e.err = e.w.Write(e.out[0:4]); e.err != nil {
143			return n, e.err
144		}
145		e.nbuf = 0
146	}
147
148	// Large interior chunks.
149	for len(p) >= 3 {
150		nn := len(e.out) / 4 * 3
151		if nn > len(p) {
152			nn = len(p)
153		}
154		nn -= nn % 3
155		if nn > 0 {
156			e.enc.Encode(e.out[0:], p[0:nn])
157			if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
158				return n, e.err
159			}
160		}
161		n += nn
162		p = p[nn:]
163	}
164
165	// Trailing fringe.
166	for i := 0; i < len(p); i++ {
167		e.buf[i] = p[i]
168	}
169	e.nbuf = len(p)
170	n += len(p)
171	return
172}
173
174// Close flushes any pending output from the encoder.
175// It is an error to call Write after calling Close.
176func (e *encoder) Close() error {
177	// If there's anything left in the buffer, flush it out
178	if e.err == nil && e.nbuf > 0 {
179		e.enc.Encode(e.out[0:], e.buf[0:e.nbuf])
180		e.nbuf = 0
181		_, e.err = e.w.Write(e.out[0:4])
182	}
183	return e.err
184}
185
186// NewEncoder returns a new base64 stream encoder.  Data written to
187// the returned writer will be encoded using enc and then written to w.
188// Base64 encodings operate in 4-byte blocks; when finished
189// writing, the caller must Close the returned encoder to flush any
190// partially written blocks.
191func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
192	return &encoder{enc: enc, w: w}
193}
194
195// EncodedLen returns the length in bytes of the base64 encoding
196// of an input buffer of length n.
197func (enc *Encoding) EncodedLen(n int) int { return (n + 2) / 3 * 4 }
198
199/*
200 * Decoder
201 */
202
203type CorruptInputError int64
204
205func (e CorruptInputError) Error() string {
206	return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
207}
208
209// decode is like Decode but returns an additional 'end' value, which
210// indicates if end-of-message padding was encountered and thus any
211// additional data is an error.
212func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
213	osrc := src
214	for len(src) > 0 && !end {
215		// Decode quantum using the base64 alphabet
216		var dbuf [4]byte
217		dlen := 4
218
219		for j := 0; j < 4; {
220			if len(src) == 0 {
221				return n, false, CorruptInputError(len(osrc) - len(src) - j)
222			}
223			in := src[0]
224			src = src[1:]
225			if in == '\r' || in == '\n' {
226				// Ignore this character.
227				continue
228			}
229			if in == '=' && j >= 2 && len(src) < 4 {
230				// We've reached the end and there's
231				// padding
232				if len(src) == 0 && j == 2 {
233					// not enough padding
234					return n, false, CorruptInputError(len(osrc))
235				}
236				if len(src) > 0 && src[0] != '=' {
237					// incorrect padding
238					return n, false, CorruptInputError(len(osrc) - len(src) - 1)
239				}
240				dlen = j
241				end = true
242				break
243			}
244			dbuf[j] = enc.decodeMap[in]
245			if dbuf[j] == 0xFF {
246				return n, false, CorruptInputError(len(osrc) - len(src) - 1)
247			}
248			j++
249		}
250
251		// Pack 4x 6-bit source blocks into 3 byte destination
252		// quantum
253		switch dlen {
254		case 4:
255			dst[2] = dbuf[2]<<6 | dbuf[3]
256			fallthrough
257		case 3:
258			dst[1] = dbuf[1]<<4 | dbuf[2]>>2
259			fallthrough
260		case 2:
261			dst[0] = dbuf[0]<<2 | dbuf[1]>>4
262		}
263		dst = dst[3:]
264		n += dlen - 1
265	}
266
267	return n, end, nil
268}
269
270// Decode decodes src using the encoding enc.  It writes at most
271// DecodedLen(len(src)) bytes to dst and returns the number of bytes
272// written.  If src contains invalid base64 data, it will return the
273// number of bytes successfully written and CorruptInputError.
274// New line characters (\r and \n) are ignored.
275func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
276	n, _, err = enc.decode(dst, src)
277	return
278}
279
280// DecodeString returns the bytes represented by the base64 string s.
281func (enc *Encoding) DecodeString(s string) ([]byte, error) {
282	dbuf := make([]byte, enc.DecodedLen(len(s)))
283	n, err := enc.Decode(dbuf, []byte(s))
284	return dbuf[:n], err
285}
286
287type decoder struct {
288	err    error
289	enc    *Encoding
290	r      io.Reader
291	end    bool       // saw end of message
292	buf    [1024]byte // leftover input
293	nbuf   int
294	out    []byte // leftover decoded output
295	outbuf [1024 / 4 * 3]byte
296}
297
298func (d *decoder) Read(p []byte) (n int, err error) {
299	if d.err != nil {
300		return 0, d.err
301	}
302
303	// Use leftover decoded output from last read.
304	if len(d.out) > 0 {
305		n = copy(p, d.out)
306		d.out = d.out[n:]
307		return n, nil
308	}
309
310	// Read a chunk.
311	nn := len(p) / 3 * 4
312	if nn < 4 {
313		nn = 4
314	}
315	if nn > len(d.buf) {
316		nn = len(d.buf)
317	}
318	nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 4-d.nbuf)
319	d.nbuf += nn
320	if d.err != nil || d.nbuf < 4 {
321		return 0, d.err
322	}
323
324	// Decode chunk into p, or d.out and then p if p is too small.
325	nr := d.nbuf / 4 * 4
326	nw := d.nbuf / 4 * 3
327	if nw > len(p) {
328		nw, d.end, d.err = d.enc.decode(d.outbuf[0:], d.buf[0:nr])
329		d.out = d.outbuf[0:nw]
330		n = copy(p, d.out)
331		d.out = d.out[n:]
332	} else {
333		n, d.end, d.err = d.enc.decode(p, d.buf[0:nr])
334	}
335	d.nbuf -= nr
336	for i := 0; i < d.nbuf; i++ {
337		d.buf[i] = d.buf[i+nr]
338	}
339
340	if d.err == nil {
341		d.err = err
342	}
343	return n, d.err
344}
345
346// NewDecoder constructs a new base64 stream decoder.
347func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
348	return &decoder{enc: enc, r: r}
349}
350
351// DecodedLen returns the maximum length in bytes of the decoded data
352// corresponding to n bytes of base64-encoded data.
353func (enc *Encoding) DecodedLen(n int) int { return n / 4 * 3 }
354