1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "flag" 7 "fmt" 8 "log" 9 "net/http" 10 "net/url" 11 "os" 12 "os/signal" 13 "strconv" 14 "strings" 15 "sync" 16 "syscall" 17 "time" 18 19 "github.com/centrifugal/centrifuge/internal/cancelctx" 20 "github.com/gorilla/websocket" 21 22 _ "net/http/pprof" 23 24 "github.com/centrifugal/centrifuge" 25) 26 27var ( 28 port = flag.Int("port", 8000, "Port to bind app to") 29 redis = flag.Bool("redis", false, "Use Redis") 30) 31 32func handleLog(e centrifuge.LogEntry) { 33 log.Printf("%s: %v", e.Message, e.Fields) 34} 35 36func authMiddleware(h http.Handler) http.Handler { 37 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 ctx := r.Context() 39 newCtx := centrifuge.SetCredentials(ctx, ¢rifuge.Credentials{ 40 UserID: "42", 41 }) 42 r = r.WithContext(newCtx) 43 h.ServeHTTP(w, r) 44 }) 45} 46 47func waitExitSignal(n *centrifuge.Node) { 48 sigCh := make(chan os.Signal, 1) 49 done := make(chan bool, 1) 50 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) 51 go func() { 52 <-sigCh 53 _ = n.Shutdown(context.Background()) 54 done <- true 55 }() 56 <-done 57} 58 59var exampleChannel = "unidirectional" 60 61func main() { 62 flag.Parse() 63 64 cfg := centrifuge.DefaultConfig 65 cfg.LogLevel = centrifuge.LogLevelDebug 66 cfg.LogHandler = handleLog 67 68 node, _ := centrifuge.New(cfg) 69 70 if *redis { 71 redisShardConfigs := []centrifuge.RedisShardConfig{ 72 {Address: "localhost:6379"}, 73 } 74 var redisShards []*centrifuge.RedisShard 75 for _, redisConf := range redisShardConfigs { 76 redisShard, err := centrifuge.NewRedisShard(node, redisConf) 77 if err != nil { 78 log.Fatal(err) 79 } 80 redisShards = append(redisShards, redisShard) 81 } 82 // Using Redis Broker here to scale nodes. 83 broker, err := centrifuge.NewRedisBroker(node, centrifuge.RedisBrokerConfig{ 84 Shards: redisShards, 85 }) 86 if err != nil { 87 log.Fatal(err) 88 } 89 node.SetBroker(broker) 90 91 presenceManager, err := centrifuge.NewRedisPresenceManager(node, centrifuge.RedisPresenceManagerConfig{ 92 Shards: redisShards, 93 }) 94 if err != nil { 95 log.Fatal(err) 96 } 97 node.SetPresenceManager(presenceManager) 98 } 99 100 node.OnConnecting(func(ctx context.Context, e centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) { 101 return centrifuge.ConnectReply{ 102 Subscriptions: map[string]centrifuge.SubscribeOptions{ 103 exampleChannel: {}, 104 }, 105 }, nil 106 }) 107 108 node.OnConnect(func(client *centrifuge.Client) { 109 client.OnUnsubscribe(func(e centrifuge.UnsubscribeEvent) { 110 log.Printf("user %s unsubscribed from %s", client.UserID(), e.Channel) 111 }) 112 client.OnDisconnect(func(e centrifuge.DisconnectEvent) { 113 log.Printf("user %s disconnected, disconnect: %s", client.UserID(), e.Disconnect) 114 }) 115 transport := client.Transport() 116 log.Printf("user %s connected via %s", client.UserID(), transport.Name()) 117 }) 118 119 // Publish to a channel periodically. 120 go func() { 121 for { 122 currentTime := strconv.FormatInt(time.Now().Unix(), 10) 123 _, err := node.Publish(exampleChannel, []byte(`{"server_time": "`+currentTime+`"}`)) 124 if err != nil { 125 log.Println(err.Error()) 126 } 127 time.Sleep(5 * time.Second) 128 } 129 }() 130 131 if err := node.Run(); err != nil { 132 log.Fatal(err) 133 } 134 135 websocketHandler := NewWebsocketHandler(node, WebsocketConfig{ 136 ReadBufferSize: 1024, 137 UseWriteBufferPool: true, 138 }) 139 http.Handle("/connection/websocket", authMiddleware(websocketHandler)) 140 http.Handle("/subscribe", handleSubscribe(node)) 141 http.Handle("/unsubscribe", handleUnsubscribe(node)) 142 http.Handle("/", http.FileServer(http.Dir("./"))) 143 144 go func() { 145 if err := http.ListenAndServe(":"+strconv.Itoa(*port), nil); err != nil { 146 log.Fatal(err) 147 } 148 }() 149 150 waitExitSignal(node) 151 log.Println("bye!") 152} 153 154func handleSubscribe(node *centrifuge.Node) http.HandlerFunc { 155 return func(w http.ResponseWriter, req *http.Request) { 156 clientID := req.URL.Query().Get("client") 157 if clientID == "" { 158 w.WriteHeader(http.StatusBadRequest) 159 return 160 } 161 err := node.Subscribe("42", exampleChannel, centrifuge.WithSubscribeClient(clientID)) 162 if err != nil { 163 w.WriteHeader(http.StatusInternalServerError) 164 return 165 } 166 w.WriteHeader(http.StatusOK) 167 } 168} 169 170func handleUnsubscribe(node *centrifuge.Node) http.HandlerFunc { 171 return func(w http.ResponseWriter, req *http.Request) { 172 clientID := req.URL.Query().Get("client") 173 if clientID == "" { 174 w.WriteHeader(http.StatusBadRequest) 175 return 176 } 177 err := node.Unsubscribe("42", exampleChannel, centrifuge.WithUnsubscribeClient(clientID)) 178 if err != nil { 179 w.WriteHeader(http.StatusInternalServerError) 180 return 181 } 182 w.WriteHeader(http.StatusOK) 183 } 184} 185 186// websocketTransport is a wrapper struct over websocket connection to fit session 187// interface so client will accept it. 188type websocketTransport struct { 189 mu sync.RWMutex 190 writeMu sync.Mutex // sync general write with unidirectional ping write. 191 conn *websocket.Conn 192 closed bool 193 closeCh chan struct{} 194 graceCh chan struct{} 195 opts websocketTransportOptions 196 pingTimer *time.Timer 197} 198 199type websocketTransportOptions struct { 200 protoType centrifuge.ProtocolType 201 pingInterval time.Duration 202 writeTimeout time.Duration 203 compressionMinSize int 204} 205 206func newWebsocketTransport(conn *websocket.Conn, opts websocketTransportOptions, graceCh chan struct{}) *websocketTransport { 207 transport := &websocketTransport{ 208 conn: conn, 209 closeCh: make(chan struct{}), 210 graceCh: graceCh, 211 opts: opts, 212 } 213 if opts.pingInterval > 0 { 214 transport.addPing() 215 } 216 return transport 217} 218 219func (t *websocketTransport) ping() { 220 select { 221 case <-t.closeCh: 222 return 223 default: 224 err := t.writeData([]byte("")) 225 if err != nil { 226 _ = t.Close(centrifuge.DisconnectWriteError) 227 return 228 } 229 deadline := time.Now().Add(t.opts.pingInterval / 2) 230 err = t.conn.WriteControl(websocket.PingMessage, nil, deadline) 231 if err != nil { 232 _ = t.Close(centrifuge.DisconnectWriteError) 233 return 234 } 235 t.addPing() 236 } 237} 238 239func (t *websocketTransport) addPing() { 240 t.mu.Lock() 241 if t.closed { 242 t.mu.Unlock() 243 return 244 } 245 t.pingTimer = time.AfterFunc(t.opts.pingInterval, t.ping) 246 t.mu.Unlock() 247} 248 249// Name returns name of transport. 250func (t *websocketTransport) Name() string { 251 return "websocket" 252} 253 254// Protocol returns transport protocol. 255func (t *websocketTransport) Protocol() centrifuge.ProtocolType { 256 return t.opts.protoType 257} 258 259// Unidirectional returns whether transport is unidirectional. 260func (t *websocketTransport) Unidirectional() bool { 261 return true 262} 263 264// DisabledPushFlags ... 265func (t *websocketTransport) DisabledPushFlags() uint64 { 266 return 0 267} 268 269func (t *websocketTransport) writeData(data []byte) error { 270 if t.opts.compressionMinSize > 0 { 271 t.conn.EnableWriteCompression(len(data) > t.opts.compressionMinSize) 272 } 273 var messageType = websocket.TextMessage 274 if t.Protocol() == centrifuge.ProtocolTypeProtobuf { 275 messageType = websocket.BinaryMessage 276 } 277 278 t.writeMu.Lock() 279 if t.opts.writeTimeout > 0 { 280 _ = t.conn.SetWriteDeadline(time.Now().Add(t.opts.writeTimeout)) 281 } 282 err := t.conn.WriteMessage(messageType, data) 283 if err != nil { 284 t.writeMu.Unlock() 285 return err 286 } 287 if t.opts.writeTimeout > 0 { 288 _ = t.conn.SetWriteDeadline(time.Time{}) 289 } 290 t.writeMu.Unlock() 291 292 return nil 293} 294 295func (t *websocketTransport) Write(message []byte) error { 296 return t.WriteMany(message) 297} 298 299// Write data to transport. 300func (t *websocketTransport) WriteMany(messages ...[]byte) error { 301 select { 302 case <-t.closeCh: 303 return nil 304 default: 305 for i := 0; i < len(messages); i++ { 306 err := t.writeData(messages[i]) 307 if err != nil { 308 return err 309 } 310 } 311 return nil 312 } 313} 314 315const closeFrameWait = 5 * time.Second 316 317// Close closes transport. 318func (t *websocketTransport) Close(_ *centrifuge.Disconnect) error { 319 t.mu.Lock() 320 if t.closed { 321 t.mu.Unlock() 322 return nil 323 } 324 t.closed = true 325 if t.pingTimer != nil { 326 t.pingTimer.Stop() 327 } 328 close(t.closeCh) 329 t.mu.Unlock() 330 return t.conn.Close() 331} 332 333// Defaults. 334const ( 335 DefaultWebsocketPingInterval = 25 * time.Second 336 DefaultWebsocketWriteTimeout = 1 * time.Second 337 DefaultWebsocketMessageSizeLimit = 65536 // 64KB 338) 339 340// WebsocketConfig represents config for WebsocketHandler. 341type WebsocketConfig struct { 342 // CompressionLevel sets a level for websocket compression. 343 // See possible value description at https://golang.org/pkg/compress/flate/#NewWriter 344 CompressionLevel int 345 346 // CompressionMinSize allows to set minimal limit in bytes for 347 // message to use compression when writing it into client connection. 348 // By default it's 0 - i.e. all messages will be compressed when 349 // WebsocketCompression enabled and compression negotiated with client. 350 CompressionMinSize int 351 352 // ReadBufferSize is a parameter that is used for raw websocket Upgrader. 353 // If set to zero reasonable default value will be used. 354 ReadBufferSize int 355 356 // WriteBufferSize is a parameter that is used for raw websocket Upgrader. 357 // If set to zero reasonable default value will be used. 358 WriteBufferSize int 359 360 // MessageSizeLimit sets the maximum size in bytes of allowed message from client. 361 // By default DefaultWebsocketMaxMessageSize will be used. 362 MessageSizeLimit int 363 364 // CheckOrigin func to provide custom origin check logic. 365 // nil means allow all origins. 366 CheckOrigin func(r *http.Request) bool 367 368 // PingInterval sets interval server will send ping messages to clients. 369 // By default DefaultPingInterval will be used. 370 PingInterval time.Duration 371 372 // WriteTimeout is maximum time of write message operation. 373 // Slow client will be disconnected. 374 // By default DefaultWebsocketWriteTimeout will be used. 375 WriteTimeout time.Duration 376 377 // Compression allows to enable websocket permessage-deflate 378 // compression support for raw websocket connections. It does 379 // not guarantee that compression will be used - i.e. it only 380 // says that server will try to negotiate it with client. 381 Compression bool 382 383 // UseWriteBufferPool enables using buffer pool for writes. 384 UseWriteBufferPool bool 385} 386 387// WebsocketHandler handles WebSocket client connections. WebSocket protocol 388// is a bidirectional connection between a client an a server for low-latency 389// communication. 390type WebsocketHandler struct { 391 node *centrifuge.Node 392 upgrade *websocket.Upgrader 393 config WebsocketConfig 394} 395 396var writeBufferPool = &sync.Pool{} 397 398// NewWebsocketHandler creates new WebsocketHandler. 399func NewWebsocketHandler(n *centrifuge.Node, c WebsocketConfig) *WebsocketHandler { 400 upgrade := &websocket.Upgrader{ 401 ReadBufferSize: c.ReadBufferSize, 402 EnableCompression: c.Compression, 403 } 404 if c.UseWriteBufferPool { 405 upgrade.WriteBufferPool = writeBufferPool 406 } else { 407 upgrade.WriteBufferSize = c.WriteBufferSize 408 } 409 if c.CheckOrigin != nil { 410 upgrade.CheckOrigin = c.CheckOrigin 411 } else { 412 upgrade.CheckOrigin = sameHostOriginCheck() 413 } 414 return &WebsocketHandler{ 415 node: n, 416 config: c, 417 upgrade: upgrade, 418 } 419} 420 421type ConnectRequest struct { 422 Token string `json:"token,omitempty"` 423 Data json.RawMessage `json:"data,omitempty"` 424 Subs map[string]*SubscribeRequest `json:"subs,omitempty"` 425 Name string `json:"name,omitempty"` 426 Version string `json:"version,omitempty"` 427} 428 429type SubscribeRequest struct { 430 Recover bool `json:"recover,omitempty"` 431 Epoch string `json:"epoch,omitempty"` 432 Offset uint64 `json:"offset,omitempty"` 433} 434 435func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { 436 compression := s.config.Compression 437 compressionLevel := s.config.CompressionLevel 438 compressionMinSize := s.config.CompressionMinSize 439 440 conn, err := s.upgrade.Upgrade(rw, r, nil) 441 if err != nil { 442 s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "websocket upgrade error", map[string]interface{}{"error": err.Error()})) 443 return 444 } 445 446 if compression { 447 err := conn.SetCompressionLevel(compressionLevel) 448 if err != nil { 449 s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "websocket error setting compression level", map[string]interface{}{"error": err.Error()})) 450 } 451 } 452 453 pingInterval := s.config.PingInterval 454 if pingInterval == 0 { 455 pingInterval = DefaultWebsocketPingInterval 456 } 457 writeTimeout := s.config.WriteTimeout 458 if writeTimeout == 0 { 459 writeTimeout = DefaultWebsocketWriteTimeout 460 } 461 messageSizeLimit := s.config.MessageSizeLimit 462 if messageSizeLimit == 0 { 463 messageSizeLimit = DefaultWebsocketMessageSizeLimit 464 } 465 466 if messageSizeLimit > 0 { 467 conn.SetReadLimit(int64(messageSizeLimit)) 468 } 469 if pingInterval > 0 { 470 pongWait := pingInterval * 10 / 9 471 _ = conn.SetReadDeadline(time.Now().Add(pongWait)) 472 conn.SetPongHandler(func(string) error { 473 _ = conn.SetReadDeadline(time.Now().Add(pongWait)) 474 return nil 475 }) 476 } 477 478 // Separate goroutine for better GC of caller's data. 479 go func() { 480 opts := websocketTransportOptions{ 481 pingInterval: pingInterval, 482 writeTimeout: writeTimeout, 483 compressionMinSize: compressionMinSize, 484 protoType: centrifuge.ProtocolTypeJSON, 485 } 486 487 graceCh := make(chan struct{}) 488 transport := newWebsocketTransport(conn, opts, graceCh) 489 490 select { 491 case <-s.node.NotifyShutdown(): 492 _ = transport.Close(centrifuge.DisconnectShutdown) 493 return 494 default: 495 } 496 497 ctxCh := make(chan struct{}) 498 defer close(ctxCh) 499 500 c, closeFn, err := centrifuge.NewClient(cancelctx.New(r.Context(), ctxCh), s.node, transport) 501 if err != nil { 502 s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error creating client", map[string]interface{}{"transport": transport.Name()})) 503 return 504 } 505 defer func() { _ = closeFn() }() 506 507 s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "client connection established", map[string]interface{}{"client": c.ID(), "transport": transport.Name()})) 508 defer func(started time.Time) { 509 s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "client connection completed", map[string]interface{}{"client": c.ID(), "transport": transport.Name(), "duration": time.Since(started)})) 510 }(time.Now()) 511 512 _, data, err := conn.ReadMessage() 513 if err != nil { 514 return 515 } 516 517 var req ConnectRequest 518 err = json.Unmarshal(data, &req) 519 if err != nil { 520 return 521 } 522 523 connectRequest := centrifuge.ConnectRequest{ 524 Token: req.Token, 525 Data: req.Data, 526 Name: req.Name, 527 Version: req.Version, 528 } 529 if req.Subs != nil { 530 subs := make(map[string]centrifuge.SubscribeRequest) 531 for k, v := range connectRequest.Subs { 532 subs[k] = centrifuge.SubscribeRequest{ 533 Recover: v.Recover, 534 Offset: v.Offset, 535 Epoch: v.Epoch, 536 } 537 } 538 } 539 540 c.Connect(connectRequest) 541 542 for { 543 _, _, err := conn.ReadMessage() 544 if err != nil { 545 break 546 } 547 } 548 549 // https://github.com/gorilla/websocket/issues/448 550 conn.SetPingHandler(nil) 551 conn.SetPongHandler(nil) 552 conn.SetCloseHandler(nil) 553 _ = conn.SetReadDeadline(time.Now().Add(closeFrameWait)) 554 for { 555 if _, _, err := conn.NextReader(); err != nil { 556 close(graceCh) 557 break 558 } 559 } 560 }() 561} 562 563func sameHostOriginCheck() func(r *http.Request) bool { 564 return func(r *http.Request) bool { 565 err := checkSameHost(r) 566 if err != nil { 567 return false 568 } 569 return true 570 } 571} 572 573func checkSameHost(r *http.Request) error { 574 origin := r.Header.Get("Origin") 575 if origin == "" { 576 return nil 577 } 578 u, err := url.Parse(origin) 579 if err != nil { 580 return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) 581 } 582 if strings.EqualFold(r.Host, u.Host) { 583 return nil 584 } 585 return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) 586} 587