1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8	"crypto/rand"
9	"crypto/sha1"
10	"encoding/base64"
11	"io"
12	"net/http"
13	"strings"
14	"unicode/utf8"
15)
16
17var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
18
19func computeAcceptKey(challengeKey string) string {
20	h := sha1.New()
21	h.Write([]byte(challengeKey))
22	h.Write(keyGUID)
23	return base64.StdEncoding.EncodeToString(h.Sum(nil))
24}
25
26func generateChallengeKey() (string, error) {
27	p := make([]byte, 16)
28	if _, err := io.ReadFull(rand.Reader, p); err != nil {
29		return "", err
30	}
31	return base64.StdEncoding.EncodeToString(p), nil
32}
33
34// Octet types from RFC 2616.
35var octetTypes [256]byte
36
37const (
38	isTokenOctet = 1 << iota
39	isSpaceOctet
40)
41
42func init() {
43	// From RFC 2616
44	//
45	// OCTET      = <any 8-bit sequence of data>
46	// CHAR       = <any US-ASCII character (octets 0 - 127)>
47	// CTL        = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
48	// CR         = <US-ASCII CR, carriage return (13)>
49	// LF         = <US-ASCII LF, linefeed (10)>
50	// SP         = <US-ASCII SP, space (32)>
51	// HT         = <US-ASCII HT, horizontal-tab (9)>
52	// <">        = <US-ASCII double-quote mark (34)>
53	// CRLF       = CR LF
54	// LWS        = [CRLF] 1*( SP | HT )
55	// TEXT       = <any OCTET except CTLs, but including LWS>
56	// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
57	//              | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
58	// token      = 1*<any CHAR except CTLs or separators>
59	// qdtext     = <any TEXT except <">>
60
61	for c := 0; c < 256; c++ {
62		var t byte
63		isCtl := c <= 31 || c == 127
64		isChar := 0 <= c && c <= 127
65		isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
66		if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
67			t |= isSpaceOctet
68		}
69		if isChar && !isCtl && !isSeparator {
70			t |= isTokenOctet
71		}
72		octetTypes[c] = t
73	}
74}
75
76func skipSpace(s string) (rest string) {
77	i := 0
78	for ; i < len(s); i++ {
79		if octetTypes[s[i]]&isSpaceOctet == 0 {
80			break
81		}
82	}
83	return s[i:]
84}
85
86func nextToken(s string) (token, rest string) {
87	i := 0
88	for ; i < len(s); i++ {
89		if octetTypes[s[i]]&isTokenOctet == 0 {
90			break
91		}
92	}
93	return s[:i], s[i:]
94}
95
96func nextTokenOrQuoted(s string) (value string, rest string) {
97	if !strings.HasPrefix(s, "\"") {
98		return nextToken(s)
99	}
100	s = s[1:]
101	for i := 0; i < len(s); i++ {
102		switch s[i] {
103		case '"':
104			return s[:i], s[i+1:]
105		case '\\':
106			p := make([]byte, len(s)-1)
107			j := copy(p, s[:i])
108			escape := true
109			for i = i + 1; i < len(s); i++ {
110				b := s[i]
111				switch {
112				case escape:
113					escape = false
114					p[j] = b
115					j++
116				case b == '\\':
117					escape = true
118				case b == '"':
119					return string(p[:j]), s[i+1:]
120				default:
121					p[j] = b
122					j++
123				}
124			}
125			return "", ""
126		}
127	}
128	return "", ""
129}
130
131// equalASCIIFold returns true if s is equal to t with ASCII case folding.
132func equalASCIIFold(s, t string) bool {
133	for s != "" && t != "" {
134		sr, size := utf8.DecodeRuneInString(s)
135		s = s[size:]
136		tr, size := utf8.DecodeRuneInString(t)
137		t = t[size:]
138		if sr == tr {
139			continue
140		}
141		if 'A' <= sr && sr <= 'Z' {
142			sr = sr + 'a' - 'A'
143		}
144		if 'A' <= tr && tr <= 'Z' {
145			tr = tr + 'a' - 'A'
146		}
147		if sr != tr {
148			return false
149		}
150	}
151	return s == t
152}
153
154// tokenListContainsValue returns true if the 1#token header with the given
155// name contains a token equal to value with ASCII case folding.
156func tokenListContainsValue(header http.Header, name string, value string) bool {
157headers:
158	for _, s := range header[name] {
159		for {
160			var t string
161			t, s = nextToken(skipSpace(s))
162			if t == "" {
163				continue headers
164			}
165			s = skipSpace(s)
166			if s != "" && s[0] != ',' {
167				continue headers
168			}
169			if equalASCIIFold(t, value) {
170				return true
171			}
172			if s == "" {
173				continue headers
174			}
175			s = s[1:]
176		}
177	}
178	return false
179}
180
181// parseExtensions parses WebSocket extensions from a header.
182func parseExtensions(header http.Header) []map[string]string {
183	// From RFC 6455:
184	//
185	//  Sec-WebSocket-Extensions = extension-list
186	//  extension-list = 1#extension
187	//  extension = extension-token *( ";" extension-param )
188	//  extension-token = registered-token
189	//  registered-token = token
190	//  extension-param = token [ "=" (token | quoted-string) ]
191	//     ;When using the quoted-string syntax variant, the value
192	//     ;after quoted-string unescaping MUST conform to the
193	//     ;'token' ABNF.
194
195	var result []map[string]string
196headers:
197	for _, s := range header["Sec-Websocket-Extensions"] {
198		for {
199			var t string
200			t, s = nextToken(skipSpace(s))
201			if t == "" {
202				continue headers
203			}
204			ext := map[string]string{"": t}
205			for {
206				s = skipSpace(s)
207				if !strings.HasPrefix(s, ";") {
208					break
209				}
210				var k string
211				k, s = nextToken(skipSpace(s[1:]))
212				if k == "" {
213					continue headers
214				}
215				s = skipSpace(s)
216				var v string
217				if strings.HasPrefix(s, "=") {
218					v, s = nextTokenOrQuoted(skipSpace(s[1:]))
219					s = skipSpace(s)
220				}
221				if s != "" && s[0] != ',' && s[0] != ';' {
222					continue headers
223				}
224				ext[k] = v
225			}
226			if s != "" && s[0] != ',' {
227				continue headers
228			}
229			result = append(result, ext)
230			if s == "" {
231				continue headers
232			}
233			s = s[1:]
234		}
235	}
236	return result
237}
238