1package centrifuge 2 3import ( 4 "fmt" 5 "net/http" 6 "net/url" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/centrifugal/centrifuge/internal/cancelctx" 12 "github.com/centrifugal/centrifuge/internal/timers" 13 14 "github.com/centrifugal/protocol" 15 "github.com/gorilla/websocket" 16) 17 18const ( 19 transportWebsocket = "websocket" 20) 21 22// websocketTransport is a wrapper struct over websocket connection to fit session 23// interface so client will accept it. 24type websocketTransport struct { 25 mu sync.RWMutex 26 conn *websocket.Conn 27 closed bool 28 closeCh chan struct{} 29 graceCh chan struct{} 30 opts websocketTransportOptions 31 pingTimer *time.Timer 32} 33 34type websocketTransportOptions struct { 35 protoType ProtocolType 36 pingInterval time.Duration 37 writeTimeout time.Duration 38 compressionMinSize int 39} 40 41func newWebsocketTransport(conn *websocket.Conn, opts websocketTransportOptions, graceCh chan struct{}) *websocketTransport { 42 transport := &websocketTransport{ 43 conn: conn, 44 closeCh: make(chan struct{}), 45 graceCh: graceCh, 46 opts: opts, 47 } 48 if opts.pingInterval > 0 { 49 transport.addPing() 50 } 51 return transport 52} 53 54func (t *websocketTransport) ping() { 55 select { 56 case <-t.closeCh: 57 return 58 default: 59 deadline := time.Now().Add(t.opts.pingInterval / 2) 60 err := t.conn.WriteControl(websocket.PingMessage, nil, deadline) 61 if err != nil { 62 _ = t.Close(DisconnectWriteError) 63 return 64 } 65 t.addPing() 66 } 67} 68 69func (t *websocketTransport) addPing() { 70 t.mu.Lock() 71 if t.closed { 72 t.mu.Unlock() 73 return 74 } 75 t.pingTimer = time.AfterFunc(t.opts.pingInterval, t.ping) 76 t.mu.Unlock() 77} 78 79// Name returns name of transport. 80func (t *websocketTransport) Name() string { 81 return transportWebsocket 82} 83 84// Protocol returns transport protocol. 85func (t *websocketTransport) Protocol() ProtocolType { 86 return t.opts.protoType 87} 88 89// Unidirectional returns whether transport is unidirectional. 90func (t *websocketTransport) Unidirectional() bool { 91 return false 92} 93 94// DisabledPushFlags ... 95func (t *websocketTransport) DisabledPushFlags() uint64 { 96 return PushFlagDisconnect 97} 98 99func (t *websocketTransport) writeData(data []byte) error { 100 if t.opts.compressionMinSize > 0 { 101 t.conn.EnableWriteCompression(len(data) > t.opts.compressionMinSize) 102 } 103 var messageType = websocket.TextMessage 104 if t.Protocol() == ProtocolTypeProtobuf { 105 messageType = websocket.BinaryMessage 106 } 107 if t.opts.writeTimeout > 0 { 108 _ = t.conn.SetWriteDeadline(time.Now().Add(t.opts.writeTimeout)) 109 } 110 err := t.conn.WriteMessage(messageType, data) 111 if err != nil { 112 return err 113 } 114 if t.opts.writeTimeout > 0 { 115 _ = t.conn.SetWriteDeadline(time.Time{}) 116 } 117 return nil 118} 119 120// Write data to transport. 121func (t *websocketTransport) Write(message []byte) error { 122 select { 123 case <-t.closeCh: 124 return nil 125 default: 126 protoType := t.Protocol().toProto() 127 if protoType == protocol.TypeJSON { 128 // Fast path for one JSON message. 129 return t.writeData(message) 130 } 131 encoder := protocol.GetDataEncoder(protoType) 132 defer protocol.PutDataEncoder(protoType, encoder) 133 _ = encoder.Encode(message) 134 return t.writeData(encoder.Finish()) 135 } 136} 137 138// WriteMany data to transport. 139func (t *websocketTransport) WriteMany(messages ...[]byte) error { 140 select { 141 case <-t.closeCh: 142 return nil 143 default: 144 protoType := t.Protocol().toProto() 145 encoder := protocol.GetDataEncoder(protoType) 146 defer protocol.PutDataEncoder(protoType, encoder) 147 for i := range messages { 148 _ = encoder.Encode(messages[i]) 149 } 150 return t.writeData(encoder.Finish()) 151 } 152} 153 154const closeFrameWait = 5 * time.Second 155 156// Close closes transport. 157func (t *websocketTransport) Close(disconnect *Disconnect) error { 158 t.mu.Lock() 159 if t.closed { 160 t.mu.Unlock() 161 return nil 162 } 163 t.closed = true 164 if t.pingTimer != nil { 165 t.pingTimer.Stop() 166 } 167 close(t.closeCh) 168 t.mu.Unlock() 169 170 if disconnect != nil { 171 msg := websocket.FormatCloseMessage(int(disconnect.Code), disconnect.CloseText()) 172 err := t.conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(time.Second)) 173 if err != nil { 174 return t.conn.Close() 175 } 176 select { 177 case <-t.graceCh: 178 default: 179 // Wait for closing handshake completion. 180 tm := timers.AcquireTimer(closeFrameWait) 181 select { 182 case <-t.graceCh: 183 case <-tm.C: 184 } 185 timers.ReleaseTimer(tm) 186 } 187 return t.conn.Close() 188 } 189 return t.conn.Close() 190} 191 192// Defaults. 193const ( 194 DefaultWebsocketPingInterval = 25 * time.Second 195 DefaultWebsocketWriteTimeout = 1 * time.Second 196 DefaultWebsocketMessageSizeLimit = 65536 // 64KB 197) 198 199// WebsocketConfig represents config for WebsocketHandler. 200type WebsocketConfig struct { 201 // CompressionLevel sets a level for websocket compression. 202 // See possible value description at https://golang.org/pkg/compress/flate/#NewWriter 203 CompressionLevel int 204 205 // CompressionMinSize allows to set minimal limit in bytes for 206 // message to use compression when writing it into client connection. 207 // By default it's 0 - i.e. all messages will be compressed when 208 // WebsocketCompression enabled and compression negotiated with client. 209 CompressionMinSize int 210 211 // ReadBufferSize is a parameter that is used for raw websocket Upgrader. 212 // If set to zero reasonable default value will be used. 213 ReadBufferSize int 214 215 // WriteBufferSize is a parameter that is used for raw websocket Upgrader. 216 // If set to zero reasonable default value will be used. 217 WriteBufferSize int 218 219 // MessageSizeLimit sets the maximum size in bytes of allowed message from client. 220 // By default DefaultWebsocketMaxMessageSize will be used. 221 MessageSizeLimit int 222 223 // CheckOrigin func to provide custom origin check logic. 224 // nil means that sameHostOriginCheck function will be used which 225 // expects Origin host to match request Host. 226 CheckOrigin func(r *http.Request) bool 227 228 // PingInterval sets interval server will send ping messages to clients. 229 // By default DefaultPingInterval will be used. 230 PingInterval time.Duration 231 232 // WriteTimeout is maximum time of write message operation. 233 // Slow client will be disconnected. 234 // By default DefaultWebsocketWriteTimeout will be used. 235 WriteTimeout time.Duration 236 237 // Compression allows to enable websocket permessage-deflate 238 // compression support for raw websocket connections. It does 239 // not guarantee that compression will be used - i.e. it only 240 // says that server will try to negotiate it with client. 241 Compression bool 242 243 // UseWriteBufferPool enables using buffer pool for writes. 244 UseWriteBufferPool bool 245} 246 247// WebsocketHandler handles WebSocket client connections. WebSocket protocol 248// is a bidirectional connection between a client an a server for low-latency 249// communication. 250type WebsocketHandler struct { 251 node *Node 252 upgrade *websocket.Upgrader 253 config WebsocketConfig 254} 255 256var writeBufferPool = &sync.Pool{} 257 258// NewWebsocketHandler creates new WebsocketHandler. 259func NewWebsocketHandler(n *Node, c WebsocketConfig) *WebsocketHandler { 260 upgrade := &websocket.Upgrader{ 261 ReadBufferSize: c.ReadBufferSize, 262 EnableCompression: c.Compression, 263 Subprotocols: []string{"centrifuge-protobuf"}, 264 } 265 if c.UseWriteBufferPool { 266 upgrade.WriteBufferPool = writeBufferPool 267 } else { 268 upgrade.WriteBufferSize = c.WriteBufferSize 269 } 270 if c.CheckOrigin != nil { 271 upgrade.CheckOrigin = c.CheckOrigin 272 } else { 273 upgrade.CheckOrigin = sameHostOriginCheck(n) 274 } 275 return &WebsocketHandler{ 276 node: n, 277 config: c, 278 upgrade: upgrade, 279 } 280} 281 282func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { 283 incTransportConnect(transportWebsocket) 284 285 compression := s.config.Compression 286 compressionLevel := s.config.CompressionLevel 287 compressionMinSize := s.config.CompressionMinSize 288 289 conn, err := s.upgrade.Upgrade(rw, r, nil) 290 if err != nil { 291 s.node.logger.log(newLogEntry(LogLevelDebug, "websocket upgrade error", map[string]interface{}{"error": err.Error()})) 292 return 293 } 294 295 if compression { 296 err := conn.SetCompressionLevel(compressionLevel) 297 if err != nil { 298 s.node.logger.log(newLogEntry(LogLevelError, "websocket error setting compression level", map[string]interface{}{"error": err.Error()})) 299 } 300 } 301 302 pingInterval := s.config.PingInterval 303 if pingInterval == 0 { 304 pingInterval = DefaultWebsocketPingInterval 305 } 306 writeTimeout := s.config.WriteTimeout 307 if writeTimeout == 0 { 308 writeTimeout = DefaultWebsocketWriteTimeout 309 } 310 messageSizeLimit := s.config.MessageSizeLimit 311 if messageSizeLimit == 0 { 312 messageSizeLimit = DefaultWebsocketMessageSizeLimit 313 } 314 315 if messageSizeLimit > 0 { 316 conn.SetReadLimit(int64(messageSizeLimit)) 317 } 318 if pingInterval > 0 { 319 pongWait := pingInterval * 10 / 9 320 _ = conn.SetReadDeadline(time.Now().Add(pongWait)) 321 conn.SetPongHandler(func(string) error { 322 _ = conn.SetReadDeadline(time.Now().Add(pongWait)) 323 return nil 324 }) 325 } 326 327 var protoType = ProtocolTypeJSON 328 329 subProtocol := conn.Subprotocol() 330 if subProtocol == "centrifuge-protobuf" { 331 protoType = ProtocolTypeProtobuf 332 } else { 333 // This is a deprecated way to get a protocol type. 334 if r.URL.Query().Get("format") == "protobuf" || r.URL.Query().Get("protocol") == "protobuf" { 335 protoType = ProtocolTypeProtobuf 336 } 337 } 338 339 // Separate goroutine for better GC of caller's data. 340 go func() { 341 opts := websocketTransportOptions{ 342 pingInterval: pingInterval, 343 writeTimeout: writeTimeout, 344 compressionMinSize: compressionMinSize, 345 protoType: protoType, 346 } 347 348 graceCh := make(chan struct{}) 349 transport := newWebsocketTransport(conn, opts, graceCh) 350 351 select { 352 case <-s.node.NotifyShutdown(): 353 _ = transport.Close(DisconnectShutdown) 354 return 355 default: 356 } 357 358 ctxCh := make(chan struct{}) 359 defer close(ctxCh) 360 361 c, closeFn, err := NewClient(cancelctx.New(r.Context(), ctxCh), s.node, transport) 362 if err != nil { 363 s.node.logger.log(newLogEntry(LogLevelError, "error creating client", map[string]interface{}{"transport": transportWebsocket})) 364 return 365 } 366 defer func() { _ = closeFn() }() 367 368 s.node.logger.log(newLogEntry(LogLevelDebug, "client connection established", map[string]interface{}{"client": c.ID(), "transport": transportWebsocket})) 369 defer func(started time.Time) { 370 s.node.logger.log(newLogEntry(LogLevelDebug, "client connection completed", map[string]interface{}{"client": c.ID(), "transport": transportWebsocket, "duration": time.Since(started)})) 371 }(time.Now()) 372 373 for { 374 _, data, err := conn.ReadMessage() 375 if err != nil { 376 break 377 } 378 closed := !c.Handle(data) 379 if closed { 380 break 381 } 382 } 383 384 // https://github.com/gorilla/websocket/issues/448 385 conn.SetPingHandler(nil) 386 conn.SetPongHandler(nil) 387 conn.SetCloseHandler(nil) 388 _ = conn.SetReadDeadline(time.Now().Add(closeFrameWait)) 389 for { 390 if _, _, err := conn.NextReader(); err != nil { 391 close(graceCh) 392 break 393 } 394 } 395 }() 396} 397 398func sameHostOriginCheck(n *Node) func(r *http.Request) bool { 399 return func(r *http.Request) bool { 400 err := checkSameHost(r) 401 if err != nil { 402 n.logger.log(newLogEntry(LogLevelInfo, "origin check failure", map[string]interface{}{"error": err.Error()})) 403 return false 404 } 405 return true 406 } 407} 408 409func checkSameHost(r *http.Request) error { 410 origin := r.Header.Get("Origin") 411 if origin == "" { 412 return nil 413 } 414 u, err := url.Parse(origin) 415 if err != nil { 416 return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) 417 } 418 if strings.EqualFold(r.Host, u.Host) { 419 return nil 420 } 421 return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) 422} 423