1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mime
6
7import (
8	"errors"
9	"fmt"
10	"sort"
11	"strings"
12	"unicode"
13)
14
15// FormatMediaType serializes mediatype t and the parameters
16// param as a media type conforming to RFC 2045 and RFC 2616.
17// The type and parameter names are written in lower-case.
18// When any of the arguments result in a standard violation then
19// FormatMediaType returns the empty string.
20func FormatMediaType(t string, param map[string]string) string {
21	var b strings.Builder
22	if slash := strings.IndexByte(t, '/'); slash == -1 {
23		if !isToken(t) {
24			return ""
25		}
26		b.WriteString(strings.ToLower(t))
27	} else {
28		major, sub := t[:slash], t[slash+1:]
29		if !isToken(major) || !isToken(sub) {
30			return ""
31		}
32		b.WriteString(strings.ToLower(major))
33		b.WriteByte('/')
34		b.WriteString(strings.ToLower(sub))
35	}
36
37	attrs := make([]string, 0, len(param))
38	for a := range param {
39		attrs = append(attrs, a)
40	}
41	sort.Strings(attrs)
42
43	for _, attribute := range attrs {
44		value := param[attribute]
45		b.WriteByte(';')
46		b.WriteByte(' ')
47		if !isToken(attribute) {
48			return ""
49		}
50		b.WriteString(strings.ToLower(attribute))
51
52		needEnc := needsEncoding(value)
53		if needEnc {
54			// RFC 2231 section 4
55			b.WriteByte('*')
56		}
57		b.WriteByte('=')
58
59		if needEnc {
60			b.WriteString("utf-8''")
61
62			offset := 0
63			for index := 0; index < len(value); index++ {
64				ch := value[index]
65				// {RFC 2231 section 7}
66				// attribute-char := <any (US-ASCII) CHAR except SPACE, CTLs, "*", "'", "%", or tspecials>
67				if ch <= ' ' || ch >= 0x7F ||
68					ch == '*' || ch == '\'' || ch == '%' ||
69					isTSpecial(rune(ch)) {
70
71					b.WriteString(value[offset:index])
72					offset = index + 1
73
74					b.WriteByte('%')
75					b.WriteByte(upperhex[ch>>4])
76					b.WriteByte(upperhex[ch&0x0F])
77				}
78			}
79			b.WriteString(value[offset:])
80			continue
81		}
82
83		if isToken(value) {
84			b.WriteString(value)
85			continue
86		}
87
88		b.WriteByte('"')
89		offset := 0
90		for index := 0; index < len(value); index++ {
91			character := value[index]
92			if character == '"' || character == '\\' {
93				b.WriteString(value[offset:index])
94				offset = index
95				b.WriteByte('\\')
96			}
97		}
98		b.WriteString(value[offset:])
99		b.WriteByte('"')
100	}
101	return b.String()
102}
103
104func checkMediaTypeDisposition(s string) error {
105	typ, rest := consumeToken(s)
106	if typ == "" {
107		return errors.New("mime: no media type")
108	}
109	if rest == "" {
110		return nil
111	}
112	if !strings.HasPrefix(rest, "/") {
113		return errors.New("mime: expected slash after first token")
114	}
115	subtype, rest := consumeToken(rest[1:])
116	if subtype == "" {
117		return errors.New("mime: expected token after slash")
118	}
119	if rest != "" {
120		return errors.New("mime: unexpected content after media subtype")
121	}
122	return nil
123}
124
125// ErrInvalidMediaParameter is returned by ParseMediaType if
126// the media type value was found but there was an error parsing
127// the optional parameters
128var ErrInvalidMediaParameter = errors.New("mime: invalid media parameter")
129
130// ParseMediaType parses a media type value and any optional
131// parameters, per RFC 1521.  Media types are the values in
132// Content-Type and Content-Disposition headers (RFC 2183).
133// On success, ParseMediaType returns the media type converted
134// to lowercase and trimmed of white space and a non-nil map.
135// If there is an error parsing the optional parameter,
136// the media type will be returned along with the error
137// ErrInvalidMediaParameter.
138// The returned map, params, maps from the lowercase
139// attribute to the attribute value with its case preserved.
140func ParseMediaType(v string) (mediatype string, params map[string]string, err error) {
141	i := strings.Index(v, ";")
142	if i == -1 {
143		i = len(v)
144	}
145	mediatype = strings.TrimSpace(strings.ToLower(v[0:i]))
146
147	err = checkMediaTypeDisposition(mediatype)
148	if err != nil {
149		return "", nil, err
150	}
151
152	params = make(map[string]string)
153
154	// Map of base parameter name -> parameter name -> value
155	// for parameters containing a '*' character.
156	// Lazily initialized.
157	var continuation map[string]map[string]string
158
159	v = v[i:]
160	for len(v) > 0 {
161		v = strings.TrimLeftFunc(v, unicode.IsSpace)
162		if len(v) == 0 {
163			break
164		}
165		key, value, rest := consumeMediaParam(v)
166		if key == "" {
167			if strings.TrimSpace(rest) == ";" {
168				// Ignore trailing semicolons.
169				// Not an error.
170				return
171			}
172			// Parse error.
173			return mediatype, nil, ErrInvalidMediaParameter
174		}
175
176		pmap := params
177		if idx := strings.Index(key, "*"); idx != -1 {
178			baseName := key[:idx]
179			if continuation == nil {
180				continuation = make(map[string]map[string]string)
181			}
182			var ok bool
183			if pmap, ok = continuation[baseName]; !ok {
184				continuation[baseName] = make(map[string]string)
185				pmap = continuation[baseName]
186			}
187		}
188		if _, exists := pmap[key]; exists {
189			// Duplicate parameter name is bogus.
190			return "", nil, errors.New("mime: duplicate parameter name")
191		}
192		pmap[key] = value
193		v = rest
194	}
195
196	// Stitch together any continuations or things with stars
197	// (i.e. RFC 2231 things with stars: "foo*0" or "foo*")
198	var buf strings.Builder
199	for key, pieceMap := range continuation {
200		singlePartKey := key + "*"
201		if v, ok := pieceMap[singlePartKey]; ok {
202			if decv, ok := decode2231Enc(v); ok {
203				params[key] = decv
204			}
205			continue
206		}
207
208		buf.Reset()
209		valid := false
210		for n := 0; ; n++ {
211			simplePart := fmt.Sprintf("%s*%d", key, n)
212			if v, ok := pieceMap[simplePart]; ok {
213				valid = true
214				buf.WriteString(v)
215				continue
216			}
217			encodedPart := simplePart + "*"
218			v, ok := pieceMap[encodedPart]
219			if !ok {
220				break
221			}
222			valid = true
223			if n == 0 {
224				if decv, ok := decode2231Enc(v); ok {
225					buf.WriteString(decv)
226				}
227			} else {
228				decv, _ := percentHexUnescape(v)
229				buf.WriteString(decv)
230			}
231		}
232		if valid {
233			params[key] = buf.String()
234		}
235	}
236
237	return
238}
239
240func decode2231Enc(v string) (string, bool) {
241	sv := strings.SplitN(v, "'", 3)
242	if len(sv) != 3 {
243		return "", false
244	}
245	// TODO: ignoring lang in sv[1] for now. If anybody needs it we'll
246	// need to decide how to expose it in the API. But I'm not sure
247	// anybody uses it in practice.
248	charset := strings.ToLower(sv[0])
249	if len(charset) == 0 {
250		return "", false
251	}
252	if charset != "us-ascii" && charset != "utf-8" {
253		// TODO: unsupported encoding
254		return "", false
255	}
256	encv, err := percentHexUnescape(sv[2])
257	if err != nil {
258		return "", false
259	}
260	return encv, true
261}
262
263func isNotTokenChar(r rune) bool {
264	return !isTokenChar(r)
265}
266
267// consumeToken consumes a token from the beginning of provided
268// string, per RFC 2045 section 5.1 (referenced from 2183), and return
269// the token consumed and the rest of the string. Returns ("", v) on
270// failure to consume at least one character.
271func consumeToken(v string) (token, rest string) {
272	notPos := strings.IndexFunc(v, isNotTokenChar)
273	if notPos == -1 {
274		return v, ""
275	}
276	if notPos == 0 {
277		return "", v
278	}
279	return v[0:notPos], v[notPos:]
280}
281
282// consumeValue consumes a "value" per RFC 2045, where a value is
283// either a 'token' or a 'quoted-string'.  On success, consumeValue
284// returns the value consumed (and de-quoted/escaped, if a
285// quoted-string) and the rest of the string. On failure, returns
286// ("", v).
287func consumeValue(v string) (value, rest string) {
288	if v == "" {
289		return
290	}
291	if v[0] != '"' {
292		return consumeToken(v)
293	}
294
295	// parse a quoted-string
296	buffer := new(strings.Builder)
297	for i := 1; i < len(v); i++ {
298		r := v[i]
299		if r == '"' {
300			return buffer.String(), v[i+1:]
301		}
302		// When MSIE sends a full file path (in "intranet mode"), it does not
303		// escape backslashes: "C:\dev\go\foo.txt", not "C:\\dev\\go\\foo.txt".
304		//
305		// No known MIME generators emit unnecessary backslash escapes
306		// for simple token characters like numbers and letters.
307		//
308		// If we see an unnecessary backslash escape, assume it is from MSIE
309		// and intended as a literal backslash. This makes Go servers deal better
310		// with MSIE without affecting the way they handle conforming MIME
311		// generators.
312		if r == '\\' && i+1 < len(v) && isTSpecial(rune(v[i+1])) {
313			buffer.WriteByte(v[i+1])
314			i++
315			continue
316		}
317		if r == '\r' || r == '\n' {
318			return "", v
319		}
320		buffer.WriteByte(v[i])
321	}
322	// Did not find end quote.
323	return "", v
324}
325
326func consumeMediaParam(v string) (param, value, rest string) {
327	rest = strings.TrimLeftFunc(v, unicode.IsSpace)
328	if !strings.HasPrefix(rest, ";") {
329		return "", "", v
330	}
331
332	rest = rest[1:] // consume semicolon
333	rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
334	param, rest = consumeToken(rest)
335	param = strings.ToLower(param)
336	if param == "" {
337		return "", "", v
338	}
339
340	rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
341	if !strings.HasPrefix(rest, "=") {
342		return "", "", v
343	}
344	rest = rest[1:] // consume equals sign
345	rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
346	value, rest2 := consumeValue(rest)
347	if value == "" && rest2 == rest {
348		return "", "", v
349	}
350	rest = rest2
351	return param, value, rest
352}
353
354func percentHexUnescape(s string) (string, error) {
355	// Count %, check that they're well-formed.
356	percents := 0
357	for i := 0; i < len(s); {
358		if s[i] != '%' {
359			i++
360			continue
361		}
362		percents++
363		if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
364			s = s[i:]
365			if len(s) > 3 {
366				s = s[0:3]
367			}
368			return "", fmt.Errorf("mime: bogus characters after %%: %q", s)
369		}
370		i += 3
371	}
372	if percents == 0 {
373		return s, nil
374	}
375
376	t := make([]byte, len(s)-2*percents)
377	j := 0
378	for i := 0; i < len(s); {
379		switch s[i] {
380		case '%':
381			t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
382			j++
383			i += 3
384		default:
385			t[j] = s[i]
386			j++
387			i++
388		}
389	}
390	return string(t), nil
391}
392
393func ishex(c byte) bool {
394	switch {
395	case '0' <= c && c <= '9':
396		return true
397	case 'a' <= c && c <= 'f':
398		return true
399	case 'A' <= c && c <= 'F':
400		return true
401	}
402	return false
403}
404
405func unhex(c byte) byte {
406	switch {
407	case '0' <= c && c <= '9':
408		return c - '0'
409	case 'a' <= c && c <= 'f':
410		return c - 'a' + 10
411	case 'A' <= c && c <= 'F':
412		return c - 'A' + 10
413	}
414	return 0
415}
416