1package centrifuge
2
3import (
4	"net/http"
5	"sync"
6	"time"
7
8	"github.com/centrifugal/centrifuge/internal/cancelctx"
9
10	"github.com/centrifugal/protocol"
11	"github.com/gorilla/websocket"
12	"github.com/igm/sockjs-go/v3/sockjs"
13)
14
15const (
16	transportSockJS = "sockjs"
17)
18
19type sockjsTransport struct {
20	mu      sync.RWMutex
21	closed  bool
22	closeCh chan struct{}
23	session sockjs.Session
24}
25
26func newSockjsTransport(s sockjs.Session) *sockjsTransport {
27	t := &sockjsTransport{
28		session: s,
29		closeCh: make(chan struct{}),
30	}
31	return t
32}
33
34// Name returns name of transport.
35func (t *sockjsTransport) Name() string {
36	return transportSockJS
37}
38
39// Protocol returns transport protocol.
40func (t *sockjsTransport) Protocol() ProtocolType {
41	return ProtocolTypeJSON
42}
43
44// Unidirectional returns whether transport is unidirectional.
45func (t *sockjsTransport) Unidirectional() bool {
46	return false
47}
48
49// DisabledPushFlags ...
50func (t *sockjsTransport) DisabledPushFlags() uint64 {
51	if !t.Unidirectional() {
52		return PushFlagDisconnect
53	}
54	return 0
55}
56
57// Write data to transport.
58func (t *sockjsTransport) Write(message []byte) error {
59	select {
60	case <-t.closeCh:
61		return nil
62	default:
63		// No need to use protocol encoders here since
64		// SockJS only supports JSON.
65		return t.session.Send(string(message))
66	}
67}
68
69// Write data to transport.
70func (t *sockjsTransport) WriteMany(messages ...[]byte) error {
71	select {
72	case <-t.closeCh:
73		return nil
74	default:
75		encoder := protocol.GetDataEncoder(ProtocolTypeJSON.toProto())
76		defer protocol.PutDataEncoder(ProtocolTypeJSON.toProto(), encoder)
77		for i := range messages {
78			_ = encoder.Encode(messages[i])
79		}
80		return t.session.Send(string(encoder.Finish()))
81	}
82}
83
84// Close closes transport.
85func (t *sockjsTransport) Close(disconnect *Disconnect) error {
86	t.mu.Lock()
87	if t.closed {
88		// Already closed, noop.
89		t.mu.Unlock()
90		return nil
91	}
92	t.closed = true
93	close(t.closeCh)
94	t.mu.Unlock()
95
96	if disconnect == nil {
97		disconnect = DisconnectNormal
98	}
99	return t.session.Close(disconnect.Code, disconnect.CloseText())
100}
101
102// SockjsConfig represents config for SockJS handler.
103type SockjsConfig struct {
104	// HandlerPrefix sets prefix for SockJS handler endpoint path.
105	HandlerPrefix string
106
107	// URL is an address to SockJS client javascript library.
108	URL string
109
110	// HeartbeatDelay sets how often to send heartbeat frames to clients.
111	HeartbeatDelay time.Duration
112
113	// CheckOrigin allows to decide whether to use CORS or not in XHR case.
114	// When false returned then CORS headers won't be set.
115	CheckOrigin func(*http.Request) bool
116
117	// WebsocketCheckOrigin allows to set custom CheckOrigin func for underlying
118	// Gorilla Websocket based websocket.Upgrader.
119	WebsocketCheckOrigin func(*http.Request) bool
120
121	// WebsocketReadBufferSize is a parameter that is used for raw websocket websocket.Upgrader.
122	// If set to zero reasonable default value will be used.
123	WebsocketReadBufferSize int
124
125	// WebsocketWriteBufferSize is a parameter that is used for raw websocket websocket.Upgrader.
126	// If set to zero reasonable default value will be used.
127	WebsocketWriteBufferSize int
128
129	// WebsocketUseWriteBufferPool enables using buffer pool for writes in Websocket transport.
130	WebsocketUseWriteBufferPool bool
131
132	// WebsocketWriteTimeout is maximum time of write message operation.
133	// Slow client will be disconnected.
134	// By default DefaultWebsocketWriteTimeout will be used.
135	WebsocketWriteTimeout time.Duration
136}
137
138// SockjsHandler accepts SockJS connections. SockJS has a bunch of fallback
139// transports when WebSocket connection is not supported. It comes with additional
140// costs though: small protocol framing overhead, lack of binary support, more
141// goroutines per connection, and you need to use sticky session mechanism on
142// your load balancer in case you are using HTTP-based SockJS fallbacks and have
143// more than one Centrifuge Node on a backend (so SockJS to be able to emulate
144// bidirectional protocol). So if you can afford it - use WebsocketHandler only.
145type SockjsHandler struct {
146	node    *Node
147	config  SockjsConfig
148	handler http.Handler
149}
150
151// NewSockjsHandler creates new SockjsHandler.
152func NewSockjsHandler(n *Node, c SockjsConfig) *SockjsHandler {
153	options := sockjs.DefaultOptions
154	wsUpgrader := &websocket.Upgrader{
155		ReadBufferSize:  c.WebsocketReadBufferSize,
156		WriteBufferSize: c.WebsocketWriteBufferSize,
157		Error:           func(w http.ResponseWriter, r *http.Request, status int, reason error) {},
158	}
159	if c.WebsocketCheckOrigin != nil {
160		wsUpgrader.CheckOrigin = c.WebsocketCheckOrigin
161	} else {
162		wsUpgrader.CheckOrigin = sameHostOriginCheck(n)
163	}
164	if c.WebsocketUseWriteBufferPool {
165		wsUpgrader.WriteBufferPool = writeBufferPool
166	} else {
167		wsUpgrader.WriteBufferSize = c.WebsocketWriteBufferSize
168	}
169	options.WebsocketUpgrader = wsUpgrader
170	// Override sockjs url. It's important to use the same SockJS
171	// library version on client and server sides when using iframe
172	// based SockJS transports, otherwise SockJS will raise error
173	// about version mismatch.
174	options.SockJSURL = c.URL
175	if c.CheckOrigin != nil {
176		options.CheckOrigin = c.CheckOrigin
177	} else {
178		options.CheckOrigin = sameHostOriginCheck(n)
179	}
180
181	options.HeartbeatDelay = c.HeartbeatDelay
182	wsWriteTimeout := c.WebsocketWriteTimeout
183	if wsWriteTimeout == 0 {
184		wsWriteTimeout = DefaultWebsocketWriteTimeout
185	}
186	options.WebsocketWriteTimeout = wsWriteTimeout
187
188	s := &SockjsHandler{
189		node:   n,
190		config: c,
191	}
192
193	handler := newSockJSHandler(s, c.HandlerPrefix, options)
194	s.handler = handler
195	return s
196}
197
198func (s *SockjsHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
199	s.handler.ServeHTTP(rw, r)
200}
201
202// newSockJSHandler returns SockJS handler bind to sockjsPrefix url prefix.
203// SockJS handler has several handlers inside responsible for various tasks
204// according to SockJS protocol.
205func newSockJSHandler(s *SockjsHandler, sockjsPrefix string, sockjsOpts sockjs.Options) http.Handler {
206	return sockjs.NewHandler(sockjsPrefix, sockjsOpts, s.sockJSHandler)
207}
208
209// sockJSHandler called when new client connection comes to SockJS endpoint.
210func (s *SockjsHandler) sockJSHandler(sess sockjs.Session) {
211	incTransportConnect(transportSockJS)
212
213	// Separate goroutine for better GC of caller's data.
214	go func() {
215		transport := newSockjsTransport(sess)
216
217		select {
218		case <-s.node.NotifyShutdown():
219			_ = transport.Close(DisconnectShutdown)
220			return
221		default:
222		}
223
224		ctxCh := make(chan struct{})
225		defer close(ctxCh)
226		c, closeFn, err := NewClient(cancelctx.New(sess.Request().Context(), ctxCh), s.node, transport)
227		if err != nil {
228			s.node.logger.log(newLogEntry(LogLevelError, "error creating client", map[string]interface{}{"transport": transportSockJS}))
229			return
230		}
231		defer func() { _ = closeFn() }()
232		s.node.logger.log(newLogEntry(LogLevelDebug, "client connection established", map[string]interface{}{"client": c.ID(), "transport": transportSockJS}))
233		defer func(started time.Time) {
234			s.node.logger.log(newLogEntry(LogLevelDebug, "client connection completed", map[string]interface{}{"client": c.ID(), "transport": transportSockJS, "duration": time.Since(started)}))
235		}(time.Now())
236
237		var needWaitLoop bool
238
239		for {
240			if msg, err := sess.Recv(); err == nil {
241				if ok := c.Handle([]byte(msg)); !ok {
242					needWaitLoop = true
243					break
244				}
245				continue
246			}
247			break
248		}
249
250		if needWaitLoop {
251			// One extra loop till we get an error from session,
252			// this is required to wait until close frame will be sent
253			// into connection inside Client implementation and transport
254			// closed with proper disconnect reason.
255			for {
256				if _, err := sess.Recv(); err != nil {
257					break
258				}
259			}
260		}
261	}()
262}
263