1package centrifuge
2
3import (
4	"fmt"
5	"net/http"
6	"net/url"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/centrifugal/centrifuge/internal/cancelctx"
12	"github.com/centrifugal/centrifuge/internal/timers"
13
14	"github.com/centrifugal/protocol"
15	"github.com/gorilla/websocket"
16)
17
18const (
19	transportWebsocket = "websocket"
20)
21
22// websocketTransport is a wrapper struct over websocket connection to fit session
23// interface so client will accept it.
24type websocketTransport struct {
25	mu        sync.RWMutex
26	conn      *websocket.Conn
27	closed    bool
28	closeCh   chan struct{}
29	graceCh   chan struct{}
30	opts      websocketTransportOptions
31	pingTimer *time.Timer
32}
33
34type websocketTransportOptions struct {
35	protoType          ProtocolType
36	pingInterval       time.Duration
37	writeTimeout       time.Duration
38	compressionMinSize int
39}
40
41func newWebsocketTransport(conn *websocket.Conn, opts websocketTransportOptions, graceCh chan struct{}) *websocketTransport {
42	transport := &websocketTransport{
43		conn:    conn,
44		closeCh: make(chan struct{}),
45		graceCh: graceCh,
46		opts:    opts,
47	}
48	if opts.pingInterval > 0 {
49		transport.addPing()
50	}
51	return transport
52}
53
54func (t *websocketTransport) ping() {
55	select {
56	case <-t.closeCh:
57		return
58	default:
59		deadline := time.Now().Add(t.opts.pingInterval / 2)
60		err := t.conn.WriteControl(websocket.PingMessage, nil, deadline)
61		if err != nil {
62			_ = t.Close(DisconnectWriteError)
63			return
64		}
65		t.addPing()
66	}
67}
68
69func (t *websocketTransport) addPing() {
70	t.mu.Lock()
71	if t.closed {
72		t.mu.Unlock()
73		return
74	}
75	t.pingTimer = time.AfterFunc(t.opts.pingInterval, t.ping)
76	t.mu.Unlock()
77}
78
79// Name returns name of transport.
80func (t *websocketTransport) Name() string {
81	return transportWebsocket
82}
83
84// Protocol returns transport protocol.
85func (t *websocketTransport) Protocol() ProtocolType {
86	return t.opts.protoType
87}
88
89// Unidirectional returns whether transport is unidirectional.
90func (t *websocketTransport) Unidirectional() bool {
91	return false
92}
93
94// DisabledPushFlags ...
95func (t *websocketTransport) DisabledPushFlags() uint64 {
96	return PushFlagDisconnect
97}
98
99func (t *websocketTransport) writeData(data []byte) error {
100	if t.opts.compressionMinSize > 0 {
101		t.conn.EnableWriteCompression(len(data) > t.opts.compressionMinSize)
102	}
103	var messageType = websocket.TextMessage
104	if t.Protocol() == ProtocolTypeProtobuf {
105		messageType = websocket.BinaryMessage
106	}
107	if t.opts.writeTimeout > 0 {
108		_ = t.conn.SetWriteDeadline(time.Now().Add(t.opts.writeTimeout))
109	}
110	err := t.conn.WriteMessage(messageType, data)
111	if err != nil {
112		return err
113	}
114	if t.opts.writeTimeout > 0 {
115		_ = t.conn.SetWriteDeadline(time.Time{})
116	}
117	return nil
118}
119
120// Write data to transport.
121func (t *websocketTransport) Write(message []byte) error {
122	select {
123	case <-t.closeCh:
124		return nil
125	default:
126		protoType := t.Protocol().toProto()
127		if protoType == protocol.TypeJSON {
128			// Fast path for one JSON message.
129			return t.writeData(message)
130		}
131		encoder := protocol.GetDataEncoder(protoType)
132		defer protocol.PutDataEncoder(protoType, encoder)
133		_ = encoder.Encode(message)
134		return t.writeData(encoder.Finish())
135	}
136}
137
138// WriteMany data to transport.
139func (t *websocketTransport) WriteMany(messages ...[]byte) error {
140	select {
141	case <-t.closeCh:
142		return nil
143	default:
144		protoType := t.Protocol().toProto()
145		encoder := protocol.GetDataEncoder(protoType)
146		defer protocol.PutDataEncoder(protoType, encoder)
147		for i := range messages {
148			_ = encoder.Encode(messages[i])
149		}
150		return t.writeData(encoder.Finish())
151	}
152}
153
154const closeFrameWait = 5 * time.Second
155
156// Close closes transport.
157func (t *websocketTransport) Close(disconnect *Disconnect) error {
158	t.mu.Lock()
159	if t.closed {
160		t.mu.Unlock()
161		return nil
162	}
163	t.closed = true
164	if t.pingTimer != nil {
165		t.pingTimer.Stop()
166	}
167	close(t.closeCh)
168	t.mu.Unlock()
169
170	if disconnect != nil {
171		msg := websocket.FormatCloseMessage(int(disconnect.Code), disconnect.CloseText())
172		err := t.conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(time.Second))
173		if err != nil {
174			return t.conn.Close()
175		}
176		select {
177		case <-t.graceCh:
178		default:
179			// Wait for closing handshake completion.
180			tm := timers.AcquireTimer(closeFrameWait)
181			select {
182			case <-t.graceCh:
183			case <-tm.C:
184			}
185			timers.ReleaseTimer(tm)
186		}
187		return t.conn.Close()
188	}
189	return t.conn.Close()
190}
191
192// Defaults.
193const (
194	DefaultWebsocketPingInterval     = 25 * time.Second
195	DefaultWebsocketWriteTimeout     = 1 * time.Second
196	DefaultWebsocketMessageSizeLimit = 65536 // 64KB
197)
198
199// WebsocketConfig represents config for WebsocketHandler.
200type WebsocketConfig struct {
201	// CompressionLevel sets a level for websocket compression.
202	// See possible value description at https://golang.org/pkg/compress/flate/#NewWriter
203	CompressionLevel int
204
205	// CompressionMinSize allows to set minimal limit in bytes for
206	// message to use compression when writing it into client connection.
207	// By default it's 0 - i.e. all messages will be compressed when
208	// WebsocketCompression enabled and compression negotiated with client.
209	CompressionMinSize int
210
211	// ReadBufferSize is a parameter that is used for raw websocket Upgrader.
212	// If set to zero reasonable default value will be used.
213	ReadBufferSize int
214
215	// WriteBufferSize is a parameter that is used for raw websocket Upgrader.
216	// If set to zero reasonable default value will be used.
217	WriteBufferSize int
218
219	// MessageSizeLimit sets the maximum size in bytes of allowed message from client.
220	// By default DefaultWebsocketMaxMessageSize will be used.
221	MessageSizeLimit int
222
223	// CheckOrigin func to provide custom origin check logic.
224	// nil means that sameHostOriginCheck function will be used which
225	// expects Origin host to match request Host.
226	CheckOrigin func(r *http.Request) bool
227
228	// PingInterval sets interval server will send ping messages to clients.
229	// By default DefaultPingInterval will be used.
230	PingInterval time.Duration
231
232	// WriteTimeout is maximum time of write message operation.
233	// Slow client will be disconnected.
234	// By default DefaultWebsocketWriteTimeout will be used.
235	WriteTimeout time.Duration
236
237	// Compression allows to enable websocket permessage-deflate
238	// compression support for raw websocket connections. It does
239	// not guarantee that compression will be used - i.e. it only
240	// says that server will try to negotiate it with client.
241	Compression bool
242
243	// UseWriteBufferPool enables using buffer pool for writes.
244	UseWriteBufferPool bool
245}
246
247// WebsocketHandler handles WebSocket client connections. WebSocket protocol
248// is a bidirectional connection between a client an a server for low-latency
249// communication.
250type WebsocketHandler struct {
251	node    *Node
252	upgrade *websocket.Upgrader
253	config  WebsocketConfig
254}
255
256var writeBufferPool = &sync.Pool{}
257
258// NewWebsocketHandler creates new WebsocketHandler.
259func NewWebsocketHandler(n *Node, c WebsocketConfig) *WebsocketHandler {
260	upgrade := &websocket.Upgrader{
261		ReadBufferSize:    c.ReadBufferSize,
262		EnableCompression: c.Compression,
263		Subprotocols:      []string{"centrifuge-protobuf"},
264	}
265	if c.UseWriteBufferPool {
266		upgrade.WriteBufferPool = writeBufferPool
267	} else {
268		upgrade.WriteBufferSize = c.WriteBufferSize
269	}
270	if c.CheckOrigin != nil {
271		upgrade.CheckOrigin = c.CheckOrigin
272	} else {
273		upgrade.CheckOrigin = sameHostOriginCheck(n)
274	}
275	return &WebsocketHandler{
276		node:    n,
277		config:  c,
278		upgrade: upgrade,
279	}
280}
281
282func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
283	incTransportConnect(transportWebsocket)
284
285	compression := s.config.Compression
286	compressionLevel := s.config.CompressionLevel
287	compressionMinSize := s.config.CompressionMinSize
288
289	conn, err := s.upgrade.Upgrade(rw, r, nil)
290	if err != nil {
291		s.node.logger.log(newLogEntry(LogLevelDebug, "websocket upgrade error", map[string]interface{}{"error": err.Error()}))
292		return
293	}
294
295	if compression {
296		err := conn.SetCompressionLevel(compressionLevel)
297		if err != nil {
298			s.node.logger.log(newLogEntry(LogLevelError, "websocket error setting compression level", map[string]interface{}{"error": err.Error()}))
299		}
300	}
301
302	pingInterval := s.config.PingInterval
303	if pingInterval == 0 {
304		pingInterval = DefaultWebsocketPingInterval
305	}
306	writeTimeout := s.config.WriteTimeout
307	if writeTimeout == 0 {
308		writeTimeout = DefaultWebsocketWriteTimeout
309	}
310	messageSizeLimit := s.config.MessageSizeLimit
311	if messageSizeLimit == 0 {
312		messageSizeLimit = DefaultWebsocketMessageSizeLimit
313	}
314
315	if messageSizeLimit > 0 {
316		conn.SetReadLimit(int64(messageSizeLimit))
317	}
318	if pingInterval > 0 {
319		pongWait := pingInterval * 10 / 9
320		_ = conn.SetReadDeadline(time.Now().Add(pongWait))
321		conn.SetPongHandler(func(string) error {
322			_ = conn.SetReadDeadline(time.Now().Add(pongWait))
323			return nil
324		})
325	}
326
327	var protoType = ProtocolTypeJSON
328
329	subProtocol := conn.Subprotocol()
330	if subProtocol == "centrifuge-protobuf" {
331		protoType = ProtocolTypeProtobuf
332	} else {
333		// This is a deprecated way to get a protocol type.
334		if r.URL.Query().Get("format") == "protobuf" || r.URL.Query().Get("protocol") == "protobuf" {
335			protoType = ProtocolTypeProtobuf
336		}
337	}
338
339	// Separate goroutine for better GC of caller's data.
340	go func() {
341		opts := websocketTransportOptions{
342			pingInterval:       pingInterval,
343			writeTimeout:       writeTimeout,
344			compressionMinSize: compressionMinSize,
345			protoType:          protoType,
346		}
347
348		graceCh := make(chan struct{})
349		transport := newWebsocketTransport(conn, opts, graceCh)
350
351		select {
352		case <-s.node.NotifyShutdown():
353			_ = transport.Close(DisconnectShutdown)
354			return
355		default:
356		}
357
358		ctxCh := make(chan struct{})
359		defer close(ctxCh)
360
361		c, closeFn, err := NewClient(cancelctx.New(r.Context(), ctxCh), s.node, transport)
362		if err != nil {
363			s.node.logger.log(newLogEntry(LogLevelError, "error creating client", map[string]interface{}{"transport": transportWebsocket}))
364			return
365		}
366		defer func() { _ = closeFn() }()
367
368		s.node.logger.log(newLogEntry(LogLevelDebug, "client connection established", map[string]interface{}{"client": c.ID(), "transport": transportWebsocket}))
369		defer func(started time.Time) {
370			s.node.logger.log(newLogEntry(LogLevelDebug, "client connection completed", map[string]interface{}{"client": c.ID(), "transport": transportWebsocket, "duration": time.Since(started)}))
371		}(time.Now())
372
373		for {
374			_, data, err := conn.ReadMessage()
375			if err != nil {
376				break
377			}
378			closed := !c.Handle(data)
379			if closed {
380				break
381			}
382		}
383
384		// https://github.com/gorilla/websocket/issues/448
385		conn.SetPingHandler(nil)
386		conn.SetPongHandler(nil)
387		conn.SetCloseHandler(nil)
388		_ = conn.SetReadDeadline(time.Now().Add(closeFrameWait))
389		for {
390			if _, _, err := conn.NextReader(); err != nil {
391				close(graceCh)
392				break
393			}
394		}
395	}()
396}
397
398func sameHostOriginCheck(n *Node) func(r *http.Request) bool {
399	return func(r *http.Request) bool {
400		err := checkSameHost(r)
401		if err != nil {
402			n.logger.log(newLogEntry(LogLevelInfo, "origin check failure", map[string]interface{}{"error": err.Error()}))
403			return false
404		}
405		return true
406	}
407}
408
409func checkSameHost(r *http.Request) error {
410	origin := r.Header.Get("Origin")
411	if origin == "" {
412		return nil
413	}
414	u, err := url.Parse(origin)
415	if err != nil {
416		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
417	}
418	if strings.EqualFold(r.Host, u.Host) {
419		return nil
420	}
421	return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
422}
423