1// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
2// See LICENSE.txt for license information.
3
4package app
5
6import (
7	"hash/maphash"
8	"runtime"
9	"runtime/debug"
10	"strconv"
11	"sync/atomic"
12	"time"
13
14	"github.com/mattermost/mattermost-server/v6/model"
15	"github.com/mattermost/mattermost-server/v6/shared/mlog"
16)
17
18const (
19	broadcastQueueSize         = 4096
20	inactiveConnReaperInterval = 5 * time.Minute
21)
22
23type webConnActivityMessage struct {
24	userID       string
25	sessionToken string
26	activityAt   int64
27}
28
29type webConnDirectMessage struct {
30	conn *WebConn
31	msg  model.WebSocketMessage
32}
33
34type webConnSessionMessage struct {
35	userID       string
36	sessionToken string
37	isRegistered chan bool
38}
39
40type webConnCheckMessage struct {
41	userID       string
42	connectionID string
43	result       chan *CheckConnResult
44}
45
46// Hub is the central place to manage all websocket connections in the server.
47// It handles different websocket events and sending messages to individual
48// user connections.
49type Hub struct {
50	// connectionCount should be kept first.
51	// See https://github.com/mattermost/mattermost-server/pull/7281
52	connectionCount int64
53	app             *App
54	connectionIndex int
55	register        chan *WebConn
56	unregister      chan *WebConn
57	broadcast       chan *model.WebSocketEvent
58	stop            chan struct{}
59	didStop         chan struct{}
60	invalidateUser  chan string
61	activity        chan *webConnActivityMessage
62	directMsg       chan *webConnDirectMessage
63	explicitStop    bool
64	checkRegistered chan *webConnSessionMessage
65	checkConn       chan *webConnCheckMessage
66}
67
68// NewWebHub creates a new Hub.
69func (a *App) NewWebHub() *Hub {
70	return &Hub{
71		app:             a,
72		register:        make(chan *WebConn),
73		unregister:      make(chan *WebConn),
74		broadcast:       make(chan *model.WebSocketEvent, broadcastQueueSize),
75		stop:            make(chan struct{}),
76		didStop:         make(chan struct{}),
77		invalidateUser:  make(chan string),
78		activity:        make(chan *webConnActivityMessage),
79		directMsg:       make(chan *webConnDirectMessage),
80		checkRegistered: make(chan *webConnSessionMessage),
81		checkConn:       make(chan *webConnCheckMessage),
82	}
83}
84
85func (a *App) TotalWebsocketConnections() int {
86	return a.Srv().TotalWebsocketConnections()
87}
88
89// HubStart starts all the hubs.
90func (a *App) HubStart() {
91	// Total number of hubs is twice the number of CPUs.
92	numberOfHubs := runtime.NumCPU() * 2
93	mlog.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
94
95	hubs := make([]*Hub, numberOfHubs)
96
97	for i := 0; i < numberOfHubs; i++ {
98		hubs[i] = a.NewWebHub()
99		hubs[i].connectionIndex = i
100		hubs[i].Start()
101	}
102	// Assigning to the hubs slice without any mutex is fine because it is only assigned once
103	// during the start of the program and always read from after that.
104	a.srv.hubs = hubs
105}
106
107func (a *App) invalidateCacheForWebhook(webhookID string) {
108	a.Srv().Store.Webhook().InvalidateWebhookCache(webhookID)
109}
110
111// HubStop stops all the hubs.
112func (s *Server) HubStop() {
113	mlog.Info("stopping websocket hub connections")
114
115	for _, hub := range s.hubs {
116		hub.Stop()
117	}
118}
119
120func (a *App) HubStop() {
121	a.Srv().HubStop()
122}
123
124// GetHubForUserId returns the hub for a given user id.
125func (s *Server) GetHubForUserId(userID string) *Hub {
126	// TODO: check if caching the userID -> hub mapping
127	// is worth the memory tradeoff.
128	// https://mattermost.atlassian.net/browse/MM-26629.
129	var hash maphash.Hash
130	hash.SetSeed(s.hashSeed)
131	hash.Write([]byte(userID))
132	index := hash.Sum64() % uint64(len(s.hubs))
133
134	return s.hubs[int(index)]
135}
136
137func (a *App) GetHubForUserId(userID string) *Hub {
138	return a.Srv().GetHubForUserId(userID)
139}
140
141// HubRegister registers a connection to a hub.
142func (a *App) HubRegister(webConn *WebConn) {
143	hub := a.GetHubForUserId(webConn.UserId)
144	if hub != nil {
145		if metrics := a.Metrics(); metrics != nil {
146			metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
147		}
148		hub.Register(webConn)
149	}
150}
151
152// HubUnregister unregisters a connection from a hub.
153func (a *App) HubUnregister(webConn *WebConn) {
154	hub := a.GetHubForUserId(webConn.UserId)
155	if hub != nil {
156		if metrics := a.Metrics(); metrics != nil {
157			metrics.DecrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
158		}
159		hub.Unregister(webConn)
160	}
161}
162
163func (s *Server) Publish(message *model.WebSocketEvent) {
164	if s.Metrics != nil {
165		s.Metrics.IncrementWebsocketEvent(message.EventType())
166	}
167
168	s.PublishSkipClusterSend(message)
169
170	if s.Cluster != nil {
171		data, err := message.ToJSON()
172		if err != nil {
173			mlog.Warn("Failed to encode message to JSON", mlog.Err(err))
174		}
175		cm := &model.ClusterMessage{
176			Event:    model.ClusterEventPublish,
177			SendType: model.ClusterSendBestEffort,
178			Data:     data,
179		}
180
181		if message.EventType() == model.WebsocketEventPosted ||
182			message.EventType() == model.WebsocketEventPostEdited ||
183			message.EventType() == model.WebsocketEventDirectAdded ||
184			message.EventType() == model.WebsocketEventGroupAdded ||
185			message.EventType() == model.WebsocketEventAddedToTeam {
186			cm.SendType = model.ClusterSendReliable
187		}
188
189		s.Cluster.SendClusterMessage(cm)
190	}
191}
192
193func (a *App) Publish(message *model.WebSocketEvent) {
194	a.Srv().Publish(message)
195}
196
197func (s *Server) PublishSkipClusterSend(event *model.WebSocketEvent) {
198	if event.GetBroadcast().UserId != "" {
199		hub := s.GetHubForUserId(event.GetBroadcast().UserId)
200		if hub != nil {
201			hub.Broadcast(event)
202		}
203	} else {
204		for _, hub := range s.hubs {
205			hub.Broadcast(event)
206		}
207	}
208
209	// Notify shared channel sync service
210	s.SharedChannelSyncHandler(event)
211}
212
213func (a *App) invalidateCacheForChannel(channel *model.Channel) {
214	a.Srv().Store.Channel().InvalidateChannel(channel.Id)
215	a.Srv().invalidateCacheForChannelByNameSkipClusterSend(channel.TeamId, channel.Name)
216
217	if a.Cluster() != nil {
218		nameMsg := &model.ClusterMessage{
219			Event:    model.ClusterEventInvalidateCacheForChannelByName,
220			SendType: model.ClusterSendBestEffort,
221			Props:    make(map[string]string),
222		}
223
224		nameMsg.Props["name"] = channel.Name
225		if channel.TeamId == "" {
226			nameMsg.Props["id"] = "dm"
227		} else {
228			nameMsg.Props["id"] = channel.TeamId
229		}
230
231		a.Cluster().SendClusterMessage(nameMsg)
232	}
233}
234
235func (a *App) invalidateCacheForChannelMembers(channelID string) {
236	a.Srv().Store.User().InvalidateProfilesInChannelCache(channelID)
237	a.Srv().Store.Channel().InvalidateMemberCount(channelID)
238	a.Srv().Store.Channel().InvalidateGuestCount(channelID)
239}
240
241func (a *App) invalidateCacheForChannelMembersNotifyProps(channelID string) {
242	a.Srv().invalidateCacheForChannelMembersNotifyPropsSkipClusterSend(channelID)
243
244	if a.Cluster() != nil {
245		msg := &model.ClusterMessage{
246			Event:    model.ClusterEventInvalidateCacheForChannelMembersNotifyProps,
247			SendType: model.ClusterSendBestEffort,
248			Data:     []byte(channelID),
249		}
250		a.Cluster().SendClusterMessage(msg)
251	}
252}
253
254func (a *App) invalidateCacheForChannelPosts(channelID string) {
255	a.Srv().Store.Channel().InvalidatePinnedPostCount(channelID)
256	a.Srv().Store.Post().InvalidateLastPostTimeCache(channelID)
257}
258
259func (a *App) InvalidateCacheForUser(userID string) {
260	a.Srv().invalidateCacheForUserSkipClusterSend(userID)
261
262	a.srv.userService.InvalidateCacheForUser(userID)
263}
264
265func (a *App) invalidateCacheForUserTeams(userID string) {
266	a.Srv().invalidateWebConnSessionCacheForUser(userID)
267	a.Srv().Store.Team().InvalidateAllTeamIdsForUser(userID)
268
269	if a.Cluster() != nil {
270		msg := &model.ClusterMessage{
271			Event:    model.ClusterEventInvalidateCacheForUserTeams,
272			SendType: model.ClusterSendBestEffort,
273			Data:     []byte(userID),
274		}
275		a.Cluster().SendClusterMessage(msg)
276	}
277}
278
279// UpdateWebConnUserActivity sets the LastUserActivityAt of the hub for the given session.
280func (a *App) UpdateWebConnUserActivity(session model.Session, activityAt int64) {
281	hub := a.GetHubForUserId(session.UserId)
282	if hub != nil {
283		hub.UpdateActivity(session.UserId, session.Token, activityAt)
284	}
285}
286
287// SessionIsRegistered determines if a specific session has been registered
288func (a *App) SessionIsRegistered(session model.Session) bool {
289	hub := a.GetHubForUserId(session.UserId)
290	if hub != nil {
291		return hub.IsRegistered(session.UserId, session.Token)
292	}
293	return false
294}
295
296func (a *App) CheckWebConn(userID, connectionID string) *CheckConnResult {
297	hub := a.GetHubForUserId(userID)
298	if hub != nil {
299		return hub.CheckConn(userID, connectionID)
300	}
301	return nil
302}
303
304// Register registers a connection to the hub.
305func (h *Hub) Register(webConn *WebConn) {
306	select {
307	case h.register <- webConn:
308	case <-h.stop:
309	}
310}
311
312// Unregister unregisters a connection from the hub.
313func (h *Hub) Unregister(webConn *WebConn) {
314	select {
315	case h.unregister <- webConn:
316	case <-h.stop:
317	}
318}
319
320// Determines if a user's session is registered a connection from the hub.
321func (h *Hub) IsRegistered(userID, sessionToken string) bool {
322	ws := &webConnSessionMessage{
323		userID:       userID,
324		sessionToken: sessionToken,
325		isRegistered: make(chan bool),
326	}
327	select {
328	case h.checkRegistered <- ws:
329		return <-ws.isRegistered
330	case <-h.stop:
331	}
332	return false
333}
334
335func (h *Hub) CheckConn(userID, connectionID string) *CheckConnResult {
336	req := &webConnCheckMessage{
337		userID:       userID,
338		connectionID: connectionID,
339		result:       make(chan *CheckConnResult),
340	}
341	select {
342	case h.checkConn <- req:
343		return <-req.result
344	case <-h.stop:
345	}
346	return nil
347}
348
349// Broadcast broadcasts the message to all connections in the hub.
350func (h *Hub) Broadcast(message *model.WebSocketEvent) {
351	// XXX: The hub nil check is because of the way we setup our tests. We call
352	// `app.NewServer()` which returns a server, but only after that, we call
353	// `wsapi.Init()` to initialize the hub.  But in the `NewServer` call
354	// itself proceeds to broadcast some messages happily.  This needs to be
355	// fixed once the wsapi cyclic dependency with server/app goes away.
356	// And possibly, we can look into doing the hub initialization inside
357	// NewServer itself.
358	if h != nil && message != nil {
359		if metrics := h.app.Metrics(); metrics != nil {
360			metrics.IncrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
361		}
362		select {
363		case h.broadcast <- message:
364		case <-h.stop:
365		}
366	}
367}
368
369// InvalidateUser invalidates the cache for the given user.
370func (h *Hub) InvalidateUser(userID string) {
371	select {
372	case h.invalidateUser <- userID:
373	case <-h.stop:
374	}
375}
376
377// UpdateActivity sets the LastUserActivityAt field for the connection
378// of the user.
379func (h *Hub) UpdateActivity(userID, sessionToken string, activityAt int64) {
380	select {
381	case h.activity <- &webConnActivityMessage{
382		userID:       userID,
383		sessionToken: sessionToken,
384		activityAt:   activityAt,
385	}:
386	case <-h.stop:
387	}
388}
389
390// SendMessage sends the given message to the given connection.
391func (h *Hub) SendMessage(conn *WebConn, msg model.WebSocketMessage) {
392	select {
393	case h.directMsg <- &webConnDirectMessage{
394		conn: conn,
395		msg:  msg,
396	}:
397	case <-h.stop:
398	}
399}
400
401// Stop stops the hub.
402func (h *Hub) Stop() {
403	close(h.stop)
404	<-h.didStop
405}
406
407// Start starts the hub.
408func (h *Hub) Start() {
409	var doStart func()
410	var doRecoverableStart func()
411	var doRecover func()
412
413	doStart = func() {
414		mlog.Debug("Hub is starting", mlog.Int("index", h.connectionIndex))
415
416		ticker := time.NewTicker(inactiveConnReaperInterval)
417		defer ticker.Stop()
418
419		connIndex := newHubConnectionIndex(inactiveConnReaperInterval)
420
421		for {
422			select {
423			case webSessionMessage := <-h.checkRegistered:
424				conns := connIndex.ForUser(webSessionMessage.userID)
425				var isRegistered bool
426				for _, conn := range conns {
427					if !conn.active {
428						continue
429					}
430					if conn.GetSessionToken() == webSessionMessage.sessionToken {
431						isRegistered = true
432					}
433				}
434				webSessionMessage.isRegistered <- isRegistered
435			case req := <-h.checkConn:
436				var res *CheckConnResult
437				conn := connIndex.GetInactiveByConnectionID(req.userID, req.connectionID)
438				if conn != nil {
439					res = &CheckConnResult{
440						ConnectionID:     req.connectionID,
441						UserID:           req.userID,
442						ActiveQueue:      conn.send,
443						DeadQueue:        conn.deadQueue,
444						DeadQueuePointer: conn.deadQueuePointer,
445					}
446				}
447				req.result <- res
448			case <-ticker.C:
449				connIndex.RemoveInactiveConnections()
450			case webConn := <-h.register:
451				var oldConn *WebConn
452				if *h.app.Config().ServiceSettings.EnableReliableWebSockets {
453					// Delete the old conn from connIndex if it exists.
454					oldConn = connIndex.RemoveInactiveByConnectionID(
455						webConn.GetSession().UserId,
456						webConn.GetConnectionID())
457				}
458
459				// Mark the current one as active.
460				// There is no need to check if it was inactive or not,
461				// we will anyways need to make it active.
462				webConn.active = true
463
464				connIndex.Add(webConn)
465				atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
466
467				if webConn.IsAuthenticated() && oldConn == nil {
468					// The hello message should only be sent when the conn wasn't found.
469					// i.e in server restart, or long timeout, or fresh connection case.
470					// In case of seq number not found in dead queue, it is handled by
471					// the webconn write pump.
472					webConn.send <- webConn.createHelloMessage()
473				}
474			case webConn := <-h.unregister:
475				// If already removed (via queue full), then removing again becomes a noop.
476				// But if not removed, mark inactive.
477				if *h.app.Config().ServiceSettings.EnableReliableWebSockets {
478					webConn.active = false
479				} else {
480					connIndex.Remove(webConn)
481				}
482
483				atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
484
485				if webConn.UserId == "" {
486					continue
487				}
488
489				conns := connIndex.ForUser(webConn.UserId)
490				if len(conns) == 0 || areAllInactive(conns) {
491					h.app.Srv().Go(func() {
492						h.app.SetStatusOffline(webConn.UserId, false)
493					})
494					continue
495				}
496				var latestActivity int64 = 0
497				for _, conn := range conns {
498					if !conn.active {
499						continue
500					}
501					if conn.lastUserActivityAt > latestActivity {
502						latestActivity = conn.lastUserActivityAt
503					}
504				}
505
506				if h.app.IsUserAway(latestActivity) {
507					h.app.Srv().Go(func() {
508						h.app.SetStatusLastActivityAt(webConn.UserId, latestActivity)
509					})
510				}
511			case userID := <-h.invalidateUser:
512				for _, webConn := range connIndex.ForUser(userID) {
513					webConn.InvalidateCache()
514				}
515			case activity := <-h.activity:
516				for _, webConn := range connIndex.ForUser(activity.userID) {
517					if !webConn.active {
518						continue
519					}
520					if webConn.GetSessionToken() == activity.sessionToken {
521						webConn.lastUserActivityAt = activity.activityAt
522					}
523				}
524			case directMsg := <-h.directMsg:
525				if !connIndex.Has(directMsg.conn) {
526					continue
527				}
528				select {
529				case directMsg.conn.send <- directMsg.msg:
530				default:
531					mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", directMsg.conn.UserId))
532					close(directMsg.conn.send)
533					connIndex.Remove(directMsg.conn)
534				}
535			case msg := <-h.broadcast:
536				if metrics := h.app.Metrics(); metrics != nil {
537					metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
538				}
539				msg = msg.PrecomputeJSON()
540				broadcast := func(webConn *WebConn) {
541					if !connIndex.Has(webConn) {
542						return
543					}
544					if webConn.shouldSendEvent(msg) {
545						select {
546						case webConn.send <- msg:
547						default:
548							mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", webConn.UserId))
549							close(webConn.send)
550							connIndex.Remove(webConn)
551						}
552					}
553				}
554				if msg.GetBroadcast().UserId != "" {
555					candidates := connIndex.ForUser(msg.GetBroadcast().UserId)
556					for _, webConn := range candidates {
557						broadcast(webConn)
558					}
559					continue
560				}
561				candidates := connIndex.All()
562				for webConn := range candidates {
563					broadcast(webConn)
564				}
565			case <-h.stop:
566				for webConn := range connIndex.All() {
567					webConn.Close()
568					h.app.SetStatusOffline(webConn.UserId, false)
569				}
570
571				h.explicitStop = true
572				close(h.didStop)
573
574				return
575			}
576		}
577	}
578
579	doRecoverableStart = func() {
580		defer doRecover()
581		doStart()
582	}
583
584	doRecover = func() {
585		if !h.explicitStop {
586			if r := recover(); r != nil {
587				mlog.Error("Recovering from Hub panic.", mlog.Any("panic", r))
588			} else {
589				mlog.Error("Webhub stopped unexpectedly. Recovering.")
590			}
591
592			mlog.Error(string(debug.Stack()))
593
594			go doRecoverableStart()
595		}
596	}
597
598	go doRecoverableStart()
599}
600
601// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
602// It requires 3 functionalities which need to be very fast:
603// - check if a connection exists or not.
604// - get all connections for a given userID.
605// - get all connections.
606type hubConnectionIndex struct {
607	// byUserId stores the list of connections for a given userID
608	byUserId map[string][]*WebConn
609	// byConnection serves the dual purpose of storing the index of the webconn
610	// in the value of byUserId map, and also to get all connections.
611	byConnection map[*WebConn]int
612	// staleThreshold is the limit beyond which inactive connections
613	// will be deleted.
614	staleThreshold time.Duration
615}
616
617func newHubConnectionIndex(interval time.Duration) *hubConnectionIndex {
618	return &hubConnectionIndex{
619		byUserId:       make(map[string][]*WebConn),
620		byConnection:   make(map[*WebConn]int),
621		staleThreshold: interval,
622	}
623}
624
625func (i *hubConnectionIndex) Add(wc *WebConn) {
626	i.byUserId[wc.UserId] = append(i.byUserId[wc.UserId], wc)
627	i.byConnection[wc] = len(i.byUserId[wc.UserId]) - 1
628}
629
630func (i *hubConnectionIndex) Remove(wc *WebConn) {
631	userConnIndex, ok := i.byConnection[wc]
632	if !ok {
633		return
634	}
635
636	// get the conn slice.
637	userConnections := i.byUserId[wc.UserId]
638	// get the last connection.
639	last := userConnections[len(userConnections)-1]
640	// set the slot that we are trying to remove to be the last connection.
641	userConnections[userConnIndex] = last
642	// remove the last connection from the slice.
643	i.byUserId[wc.UserId] = userConnections[:len(userConnections)-1]
644	// set the index of the connection that was moved to the new index.
645	i.byConnection[last] = userConnIndex
646
647	delete(i.byConnection, wc)
648}
649
650func (i *hubConnectionIndex) Has(wc *WebConn) bool {
651	_, ok := i.byConnection[wc]
652	return ok
653}
654
655// ForUser returns all connections for a user ID.
656func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
657	return i.byUserId[id]
658}
659
660// All returns the full webConn index.
661func (i *hubConnectionIndex) All() map[*WebConn]int {
662	return i.byConnection
663}
664
665// GetInactiveByConnectionID returns an inactive connection for the given
666// userID and connectionID.
667func (i *hubConnectionIndex) GetInactiveByConnectionID(userID, connectionID string) *WebConn {
668	// To handle empty sessions.
669	if userID == "" {
670		return nil
671	}
672	for _, conn := range i.ForUser(userID) {
673		if conn.GetConnectionID() == connectionID && !conn.active {
674			return conn
675		}
676	}
677	return nil
678}
679
680// RemoveInactiveByConnectionID removes an inactive connection for the given
681// userID and connectionID.
682func (i *hubConnectionIndex) RemoveInactiveByConnectionID(userID, connectionID string) *WebConn {
683	// To handle empty sessions.
684	if userID == "" {
685		return nil
686	}
687	for _, conn := range i.ForUser(userID) {
688		if conn.GetConnectionID() == connectionID && !conn.active {
689			i.Remove(conn)
690			return conn
691		}
692	}
693	return nil
694}
695
696// RemoveInactiveConnections removes all inactive connections whose lastUserActivityAt
697// exceeded staleThreshold.
698func (i *hubConnectionIndex) RemoveInactiveConnections() {
699	now := model.GetMillis()
700	for conn := range i.byConnection {
701		if !conn.active && now-conn.lastUserActivityAt > i.staleThreshold.Milliseconds() {
702			i.Remove(conn)
703		}
704	}
705}
706
707// AllActive returns the number of active connections.
708// This is only called during register/unregister so we can take
709// a bit of perf hit here.
710func (i *hubConnectionIndex) AllActive() int {
711	cnt := 0
712	for conn := range i.byConnection {
713		if conn.active {
714			cnt++
715		}
716	}
717	return cnt
718}
719