1// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2// See LICENSE.txt for license information. 3 4package app 5 6import ( 7 "hash/maphash" 8 "runtime" 9 "runtime/debug" 10 "strconv" 11 "sync/atomic" 12 "time" 13 14 "github.com/mattermost/mattermost-server/v6/model" 15 "github.com/mattermost/mattermost-server/v6/shared/mlog" 16) 17 18const ( 19 broadcastQueueSize = 4096 20 inactiveConnReaperInterval = 5 * time.Minute 21) 22 23type webConnActivityMessage struct { 24 userID string 25 sessionToken string 26 activityAt int64 27} 28 29type webConnDirectMessage struct { 30 conn *WebConn 31 msg model.WebSocketMessage 32} 33 34type webConnSessionMessage struct { 35 userID string 36 sessionToken string 37 isRegistered chan bool 38} 39 40type webConnCheckMessage struct { 41 userID string 42 connectionID string 43 result chan *CheckConnResult 44} 45 46// Hub is the central place to manage all websocket connections in the server. 47// It handles different websocket events and sending messages to individual 48// user connections. 49type Hub struct { 50 // connectionCount should be kept first. 51 // See https://github.com/mattermost/mattermost-server/pull/7281 52 connectionCount int64 53 app *App 54 connectionIndex int 55 register chan *WebConn 56 unregister chan *WebConn 57 broadcast chan *model.WebSocketEvent 58 stop chan struct{} 59 didStop chan struct{} 60 invalidateUser chan string 61 activity chan *webConnActivityMessage 62 directMsg chan *webConnDirectMessage 63 explicitStop bool 64 checkRegistered chan *webConnSessionMessage 65 checkConn chan *webConnCheckMessage 66} 67 68// NewWebHub creates a new Hub. 69func (a *App) NewWebHub() *Hub { 70 return &Hub{ 71 app: a, 72 register: make(chan *WebConn), 73 unregister: make(chan *WebConn), 74 broadcast: make(chan *model.WebSocketEvent, broadcastQueueSize), 75 stop: make(chan struct{}), 76 didStop: make(chan struct{}), 77 invalidateUser: make(chan string), 78 activity: make(chan *webConnActivityMessage), 79 directMsg: make(chan *webConnDirectMessage), 80 checkRegistered: make(chan *webConnSessionMessage), 81 checkConn: make(chan *webConnCheckMessage), 82 } 83} 84 85func (a *App) TotalWebsocketConnections() int { 86 return a.Srv().TotalWebsocketConnections() 87} 88 89// HubStart starts all the hubs. 90func (a *App) HubStart() { 91 // Total number of hubs is twice the number of CPUs. 92 numberOfHubs := runtime.NumCPU() * 2 93 mlog.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs)) 94 95 hubs := make([]*Hub, numberOfHubs) 96 97 for i := 0; i < numberOfHubs; i++ { 98 hubs[i] = a.NewWebHub() 99 hubs[i].connectionIndex = i 100 hubs[i].Start() 101 } 102 // Assigning to the hubs slice without any mutex is fine because it is only assigned once 103 // during the start of the program and always read from after that. 104 a.srv.hubs = hubs 105} 106 107func (a *App) invalidateCacheForWebhook(webhookID string) { 108 a.Srv().Store.Webhook().InvalidateWebhookCache(webhookID) 109} 110 111// HubStop stops all the hubs. 112func (s *Server) HubStop() { 113 mlog.Info("stopping websocket hub connections") 114 115 for _, hub := range s.hubs { 116 hub.Stop() 117 } 118} 119 120func (a *App) HubStop() { 121 a.Srv().HubStop() 122} 123 124// GetHubForUserId returns the hub for a given user id. 125func (s *Server) GetHubForUserId(userID string) *Hub { 126 // TODO: check if caching the userID -> hub mapping 127 // is worth the memory tradeoff. 128 // https://mattermost.atlassian.net/browse/MM-26629. 129 var hash maphash.Hash 130 hash.SetSeed(s.hashSeed) 131 hash.Write([]byte(userID)) 132 index := hash.Sum64() % uint64(len(s.hubs)) 133 134 return s.hubs[int(index)] 135} 136 137func (a *App) GetHubForUserId(userID string) *Hub { 138 return a.Srv().GetHubForUserId(userID) 139} 140 141// HubRegister registers a connection to a hub. 142func (a *App) HubRegister(webConn *WebConn) { 143 hub := a.GetHubForUserId(webConn.UserId) 144 if hub != nil { 145 if metrics := a.Metrics(); metrics != nil { 146 metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1) 147 } 148 hub.Register(webConn) 149 } 150} 151 152// HubUnregister unregisters a connection from a hub. 153func (a *App) HubUnregister(webConn *WebConn) { 154 hub := a.GetHubForUserId(webConn.UserId) 155 if hub != nil { 156 if metrics := a.Metrics(); metrics != nil { 157 metrics.DecrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1) 158 } 159 hub.Unregister(webConn) 160 } 161} 162 163func (s *Server) Publish(message *model.WebSocketEvent) { 164 if s.Metrics != nil { 165 s.Metrics.IncrementWebsocketEvent(message.EventType()) 166 } 167 168 s.PublishSkipClusterSend(message) 169 170 if s.Cluster != nil { 171 data, err := message.ToJSON() 172 if err != nil { 173 mlog.Warn("Failed to encode message to JSON", mlog.Err(err)) 174 } 175 cm := &model.ClusterMessage{ 176 Event: model.ClusterEventPublish, 177 SendType: model.ClusterSendBestEffort, 178 Data: data, 179 } 180 181 if message.EventType() == model.WebsocketEventPosted || 182 message.EventType() == model.WebsocketEventPostEdited || 183 message.EventType() == model.WebsocketEventDirectAdded || 184 message.EventType() == model.WebsocketEventGroupAdded || 185 message.EventType() == model.WebsocketEventAddedToTeam { 186 cm.SendType = model.ClusterSendReliable 187 } 188 189 s.Cluster.SendClusterMessage(cm) 190 } 191} 192 193func (a *App) Publish(message *model.WebSocketEvent) { 194 a.Srv().Publish(message) 195} 196 197func (s *Server) PublishSkipClusterSend(event *model.WebSocketEvent) { 198 if event.GetBroadcast().UserId != "" { 199 hub := s.GetHubForUserId(event.GetBroadcast().UserId) 200 if hub != nil { 201 hub.Broadcast(event) 202 } 203 } else { 204 for _, hub := range s.hubs { 205 hub.Broadcast(event) 206 } 207 } 208 209 // Notify shared channel sync service 210 s.SharedChannelSyncHandler(event) 211} 212 213func (a *App) invalidateCacheForChannel(channel *model.Channel) { 214 a.Srv().Store.Channel().InvalidateChannel(channel.Id) 215 a.Srv().invalidateCacheForChannelByNameSkipClusterSend(channel.TeamId, channel.Name) 216 217 if a.Cluster() != nil { 218 nameMsg := &model.ClusterMessage{ 219 Event: model.ClusterEventInvalidateCacheForChannelByName, 220 SendType: model.ClusterSendBestEffort, 221 Props: make(map[string]string), 222 } 223 224 nameMsg.Props["name"] = channel.Name 225 if channel.TeamId == "" { 226 nameMsg.Props["id"] = "dm" 227 } else { 228 nameMsg.Props["id"] = channel.TeamId 229 } 230 231 a.Cluster().SendClusterMessage(nameMsg) 232 } 233} 234 235func (a *App) invalidateCacheForChannelMembers(channelID string) { 236 a.Srv().Store.User().InvalidateProfilesInChannelCache(channelID) 237 a.Srv().Store.Channel().InvalidateMemberCount(channelID) 238 a.Srv().Store.Channel().InvalidateGuestCount(channelID) 239} 240 241func (a *App) invalidateCacheForChannelMembersNotifyProps(channelID string) { 242 a.Srv().invalidateCacheForChannelMembersNotifyPropsSkipClusterSend(channelID) 243 244 if a.Cluster() != nil { 245 msg := &model.ClusterMessage{ 246 Event: model.ClusterEventInvalidateCacheForChannelMembersNotifyProps, 247 SendType: model.ClusterSendBestEffort, 248 Data: []byte(channelID), 249 } 250 a.Cluster().SendClusterMessage(msg) 251 } 252} 253 254func (a *App) invalidateCacheForChannelPosts(channelID string) { 255 a.Srv().Store.Channel().InvalidatePinnedPostCount(channelID) 256 a.Srv().Store.Post().InvalidateLastPostTimeCache(channelID) 257} 258 259func (a *App) InvalidateCacheForUser(userID string) { 260 a.Srv().invalidateCacheForUserSkipClusterSend(userID) 261 262 a.srv.userService.InvalidateCacheForUser(userID) 263} 264 265func (a *App) invalidateCacheForUserTeams(userID string) { 266 a.Srv().invalidateWebConnSessionCacheForUser(userID) 267 a.Srv().Store.Team().InvalidateAllTeamIdsForUser(userID) 268 269 if a.Cluster() != nil { 270 msg := &model.ClusterMessage{ 271 Event: model.ClusterEventInvalidateCacheForUserTeams, 272 SendType: model.ClusterSendBestEffort, 273 Data: []byte(userID), 274 } 275 a.Cluster().SendClusterMessage(msg) 276 } 277} 278 279// UpdateWebConnUserActivity sets the LastUserActivityAt of the hub for the given session. 280func (a *App) UpdateWebConnUserActivity(session model.Session, activityAt int64) { 281 hub := a.GetHubForUserId(session.UserId) 282 if hub != nil { 283 hub.UpdateActivity(session.UserId, session.Token, activityAt) 284 } 285} 286 287// SessionIsRegistered determines if a specific session has been registered 288func (a *App) SessionIsRegistered(session model.Session) bool { 289 hub := a.GetHubForUserId(session.UserId) 290 if hub != nil { 291 return hub.IsRegistered(session.UserId, session.Token) 292 } 293 return false 294} 295 296func (a *App) CheckWebConn(userID, connectionID string) *CheckConnResult { 297 hub := a.GetHubForUserId(userID) 298 if hub != nil { 299 return hub.CheckConn(userID, connectionID) 300 } 301 return nil 302} 303 304// Register registers a connection to the hub. 305func (h *Hub) Register(webConn *WebConn) { 306 select { 307 case h.register <- webConn: 308 case <-h.stop: 309 } 310} 311 312// Unregister unregisters a connection from the hub. 313func (h *Hub) Unregister(webConn *WebConn) { 314 select { 315 case h.unregister <- webConn: 316 case <-h.stop: 317 } 318} 319 320// Determines if a user's session is registered a connection from the hub. 321func (h *Hub) IsRegistered(userID, sessionToken string) bool { 322 ws := &webConnSessionMessage{ 323 userID: userID, 324 sessionToken: sessionToken, 325 isRegistered: make(chan bool), 326 } 327 select { 328 case h.checkRegistered <- ws: 329 return <-ws.isRegistered 330 case <-h.stop: 331 } 332 return false 333} 334 335func (h *Hub) CheckConn(userID, connectionID string) *CheckConnResult { 336 req := &webConnCheckMessage{ 337 userID: userID, 338 connectionID: connectionID, 339 result: make(chan *CheckConnResult), 340 } 341 select { 342 case h.checkConn <- req: 343 return <-req.result 344 case <-h.stop: 345 } 346 return nil 347} 348 349// Broadcast broadcasts the message to all connections in the hub. 350func (h *Hub) Broadcast(message *model.WebSocketEvent) { 351 // XXX: The hub nil check is because of the way we setup our tests. We call 352 // `app.NewServer()` which returns a server, but only after that, we call 353 // `wsapi.Init()` to initialize the hub. But in the `NewServer` call 354 // itself proceeds to broadcast some messages happily. This needs to be 355 // fixed once the wsapi cyclic dependency with server/app goes away. 356 // And possibly, we can look into doing the hub initialization inside 357 // NewServer itself. 358 if h != nil && message != nil { 359 if metrics := h.app.Metrics(); metrics != nil { 360 metrics.IncrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1) 361 } 362 select { 363 case h.broadcast <- message: 364 case <-h.stop: 365 } 366 } 367} 368 369// InvalidateUser invalidates the cache for the given user. 370func (h *Hub) InvalidateUser(userID string) { 371 select { 372 case h.invalidateUser <- userID: 373 case <-h.stop: 374 } 375} 376 377// UpdateActivity sets the LastUserActivityAt field for the connection 378// of the user. 379func (h *Hub) UpdateActivity(userID, sessionToken string, activityAt int64) { 380 select { 381 case h.activity <- &webConnActivityMessage{ 382 userID: userID, 383 sessionToken: sessionToken, 384 activityAt: activityAt, 385 }: 386 case <-h.stop: 387 } 388} 389 390// SendMessage sends the given message to the given connection. 391func (h *Hub) SendMessage(conn *WebConn, msg model.WebSocketMessage) { 392 select { 393 case h.directMsg <- &webConnDirectMessage{ 394 conn: conn, 395 msg: msg, 396 }: 397 case <-h.stop: 398 } 399} 400 401// Stop stops the hub. 402func (h *Hub) Stop() { 403 close(h.stop) 404 <-h.didStop 405} 406 407// Start starts the hub. 408func (h *Hub) Start() { 409 var doStart func() 410 var doRecoverableStart func() 411 var doRecover func() 412 413 doStart = func() { 414 mlog.Debug("Hub is starting", mlog.Int("index", h.connectionIndex)) 415 416 ticker := time.NewTicker(inactiveConnReaperInterval) 417 defer ticker.Stop() 418 419 connIndex := newHubConnectionIndex(inactiveConnReaperInterval) 420 421 for { 422 select { 423 case webSessionMessage := <-h.checkRegistered: 424 conns := connIndex.ForUser(webSessionMessage.userID) 425 var isRegistered bool 426 for _, conn := range conns { 427 if !conn.active { 428 continue 429 } 430 if conn.GetSessionToken() == webSessionMessage.sessionToken { 431 isRegistered = true 432 } 433 } 434 webSessionMessage.isRegistered <- isRegistered 435 case req := <-h.checkConn: 436 var res *CheckConnResult 437 conn := connIndex.GetInactiveByConnectionID(req.userID, req.connectionID) 438 if conn != nil { 439 res = &CheckConnResult{ 440 ConnectionID: req.connectionID, 441 UserID: req.userID, 442 ActiveQueue: conn.send, 443 DeadQueue: conn.deadQueue, 444 DeadQueuePointer: conn.deadQueuePointer, 445 } 446 } 447 req.result <- res 448 case <-ticker.C: 449 connIndex.RemoveInactiveConnections() 450 case webConn := <-h.register: 451 var oldConn *WebConn 452 if *h.app.Config().ServiceSettings.EnableReliableWebSockets { 453 // Delete the old conn from connIndex if it exists. 454 oldConn = connIndex.RemoveInactiveByConnectionID( 455 webConn.GetSession().UserId, 456 webConn.GetConnectionID()) 457 } 458 459 // Mark the current one as active. 460 // There is no need to check if it was inactive or not, 461 // we will anyways need to make it active. 462 webConn.active = true 463 464 connIndex.Add(webConn) 465 atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive())) 466 467 if webConn.IsAuthenticated() && oldConn == nil { 468 // The hello message should only be sent when the conn wasn't found. 469 // i.e in server restart, or long timeout, or fresh connection case. 470 // In case of seq number not found in dead queue, it is handled by 471 // the webconn write pump. 472 webConn.send <- webConn.createHelloMessage() 473 } 474 case webConn := <-h.unregister: 475 // If already removed (via queue full), then removing again becomes a noop. 476 // But if not removed, mark inactive. 477 if *h.app.Config().ServiceSettings.EnableReliableWebSockets { 478 webConn.active = false 479 } else { 480 connIndex.Remove(webConn) 481 } 482 483 atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive())) 484 485 if webConn.UserId == "" { 486 continue 487 } 488 489 conns := connIndex.ForUser(webConn.UserId) 490 if len(conns) == 0 || areAllInactive(conns) { 491 h.app.Srv().Go(func() { 492 h.app.SetStatusOffline(webConn.UserId, false) 493 }) 494 continue 495 } 496 var latestActivity int64 = 0 497 for _, conn := range conns { 498 if !conn.active { 499 continue 500 } 501 if conn.lastUserActivityAt > latestActivity { 502 latestActivity = conn.lastUserActivityAt 503 } 504 } 505 506 if h.app.IsUserAway(latestActivity) { 507 h.app.Srv().Go(func() { 508 h.app.SetStatusLastActivityAt(webConn.UserId, latestActivity) 509 }) 510 } 511 case userID := <-h.invalidateUser: 512 for _, webConn := range connIndex.ForUser(userID) { 513 webConn.InvalidateCache() 514 } 515 case activity := <-h.activity: 516 for _, webConn := range connIndex.ForUser(activity.userID) { 517 if !webConn.active { 518 continue 519 } 520 if webConn.GetSessionToken() == activity.sessionToken { 521 webConn.lastUserActivityAt = activity.activityAt 522 } 523 } 524 case directMsg := <-h.directMsg: 525 if !connIndex.Has(directMsg.conn) { 526 continue 527 } 528 select { 529 case directMsg.conn.send <- directMsg.msg: 530 default: 531 mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", directMsg.conn.UserId)) 532 close(directMsg.conn.send) 533 connIndex.Remove(directMsg.conn) 534 } 535 case msg := <-h.broadcast: 536 if metrics := h.app.Metrics(); metrics != nil { 537 metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1) 538 } 539 msg = msg.PrecomputeJSON() 540 broadcast := func(webConn *WebConn) { 541 if !connIndex.Has(webConn) { 542 return 543 } 544 if webConn.shouldSendEvent(msg) { 545 select { 546 case webConn.send <- msg: 547 default: 548 mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", webConn.UserId)) 549 close(webConn.send) 550 connIndex.Remove(webConn) 551 } 552 } 553 } 554 if msg.GetBroadcast().UserId != "" { 555 candidates := connIndex.ForUser(msg.GetBroadcast().UserId) 556 for _, webConn := range candidates { 557 broadcast(webConn) 558 } 559 continue 560 } 561 candidates := connIndex.All() 562 for webConn := range candidates { 563 broadcast(webConn) 564 } 565 case <-h.stop: 566 for webConn := range connIndex.All() { 567 webConn.Close() 568 h.app.SetStatusOffline(webConn.UserId, false) 569 } 570 571 h.explicitStop = true 572 close(h.didStop) 573 574 return 575 } 576 } 577 } 578 579 doRecoverableStart = func() { 580 defer doRecover() 581 doStart() 582 } 583 584 doRecover = func() { 585 if !h.explicitStop { 586 if r := recover(); r != nil { 587 mlog.Error("Recovering from Hub panic.", mlog.Any("panic", r)) 588 } else { 589 mlog.Error("Webhub stopped unexpectedly. Recovering.") 590 } 591 592 mlog.Error(string(debug.Stack())) 593 594 go doRecoverableStart() 595 } 596 } 597 598 go doRecoverableStart() 599} 600 601// hubConnectionIndex provides fast addition, removal, and iteration of web connections. 602// It requires 3 functionalities which need to be very fast: 603// - check if a connection exists or not. 604// - get all connections for a given userID. 605// - get all connections. 606type hubConnectionIndex struct { 607 // byUserId stores the list of connections for a given userID 608 byUserId map[string][]*WebConn 609 // byConnection serves the dual purpose of storing the index of the webconn 610 // in the value of byUserId map, and also to get all connections. 611 byConnection map[*WebConn]int 612 // staleThreshold is the limit beyond which inactive connections 613 // will be deleted. 614 staleThreshold time.Duration 615} 616 617func newHubConnectionIndex(interval time.Duration) *hubConnectionIndex { 618 return &hubConnectionIndex{ 619 byUserId: make(map[string][]*WebConn), 620 byConnection: make(map[*WebConn]int), 621 staleThreshold: interval, 622 } 623} 624 625func (i *hubConnectionIndex) Add(wc *WebConn) { 626 i.byUserId[wc.UserId] = append(i.byUserId[wc.UserId], wc) 627 i.byConnection[wc] = len(i.byUserId[wc.UserId]) - 1 628} 629 630func (i *hubConnectionIndex) Remove(wc *WebConn) { 631 userConnIndex, ok := i.byConnection[wc] 632 if !ok { 633 return 634 } 635 636 // get the conn slice. 637 userConnections := i.byUserId[wc.UserId] 638 // get the last connection. 639 last := userConnections[len(userConnections)-1] 640 // set the slot that we are trying to remove to be the last connection. 641 userConnections[userConnIndex] = last 642 // remove the last connection from the slice. 643 i.byUserId[wc.UserId] = userConnections[:len(userConnections)-1] 644 // set the index of the connection that was moved to the new index. 645 i.byConnection[last] = userConnIndex 646 647 delete(i.byConnection, wc) 648} 649 650func (i *hubConnectionIndex) Has(wc *WebConn) bool { 651 _, ok := i.byConnection[wc] 652 return ok 653} 654 655// ForUser returns all connections for a user ID. 656func (i *hubConnectionIndex) ForUser(id string) []*WebConn { 657 return i.byUserId[id] 658} 659 660// All returns the full webConn index. 661func (i *hubConnectionIndex) All() map[*WebConn]int { 662 return i.byConnection 663} 664 665// GetInactiveByConnectionID returns an inactive connection for the given 666// userID and connectionID. 667func (i *hubConnectionIndex) GetInactiveByConnectionID(userID, connectionID string) *WebConn { 668 // To handle empty sessions. 669 if userID == "" { 670 return nil 671 } 672 for _, conn := range i.ForUser(userID) { 673 if conn.GetConnectionID() == connectionID && !conn.active { 674 return conn 675 } 676 } 677 return nil 678} 679 680// RemoveInactiveByConnectionID removes an inactive connection for the given 681// userID and connectionID. 682func (i *hubConnectionIndex) RemoveInactiveByConnectionID(userID, connectionID string) *WebConn { 683 // To handle empty sessions. 684 if userID == "" { 685 return nil 686 } 687 for _, conn := range i.ForUser(userID) { 688 if conn.GetConnectionID() == connectionID && !conn.active { 689 i.Remove(conn) 690 return conn 691 } 692 } 693 return nil 694} 695 696// RemoveInactiveConnections removes all inactive connections whose lastUserActivityAt 697// exceeded staleThreshold. 698func (i *hubConnectionIndex) RemoveInactiveConnections() { 699 now := model.GetMillis() 700 for conn := range i.byConnection { 701 if !conn.active && now-conn.lastUserActivityAt > i.staleThreshold.Milliseconds() { 702 i.Remove(conn) 703 } 704 } 705} 706 707// AllActive returns the number of active connections. 708// This is only called during register/unregister so we can take 709// a bit of perf hit here. 710func (i *hubConnectionIndex) AllActive() int { 711 cnt := 0 712 for conn := range i.byConnection { 713 if conn.active { 714 cnt++ 715 } 716 } 717 return cnt 718} 719