1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package websocket 6 7import ( 8 "bufio" 9 "errors" 10 "net" 11 "net/http" 12 "net/url" 13 "strings" 14 "time" 15) 16 17// HandshakeError describes an error with the handshake from the peer. 18type HandshakeError struct { 19 message string 20} 21 22func (e HandshakeError) Error() string { return e.message } 23 24// Upgrader specifies parameters for upgrading an HTTP connection to a 25// WebSocket connection. 26type Upgrader struct { 27 // HandshakeTimeout specifies the duration for the handshake to complete. 28 HandshakeTimeout time.Duration 29 30 // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer 31 // size is zero, then buffers allocated by the HTTP server are used. The 32 // I/O buffer sizes do not limit the size of the messages that can be sent 33 // or received. 34 ReadBufferSize, WriteBufferSize int 35 36 // Subprotocols specifies the server's supported protocols in order of 37 // preference. If this field is set, then the Upgrade method negotiates a 38 // subprotocol by selecting the first match in this list with a protocol 39 // requested by the client. 40 Subprotocols []string 41 42 // Error specifies the function for generating HTTP error responses. If Error 43 // is nil, then http.Error is used to generate the HTTP response. 44 Error func(w http.ResponseWriter, r *http.Request, status int, reason error) 45 46 // CheckOrigin returns true if the request Origin header is acceptable. If 47 // CheckOrigin is nil, the host in the Origin header must not be set or 48 // must match the host of the request. 49 CheckOrigin func(r *http.Request) bool 50 51 // EnableCompression specify if the server should attempt to negotiate per 52 // message compression (RFC 7692). Setting this value to true does not 53 // guarantee that compression will be supported. Currently only "no context 54 // takeover" modes are supported. 55 EnableCompression bool 56} 57 58func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { 59 err := HandshakeError{reason} 60 if u.Error != nil { 61 u.Error(w, r, status, err) 62 } else { 63 w.Header().Set("Sec-Websocket-Version", "13") 64 http.Error(w, http.StatusText(status), status) 65 } 66 return nil, err 67} 68 69// checkSameOrigin returns true if the origin is not set or is equal to the request host. 70func checkSameOrigin(r *http.Request) bool { 71 origin := r.Header["Origin"] 72 if len(origin) == 0 { 73 return true 74 } 75 u, err := url.Parse(origin[0]) 76 if err != nil { 77 return false 78 } 79 return u.Host == r.Host 80} 81 82func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { 83 if u.Subprotocols != nil { 84 clientProtocols := Subprotocols(r) 85 for _, serverProtocol := range u.Subprotocols { 86 for _, clientProtocol := range clientProtocols { 87 if clientProtocol == serverProtocol { 88 return clientProtocol 89 } 90 } 91 } 92 } else if responseHeader != nil { 93 return responseHeader.Get("Sec-Websocket-Protocol") 94 } 95 return "" 96} 97 98// Upgrade upgrades the HTTP server connection to the WebSocket protocol. 99// 100// The responseHeader is included in the response to the client's upgrade 101// request. Use the responseHeader to specify cookies (Set-Cookie) and the 102// application negotiated subprotocol (Sec-Websocket-Protocol). 103// 104// If the upgrade fails, then Upgrade replies to the client with an HTTP error 105// response. 106func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { 107 if r.Method != "GET" { 108 return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET") 109 } 110 111 if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { 112 return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported") 113 } 114 115 if !tokenListContainsValue(r.Header, "Connection", "upgrade") { 116 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header") 117 } 118 119 if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { 120 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header") 121 } 122 123 if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { 124 return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") 125 } 126 127 checkOrigin := u.CheckOrigin 128 if checkOrigin == nil { 129 checkOrigin = checkSameOrigin 130 } 131 if !checkOrigin(r) { 132 return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed") 133 } 134 135 challengeKey := r.Header.Get("Sec-Websocket-Key") 136 if challengeKey == "" { 137 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank") 138 } 139 140 subprotocol := u.selectSubprotocol(r, responseHeader) 141 142 // Negotiate PMCE 143 var compress bool 144 if u.EnableCompression { 145 for _, ext := range parseExtensions(r.Header) { 146 if ext[""] != "permessage-deflate" { 147 continue 148 } 149 compress = true 150 break 151 } 152 } 153 154 var ( 155 netConn net.Conn 156 err error 157 ) 158 159 h, ok := w.(http.Hijacker) 160 if !ok { 161 return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") 162 } 163 var brw *bufio.ReadWriter 164 netConn, brw, err = h.Hijack() 165 if err != nil { 166 return u.returnError(w, r, http.StatusInternalServerError, err.Error()) 167 } 168 169 if brw.Reader.Buffered() > 0 { 170 netConn.Close() 171 return nil, errors.New("websocket: client sent data before handshake is complete") 172 } 173 174 c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) 175 c.subprotocol = subprotocol 176 177 if compress { 178 c.newCompressionWriter = compressNoContextTakeover 179 c.newDecompressionReader = decompressNoContextTakeover 180 } 181 182 p := c.writeBuf[:0] 183 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 184 p = append(p, computeAcceptKey(challengeKey)...) 185 p = append(p, "\r\n"...) 186 if c.subprotocol != "" { 187 p = append(p, "Sec-Websocket-Protocol: "...) 188 p = append(p, c.subprotocol...) 189 p = append(p, "\r\n"...) 190 } 191 if compress { 192 p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) 193 } 194 for k, vs := range responseHeader { 195 if k == "Sec-Websocket-Protocol" { 196 continue 197 } 198 for _, v := range vs { 199 p = append(p, k...) 200 p = append(p, ": "...) 201 for i := 0; i < len(v); i++ { 202 b := v[i] 203 if b <= 31 { 204 // prevent response splitting. 205 b = ' ' 206 } 207 p = append(p, b) 208 } 209 p = append(p, "\r\n"...) 210 } 211 } 212 p = append(p, "\r\n"...) 213 214 // Clear deadlines set by HTTP server. 215 netConn.SetDeadline(time.Time{}) 216 217 if u.HandshakeTimeout > 0 { 218 netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) 219 } 220 if _, err = netConn.Write(p); err != nil { 221 netConn.Close() 222 return nil, err 223 } 224 if u.HandshakeTimeout > 0 { 225 netConn.SetWriteDeadline(time.Time{}) 226 } 227 228 return c, nil 229} 230 231// Upgrade upgrades the HTTP server connection to the WebSocket protocol. 232// 233// This function is deprecated, use websocket.Upgrader instead. 234// 235// The application is responsible for checking the request origin before 236// calling Upgrade. An example implementation of the same origin policy is: 237// 238// if req.Header.Get("Origin") != "http://"+req.Host { 239// http.Error(w, "Origin not allowed", 403) 240// return 241// } 242// 243// If the endpoint supports subprotocols, then the application is responsible 244// for negotiating the protocol used on the connection. Use the Subprotocols() 245// function to get the subprotocols requested by the client. Use the 246// Sec-Websocket-Protocol response header to specify the subprotocol selected 247// by the application. 248// 249// The responseHeader is included in the response to the client's upgrade 250// request. Use the responseHeader to specify cookies (Set-Cookie) and the 251// negotiated subprotocol (Sec-Websocket-Protocol). 252// 253// The connection buffers IO to the underlying network connection. The 254// readBufSize and writeBufSize parameters specify the size of the buffers to 255// use. Messages can be larger than the buffers. 256// 257// If the request is not a valid WebSocket handshake, then Upgrade returns an 258// error of type HandshakeError. Applications should handle this error by 259// replying to the client with an HTTP error response. 260func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { 261 u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} 262 u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { 263 // don't return errors to maintain backwards compatibility 264 } 265 u.CheckOrigin = func(r *http.Request) bool { 266 // allow all connections by default 267 return true 268 } 269 return u.Upgrade(w, r, responseHeader) 270} 271 272// Subprotocols returns the subprotocols requested by the client in the 273// Sec-Websocket-Protocol header. 274func Subprotocols(r *http.Request) []string { 275 h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) 276 if h == "" { 277 return nil 278 } 279 protocols := strings.Split(h, ",") 280 for i := range protocols { 281 protocols[i] = strings.TrimSpace(protocols[i]) 282 } 283 return protocols 284} 285 286// IsWebSocketUpgrade returns true if the client requested upgrade to the 287// WebSocket protocol. 288func IsWebSocketUpgrade(r *http.Request) bool { 289 return tokenListContainsValue(r.Header, "Connection", "upgrade") && 290 tokenListContainsValue(r.Header, "Upgrade", "websocket") 291} 292