1package main
2
3import (
4	"context"
5	"encoding/json"
6	"flag"
7	"fmt"
8	"log"
9	"net/http"
10	"net/url"
11	"os"
12	"os/signal"
13	"strconv"
14	"strings"
15	"sync"
16	"syscall"
17	"time"
18
19	"github.com/centrifugal/centrifuge/internal/cancelctx"
20	"github.com/gorilla/websocket"
21
22	_ "net/http/pprof"
23
24	"github.com/centrifugal/centrifuge"
25)
26
27var (
28	port  = flag.Int("port", 8000, "Port to bind app to")
29	redis = flag.Bool("redis", false, "Use Redis")
30)
31
32func handleLog(e centrifuge.LogEntry) {
33	log.Printf("%s: %v", e.Message, e.Fields)
34}
35
36func authMiddleware(h http.Handler) http.Handler {
37	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38		ctx := r.Context()
39		newCtx := centrifuge.SetCredentials(ctx, &centrifuge.Credentials{
40			UserID: "42",
41		})
42		r = r.WithContext(newCtx)
43		h.ServeHTTP(w, r)
44	})
45}
46
47func waitExitSignal(n *centrifuge.Node) {
48	sigCh := make(chan os.Signal, 1)
49	done := make(chan bool, 1)
50	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
51	go func() {
52		<-sigCh
53		_ = n.Shutdown(context.Background())
54		done <- true
55	}()
56	<-done
57}
58
59var exampleChannel = "unidirectional"
60
61func main() {
62	flag.Parse()
63
64	cfg := centrifuge.DefaultConfig
65	cfg.LogLevel = centrifuge.LogLevelDebug
66	cfg.LogHandler = handleLog
67
68	node, _ := centrifuge.New(cfg)
69
70	if *redis {
71		redisShardConfigs := []centrifuge.RedisShardConfig{
72			{Address: "localhost:6379"},
73		}
74		var redisShards []*centrifuge.RedisShard
75		for _, redisConf := range redisShardConfigs {
76			redisShard, err := centrifuge.NewRedisShard(node, redisConf)
77			if err != nil {
78				log.Fatal(err)
79			}
80			redisShards = append(redisShards, redisShard)
81		}
82		// Using Redis Broker here to scale nodes.
83		broker, err := centrifuge.NewRedisBroker(node, centrifuge.RedisBrokerConfig{
84			Shards: redisShards,
85		})
86		if err != nil {
87			log.Fatal(err)
88		}
89		node.SetBroker(broker)
90
91		presenceManager, err := centrifuge.NewRedisPresenceManager(node, centrifuge.RedisPresenceManagerConfig{
92			Shards: redisShards,
93		})
94		if err != nil {
95			log.Fatal(err)
96		}
97		node.SetPresenceManager(presenceManager)
98	}
99
100	node.OnConnecting(func(ctx context.Context, e centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
101		return centrifuge.ConnectReply{
102			Subscriptions: map[string]centrifuge.SubscribeOptions{
103				exampleChannel: {},
104			},
105		}, nil
106	})
107
108	node.OnConnect(func(client *centrifuge.Client) {
109		client.OnUnsubscribe(func(e centrifuge.UnsubscribeEvent) {
110			log.Printf("user %s unsubscribed from %s", client.UserID(), e.Channel)
111		})
112		client.OnDisconnect(func(e centrifuge.DisconnectEvent) {
113			log.Printf("user %s disconnected, disconnect: %s", client.UserID(), e.Disconnect)
114		})
115		transport := client.Transport()
116		log.Printf("user %s connected via %s", client.UserID(), transport.Name())
117	})
118
119	// Publish to a channel periodically.
120	go func() {
121		for {
122			currentTime := strconv.FormatInt(time.Now().Unix(), 10)
123			_, err := node.Publish(exampleChannel, []byte(`{"server_time": "`+currentTime+`"}`))
124			if err != nil {
125				log.Println(err.Error())
126			}
127			time.Sleep(5 * time.Second)
128		}
129	}()
130
131	if err := node.Run(); err != nil {
132		log.Fatal(err)
133	}
134
135	websocketHandler := NewWebsocketHandler(node, WebsocketConfig{
136		ReadBufferSize:     1024,
137		UseWriteBufferPool: true,
138	})
139	http.Handle("/connection/websocket", authMiddleware(websocketHandler))
140	http.Handle("/subscribe", handleSubscribe(node))
141	http.Handle("/unsubscribe", handleUnsubscribe(node))
142	http.Handle("/", http.FileServer(http.Dir("./")))
143
144	go func() {
145		if err := http.ListenAndServe(":"+strconv.Itoa(*port), nil); err != nil {
146			log.Fatal(err)
147		}
148	}()
149
150	waitExitSignal(node)
151	log.Println("bye!")
152}
153
154func handleSubscribe(node *centrifuge.Node) http.HandlerFunc {
155	return func(w http.ResponseWriter, req *http.Request) {
156		clientID := req.URL.Query().Get("client")
157		if clientID == "" {
158			w.WriteHeader(http.StatusBadRequest)
159			return
160		}
161		err := node.Subscribe("42", exampleChannel, centrifuge.WithSubscribeClient(clientID))
162		if err != nil {
163			w.WriteHeader(http.StatusInternalServerError)
164			return
165		}
166		w.WriteHeader(http.StatusOK)
167	}
168}
169
170func handleUnsubscribe(node *centrifuge.Node) http.HandlerFunc {
171	return func(w http.ResponseWriter, req *http.Request) {
172		clientID := req.URL.Query().Get("client")
173		if clientID == "" {
174			w.WriteHeader(http.StatusBadRequest)
175			return
176		}
177		err := node.Unsubscribe("42", exampleChannel, centrifuge.WithUnsubscribeClient(clientID))
178		if err != nil {
179			w.WriteHeader(http.StatusInternalServerError)
180			return
181		}
182		w.WriteHeader(http.StatusOK)
183	}
184}
185
186// websocketTransport is a wrapper struct over websocket connection to fit session
187// interface so client will accept it.
188type websocketTransport struct {
189	mu        sync.RWMutex
190	writeMu   sync.Mutex // sync general write with unidirectional ping write.
191	conn      *websocket.Conn
192	closed    bool
193	closeCh   chan struct{}
194	graceCh   chan struct{}
195	opts      websocketTransportOptions
196	pingTimer *time.Timer
197}
198
199type websocketTransportOptions struct {
200	protoType          centrifuge.ProtocolType
201	pingInterval       time.Duration
202	writeTimeout       time.Duration
203	compressionMinSize int
204}
205
206func newWebsocketTransport(conn *websocket.Conn, opts websocketTransportOptions, graceCh chan struct{}) *websocketTransport {
207	transport := &websocketTransport{
208		conn:    conn,
209		closeCh: make(chan struct{}),
210		graceCh: graceCh,
211		opts:    opts,
212	}
213	if opts.pingInterval > 0 {
214		transport.addPing()
215	}
216	return transport
217}
218
219func (t *websocketTransport) ping() {
220	select {
221	case <-t.closeCh:
222		return
223	default:
224		err := t.writeData([]byte(""))
225		if err != nil {
226			_ = t.Close(centrifuge.DisconnectWriteError)
227			return
228		}
229		deadline := time.Now().Add(t.opts.pingInterval / 2)
230		err = t.conn.WriteControl(websocket.PingMessage, nil, deadline)
231		if err != nil {
232			_ = t.Close(centrifuge.DisconnectWriteError)
233			return
234		}
235		t.addPing()
236	}
237}
238
239func (t *websocketTransport) addPing() {
240	t.mu.Lock()
241	if t.closed {
242		t.mu.Unlock()
243		return
244	}
245	t.pingTimer = time.AfterFunc(t.opts.pingInterval, t.ping)
246	t.mu.Unlock()
247}
248
249// Name returns name of transport.
250func (t *websocketTransport) Name() string {
251	return "websocket"
252}
253
254// Protocol returns transport protocol.
255func (t *websocketTransport) Protocol() centrifuge.ProtocolType {
256	return t.opts.protoType
257}
258
259// Unidirectional returns whether transport is unidirectional.
260func (t *websocketTransport) Unidirectional() bool {
261	return true
262}
263
264// DisabledPushFlags ...
265func (t *websocketTransport) DisabledPushFlags() uint64 {
266	return 0
267}
268
269func (t *websocketTransport) writeData(data []byte) error {
270	if t.opts.compressionMinSize > 0 {
271		t.conn.EnableWriteCompression(len(data) > t.opts.compressionMinSize)
272	}
273	var messageType = websocket.TextMessage
274	if t.Protocol() == centrifuge.ProtocolTypeProtobuf {
275		messageType = websocket.BinaryMessage
276	}
277
278	t.writeMu.Lock()
279	if t.opts.writeTimeout > 0 {
280		_ = t.conn.SetWriteDeadline(time.Now().Add(t.opts.writeTimeout))
281	}
282	err := t.conn.WriteMessage(messageType, data)
283	if err != nil {
284		t.writeMu.Unlock()
285		return err
286	}
287	if t.opts.writeTimeout > 0 {
288		_ = t.conn.SetWriteDeadline(time.Time{})
289	}
290	t.writeMu.Unlock()
291
292	return nil
293}
294
295func (t *websocketTransport) Write(message []byte) error {
296	return t.WriteMany(message)
297}
298
299// Write data to transport.
300func (t *websocketTransport) WriteMany(messages ...[]byte) error {
301	select {
302	case <-t.closeCh:
303		return nil
304	default:
305		for i := 0; i < len(messages); i++ {
306			err := t.writeData(messages[i])
307			if err != nil {
308				return err
309			}
310		}
311		return nil
312	}
313}
314
315const closeFrameWait = 5 * time.Second
316
317// Close closes transport.
318func (t *websocketTransport) Close(_ *centrifuge.Disconnect) error {
319	t.mu.Lock()
320	if t.closed {
321		t.mu.Unlock()
322		return nil
323	}
324	t.closed = true
325	if t.pingTimer != nil {
326		t.pingTimer.Stop()
327	}
328	close(t.closeCh)
329	t.mu.Unlock()
330	return t.conn.Close()
331}
332
333// Defaults.
334const (
335	DefaultWebsocketPingInterval     = 25 * time.Second
336	DefaultWebsocketWriteTimeout     = 1 * time.Second
337	DefaultWebsocketMessageSizeLimit = 65536 // 64KB
338)
339
340// WebsocketConfig represents config for WebsocketHandler.
341type WebsocketConfig struct {
342	// CompressionLevel sets a level for websocket compression.
343	// See possible value description at https://golang.org/pkg/compress/flate/#NewWriter
344	CompressionLevel int
345
346	// CompressionMinSize allows to set minimal limit in bytes for
347	// message to use compression when writing it into client connection.
348	// By default it's 0 - i.e. all messages will be compressed when
349	// WebsocketCompression enabled and compression negotiated with client.
350	CompressionMinSize int
351
352	// ReadBufferSize is a parameter that is used for raw websocket Upgrader.
353	// If set to zero reasonable default value will be used.
354	ReadBufferSize int
355
356	// WriteBufferSize is a parameter that is used for raw websocket Upgrader.
357	// If set to zero reasonable default value will be used.
358	WriteBufferSize int
359
360	// MessageSizeLimit sets the maximum size in bytes of allowed message from client.
361	// By default DefaultWebsocketMaxMessageSize will be used.
362	MessageSizeLimit int
363
364	// CheckOrigin func to provide custom origin check logic.
365	// nil means allow all origins.
366	CheckOrigin func(r *http.Request) bool
367
368	// PingInterval sets interval server will send ping messages to clients.
369	// By default DefaultPingInterval will be used.
370	PingInterval time.Duration
371
372	// WriteTimeout is maximum time of write message operation.
373	// Slow client will be disconnected.
374	// By default DefaultWebsocketWriteTimeout will be used.
375	WriteTimeout time.Duration
376
377	// Compression allows to enable websocket permessage-deflate
378	// compression support for raw websocket connections. It does
379	// not guarantee that compression will be used - i.e. it only
380	// says that server will try to negotiate it with client.
381	Compression bool
382
383	// UseWriteBufferPool enables using buffer pool for writes.
384	UseWriteBufferPool bool
385}
386
387// WebsocketHandler handles WebSocket client connections. WebSocket protocol
388// is a bidirectional connection between a client an a server for low-latency
389// communication.
390type WebsocketHandler struct {
391	node    *centrifuge.Node
392	upgrade *websocket.Upgrader
393	config  WebsocketConfig
394}
395
396var writeBufferPool = &sync.Pool{}
397
398// NewWebsocketHandler creates new WebsocketHandler.
399func NewWebsocketHandler(n *centrifuge.Node, c WebsocketConfig) *WebsocketHandler {
400	upgrade := &websocket.Upgrader{
401		ReadBufferSize:    c.ReadBufferSize,
402		EnableCompression: c.Compression,
403	}
404	if c.UseWriteBufferPool {
405		upgrade.WriteBufferPool = writeBufferPool
406	} else {
407		upgrade.WriteBufferSize = c.WriteBufferSize
408	}
409	if c.CheckOrigin != nil {
410		upgrade.CheckOrigin = c.CheckOrigin
411	} else {
412		upgrade.CheckOrigin = sameHostOriginCheck()
413	}
414	return &WebsocketHandler{
415		node:    n,
416		config:  c,
417		upgrade: upgrade,
418	}
419}
420
421type ConnectRequest struct {
422	Token   string                       `json:"token,omitempty"`
423	Data    json.RawMessage              `json:"data,omitempty"`
424	Subs    map[string]*SubscribeRequest `json:"subs,omitempty"`
425	Name    string                       `json:"name,omitempty"`
426	Version string                       `json:"version,omitempty"`
427}
428
429type SubscribeRequest struct {
430	Recover bool   `json:"recover,omitempty"`
431	Epoch   string `json:"epoch,omitempty"`
432	Offset  uint64 `json:"offset,omitempty"`
433}
434
435func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
436	compression := s.config.Compression
437	compressionLevel := s.config.CompressionLevel
438	compressionMinSize := s.config.CompressionMinSize
439
440	conn, err := s.upgrade.Upgrade(rw, r, nil)
441	if err != nil {
442		s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "websocket upgrade error", map[string]interface{}{"error": err.Error()}))
443		return
444	}
445
446	if compression {
447		err := conn.SetCompressionLevel(compressionLevel)
448		if err != nil {
449			s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "websocket error setting compression level", map[string]interface{}{"error": err.Error()}))
450		}
451	}
452
453	pingInterval := s.config.PingInterval
454	if pingInterval == 0 {
455		pingInterval = DefaultWebsocketPingInterval
456	}
457	writeTimeout := s.config.WriteTimeout
458	if writeTimeout == 0 {
459		writeTimeout = DefaultWebsocketWriteTimeout
460	}
461	messageSizeLimit := s.config.MessageSizeLimit
462	if messageSizeLimit == 0 {
463		messageSizeLimit = DefaultWebsocketMessageSizeLimit
464	}
465
466	if messageSizeLimit > 0 {
467		conn.SetReadLimit(int64(messageSizeLimit))
468	}
469	if pingInterval > 0 {
470		pongWait := pingInterval * 10 / 9
471		_ = conn.SetReadDeadline(time.Now().Add(pongWait))
472		conn.SetPongHandler(func(string) error {
473			_ = conn.SetReadDeadline(time.Now().Add(pongWait))
474			return nil
475		})
476	}
477
478	// Separate goroutine for better GC of caller's data.
479	go func() {
480		opts := websocketTransportOptions{
481			pingInterval:       pingInterval,
482			writeTimeout:       writeTimeout,
483			compressionMinSize: compressionMinSize,
484			protoType:          centrifuge.ProtocolTypeJSON,
485		}
486
487		graceCh := make(chan struct{})
488		transport := newWebsocketTransport(conn, opts, graceCh)
489
490		select {
491		case <-s.node.NotifyShutdown():
492			_ = transport.Close(centrifuge.DisconnectShutdown)
493			return
494		default:
495		}
496
497		ctxCh := make(chan struct{})
498		defer close(ctxCh)
499
500		c, closeFn, err := centrifuge.NewClient(cancelctx.New(r.Context(), ctxCh), s.node, transport)
501		if err != nil {
502			s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error creating client", map[string]interface{}{"transport": transport.Name()}))
503			return
504		}
505		defer func() { _ = closeFn() }()
506
507		s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "client connection established", map[string]interface{}{"client": c.ID(), "transport": transport.Name()}))
508		defer func(started time.Time) {
509			s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "client connection completed", map[string]interface{}{"client": c.ID(), "transport": transport.Name(), "duration": time.Since(started)}))
510		}(time.Now())
511
512		_, data, err := conn.ReadMessage()
513		if err != nil {
514			return
515		}
516
517		var req ConnectRequest
518		err = json.Unmarshal(data, &req)
519		if err != nil {
520			return
521		}
522
523		connectRequest := centrifuge.ConnectRequest{
524			Token:   req.Token,
525			Data:    req.Data,
526			Name:    req.Name,
527			Version: req.Version,
528		}
529		if req.Subs != nil {
530			subs := make(map[string]centrifuge.SubscribeRequest)
531			for k, v := range connectRequest.Subs {
532				subs[k] = centrifuge.SubscribeRequest{
533					Recover: v.Recover,
534					Offset:  v.Offset,
535					Epoch:   v.Epoch,
536				}
537			}
538		}
539
540		c.Connect(connectRequest)
541
542		for {
543			_, _, err := conn.ReadMessage()
544			if err != nil {
545				break
546			}
547		}
548
549		// https://github.com/gorilla/websocket/issues/448
550		conn.SetPingHandler(nil)
551		conn.SetPongHandler(nil)
552		conn.SetCloseHandler(nil)
553		_ = conn.SetReadDeadline(time.Now().Add(closeFrameWait))
554		for {
555			if _, _, err := conn.NextReader(); err != nil {
556				close(graceCh)
557				break
558			}
559		}
560	}()
561}
562
563func sameHostOriginCheck() func(r *http.Request) bool {
564	return func(r *http.Request) bool {
565		err := checkSameHost(r)
566		if err != nil {
567			return false
568		}
569		return true
570	}
571}
572
573func checkSameHost(r *http.Request) error {
574	origin := r.Header.Get("Origin")
575	if origin == "" {
576		return nil
577	}
578	u, err := url.Parse(origin)
579	if err != nil {
580		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
581	}
582	if strings.EqualFold(r.Host, u.Host) {
583		return nil
584	}
585	return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
586}
587