1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package base64 implements base64 encoding as specified by RFC 4648.
6package base64
7
8import (
9	"encoding/binary"
10	"io"
11	"strconv"
12)
13
14/*
15 * Encodings
16 */
17
18// An Encoding is a radix 64 encoding/decoding scheme, defined by a
19// 64-character alphabet. The most common encoding is the "base64"
20// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
21// (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
22// the standard encoding with - and _ substituted for + and /.
23type Encoding struct {
24	encode    [64]byte
25	decodeMap [256]byte
26	padChar   rune
27	strict    bool
28}
29
30const (
31	StdPadding rune = '=' // Standard padding character
32	NoPadding  rune = -1  // No padding
33)
34
35const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
36const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
37
38// NewEncoding returns a new padded Encoding defined by the given alphabet,
39// which must be a 64-byte string that does not contain the padding character
40// or CR / LF ('\r', '\n').
41// The resulting Encoding uses the default padding character ('='),
42// which may be changed or disabled via WithPadding.
43func NewEncoding(encoder string) *Encoding {
44	if len(encoder) != 64 {
45		panic("encoding alphabet is not 64-bytes long")
46	}
47	for i := 0; i < len(encoder); i++ {
48		if encoder[i] == '\n' || encoder[i] == '\r' {
49			panic("encoding alphabet contains newline character")
50		}
51	}
52
53	e := new(Encoding)
54	e.padChar = StdPadding
55	copy(e.encode[:], encoder)
56
57	for i := 0; i < len(e.decodeMap); i++ {
58		e.decodeMap[i] = 0xFF
59	}
60	for i := 0; i < len(encoder); i++ {
61		e.decodeMap[encoder[i]] = byte(i)
62	}
63	return e
64}
65
66// WithPadding creates a new encoding identical to enc except
67// with a specified padding character, or NoPadding to disable padding.
68// The padding character must not be '\r' or '\n', must not
69// be contained in the encoding's alphabet and must be a rune equal or
70// below '\xff'.
71func (enc Encoding) WithPadding(padding rune) *Encoding {
72	if padding == '\r' || padding == '\n' || padding > 0xff {
73		panic("invalid padding")
74	}
75
76	for i := 0; i < len(enc.encode); i++ {
77		if rune(enc.encode[i]) == padding {
78			panic("padding contained in alphabet")
79		}
80	}
81
82	enc.padChar = padding
83	return &enc
84}
85
86// Strict creates a new encoding identical to enc except with
87// strict decoding enabled. In this mode, the decoder requires that
88// trailing padding bits are zero, as described in RFC 4648 section 3.5.
89func (enc Encoding) Strict() *Encoding {
90	enc.strict = true
91	return &enc
92}
93
94// StdEncoding is the standard base64 encoding, as defined in
95// RFC 4648.
96var StdEncoding = NewEncoding(encodeStd)
97
98// URLEncoding is the alternate base64 encoding defined in RFC 4648.
99// It is typically used in URLs and file names.
100var URLEncoding = NewEncoding(encodeURL)
101
102// RawStdEncoding is the standard raw, unpadded base64 encoding,
103// as defined in RFC 4648 section 3.2.
104// This is the same as StdEncoding but omits padding characters.
105var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
106
107// RawURLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
108// It is typically used in URLs and file names.
109// This is the same as URLEncoding but omits padding characters.
110var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
111
112/*
113 * Encoder
114 */
115
116// Encode encodes src using the encoding enc, writing
117// EncodedLen(len(src)) bytes to dst.
118//
119// The encoding pads the output to a multiple of 4 bytes,
120// so Encode is not appropriate for use on individual blocks
121// of a large data stream. Use NewEncoder() instead.
122func (enc *Encoding) Encode(dst, src []byte) {
123	if len(src) == 0 {
124		return
125	}
126
127	di, si := 0, 0
128	n := (len(src) / 3) * 3
129	for si < n {
130		// Convert 3x 8bit source bytes into 4 bytes
131		val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
132
133		dst[di+0] = enc.encode[val>>18&0x3F]
134		dst[di+1] = enc.encode[val>>12&0x3F]
135		dst[di+2] = enc.encode[val>>6&0x3F]
136		dst[di+3] = enc.encode[val&0x3F]
137
138		si += 3
139		di += 4
140	}
141
142	remain := len(src) - si
143	if remain == 0 {
144		return
145	}
146	// Add the remaining small block
147	val := uint(src[si+0]) << 16
148	if remain == 2 {
149		val |= uint(src[si+1]) << 8
150	}
151
152	dst[di+0] = enc.encode[val>>18&0x3F]
153	dst[di+1] = enc.encode[val>>12&0x3F]
154
155	switch remain {
156	case 2:
157		dst[di+2] = enc.encode[val>>6&0x3F]
158		if enc.padChar != NoPadding {
159			dst[di+3] = byte(enc.padChar)
160		}
161	case 1:
162		if enc.padChar != NoPadding {
163			dst[di+2] = byte(enc.padChar)
164			dst[di+3] = byte(enc.padChar)
165		}
166	}
167}
168
169// EncodeToString returns the base64 encoding of src.
170func (enc *Encoding) EncodeToString(src []byte) string {
171	buf := make([]byte, enc.EncodedLen(len(src)))
172	enc.Encode(buf, src)
173	return string(buf)
174}
175
176type encoder struct {
177	err  error
178	enc  *Encoding
179	w    io.Writer
180	buf  [3]byte    // buffered data waiting to be encoded
181	nbuf int        // number of bytes in buf
182	out  [1024]byte // output buffer
183}
184
185func (e *encoder) Write(p []byte) (n int, err error) {
186	if e.err != nil {
187		return 0, e.err
188	}
189
190	// Leading fringe.
191	if e.nbuf > 0 {
192		var i int
193		for i = 0; i < len(p) && e.nbuf < 3; i++ {
194			e.buf[e.nbuf] = p[i]
195			e.nbuf++
196		}
197		n += i
198		p = p[i:]
199		if e.nbuf < 3 {
200			return
201		}
202		e.enc.Encode(e.out[:], e.buf[:])
203		if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
204			return n, e.err
205		}
206		e.nbuf = 0
207	}
208
209	// Large interior chunks.
210	for len(p) >= 3 {
211		nn := len(e.out) / 4 * 3
212		if nn > len(p) {
213			nn = len(p)
214			nn -= nn % 3
215		}
216		e.enc.Encode(e.out[:], p[:nn])
217		if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
218			return n, e.err
219		}
220		n += nn
221		p = p[nn:]
222	}
223
224	// Trailing fringe.
225	for i := 0; i < len(p); i++ {
226		e.buf[i] = p[i]
227	}
228	e.nbuf = len(p)
229	n += len(p)
230	return
231}
232
233// Close flushes any pending output from the encoder.
234// It is an error to call Write after calling Close.
235func (e *encoder) Close() error {
236	// If there's anything left in the buffer, flush it out
237	if e.err == nil && e.nbuf > 0 {
238		e.enc.Encode(e.out[:], e.buf[:e.nbuf])
239		_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
240		e.nbuf = 0
241	}
242	return e.err
243}
244
245// NewEncoder returns a new base64 stream encoder. Data written to
246// the returned writer will be encoded using enc and then written to w.
247// Base64 encodings operate in 4-byte blocks; when finished
248// writing, the caller must Close the returned encoder to flush any
249// partially written blocks.
250func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
251	return &encoder{enc: enc, w: w}
252}
253
254// EncodedLen returns the length in bytes of the base64 encoding
255// of an input buffer of length n.
256func (enc *Encoding) EncodedLen(n int) int {
257	if enc.padChar == NoPadding {
258		return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
259	}
260	return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
261}
262
263/*
264 * Decoder
265 */
266
267type CorruptInputError int64
268
269func (e CorruptInputError) Error() string {
270	return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
271}
272
273// decodeQuantum decodes up to 4 base64 bytes. It takes for parameters
274// the destination buffer dst, the source buffer src and an index in the
275// source buffer si.
276// It returns the number of bytes read from src, the number of bytes written
277// to dst, and an error, if any.
278func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
279	// Decode quantum using the base64 alphabet
280	var dbuf [4]byte
281	dinc, dlen := 3, 4
282
283	for j := 0; j < len(dbuf); j++ {
284		if len(src) == si {
285			switch {
286			case j == 0:
287				return si, 0, nil
288			case j == 1, enc.padChar != NoPadding:
289				return si, 0, CorruptInputError(si - j)
290			}
291			dinc, dlen = j-1, j
292			break
293		}
294		in := src[si]
295		si++
296
297		out := enc.decodeMap[in]
298		if out != 0xff {
299			dbuf[j] = out
300			continue
301		}
302
303		if in == '\n' || in == '\r' {
304			j--
305			continue
306		}
307
308		if rune(in) != enc.padChar {
309			return si, 0, CorruptInputError(si - 1)
310		}
311
312		// We've reached the end and there's padding
313		switch j {
314		case 0, 1:
315			// incorrect padding
316			return si, 0, CorruptInputError(si - 1)
317		case 2:
318			// "==" is expected, the first "=" is already consumed.
319			// skip over newlines
320			for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
321				si++
322			}
323			if si == len(src) {
324				// not enough padding
325				return si, 0, CorruptInputError(len(src))
326			}
327			if rune(src[si]) != enc.padChar {
328				// incorrect padding
329				return si, 0, CorruptInputError(si - 1)
330			}
331
332			si++
333		}
334
335		// skip over newlines
336		for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
337			si++
338		}
339		if si < len(src) {
340			// trailing garbage
341			err = CorruptInputError(si)
342		}
343		dinc, dlen = 3, j
344		break
345	}
346
347	// Convert 4x 6bit source bytes into 3 bytes
348	val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
349	dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
350	switch dlen {
351	case 4:
352		dst[2] = dbuf[2]
353		dbuf[2] = 0
354		fallthrough
355	case 3:
356		dst[1] = dbuf[1]
357		if enc.strict && dbuf[2] != 0 {
358			return si, 0, CorruptInputError(si - 1)
359		}
360		dbuf[1] = 0
361		fallthrough
362	case 2:
363		dst[0] = dbuf[0]
364		if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
365			return si, 0, CorruptInputError(si - 2)
366		}
367	}
368	dst = dst[dinc:]
369
370	return si, dlen - 1, err
371}
372
373// DecodeString returns the bytes represented by the base64 string s.
374func (enc *Encoding) DecodeString(s string) ([]byte, error) {
375	dbuf := make([]byte, enc.DecodedLen(len(s)))
376	n, err := enc.Decode(dbuf, []byte(s))
377	return dbuf[:n], err
378}
379
380type decoder struct {
381	err     error
382	readErr error // error from r.Read
383	enc     *Encoding
384	r       io.Reader
385	buf     [1024]byte // leftover input
386	nbuf    int
387	out     []byte // leftover decoded output
388	outbuf  [1024 / 4 * 3]byte
389}
390
391func (d *decoder) Read(p []byte) (n int, err error) {
392	// Use leftover decoded output from last read.
393	if len(d.out) > 0 {
394		n = copy(p, d.out)
395		d.out = d.out[n:]
396		return n, nil
397	}
398
399	if d.err != nil {
400		return 0, d.err
401	}
402
403	// This code assumes that d.r strips supported whitespace ('\r' and '\n').
404
405	// Refill buffer.
406	for d.nbuf < 4 && d.readErr == nil {
407		nn := len(p) / 3 * 4
408		if nn < 4 {
409			nn = 4
410		}
411		if nn > len(d.buf) {
412			nn = len(d.buf)
413		}
414		nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
415		d.nbuf += nn
416	}
417
418	if d.nbuf < 4 {
419		if d.enc.padChar == NoPadding && d.nbuf > 0 {
420			// Decode final fragment, without padding.
421			var nw int
422			nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
423			d.nbuf = 0
424			d.out = d.outbuf[:nw]
425			n = copy(p, d.out)
426			d.out = d.out[n:]
427			if n > 0 || len(p) == 0 && len(d.out) > 0 {
428				return n, nil
429			}
430			if d.err != nil {
431				return 0, d.err
432			}
433		}
434		d.err = d.readErr
435		if d.err == io.EOF && d.nbuf > 0 {
436			d.err = io.ErrUnexpectedEOF
437		}
438		return 0, d.err
439	}
440
441	// Decode chunk into p, or d.out and then p if p is too small.
442	nr := d.nbuf / 4 * 4
443	nw := d.nbuf / 4 * 3
444	if nw > len(p) {
445		nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
446		d.out = d.outbuf[:nw]
447		n = copy(p, d.out)
448		d.out = d.out[n:]
449	} else {
450		n, d.err = d.enc.Decode(p, d.buf[:nr])
451	}
452	d.nbuf -= nr
453	copy(d.buf[:d.nbuf], d.buf[nr:])
454	return n, d.err
455}
456
457// Decode decodes src using the encoding enc. It writes at most
458// DecodedLen(len(src)) bytes to dst and returns the number of bytes
459// written. If src contains invalid base64 data, it will return the
460// number of bytes successfully written and CorruptInputError.
461// New line characters (\r and \n) are ignored.
462func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
463	if len(src) == 0 {
464		return 0, nil
465	}
466
467	si := 0
468	ilen := len(src)
469	olen := len(dst)
470	for strconv.IntSize >= 64 && ilen-si >= 8 && olen-n >= 8 {
471		if ok := enc.decode64(dst[n:], src[si:]); ok {
472			n += 6
473			si += 8
474		} else {
475			var ninc int
476			si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
477			n += ninc
478			if err != nil {
479				return n, err
480			}
481		}
482	}
483
484	for ilen-si >= 4 && olen-n >= 4 {
485		if ok := enc.decode32(dst[n:], src[si:]); ok {
486			n += 3
487			si += 4
488		} else {
489			var ninc int
490			si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
491			n += ninc
492			if err != nil {
493				return n, err
494			}
495		}
496	}
497
498	for si < len(src) {
499		var ninc int
500		si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
501		n += ninc
502		if err != nil {
503			return n, err
504		}
505	}
506	return n, err
507}
508
509// decode32 tries to decode 4 base64 char into 3 bytes.
510// len(dst) and len(src) must both be >= 4.
511// Returns true if decode succeeded.
512func (enc *Encoding) decode32(dst, src []byte) bool {
513	var dn, n uint32
514	if n = uint32(enc.decodeMap[src[0]]); n == 0xff {
515		return false
516	}
517	dn |= n << 26
518	if n = uint32(enc.decodeMap[src[1]]); n == 0xff {
519		return false
520	}
521	dn |= n << 20
522	if n = uint32(enc.decodeMap[src[2]]); n == 0xff {
523		return false
524	}
525	dn |= n << 14
526	if n = uint32(enc.decodeMap[src[3]]); n == 0xff {
527		return false
528	}
529	dn |= n << 8
530
531	binary.BigEndian.PutUint32(dst, dn)
532	return true
533}
534
535// decode64 tries to decode 8 base64 char into 6 bytes.
536// len(dst) and len(src) must both be >= 8.
537// Returns true if decode succeeded.
538func (enc *Encoding) decode64(dst, src []byte) bool {
539	var dn, n uint64
540	if n = uint64(enc.decodeMap[src[0]]); n == 0xff {
541		return false
542	}
543	dn |= n << 58
544	if n = uint64(enc.decodeMap[src[1]]); n == 0xff {
545		return false
546	}
547	dn |= n << 52
548	if n = uint64(enc.decodeMap[src[2]]); n == 0xff {
549		return false
550	}
551	dn |= n << 46
552	if n = uint64(enc.decodeMap[src[3]]); n == 0xff {
553		return false
554	}
555	dn |= n << 40
556	if n = uint64(enc.decodeMap[src[4]]); n == 0xff {
557		return false
558	}
559	dn |= n << 34
560	if n = uint64(enc.decodeMap[src[5]]); n == 0xff {
561		return false
562	}
563	dn |= n << 28
564	if n = uint64(enc.decodeMap[src[6]]); n == 0xff {
565		return false
566	}
567	dn |= n << 22
568	if n = uint64(enc.decodeMap[src[7]]); n == 0xff {
569		return false
570	}
571	dn |= n << 16
572
573	binary.BigEndian.PutUint64(dst, dn)
574	return true
575}
576
577type newlineFilteringReader struct {
578	wrapped io.Reader
579}
580
581func (r *newlineFilteringReader) Read(p []byte) (int, error) {
582	n, err := r.wrapped.Read(p)
583	for n > 0 {
584		offset := 0
585		for i, b := range p[:n] {
586			if b != '\r' && b != '\n' {
587				if i != offset {
588					p[offset] = b
589				}
590				offset++
591			}
592		}
593		if offset > 0 {
594			return offset, err
595		}
596		// Previous buffer entirely whitespace, read again
597		n, err = r.wrapped.Read(p)
598	}
599	return n, err
600}
601
602// NewDecoder constructs a new base64 stream decoder.
603func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
604	return &decoder{enc: enc, r: &newlineFilteringReader{r}}
605}
606
607// DecodedLen returns the maximum length in bytes of the decoded data
608// corresponding to n bytes of base64-encoded data.
609func (enc *Encoding) DecodedLen(n int) int {
610	if enc.padChar == NoPadding {
611		// Unpadded data may end with partial block of 2-3 characters.
612		return n * 6 / 8
613	}
614	// Padded base64 should always be a multiple of 4 characters in length.
615	return n / 4 * 3
616}
617