1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package textproto
6
7import (
8	"bufio"
9	"bytes"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"strconv"
14	"strings"
15	"sync"
16)
17
18// A Reader implements convenience methods for reading requests
19// or responses from a text protocol network connection.
20type Reader struct {
21	R   *bufio.Reader
22	dot *dotReader
23	buf []byte // a re-usable buffer for readContinuedLineSlice
24}
25
26// NewReader returns a new Reader reading from r.
27//
28// To avoid denial of service attacks, the provided bufio.Reader
29// should be reading from an io.LimitReader or similar Reader to bound
30// the size of responses.
31func NewReader(r *bufio.Reader) *Reader {
32	commonHeaderOnce.Do(initCommonHeader)
33	return &Reader{R: r}
34}
35
36// ReadLine reads a single line from r,
37// eliding the final \n or \r\n from the returned string.
38func (r *Reader) ReadLine() (string, error) {
39	line, err := r.readLineSlice()
40	return string(line), err
41}
42
43// ReadLineBytes is like ReadLine but returns a []byte instead of a string.
44func (r *Reader) ReadLineBytes() ([]byte, error) {
45	line, err := r.readLineSlice()
46	if line != nil {
47		buf := make([]byte, len(line))
48		copy(buf, line)
49		line = buf
50	}
51	return line, err
52}
53
54func (r *Reader) readLineSlice() ([]byte, error) {
55	r.closeDot()
56	var line []byte
57	for {
58		l, more, err := r.R.ReadLine()
59		if err != nil {
60			return nil, err
61		}
62		// Avoid the copy if the first call produced a full line.
63		if line == nil && !more {
64			return l, nil
65		}
66		line = append(line, l...)
67		if !more {
68			break
69		}
70	}
71	return line, nil
72}
73
74// ReadContinuedLine reads a possibly continued line from r,
75// eliding the final trailing ASCII white space.
76// Lines after the first are considered continuations if they
77// begin with a space or tab character. In the returned data,
78// continuation lines are separated from the previous line
79// only by a single space: the newline and leading white space
80// are removed.
81//
82// For example, consider this input:
83//
84//	Line 1
85//	  continued...
86//	Line 2
87//
88// The first call to ReadContinuedLine will return "Line 1 continued..."
89// and the second will return "Line 2".
90//
91// A line consisting of only white space is never continued.
92//
93func (r *Reader) ReadContinuedLine() (string, error) {
94	line, err := r.readContinuedLineSlice(noValidation)
95	return string(line), err
96}
97
98// trim returns s with leading and trailing spaces and tabs removed.
99// It does not assume Unicode or UTF-8.
100func trim(s []byte) []byte {
101	i := 0
102	for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
103		i++
104	}
105	n := len(s)
106	for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
107		n--
108	}
109	return s[i:n]
110}
111
112// ReadContinuedLineBytes is like ReadContinuedLine but
113// returns a []byte instead of a string.
114func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
115	line, err := r.readContinuedLineSlice(noValidation)
116	if line != nil {
117		buf := make([]byte, len(line))
118		copy(buf, line)
119		line = buf
120	}
121	return line, err
122}
123
124// readContinuedLineSlice reads continued lines from the reader buffer,
125// returning a byte slice with all lines. The validateFirstLine function
126// is run on the first read line, and if it returns an error then this
127// error is returned from readContinuedLineSlice.
128func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) {
129	if validateFirstLine == nil {
130		return nil, fmt.Errorf("missing validateFirstLine func")
131	}
132
133	// Read the first line.
134	line, err := r.readLineSlice()
135	if err != nil {
136		return nil, err
137	}
138	if len(line) == 0 { // blank line - no continuation
139		return line, nil
140	}
141
142	if err := validateFirstLine(line); err != nil {
143		return nil, err
144	}
145
146	// Optimistically assume that we have started to buffer the next line
147	// and it starts with an ASCII letter (the next header key), or a blank
148	// line, so we can avoid copying that buffered data around in memory
149	// and skipping over non-existent whitespace.
150	if r.R.Buffered() > 1 {
151		peek, _ := r.R.Peek(2)
152		if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
153			len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
154			return trim(line), nil
155		}
156	}
157
158	// ReadByte or the next readLineSlice will flush the read buffer;
159	// copy the slice into buf.
160	r.buf = append(r.buf[:0], trim(line)...)
161
162	// Read continuation lines.
163	for r.skipSpace() > 0 {
164		line, err := r.readLineSlice()
165		if err != nil {
166			break
167		}
168		r.buf = append(r.buf, ' ')
169		r.buf = append(r.buf, trim(line)...)
170	}
171	return r.buf, nil
172}
173
174// skipSpace skips R over all spaces and returns the number of bytes skipped.
175func (r *Reader) skipSpace() int {
176	n := 0
177	for {
178		c, err := r.R.ReadByte()
179		if err != nil {
180			// Bufio will keep err until next read.
181			break
182		}
183		if c != ' ' && c != '\t' {
184			r.R.UnreadByte()
185			break
186		}
187		n++
188	}
189	return n
190}
191
192func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
193	line, err := r.ReadLine()
194	if err != nil {
195		return
196	}
197	return parseCodeLine(line, expectCode)
198}
199
200func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
201	if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
202		err = ProtocolError("short response: " + line)
203		return
204	}
205	continued = line[3] == '-'
206	code, err = strconv.Atoi(line[0:3])
207	if err != nil || code < 100 {
208		err = ProtocolError("invalid response code: " + line)
209		return
210	}
211	message = line[4:]
212	if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
213		10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
214		100 <= expectCode && expectCode < 1000 && code != expectCode {
215		err = &Error{code, message}
216	}
217	return
218}
219
220// ReadCodeLine reads a response code line of the form
221//	code message
222// where code is a three-digit status code and the message
223// extends to the rest of the line. An example of such a line is:
224//	220 plan9.bell-labs.com ESMTP
225//
226// If the prefix of the status does not match the digits in expectCode,
227// ReadCodeLine returns with err set to &Error{code, message}.
228// For example, if expectCode is 31, an error will be returned if
229// the status is not in the range [310,319].
230//
231// If the response is multi-line, ReadCodeLine returns an error.
232//
233// An expectCode <= 0 disables the check of the status code.
234//
235func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
236	code, continued, message, err := r.readCodeLine(expectCode)
237	if err == nil && continued {
238		err = ProtocolError("unexpected multi-line response: " + message)
239	}
240	return
241}
242
243// ReadResponse reads a multi-line response of the form:
244//
245//	code-message line 1
246//	code-message line 2
247//	...
248//	code message line n
249//
250// where code is a three-digit status code. The first line starts with the
251// code and a hyphen. The response is terminated by a line that starts
252// with the same code followed by a space. Each line in message is
253// separated by a newline (\n).
254//
255// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for
256// details of another form of response accepted:
257//
258//  code-message line 1
259//  message line 2
260//  ...
261//  code message line n
262//
263// If the prefix of the status does not match the digits in expectCode,
264// ReadResponse returns with err set to &Error{code, message}.
265// For example, if expectCode is 31, an error will be returned if
266// the status is not in the range [310,319].
267//
268// An expectCode <= 0 disables the check of the status code.
269//
270func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
271	code, continued, message, err := r.readCodeLine(expectCode)
272	multi := continued
273	for continued {
274		line, err := r.ReadLine()
275		if err != nil {
276			return 0, "", err
277		}
278
279		var code2 int
280		var moreMessage string
281		code2, continued, moreMessage, err = parseCodeLine(line, 0)
282		if err != nil || code2 != code {
283			message += "\n" + strings.TrimRight(line, "\r\n")
284			continued = true
285			continue
286		}
287		message += "\n" + moreMessage
288	}
289	if err != nil && multi && message != "" {
290		// replace one line error message with all lines (full message)
291		err = &Error{code, message}
292	}
293	return
294}
295
296// DotReader returns a new Reader that satisfies Reads using the
297// decoded text of a dot-encoded block read from r.
298// The returned Reader is only valid until the next call
299// to a method on r.
300//
301// Dot encoding is a common framing used for data blocks
302// in text protocols such as SMTP.  The data consists of a sequence
303// of lines, each of which ends in "\r\n".  The sequence itself
304// ends at a line containing just a dot: ".\r\n".  Lines beginning
305// with a dot are escaped with an additional dot to avoid
306// looking like the end of the sequence.
307//
308// The decoded form returned by the Reader's Read method
309// rewrites the "\r\n" line endings into the simpler "\n",
310// removes leading dot escapes if present, and stops with error io.EOF
311// after consuming (and discarding) the end-of-sequence line.
312func (r *Reader) DotReader() io.Reader {
313	r.closeDot()
314	r.dot = &dotReader{r: r}
315	return r.dot
316}
317
318type dotReader struct {
319	r     *Reader
320	state int
321}
322
323// Read satisfies reads by decoding dot-encoded data read from d.r.
324func (d *dotReader) Read(b []byte) (n int, err error) {
325	// Run data through a simple state machine to
326	// elide leading dots, rewrite trailing \r\n into \n,
327	// and detect ending .\r\n line.
328	const (
329		stateBeginLine = iota // beginning of line; initial state; must be zero
330		stateDot              // read . at beginning of line
331		stateDotCR            // read .\r at beginning of line
332		stateCR               // read \r (possibly at end of line)
333		stateData             // reading data in middle of line
334		stateEOF              // reached .\r\n end marker line
335	)
336	br := d.r.R
337	for n < len(b) && d.state != stateEOF {
338		var c byte
339		c, err = br.ReadByte()
340		if err != nil {
341			if err == io.EOF {
342				err = io.ErrUnexpectedEOF
343			}
344			break
345		}
346		switch d.state {
347		case stateBeginLine:
348			if c == '.' {
349				d.state = stateDot
350				continue
351			}
352			if c == '\r' {
353				d.state = stateCR
354				continue
355			}
356			d.state = stateData
357
358		case stateDot:
359			if c == '\r' {
360				d.state = stateDotCR
361				continue
362			}
363			if c == '\n' {
364				d.state = stateEOF
365				continue
366			}
367			d.state = stateData
368
369		case stateDotCR:
370			if c == '\n' {
371				d.state = stateEOF
372				continue
373			}
374			// Not part of .\r\n.
375			// Consume leading dot and emit saved \r.
376			br.UnreadByte()
377			c = '\r'
378			d.state = stateData
379
380		case stateCR:
381			if c == '\n' {
382				d.state = stateBeginLine
383				break
384			}
385			// Not part of \r\n. Emit saved \r
386			br.UnreadByte()
387			c = '\r'
388			d.state = stateData
389
390		case stateData:
391			if c == '\r' {
392				d.state = stateCR
393				continue
394			}
395			if c == '\n' {
396				d.state = stateBeginLine
397			}
398		}
399		b[n] = c
400		n++
401	}
402	if err == nil && d.state == stateEOF {
403		err = io.EOF
404	}
405	if err != nil && d.r.dot == d {
406		d.r.dot = nil
407	}
408	return
409}
410
411// closeDot drains the current DotReader if any,
412// making sure that it reads until the ending dot line.
413func (r *Reader) closeDot() {
414	if r.dot == nil {
415		return
416	}
417	buf := make([]byte, 128)
418	for r.dot != nil {
419		// When Read reaches EOF or an error,
420		// it will set r.dot == nil.
421		r.dot.Read(buf)
422	}
423}
424
425// ReadDotBytes reads a dot-encoding and returns the decoded data.
426//
427// See the documentation for the DotReader method for details about dot-encoding.
428func (r *Reader) ReadDotBytes() ([]byte, error) {
429	return ioutil.ReadAll(r.DotReader())
430}
431
432// ReadDotLines reads a dot-encoding and returns a slice
433// containing the decoded lines, with the final \r\n or \n elided from each.
434//
435// See the documentation for the DotReader method for details about dot-encoding.
436func (r *Reader) ReadDotLines() ([]string, error) {
437	// We could use ReadDotBytes and then Split it,
438	// but reading a line at a time avoids needing a
439	// large contiguous block of memory and is simpler.
440	var v []string
441	var err error
442	for {
443		var line string
444		line, err = r.ReadLine()
445		if err != nil {
446			if err == io.EOF {
447				err = io.ErrUnexpectedEOF
448			}
449			break
450		}
451
452		// Dot by itself marks end; otherwise cut one dot.
453		if len(line) > 0 && line[0] == '.' {
454			if len(line) == 1 {
455				break
456			}
457			line = line[1:]
458		}
459		v = append(v, line)
460	}
461	return v, err
462}
463
464// ReadMIMEHeader reads a MIME-style header from r.
465// The header is a sequence of possibly continued Key: Value lines
466// ending in a blank line.
467// The returned map m maps CanonicalMIMEHeaderKey(key) to a
468// sequence of values in the same order encountered in the input.
469//
470// For example, consider this input:
471//
472//	My-Key: Value 1
473//	Long-Key: Even
474//	       Longer Value
475//	My-Key: Value 2
476//
477// Given that input, ReadMIMEHeader returns the map:
478//
479//	map[string][]string{
480//		"My-Key": {"Value 1", "Value 2"},
481//		"Long-Key": {"Even Longer Value"},
482//	}
483//
484func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
485	// Avoid lots of small slice allocations later by allocating one
486	// large one ahead of time which we'll cut up into smaller
487	// slices. If this isn't big enough later, we allocate small ones.
488	var strs []string
489	hint := r.upcomingHeaderNewlines()
490	if hint > 0 {
491		strs = make([]string, hint)
492	}
493
494	m := make(MIMEHeader, hint)
495
496	// The first line cannot start with a leading space.
497	if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
498		line, err := r.readLineSlice()
499		if err != nil {
500			return m, err
501		}
502		return m, ProtocolError("malformed MIME header initial line: " + string(line))
503	}
504
505	for {
506		kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon)
507		if len(kv) == 0 {
508			return m, err
509		}
510
511		// Key ends at first colon.
512		i := bytes.IndexByte(kv, ':')
513		if i < 0 {
514			return m, ProtocolError("malformed MIME header line: " + string(kv))
515		}
516		key := canonicalMIMEHeaderKey(kv[:i])
517
518		// As per RFC 7230 field-name is a token, tokens consist of one or more chars.
519		// We could return a ProtocolError here, but better to be liberal in what we
520		// accept, so if we get an empty key, skip it.
521		if key == "" {
522			continue
523		}
524
525		// Skip initial spaces in value.
526		i++ // skip colon
527		for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') {
528			i++
529		}
530		value := string(kv[i:])
531
532		vv := m[key]
533		if vv == nil && len(strs) > 0 {
534			// More than likely this will be a single-element key.
535			// Most headers aren't multi-valued.
536			// Set the capacity on strs[0] to 1, so any future append
537			// won't extend the slice into the other strings.
538			vv, strs = strs[:1:1], strs[1:]
539			vv[0] = value
540			m[key] = vv
541		} else {
542			m[key] = append(vv, value)
543		}
544
545		if err != nil {
546			return m, err
547		}
548	}
549}
550
551// noValidation is a no-op validation func for readContinuedLineSlice
552// that permits any lines.
553func noValidation(_ []byte) error { return nil }
554
555// mustHaveFieldNameColon ensures that, per RFC 7230, the
556// field-name is on a single line, so the first line must
557// contain a colon.
558func mustHaveFieldNameColon(line []byte) error {
559	if bytes.IndexByte(line, ':') < 0 {
560		return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q" + string(line)))
561	}
562	return nil
563}
564
565// upcomingHeaderNewlines returns an approximation of the number of newlines
566// that will be in this header. If it gets confused, it returns 0.
567func (r *Reader) upcomingHeaderNewlines() (n int) {
568	// Try to determine the 'hint' size.
569	r.R.Peek(1) // force a buffer load if empty
570	s := r.R.Buffered()
571	if s == 0 {
572		return
573	}
574	peek, _ := r.R.Peek(s)
575	for len(peek) > 0 {
576		i := bytes.IndexByte(peek, '\n')
577		if i < 3 {
578			// Not present (-1) or found within the next few bytes,
579			// implying we're at the end ("\r\n\r\n" or "\n\n")
580			return
581		}
582		n++
583		peek = peek[i+1:]
584	}
585	return
586}
587
588// CanonicalMIMEHeaderKey returns the canonical format of the
589// MIME header key s. The canonicalization converts the first
590// letter and any letter following a hyphen to upper case;
591// the rest are converted to lowercase. For example, the
592// canonical key for "accept-encoding" is "Accept-Encoding".
593// MIME header keys are assumed to be ASCII only.
594// If s contains a space or invalid header field bytes, it is
595// returned without modifications.
596func CanonicalMIMEHeaderKey(s string) string {
597	commonHeaderOnce.Do(initCommonHeader)
598
599	// Quick check for canonical encoding.
600	upper := true
601	for i := 0; i < len(s); i++ {
602		c := s[i]
603		if !validHeaderFieldByte(c) {
604			return s
605		}
606		if upper && 'a' <= c && c <= 'z' {
607			return canonicalMIMEHeaderKey([]byte(s))
608		}
609		if !upper && 'A' <= c && c <= 'Z' {
610			return canonicalMIMEHeaderKey([]byte(s))
611		}
612		upper = c == '-'
613	}
614	return s
615}
616
617const toLower = 'a' - 'A'
618
619// validHeaderFieldByte reports whether b is a valid byte in a header
620// field name. RFC 7230 says:
621//   header-field   = field-name ":" OWS field-value OWS
622//   field-name     = token
623//   tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
624//           "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
625//   token = 1*tchar
626func validHeaderFieldByte(b byte) bool {
627	return int(b) < len(isTokenTable) && isTokenTable[b]
628}
629
630// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is
631// allowed to mutate the provided byte slice before returning the
632// string.
633//
634// For invalid inputs (if a contains spaces or non-token bytes), a
635// is unchanged and a string copy is returned.
636func canonicalMIMEHeaderKey(a []byte) string {
637	// See if a looks like a header key. If not, return it unchanged.
638	for _, c := range a {
639		if validHeaderFieldByte(c) {
640			continue
641		}
642		// Don't canonicalize.
643		return string(a)
644	}
645
646	upper := true
647	for i, c := range a {
648		// Canonicalize: first letter upper case
649		// and upper case after each dash.
650		// (Host, User-Agent, If-Modified-Since).
651		// MIME headers are ASCII only, so no Unicode issues.
652		if upper && 'a' <= c && c <= 'z' {
653			c -= toLower
654		} else if !upper && 'A' <= c && c <= 'Z' {
655			c += toLower
656		}
657		a[i] = c
658		upper = c == '-' // for next time
659	}
660	// The compiler recognizes m[string(byteSlice)] as a special
661	// case, so a copy of a's bytes into a new string does not
662	// happen in this map lookup:
663	if v := commonHeader[string(a)]; v != "" {
664		return v
665	}
666	return string(a)
667}
668
669// commonHeader interns common header strings.
670var commonHeader map[string]string
671
672var commonHeaderOnce sync.Once
673
674func initCommonHeader() {
675	commonHeader = make(map[string]string)
676	for _, v := range []string{
677		"Accept",
678		"Accept-Charset",
679		"Accept-Encoding",
680		"Accept-Language",
681		"Accept-Ranges",
682		"Cache-Control",
683		"Cc",
684		"Connection",
685		"Content-Id",
686		"Content-Language",
687		"Content-Length",
688		"Content-Transfer-Encoding",
689		"Content-Type",
690		"Cookie",
691		"Date",
692		"Dkim-Signature",
693		"Etag",
694		"Expires",
695		"From",
696		"Host",
697		"If-Modified-Since",
698		"If-None-Match",
699		"In-Reply-To",
700		"Last-Modified",
701		"Location",
702		"Message-Id",
703		"Mime-Version",
704		"Pragma",
705		"Received",
706		"Return-Path",
707		"Server",
708		"Set-Cookie",
709		"Subject",
710		"To",
711		"User-Agent",
712		"Via",
713		"X-Forwarded-For",
714		"X-Imforwards",
715		"X-Powered-By",
716	} {
717		commonHeader[v] = v
718	}
719}
720
721// isTokenTable is a copy of net/http/lex.go's isTokenTable.
722// See https://httpwg.github.io/specs/rfc7230.html#rule.token.separators
723var isTokenTable = [127]bool{
724	'!':  true,
725	'#':  true,
726	'$':  true,
727	'%':  true,
728	'&':  true,
729	'\'': true,
730	'*':  true,
731	'+':  true,
732	'-':  true,
733	'.':  true,
734	'0':  true,
735	'1':  true,
736	'2':  true,
737	'3':  true,
738	'4':  true,
739	'5':  true,
740	'6':  true,
741	'7':  true,
742	'8':  true,
743	'9':  true,
744	'A':  true,
745	'B':  true,
746	'C':  true,
747	'D':  true,
748	'E':  true,
749	'F':  true,
750	'G':  true,
751	'H':  true,
752	'I':  true,
753	'J':  true,
754	'K':  true,
755	'L':  true,
756	'M':  true,
757	'N':  true,
758	'O':  true,
759	'P':  true,
760	'Q':  true,
761	'R':  true,
762	'S':  true,
763	'T':  true,
764	'U':  true,
765	'W':  true,
766	'V':  true,
767	'X':  true,
768	'Y':  true,
769	'Z':  true,
770	'^':  true,
771	'_':  true,
772	'`':  true,
773	'a':  true,
774	'b':  true,
775	'c':  true,
776	'd':  true,
777	'e':  true,
778	'f':  true,
779	'g':  true,
780	'h':  true,
781	'i':  true,
782	'j':  true,
783	'k':  true,
784	'l':  true,
785	'm':  true,
786	'n':  true,
787	'o':  true,
788	'p':  true,
789	'q':  true,
790	'r':  true,
791	's':  true,
792	't':  true,
793	'u':  true,
794	'v':  true,
795	'w':  true,
796	'x':  true,
797	'y':  true,
798	'z':  true,
799	'|':  true,
800	'~':  true,
801}
802