1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8	"crypto/rand"
9	"crypto/sha1"
10	"encoding/base64"
11	"io"
12	"net/http"
13	"strings"
14	"unicode/utf8"
15)
16
17var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
18
19func computeAcceptKey(challengeKey string) string {
20	h := sha1.New()
21	h.Write([]byte(challengeKey))
22	h.Write(keyGUID)
23	return base64.StdEncoding.EncodeToString(h.Sum(nil))
24}
25
26func generateChallengeKey() (string, error) {
27	p := make([]byte, 16)
28	if _, err := io.ReadFull(rand.Reader, p); err != nil {
29		return "", err
30	}
31	return base64.StdEncoding.EncodeToString(p), nil
32}
33
34// Token octets per RFC 2616.
35var isTokenOctet = [256]bool{
36	'!':  true,
37	'#':  true,
38	'$':  true,
39	'%':  true,
40	'&':  true,
41	'\'': true,
42	'*':  true,
43	'+':  true,
44	'-':  true,
45	'.':  true,
46	'0':  true,
47	'1':  true,
48	'2':  true,
49	'3':  true,
50	'4':  true,
51	'5':  true,
52	'6':  true,
53	'7':  true,
54	'8':  true,
55	'9':  true,
56	'A':  true,
57	'B':  true,
58	'C':  true,
59	'D':  true,
60	'E':  true,
61	'F':  true,
62	'G':  true,
63	'H':  true,
64	'I':  true,
65	'J':  true,
66	'K':  true,
67	'L':  true,
68	'M':  true,
69	'N':  true,
70	'O':  true,
71	'P':  true,
72	'Q':  true,
73	'R':  true,
74	'S':  true,
75	'T':  true,
76	'U':  true,
77	'W':  true,
78	'V':  true,
79	'X':  true,
80	'Y':  true,
81	'Z':  true,
82	'^':  true,
83	'_':  true,
84	'`':  true,
85	'a':  true,
86	'b':  true,
87	'c':  true,
88	'd':  true,
89	'e':  true,
90	'f':  true,
91	'g':  true,
92	'h':  true,
93	'i':  true,
94	'j':  true,
95	'k':  true,
96	'l':  true,
97	'm':  true,
98	'n':  true,
99	'o':  true,
100	'p':  true,
101	'q':  true,
102	'r':  true,
103	's':  true,
104	't':  true,
105	'u':  true,
106	'v':  true,
107	'w':  true,
108	'x':  true,
109	'y':  true,
110	'z':  true,
111	'|':  true,
112	'~':  true,
113}
114
115// skipSpace returns a slice of the string s with all leading RFC 2616 linear
116// whitespace removed.
117func skipSpace(s string) (rest string) {
118	i := 0
119	for ; i < len(s); i++ {
120		if b := s[i]; b != ' ' && b != '\t' {
121			break
122		}
123	}
124	return s[i:]
125}
126
127// nextToken returns the leading RFC 2616 token of s and the string following
128// the token.
129func nextToken(s string) (token, rest string) {
130	i := 0
131	for ; i < len(s); i++ {
132		if !isTokenOctet[s[i]] {
133			break
134		}
135	}
136	return s[:i], s[i:]
137}
138
139// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
140// and the string following the token or quoted string.
141func nextTokenOrQuoted(s string) (value string, rest string) {
142	if !strings.HasPrefix(s, "\"") {
143		return nextToken(s)
144	}
145	s = s[1:]
146	for i := 0; i < len(s); i++ {
147		switch s[i] {
148		case '"':
149			return s[:i], s[i+1:]
150		case '\\':
151			p := make([]byte, len(s)-1)
152			j := copy(p, s[:i])
153			escape := true
154			for i = i + 1; i < len(s); i++ {
155				b := s[i]
156				switch {
157				case escape:
158					escape = false
159					p[j] = b
160					j++
161				case b == '\\':
162					escape = true
163				case b == '"':
164					return string(p[:j]), s[i+1:]
165				default:
166					p[j] = b
167					j++
168				}
169			}
170			return "", ""
171		}
172	}
173	return "", ""
174}
175
176// equalASCIIFold returns true if s is equal to t with ASCII case folding as
177// defined in RFC 4790.
178func equalASCIIFold(s, t string) bool {
179	for s != "" && t != "" {
180		sr, size := utf8.DecodeRuneInString(s)
181		s = s[size:]
182		tr, size := utf8.DecodeRuneInString(t)
183		t = t[size:]
184		if sr == tr {
185			continue
186		}
187		if 'A' <= sr && sr <= 'Z' {
188			sr = sr + 'a' - 'A'
189		}
190		if 'A' <= tr && tr <= 'Z' {
191			tr = tr + 'a' - 'A'
192		}
193		if sr != tr {
194			return false
195		}
196	}
197	return s == t
198}
199
200// tokenListContainsValue returns true if the 1#token header with the given
201// name contains a token equal to value with ASCII case folding.
202func tokenListContainsValue(header http.Header, name string, value string) bool {
203headers:
204	for _, s := range header[name] {
205		for {
206			var t string
207			t, s = nextToken(skipSpace(s))
208			if t == "" {
209				continue headers
210			}
211			s = skipSpace(s)
212			if s != "" && s[0] != ',' {
213				continue headers
214			}
215			if equalASCIIFold(t, value) {
216				return true
217			}
218			if s == "" {
219				continue headers
220			}
221			s = s[1:]
222		}
223	}
224	return false
225}
226
227// parseExtensions parses WebSocket extensions from a header.
228func parseExtensions(header http.Header) []map[string]string {
229	// From RFC 6455:
230	//
231	//  Sec-WebSocket-Extensions = extension-list
232	//  extension-list = 1#extension
233	//  extension = extension-token *( ";" extension-param )
234	//  extension-token = registered-token
235	//  registered-token = token
236	//  extension-param = token [ "=" (token | quoted-string) ]
237	//     ;When using the quoted-string syntax variant, the value
238	//     ;after quoted-string unescaping MUST conform to the
239	//     ;'token' ABNF.
240
241	var result []map[string]string
242headers:
243	for _, s := range header["Sec-Websocket-Extensions"] {
244		for {
245			var t string
246			t, s = nextToken(skipSpace(s))
247			if t == "" {
248				continue headers
249			}
250			ext := map[string]string{"": t}
251			for {
252				s = skipSpace(s)
253				if !strings.HasPrefix(s, ";") {
254					break
255				}
256				var k string
257				k, s = nextToken(skipSpace(s[1:]))
258				if k == "" {
259					continue headers
260				}
261				s = skipSpace(s)
262				var v string
263				if strings.HasPrefix(s, "=") {
264					v, s = nextTokenOrQuoted(skipSpace(s[1:]))
265					s = skipSpace(s)
266				}
267				if s != "" && s[0] != ',' && s[0] != ';' {
268					continue headers
269				}
270				ext[k] = v
271			}
272			if s != "" && s[0] != ',' {
273				continue headers
274			}
275			result = append(result, ext)
276			if s == "" {
277				continue headers
278			}
279			s = s[1:]
280		}
281	}
282	return result
283}
284