1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package websocket implements a client and server for the WebSocket protocol
6// as specified in RFC 6455.
7//
8// This package currently lacks some features found in an alternative
9// and more actively maintained WebSocket package:
10//
11//     https://godoc.org/github.com/gorilla/websocket
12//
13package websocket // import "golang.org/x/net/websocket"
14
15import (
16	"bufio"
17	"crypto/tls"
18	"encoding/json"
19	"errors"
20	"io"
21	"io/ioutil"
22	"net"
23	"net/http"
24	"net/url"
25	"sync"
26	"time"
27)
28
29const (
30	ProtocolVersionHybi13    = 13
31	ProtocolVersionHybi      = ProtocolVersionHybi13
32	SupportedProtocolVersion = "13"
33
34	ContinuationFrame = 0
35	TextFrame         = 1
36	BinaryFrame       = 2
37	CloseFrame        = 8
38	PingFrame         = 9
39	PongFrame         = 10
40	UnknownFrame      = 255
41
42	DefaultMaxPayloadBytes = 32 << 20 // 32MB
43)
44
45// ProtocolError represents WebSocket protocol errors.
46type ProtocolError struct {
47	ErrorString string
48}
49
50func (err *ProtocolError) Error() string { return err.ErrorString }
51
52var (
53	ErrBadProtocolVersion   = &ProtocolError{"bad protocol version"}
54	ErrBadScheme            = &ProtocolError{"bad scheme"}
55	ErrBadStatus            = &ProtocolError{"bad status"}
56	ErrBadUpgrade           = &ProtocolError{"missing or bad upgrade"}
57	ErrBadWebSocketOrigin   = &ProtocolError{"missing or bad WebSocket-Origin"}
58	ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
59	ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
60	ErrBadWebSocketVersion  = &ProtocolError{"missing or bad WebSocket Version"}
61	ErrChallengeResponse    = &ProtocolError{"mismatch challenge/response"}
62	ErrBadFrame             = &ProtocolError{"bad frame"}
63	ErrBadFrameBoundary     = &ProtocolError{"not on frame boundary"}
64	ErrNotWebSocket         = &ProtocolError{"not websocket protocol"}
65	ErrBadRequestMethod     = &ProtocolError{"bad method"}
66	ErrNotSupported         = &ProtocolError{"not supported"}
67)
68
69// ErrFrameTooLarge is returned by Codec's Receive method if payload size
70// exceeds limit set by Conn.MaxPayloadBytes
71var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
72
73// Addr is an implementation of net.Addr for WebSocket.
74type Addr struct {
75	*url.URL
76}
77
78// Network returns the network type for a WebSocket, "websocket".
79func (addr *Addr) Network() string { return "websocket" }
80
81// Config is a WebSocket configuration
82type Config struct {
83	// A WebSocket server address.
84	Location *url.URL
85
86	// A Websocket client origin.
87	Origin *url.URL
88
89	// WebSocket subprotocols.
90	Protocol []string
91
92	// WebSocket protocol version.
93	Version int
94
95	// TLS config for secure WebSocket (wss).
96	TlsConfig *tls.Config
97
98	// Additional header fields to be sent in WebSocket opening handshake.
99	Header http.Header
100
101	// Dialer used when opening websocket connections.
102	Dialer *net.Dialer
103
104	handshakeData map[string]string
105}
106
107// serverHandshaker is an interface to handle WebSocket server side handshake.
108type serverHandshaker interface {
109	// ReadHandshake reads handshake request message from client.
110	// Returns http response code and error if any.
111	ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
112
113	// AcceptHandshake accepts the client handshake request and sends
114	// handshake response back to client.
115	AcceptHandshake(buf *bufio.Writer) (err error)
116
117	// NewServerConn creates a new WebSocket connection.
118	NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
119}
120
121// frameReader is an interface to read a WebSocket frame.
122type frameReader interface {
123	// Reader is to read payload of the frame.
124	io.Reader
125
126	// PayloadType returns payload type.
127	PayloadType() byte
128
129	// HeaderReader returns a reader to read header of the frame.
130	HeaderReader() io.Reader
131
132	// TrailerReader returns a reader to read trailer of the frame.
133	// If it returns nil, there is no trailer in the frame.
134	TrailerReader() io.Reader
135
136	// Len returns total length of the frame, including header and trailer.
137	Len() int
138}
139
140// frameReaderFactory is an interface to creates new frame reader.
141type frameReaderFactory interface {
142	NewFrameReader() (r frameReader, err error)
143}
144
145// frameWriter is an interface to write a WebSocket frame.
146type frameWriter interface {
147	// Writer is to write payload of the frame.
148	io.WriteCloser
149}
150
151// frameWriterFactory is an interface to create new frame writer.
152type frameWriterFactory interface {
153	NewFrameWriter(payloadType byte) (w frameWriter, err error)
154}
155
156type frameHandler interface {
157	HandleFrame(frame frameReader) (r frameReader, err error)
158	WriteClose(status int) (err error)
159}
160
161// Conn represents a WebSocket connection.
162//
163// Multiple goroutines may invoke methods on a Conn simultaneously.
164type Conn struct {
165	config  *Config
166	request *http.Request
167
168	buf *bufio.ReadWriter
169	rwc io.ReadWriteCloser
170
171	rio sync.Mutex
172	frameReaderFactory
173	frameReader
174
175	wio sync.Mutex
176	frameWriterFactory
177
178	frameHandler
179	PayloadType        byte
180	defaultCloseStatus int
181
182	// MaxPayloadBytes limits the size of frame payload received over Conn
183	// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
184	MaxPayloadBytes int
185}
186
187// Read implements the io.Reader interface:
188// it reads data of a frame from the WebSocket connection.
189// if msg is not large enough for the frame data, it fills the msg and next Read
190// will read the rest of the frame data.
191// it reads Text frame or Binary frame.
192func (ws *Conn) Read(msg []byte) (n int, err error) {
193	ws.rio.Lock()
194	defer ws.rio.Unlock()
195again:
196	if ws.frameReader == nil {
197		frame, err := ws.frameReaderFactory.NewFrameReader()
198		if err != nil {
199			return 0, err
200		}
201		ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
202		if err != nil {
203			return 0, err
204		}
205		if ws.frameReader == nil {
206			goto again
207		}
208	}
209	n, err = ws.frameReader.Read(msg)
210	if err == io.EOF {
211		if trailer := ws.frameReader.TrailerReader(); trailer != nil {
212			io.Copy(ioutil.Discard, trailer)
213		}
214		ws.frameReader = nil
215		goto again
216	}
217	return n, err
218}
219
220// Write implements the io.Writer interface:
221// it writes data as a frame to the WebSocket connection.
222func (ws *Conn) Write(msg []byte) (n int, err error) {
223	ws.wio.Lock()
224	defer ws.wio.Unlock()
225	w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
226	if err != nil {
227		return 0, err
228	}
229	n, err = w.Write(msg)
230	w.Close()
231	return n, err
232}
233
234// Close implements the io.Closer interface.
235func (ws *Conn) Close() error {
236	err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
237	err1 := ws.rwc.Close()
238	if err != nil {
239		return err
240	}
241	return err1
242}
243
244func (ws *Conn) IsClientConn() bool { return ws.request == nil }
245func (ws *Conn) IsServerConn() bool { return ws.request != nil }
246
247// LocalAddr returns the WebSocket Origin for the connection for client, or
248// the WebSocket location for server.
249func (ws *Conn) LocalAddr() net.Addr {
250	if ws.IsClientConn() {
251		return &Addr{ws.config.Origin}
252	}
253	return &Addr{ws.config.Location}
254}
255
256// RemoteAddr returns the WebSocket location for the connection for client, or
257// the Websocket Origin for server.
258func (ws *Conn) RemoteAddr() net.Addr {
259	if ws.IsClientConn() {
260		return &Addr{ws.config.Location}
261	}
262	return &Addr{ws.config.Origin}
263}
264
265var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
266
267// SetDeadline sets the connection's network read & write deadlines.
268func (ws *Conn) SetDeadline(t time.Time) error {
269	if conn, ok := ws.rwc.(net.Conn); ok {
270		return conn.SetDeadline(t)
271	}
272	return errSetDeadline
273}
274
275// SetReadDeadline sets the connection's network read deadline.
276func (ws *Conn) SetReadDeadline(t time.Time) error {
277	if conn, ok := ws.rwc.(net.Conn); ok {
278		return conn.SetReadDeadline(t)
279	}
280	return errSetDeadline
281}
282
283// SetWriteDeadline sets the connection's network write deadline.
284func (ws *Conn) SetWriteDeadline(t time.Time) error {
285	if conn, ok := ws.rwc.(net.Conn); ok {
286		return conn.SetWriteDeadline(t)
287	}
288	return errSetDeadline
289}
290
291// Config returns the WebSocket config.
292func (ws *Conn) Config() *Config { return ws.config }
293
294// Request returns the http request upgraded to the WebSocket.
295// It is nil for client side.
296func (ws *Conn) Request() *http.Request { return ws.request }
297
298// Codec represents a symmetric pair of functions that implement a codec.
299type Codec struct {
300	Marshal   func(v interface{}) (data []byte, payloadType byte, err error)
301	Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
302}
303
304// Send sends v marshaled by cd.Marshal as single frame to ws.
305func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
306	data, payloadType, err := cd.Marshal(v)
307	if err != nil {
308		return err
309	}
310	ws.wio.Lock()
311	defer ws.wio.Unlock()
312	w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
313	if err != nil {
314		return err
315	}
316	_, err = w.Write(data)
317	w.Close()
318	return err
319}
320
321// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
322// in v. The whole frame payload is read to an in-memory buffer; max size of
323// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
324// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
325// completely. The next call to Receive would read and discard leftover data of
326// previous oversized frame before processing next frame.
327func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
328	ws.rio.Lock()
329	defer ws.rio.Unlock()
330	if ws.frameReader != nil {
331		_, err = io.Copy(ioutil.Discard, ws.frameReader)
332		if err != nil {
333			return err
334		}
335		ws.frameReader = nil
336	}
337again:
338	frame, err := ws.frameReaderFactory.NewFrameReader()
339	if err != nil {
340		return err
341	}
342	frame, err = ws.frameHandler.HandleFrame(frame)
343	if err != nil {
344		return err
345	}
346	if frame == nil {
347		goto again
348	}
349	maxPayloadBytes := ws.MaxPayloadBytes
350	if maxPayloadBytes == 0 {
351		maxPayloadBytes = DefaultMaxPayloadBytes
352	}
353	if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
354		// payload size exceeds limit, no need to call Unmarshal
355		//
356		// set frameReader to current oversized frame so that
357		// the next call to this function can drain leftover
358		// data before processing the next frame
359		ws.frameReader = frame
360		return ErrFrameTooLarge
361	}
362	payloadType := frame.PayloadType()
363	data, err := ioutil.ReadAll(frame)
364	if err != nil {
365		return err
366	}
367	return cd.Unmarshal(data, payloadType, v)
368}
369
370func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
371	switch data := v.(type) {
372	case string:
373		return []byte(data), TextFrame, nil
374	case []byte:
375		return data, BinaryFrame, nil
376	}
377	return nil, UnknownFrame, ErrNotSupported
378}
379
380func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
381	switch data := v.(type) {
382	case *string:
383		*data = string(msg)
384		return nil
385	case *[]byte:
386		*data = msg
387		return nil
388	}
389	return ErrNotSupported
390}
391
392/*
393Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
394To send/receive text frame, use string type.
395To send/receive binary frame, use []byte type.
396
397Trivial usage:
398
399	import "websocket"
400
401	// receive text frame
402	var message string
403	websocket.Message.Receive(ws, &message)
404
405	// send text frame
406	message = "hello"
407	websocket.Message.Send(ws, message)
408
409	// receive binary frame
410	var data []byte
411	websocket.Message.Receive(ws, &data)
412
413	// send binary frame
414	data = []byte{0, 1, 2}
415	websocket.Message.Send(ws, data)
416
417*/
418var Message = Codec{marshal, unmarshal}
419
420func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
421	msg, err = json.Marshal(v)
422	return msg, TextFrame, err
423}
424
425func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
426	return json.Unmarshal(msg, v)
427}
428
429/*
430JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
431
432Trivial usage:
433
434	import "websocket"
435
436	type T struct {
437		Msg string
438		Count int
439	}
440
441	// receive JSON type T
442	var data T
443	websocket.JSON.Receive(ws, &data)
444
445	// send JSON type T
446	websocket.JSON.Send(ws, data)
447*/
448var JSON = Codec{jsonMarshal, jsonUnmarshal}
449