1package centrifuge
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io"
8	"sync"
9	"time"
10
11	"github.com/centrifugal/centrifuge/internal/prepared"
12	"github.com/centrifugal/centrifuge/internal/queue"
13	"github.com/centrifugal/centrifuge/internal/recovery"
14
15	"github.com/centrifugal/protocol"
16	"github.com/google/uuid"
17)
18
19// clientEventHub allows to deal with client event handlers.
20// All its methods are not goroutine-safe and supposed to be called
21// once inside Node ConnectHandler.
22type clientEventHub struct {
23	aliveHandler         AliveHandler
24	disconnectHandler    DisconnectHandler
25	subscribeHandler     SubscribeHandler
26	unsubscribeHandler   UnsubscribeHandler
27	publishHandler       PublishHandler
28	refreshHandler       RefreshHandler
29	subRefreshHandler    SubRefreshHandler
30	rpcHandler           RPCHandler
31	messageHandler       MessageHandler
32	presenceHandler      PresenceHandler
33	presenceStatsHandler PresenceStatsHandler
34	historyHandler       HistoryHandler
35}
36
37// OnAlive allows setting AliveHandler.
38// AliveHandler called periodically for active client connection.
39func (c *Client) OnAlive(h AliveHandler) {
40	c.eventHub.aliveHandler = h
41}
42
43// OnRefresh allows setting RefreshHandler.
44// RefreshHandler called when it's time to refresh expiring client connection.
45func (c *Client) OnRefresh(h RefreshHandler) {
46	c.eventHub.refreshHandler = h
47}
48
49// OnDisconnect allows setting DisconnectHandler.
50// DisconnectHandler called when client disconnected.
51func (c *Client) OnDisconnect(h DisconnectHandler) {
52	c.eventHub.disconnectHandler = h
53}
54
55// OnMessage allows setting MessageHandler.
56// MessageHandler called when client sent asynchronous message.
57func (c *Client) OnMessage(h MessageHandler) {
58	c.eventHub.messageHandler = h
59}
60
61// OnRPC allows setting RPCHandler.
62// RPCHandler will be executed on every incoming RPC call.
63func (c *Client) OnRPC(h RPCHandler) {
64	c.eventHub.rpcHandler = h
65}
66
67// OnSubRefresh allows setting SubRefreshHandler.
68// SubRefreshHandler called when it's time to refresh client subscription.
69func (c *Client) OnSubRefresh(h SubRefreshHandler) {
70	c.eventHub.subRefreshHandler = h
71}
72
73// OnSubscribe allows setting SubscribeHandler.
74// SubscribeHandler called when client subscribes on a channel.
75func (c *Client) OnSubscribe(h SubscribeHandler) {
76	c.eventHub.subscribeHandler = h
77}
78
79// OnUnsubscribe allows setting UnsubscribeHandler.
80// UnsubscribeHandler called when client unsubscribes from channel.
81func (c *Client) OnUnsubscribe(h UnsubscribeHandler) {
82	c.eventHub.unsubscribeHandler = h
83}
84
85// OnPublish allows setting PublishHandler.
86// PublishHandler called when client publishes message into channel.
87func (c *Client) OnPublish(h PublishHandler) {
88	c.eventHub.publishHandler = h
89}
90
91// OnPresence allows setting PresenceHandler.
92// PresenceHandler called when Presence request from client received.
93// At this moment you can only return a custom error or disconnect client.
94func (c *Client) OnPresence(h PresenceHandler) {
95	c.eventHub.presenceHandler = h
96}
97
98// OnPresenceStats allows settings PresenceStatsHandler.
99// PresenceStatsHandler called when Presence Stats request from client received.
100// At this moment you can only return a custom error or disconnect client.
101func (c *Client) OnPresenceStats(h PresenceStatsHandler) {
102	c.eventHub.presenceStatsHandler = h
103}
104
105// OnHistory allows settings HistoryHandler.
106// HistoryHandler called when History request from client received.
107// At this moment you can only return a custom error or disconnect client.
108func (c *Client) OnHistory(h HistoryHandler) {
109	c.eventHub.historyHandler = h
110}
111
112// We poll current position in channel from history storage periodically.
113// If client position is wrong maxCheckPositionFailures times in a row
114// then client will be disconnected with InsufficientState reason. Polling
115// not used in channels with high frequency updates since we can check position
116// comparing client offset with offset in incoming Publication.
117const maxCheckPositionFailures uint8 = 2
118
119// Note: up to 8 possible flags here.
120const (
121	// flagSubscribed will be set upon successful Subscription to a channel.
122	// Until that moment channel exists in client Channels map only to track
123	// duplicate subscription requests.
124	flagSubscribed uint8 = 1 << iota
125	flagPresence
126	flagJoinLeave
127	flagPosition
128	flagRecover
129	flagServerSide
130	flagClientSideRefresh
131)
132
133// channelContext contains extra context for channel connection subscribed to.
134// Note: this struct is aligned to consume less memory.
135type channelContext struct {
136	Info                  []byte
137	expireAt              int64
138	positionCheckTime     int64
139	streamPosition        StreamPosition
140	positionCheckFailures uint8
141	flags                 uint8
142}
143
144func channelHasFlag(flags, flag uint8) bool {
145	return flags&flag != 0
146}
147
148type timerOp uint8
149
150const (
151	timerOpStale    timerOp = 1
152	timerOpPresence timerOp = 2
153	timerOpExpire   timerOp = 3
154)
155
156type status uint8
157
158const (
159	statusConnecting status = 1
160	statusConnected  status = 2
161	statusClosed     status = 3
162)
163
164// ConnectRequest can be used in a unidirectional connection case to
165// pass initial connection information from a client-side.
166type ConnectRequest struct {
167	// Token is an optional token from a client.
168	Token string
169	// Data is an optional custom data from a client.
170	Data []byte
171	// Name of a client.
172	Name string
173	// Version of a client.
174	Version string
175	// Subs is a map with channel subscription state (for recovery on connect).
176	Subs map[string]SubscribeRequest
177}
178
179// SubscribeRequest contains state of subscription to a channel.
180type SubscribeRequest struct {
181	// Recover enables publication recovery for a channel.
182	Recover bool
183	// Epoch last seen by a client.
184	Epoch string
185	// Offset last seen by a client.
186	Offset uint64
187}
188
189func (r *ConnectRequest) toProto() *protocol.ConnectRequest {
190	if r == nil {
191		return nil
192	}
193	req := &protocol.ConnectRequest{
194		Token:   r.Token,
195		Data:    r.Data,
196		Name:    r.Name,
197		Version: r.Version,
198	}
199	if len(r.Subs) > 0 {
200		subs := make(map[string]*protocol.SubscribeRequest, len(r.Subs))
201		for k, v := range r.Subs {
202			subs[k] = &protocol.SubscribeRequest{
203				Recover: v.Recover,
204				Epoch:   v.Epoch,
205				Offset:  v.Offset,
206			}
207		}
208		req.Subs = subs
209	}
210	return req
211}
212
213// Client represents client connection to server.
214type Client struct {
215	mu                sync.RWMutex
216	connectMu         sync.Mutex // allows to sync connect with disconnect.
217	presenceMu        sync.Mutex // allows to sync presence routine with client closing.
218	ctx               context.Context
219	transport         Transport
220	node              *Node
221	exp               int64
222	channels          map[string]channelContext
223	messageWriter     *writer
224	pubSubSync        *recovery.PubSubSync
225	uid               string
226	user              string
227	info              []byte
228	authenticated     bool
229	clientSideRefresh bool
230	status            status
231	timerOp           timerOp
232	nextPresence      int64
233	nextExpire        int64
234	eventHub          *clientEventHub
235	timer             *time.Timer
236}
237
238// ClientCloseFunc must be called on Transport handler close to clean up Client.
239type ClientCloseFunc func() error
240
241// NewClient initializes new Client.
242func NewClient(ctx context.Context, n *Node, t Transport) (*Client, ClientCloseFunc, error) {
243	uuidObject, err := uuid.NewRandom()
244	if err != nil {
245		return nil, nil, err
246	}
247
248	client := &Client{
249		ctx:        ctx,
250		uid:        uuidObject.String(),
251		node:       n,
252		transport:  t,
253		channels:   make(map[string]channelContext),
254		pubSubSync: recovery.NewPubSubSync(),
255		status:     statusConnecting,
256		eventHub:   &clientEventHub{},
257	}
258
259	messageWriterConf := writerConfig{
260		MaxQueueSize: n.config.ClientQueueMaxSize,
261		WriteFn: func(item queue.Item) error {
262			if client.node.transportWriteHandler != nil {
263				pass := client.node.transportWriteHandler(client, TransportWriteEvent(item))
264				if !pass {
265					return nil
266				}
267			}
268			if err := t.Write(item.Data); err != nil {
269				switch v := err.(type) {
270				case *Disconnect:
271					go func() { _ = client.close(v) }()
272				default:
273					go func() { _ = client.close(DisconnectWriteError) }()
274				}
275				return err
276			}
277			incTransportMessagesSent(t.Name())
278			return nil
279		},
280		WriteManyFn: func(items ...queue.Item) error {
281			messages := make([][]byte, 0, len(items))
282			for i := 0; i < len(items); i++ {
283				if client.node.transportWriteHandler != nil {
284					pass := client.node.transportWriteHandler(client, TransportWriteEvent(items[i]))
285					if !pass {
286						continue
287					}
288				}
289				messages = append(messages, items[i].Data)
290			}
291			if err := t.WriteMany(messages...); err != nil {
292				switch v := err.(type) {
293				case *Disconnect:
294					go func() { _ = client.close(v) }()
295				default:
296					go func() { _ = client.close(DisconnectWriteError) }()
297				}
298				return err
299			}
300			addTransportMessagesSent(t.Name(), float64(len(items)))
301			return nil
302		},
303	}
304
305	client.messageWriter = newWriter(messageWriterConf)
306	go client.messageWriter.run()
307
308	staleCloseDelay := n.config.ClientStaleCloseDelay
309	if staleCloseDelay > 0 && !client.authenticated {
310		client.mu.Lock()
311		client.timerOp = timerOpStale
312		client.timer = time.AfterFunc(staleCloseDelay, client.onTimerOp)
313		client.mu.Unlock()
314	}
315	return client, func() error { return client.close(nil) }, nil
316}
317
318func extractUnidirectionalDisconnect(err error) *Disconnect {
319	if err == nil {
320		return nil
321	}
322	var d *Disconnect
323	switch t := err.(type) {
324	case *Disconnect:
325		d = t
326	case *Error:
327		switch t.Code {
328		case ErrorExpired.Code:
329			d = DisconnectExpired
330		case ErrorTokenExpired.Code:
331			d = DisconnectExpired
332		default:
333			d = DisconnectServerError
334		}
335	default:
336		d = DisconnectServerError
337	}
338	return d
339}
340
341// Connect supposed to be called from unidirectional transport layer to pass
342// initial information about connection and thus initiate Node.OnConnecting
343// event. Bidirectional transport initiate connecting workflow automatically
344// since client passes Connect command upon successful connection establishment
345// with a server.
346func (c *Client) Connect(req ConnectRequest) {
347	err := c.unidirectionalConnect(req.toProto())
348	if err != nil {
349		d := extractUnidirectionalDisconnect(err)
350		go func() { _ = c.close(d) }()
351	}
352}
353
354func (c *Client) encodeDisconnect(d *Disconnect) (*prepared.Reply, error) {
355	disconnect := &protocol.Disconnect{
356		Code:      d.Code,
357		Reason:    d.Reason,
358		Reconnect: d.Reconnect,
359	}
360	pushBytes, err := protocol.EncodeDisconnectPush(c.transport.Protocol().toProto(), disconnect)
361	if err != nil {
362		return nil, err
363	}
364	return prepared.NewReply(&protocol.Reply{
365		Result: pushBytes,
366	}, c.transport.Protocol().toProto()), nil
367}
368
369func (c *Client) encodeConnectPush(res *protocol.ConnectResult) ([]byte, error) {
370	p := &protocol.Connect{
371		Version: res.GetVersion(),
372		Client:  res.GetClient(),
373		Data:    res.Data,
374		Subs:    res.Subs,
375		Expires: res.Expires,
376		Ttl:     res.Ttl,
377	}
378	return protocol.EncodeConnectPush(c.transport.Protocol().toProto(), p)
379}
380
381func hasFlag(flags, flag uint64) bool {
382	return flags&flag != 0
383}
384
385func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest) error {
386	write := func(rep *protocol.Reply) error {
387		if hasFlag(c.transport.DisabledPushFlags(), PushFlagConnect) {
388			return nil
389		}
390		c.trace("-->", rep.Result)
391		disconnect := c.messageWriter.enqueue(queue.Item{Data: rep.Result, IsPush: false})
392		if disconnect != nil {
393			if c.node.logger.enabled(LogLevelDebug) {
394				c.node.logger.log(newLogEntry(LogLevelDebug, "disconnect after connect push", map[string]interface{}{"client": c.ID(), "user": c.UserID(), "reason": disconnect.Reason}))
395			}
396			go func() { _ = c.close(disconnect) }()
397		}
398		return disconnect
399	}
400
401	rw := &replyWriter{write, func() {}}
402
403	_, err := c.connectCmd(connectRequest, rw)
404	if err != nil {
405		return err
406	}
407	c.triggerConnect()
408	c.scheduleOnConnectTimers()
409	return nil
410}
411
412func (c *Client) onTimerOp() {
413	c.mu.Lock()
414	if c.status == statusClosed {
415		c.mu.Unlock()
416		return
417	}
418	timerOp := c.timerOp
419	c.mu.Unlock()
420	switch timerOp {
421	case timerOpStale:
422		c.closeUnauthenticated()
423	case timerOpPresence:
424		c.updatePresence()
425	case timerOpExpire:
426		c.expire()
427	}
428}
429
430// Lock must be held outside.
431func (c *Client) scheduleNextTimer() {
432	if c.status == statusClosed {
433		return
434	}
435	c.stopTimer()
436	var minEventTime int64
437	var nextTimerOp timerOp
438	var needTimer bool
439	if c.nextExpire > 0 {
440		nextTimerOp = timerOpExpire
441		minEventTime = c.nextExpire
442		needTimer = true
443	}
444	if c.nextPresence > 0 && (minEventTime == 0 || c.nextPresence < minEventTime) {
445		nextTimerOp = timerOpPresence
446		minEventTime = c.nextPresence
447		needTimer = true
448	}
449	if needTimer {
450		c.timerOp = nextTimerOp
451		afterDuration := time.Duration(minEventTime-time.Now().UnixNano()) * time.Nanosecond
452		c.timer = time.AfterFunc(afterDuration, c.onTimerOp)
453	}
454}
455
456// Lock must be held outside.
457func (c *Client) stopTimer() {
458	if c.timer != nil {
459		c.timer.Stop()
460	}
461}
462
463// Lock must be held outside.
464func (c *Client) addPresenceUpdate() {
465	config := c.node.config
466	presenceInterval := config.ClientPresenceUpdateInterval
467	c.nextPresence = time.Now().Add(presenceInterval).UnixNano()
468	c.scheduleNextTimer()
469}
470
471// Lock must be held outside.
472func (c *Client) addExpireUpdate(after time.Duration) {
473	c.nextExpire = time.Now().Add(after).UnixNano()
474	c.scheduleNextTimer()
475}
476
477// closeUnauthenticated closes connection if it's not authenticated yet.
478// At moment used to close client connections which have not sent valid
479// connect command in a reasonable time interval after established connection
480// with server.
481func (c *Client) closeUnauthenticated() {
482	c.mu.RLock()
483	authenticated := c.authenticated
484	closed := c.status == statusClosed
485	c.mu.RUnlock()
486	if !authenticated && !closed {
487		_ = c.close(DisconnectStale)
488	}
489}
490
491func (c *Client) transportEnqueue(reply *prepared.Reply) error {
492	var data []byte
493	if c.transport.Unidirectional() {
494		data = reply.Reply.Result
495	} else {
496		data = reply.Data()
497	}
498	c.trace("-->", data)
499	disconnect := c.messageWriter.enqueue(queue.Item{
500		Data:   data,
501		IsPush: reply.Reply.Id == 0,
502	})
503	if disconnect != nil {
504		// close in goroutine to not block message broadcast.
505		go func() { _ = c.close(disconnect) }()
506		return io.EOF
507	}
508	return nil
509}
510
511// updateChannelPresence updates client presence info for channel so it
512// won't expire until client disconnect.
513func (c *Client) updateChannelPresence(ch string, chCtx channelContext) error {
514	if !channelHasFlag(chCtx.flags, flagPresence) {
515		return nil
516	}
517	return c.node.addPresence(ch, c.uid, &ClientInfo{
518		ClientID: c.uid,
519		UserID:   c.user,
520		ConnInfo: c.info,
521		ChanInfo: chCtx.Info,
522	})
523}
524
525// Context returns client Context. This context will be canceled
526// as soon as client connection closes.
527func (c *Client) Context() context.Context {
528	return c.ctx
529}
530
531func (c *Client) checkSubscriptionExpiration(channel string, channelContext channelContext, delay time.Duration, resultCB func(bool)) {
532	now := c.node.nowTimeGetter().Unix()
533	expireAt := channelContext.expireAt
534	clientSideRefresh := channelHasFlag(channelContext.flags, flagClientSideRefresh)
535	if expireAt > 0 && now > expireAt+int64(delay.Seconds()) {
536		// Subscription expired.
537		if clientSideRefresh || c.eventHub.subRefreshHandler == nil {
538			// The only way subscription could be refreshed in this case is via
539			// SUB_REFRESH command sent from client but looks like that command
540			// with new refreshed token have not been received in configured window.
541			resultCB(false)
542			return
543		}
544		cb := func(reply SubRefreshReply, err error) {
545			if err != nil {
546				resultCB(false)
547				return
548			}
549			if reply.Expired || (reply.ExpireAt > 0 && reply.ExpireAt < now) {
550				resultCB(false)
551				return
552			}
553			c.mu.Lock()
554			if ctx, ok := c.channels[channel]; ok {
555				if len(reply.Info) > 0 {
556					ctx.Info = reply.Info
557				}
558				ctx.expireAt = reply.ExpireAt
559				c.channels[channel] = ctx
560			}
561			c.mu.Unlock()
562			resultCB(true)
563		}
564		// Give subscription a chance to be refreshed via SubRefreshHandler.
565		event := SubRefreshEvent{Channel: channel}
566		c.eventHub.subRefreshHandler(event, cb)
567		return
568	}
569	resultCB(true)
570}
571
572// updatePresence used for various periodic actions we need to do with client connections.
573func (c *Client) updatePresence() {
574	c.presenceMu.Lock()
575	defer c.presenceMu.Unlock()
576	config := c.node.config
577	c.mu.Lock()
578	if c.status == statusClosed {
579		c.mu.Unlock()
580		return
581	}
582	channels := make(map[string]channelContext, len(c.channels))
583	for channel, channelContext := range c.channels {
584		if !channelHasFlag(channelContext.flags, flagSubscribed) {
585			continue
586		}
587		channels[channel] = channelContext
588	}
589	c.mu.Unlock()
590	if c.eventHub.aliveHandler != nil {
591		c.eventHub.aliveHandler()
592	}
593	for channel, channelContext := range channels {
594		c.checkSubscriptionExpiration(channel, channelContext, config.ClientExpiredSubCloseDelay, func(result bool) {
595			// Ideally we should deal with single expired subscription in this
596			// case - i.e. unsubscribe client from channel and give an advice
597			// to resubscribe. But there is scenario when browser goes online
598			// after computer was in sleeping mode which I have not managed to
599			// handle reliably on client side when unsubscribe with resubscribe
600			// flag was used. So I decided to stick with disconnect for now -
601			// it seems to work fine and drastically simplifies client code.
602			if !result {
603				go func() { _ = c.close(DisconnectSubExpired) }()
604			}
605		})
606
607		checkDelay := config.ClientChannelPositionCheckDelay
608		if checkDelay > 0 && !c.checkPosition(checkDelay, channel, channelContext) {
609			go func() { _ = c.close(DisconnectInsufficientState) }()
610			// No need to proceed after close.
611			return
612		}
613
614		err := c.updateChannelPresence(channel, channelContext)
615		if err != nil {
616			c.node.logger.log(newLogEntry(LogLevelError, "error updating presence for channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
617		}
618	}
619	c.mu.Lock()
620	c.addPresenceUpdate()
621	c.mu.Unlock()
622}
623
624func (c *Client) checkPosition(checkDelay time.Duration, ch string, chCtx channelContext) bool {
625	if !channelHasFlag(chCtx.flags, flagRecover|flagPosition) {
626		return true
627	}
628	nowUnix := c.node.nowTimeGetter().Unix()
629
630	isInitialCheck := chCtx.positionCheckTime == 0
631	isTimeToCheck := nowUnix-chCtx.positionCheckTime > int64(checkDelay.Seconds())
632	needCheckPosition := isInitialCheck || isTimeToCheck
633
634	if !needCheckPosition {
635		return true
636	}
637	position := chCtx.streamPosition
638	streamTop, err := c.node.streamTop(ch)
639	if err != nil {
640		return true
641	}
642
643	isValidPosition := streamTop.Offset == position.Offset && streamTop.Epoch == position.Epoch
644	keepConnection := true
645	c.mu.Lock()
646	if chContext, ok := c.channels[ch]; ok {
647		chContext.positionCheckTime = nowUnix
648		if !isValidPosition {
649			chContext.positionCheckFailures++
650			keepConnection = chContext.positionCheckFailures < maxCheckPositionFailures
651		} else {
652			chContext.positionCheckFailures = 0
653		}
654		c.channels[ch] = chContext
655	}
656	c.mu.Unlock()
657	return keepConnection
658}
659
660// ID returns unique client connection id.
661func (c *Client) ID() string {
662	return c.uid
663}
664
665// UserID returns user id associated with client connection.
666func (c *Client) UserID() string {
667	return c.user
668}
669
670// Info returns connection info.
671func (c *Client) Info() []byte {
672	c.mu.Lock()
673	info := make([]byte, len(c.info))
674	copy(info, c.info)
675	c.mu.Unlock()
676	return info
677}
678
679// Transport returns client connection transport information.
680func (c *Client) Transport() TransportInfo {
681	return c.transport
682}
683
684// Channels returns a slice of channels client connection currently subscribed to.
685func (c *Client) Channels() []string {
686	c.mu.RLock()
687	defer c.mu.RUnlock()
688	channels := make([]string, 0, len(c.channels))
689	for ch, ctx := range c.channels {
690		if !channelHasFlag(ctx.flags, flagSubscribed) {
691			continue
692		}
693		channels = append(channels, ch)
694	}
695	return channels
696}
697
698// IsSubscribed returns true if client subscribed to a channel.
699func (c *Client) IsSubscribed(ch string) bool {
700	c.mu.RLock()
701	defer c.mu.RUnlock()
702	ctx, ok := c.channels[ch]
703	return ok && channelHasFlag(ctx.flags, flagSubscribed)
704}
705
706// Send data to client. This sends an asynchronous message – data will be
707// just written to connection. on client side this message can be handled
708// with Message handler.
709func (c *Client) Send(data []byte) error {
710	if hasFlag(c.transport.DisabledPushFlags(), PushFlagMessage) {
711		return nil
712	}
713	p := &protocol.Message{
714		Data: data,
715	}
716	pushBytes, err := protocol.EncodeMessagePush(c.transport.Protocol().toProto(), p)
717	if err != nil {
718		return err
719	}
720	reply := prepared.NewReply(&protocol.Reply{
721		Result: pushBytes,
722	}, c.transport.Protocol().toProto())
723	return c.transportEnqueue(reply)
724}
725
726// Unsubscribe allows to unsubscribe client from channel.
727func (c *Client) Unsubscribe(ch string) error {
728	c.mu.RLock()
729	if c.status == statusClosed {
730		c.mu.RUnlock()
731		return nil
732	}
733	c.mu.RUnlock()
734
735	err := c.unsubscribe(ch)
736	if err != nil {
737		return err
738	}
739	return c.sendUnsubscribe(ch)
740}
741
742func (c *Client) sendUnsubscribe(ch string) error {
743	if hasFlag(c.transport.DisabledPushFlags(), PushFlagUnsubscribe) {
744		return nil
745	}
746	pushBytes, err := protocol.EncodeUnsubscribePush(c.transport.Protocol().toProto(), ch, &protocol.Unsubscribe{})
747	if err != nil {
748		return err
749	}
750	reply := prepared.NewReply(&protocol.Reply{
751		Result: pushBytes,
752	}, c.transport.Protocol().toProto())
753
754	_ = c.transportEnqueue(reply)
755	return nil
756}
757
758// Disconnect client connection with specific disconnect code and reason.
759// This method internally creates a new goroutine at moment to do
760// closing stuff. An extra goroutine is required to solve disconnect
761// and alive callback ordering/sync problems. Will be a noop if client
762// already closed. As this method runs a separate goroutine client
763// connection will be closed eventually (i.e. not immediately).
764func (c *Client) Disconnect(disconnect *Disconnect) {
765	go func() {
766		_ = c.close(disconnect)
767	}()
768}
769
770func (c *Client) close(disconnect *Disconnect) error {
771	c.presenceMu.Lock()
772	defer c.presenceMu.Unlock()
773	c.connectMu.Lock()
774	defer c.connectMu.Unlock()
775	c.mu.Lock()
776	if c.status == statusClosed {
777		c.mu.Unlock()
778		return nil
779	}
780	prevStatus := c.status
781	c.status = statusClosed
782
783	c.stopTimer()
784
785	channels := make(map[string]channelContext, len(c.channels))
786	for channel, channelContext := range c.channels {
787		channels[channel] = channelContext
788	}
789	c.mu.Unlock()
790
791	if len(channels) > 0 {
792		// Unsubscribe from all channels.
793		for channel := range channels {
794			err := c.unsubscribe(channel)
795			if err != nil {
796				c.node.logger.log(newLogEntry(LogLevelError, "error unsubscribing client from channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
797			}
798		}
799	}
800
801	c.mu.RLock()
802	authenticated := c.authenticated
803	c.mu.RUnlock()
804
805	if authenticated {
806		err := c.node.removeClient(c)
807		if err != nil {
808			c.node.logger.log(newLogEntry(LogLevelError, "error removing client", map[string]interface{}{"user": c.user, "client": c.uid, "error": err.Error()}))
809		}
810	}
811
812	if disconnect != nil && !hasFlag(c.transport.DisabledPushFlags(), PushFlagDisconnect) {
813		if reply, err := c.encodeDisconnect(disconnect); err == nil {
814			_ = c.transportEnqueue(reply)
815		}
816	}
817
818	// close writer and send messages remaining in writer queue if any.
819	_ = c.messageWriter.close()
820
821	_ = c.transport.Close(disconnect)
822
823	if disconnect != nil && disconnect.Reason != "" {
824		c.node.logger.log(newLogEntry(LogLevelDebug, "closing client connection", map[string]interface{}{"client": c.uid, "user": c.user, "reason": disconnect.Reason, "reconnect": disconnect.Reconnect}))
825	}
826	if disconnect != nil {
827		incServerDisconnect(disconnect.Code)
828	}
829	if c.eventHub.disconnectHandler != nil && prevStatus == statusConnected {
830		c.eventHub.disconnectHandler(DisconnectEvent{
831			Disconnect: disconnect,
832		})
833	}
834	return nil
835}
836
837func (c *Client) trace(msg string, data []byte) {
838	if !c.node.LogEnabled(LogLevelTrace) {
839		return
840	}
841	c.mu.RLock()
842	user := c.user
843	c.mu.RUnlock()
844	c.node.logger.log(newLogEntry(LogLevelTrace, msg, map[string]interface{}{"client": c.ID(), "user": user, "data": fmt.Sprintf("%#v", string(data))}))
845}
846
847// Lock must be held outside.
848func (c *Client) clientInfo(ch string) *ClientInfo {
849	var channelInfo protocol.Raw
850	channelContext, ok := c.channels[ch]
851	if ok && channelHasFlag(channelContext.flags, flagSubscribed) {
852		channelInfo = channelContext.Info
853	}
854	return &ClientInfo{
855		ClientID: c.uid,
856		UserID:   c.user,
857		ConnInfo: c.info,
858		ChanInfo: channelInfo,
859	}
860}
861
862// Handle raw data encoded with Centrifuge protocol.
863// Not goroutine-safe. Supposed to be called only from a transport connection reader.
864func (c *Client) Handle(data []byte) bool {
865	c.mu.Lock()
866	if c.status == statusClosed {
867		c.mu.Unlock()
868		return false
869	}
870	c.mu.Unlock()
871
872	if c.transport.Unidirectional() {
873		c.node.logger.log(newLogEntry(LogLevelInfo, "can't handle data for unidirectional client", map[string]interface{}{"client": c.ID(), "user": c.UserID()}))
874		go func() { _ = c.close(DisconnectBadRequest) }()
875		return false
876	}
877
878	if len(data) == 0 {
879		c.node.logger.log(newLogEntry(LogLevelInfo, "empty client request received", map[string]interface{}{"client": c.ID(), "user": c.UserID()}))
880		go func() { _ = c.close(DisconnectBadRequest) }()
881		return false
882	}
883
884	c.trace("<--", data)
885
886	protoType := c.transport.Protocol().toProto()
887	decoder := protocol.GetCommandDecoder(protoType, data)
888	defer protocol.PutCommandDecoder(protoType, decoder)
889
890	for {
891		cmd, err := decoder.Decode()
892		if err != nil && err != io.EOF {
893			c.node.logger.log(newLogEntry(LogLevelInfo, "error decoding command", map[string]interface{}{"data": string(data), "client": c.ID(), "user": c.UserID(), "error": err.Error()}))
894			go func() { _ = c.close(DisconnectBadRequest) }()
895			return false
896		}
897		if cmd != nil {
898			ok := c.handleCommand(cmd)
899			if !ok {
900				return false
901			}
902		}
903		if err == io.EOF {
904			break
905		}
906	}
907	return true
908}
909
910// handleCommand processes a single protocol.Command.
911func (c *Client) handleCommand(cmd *protocol.Command) bool {
912	if cmd.Method != protocol.Command_CONNECT && !c.authenticated {
913		// Client must send connect command to authenticate itself first.
914		c.node.logger.log(newLogEntry(LogLevelInfo, "client not authenticated to handle command", map[string]interface{}{"client": c.ID(), "user": c.UserID(), "command": fmt.Sprintf("%v", cmd)}))
915		go func() { _ = c.close(DisconnectBadRequest) }()
916		return false
917	}
918
919	if cmd.Id == 0 && cmd.Method != protocol.Command_SEND {
920		// Only send command from client can be sent without incremental ID.
921		c.node.logger.log(newLogEntry(LogLevelInfo, "command ID required for commands with reply expected", map[string]interface{}{"client": c.ID(), "user": c.UserID()}))
922		go func() { _ = c.close(DisconnectBadRequest) }()
923		return false
924	}
925
926	select {
927	case <-c.ctx.Done():
928		return false
929	default:
930	}
931
932	disconnect := c.dispatchCommand(cmd)
933
934	select {
935	case <-c.ctx.Done():
936		return false
937	default:
938	}
939	if disconnect != nil {
940		if disconnect != DisconnectNormal {
941			c.node.logger.log(newLogEntry(LogLevelInfo, "disconnect after handling command", map[string]interface{}{"command": fmt.Sprintf("%v", cmd), "client": c.ID(), "user": c.UserID(), "reason": disconnect.Reason}))
942		}
943		go func() { _ = c.close(disconnect) }()
944		return false
945	}
946	return true
947}
948
949type replyWriter struct {
950	write func(*protocol.Reply) error
951	done  func()
952}
953
954// dispatchCommand dispatches Command into correct command handler.
955func (c *Client) dispatchCommand(cmd *protocol.Command) *Disconnect {
956	c.mu.Lock()
957	if c.status == statusClosed {
958		c.mu.Unlock()
959		return nil
960	}
961	c.mu.Unlock()
962
963	method := cmd.Method
964	params := cmd.Params
965
966	protoType := c.transport.Protocol().toProto()
967	replyEncoder := protocol.GetReplyEncoder(protoType)
968
969	var encodeErr error
970
971	started := time.Now()
972
973	write := func(rep *protocol.Reply) error {
974		rep.Id = cmd.Id
975		if rep.Error != nil {
976			if c.node.LogEnabled(LogLevelInfo) {
977				c.node.logger.log(newLogEntry(LogLevelInfo, "client command error", map[string]interface{}{"reply": fmt.Sprintf("%v", rep), "command": fmt.Sprintf("%v", cmd), "client": c.ID(), "user": c.UserID(), "error": rep.Error.Message, "code": rep.Error.Code}))
978			}
979			incReplyError(cmd.Method, rep.Error.Code)
980		}
981
982		var replyData []byte
983		replyData, encodeErr = replyEncoder.Encode(rep)
984		if encodeErr != nil {
985			c.node.logger.log(newLogEntry(LogLevelError, "error encoding reply", map[string]interface{}{"reply": fmt.Sprintf("%v", rep), "client": c.ID(), "user": c.UserID(), "error": encodeErr.Error()}))
986			return encodeErr
987		}
988		c.trace("-->", replyData)
989		disconnect := c.messageWriter.enqueue(queue.Item{Data: replyData, IsPush: false})
990		if disconnect != nil {
991			if c.node.logger.enabled(LogLevelDebug) {
992				c.node.logger.log(newLogEntry(LogLevelDebug, "disconnect after sending reply", map[string]interface{}{"client": c.ID(), "user": c.UserID(), "reason": disconnect.Reason}))
993			}
994			go func() { _ = c.close(disconnect) }()
995		}
996		return disconnect
997	}
998
999	// done should be called after command fully processed.
1000	done := func() {
1001		observeCommandDuration(method, time.Since(started))
1002	}
1003
1004	// The rule is as follows: if command handler returns an
1005	// error then we handle it here: write error into connection
1006	// or return disconnect further to caller and call rw.done()
1007	// in the end.
1008	// If handler returned nil error then we assume that all
1009	// rw operations will be executed inside handler itself.
1010	rw := &replyWriter{write, done}
1011
1012	var handleErr error
1013
1014	switch method {
1015	case protocol.Command_CONNECT:
1016		handleErr = c.handleConnect(params, rw)
1017	case protocol.Command_PING:
1018		handleErr = c.handlePing(params, rw)
1019	case protocol.Command_SUBSCRIBE:
1020		handleErr = c.handleSubscribe(params, rw)
1021	case protocol.Command_UNSUBSCRIBE:
1022		handleErr = c.handleUnsubscribe(params, rw)
1023	case protocol.Command_PUBLISH:
1024		handleErr = c.handlePublish(params, rw)
1025	case protocol.Command_PRESENCE:
1026		handleErr = c.handlePresence(params, rw)
1027	case protocol.Command_PRESENCE_STATS:
1028		handleErr = c.handlePresenceStats(params, rw)
1029	case protocol.Command_HISTORY:
1030		handleErr = c.handleHistory(params, rw)
1031	case protocol.Command_RPC:
1032		handleErr = c.handleRPC(params, rw)
1033	case protocol.Command_SEND:
1034		handleErr = c.handleSend(params, rw)
1035	case protocol.Command_REFRESH:
1036		handleErr = c.handleRefresh(params, rw)
1037	case protocol.Command_SUB_REFRESH:
1038		handleErr = c.handleSubRefresh(params, rw)
1039	default:
1040		handleErr = ErrorMethodNotFound
1041	}
1042	if encodeErr != nil {
1043		return DisconnectServerError
1044	}
1045	if handleErr != nil {
1046		defer rw.done()
1047		switch t := handleErr.(type) {
1048		case *Disconnect:
1049			return t
1050		default:
1051			c.writeError(rw, toClientErr(handleErr))
1052		}
1053	}
1054	return nil
1055}
1056
1057func (c *Client) checkExpired() {
1058	c.mu.RLock()
1059	closed := c.status == statusClosed
1060	clientSideRefresh := c.clientSideRefresh
1061	exp := c.exp
1062	c.mu.RUnlock()
1063	if closed || exp == 0 {
1064		return
1065	}
1066	now := time.Now().Unix()
1067	ttl := exp - now
1068
1069	if !clientSideRefresh && c.eventHub.refreshHandler != nil {
1070		if ttl > 0 {
1071			c.mu.Lock()
1072			if c.status != statusClosed {
1073				c.addExpireUpdate(time.Duration(ttl) * time.Second)
1074			}
1075			c.mu.Unlock()
1076		}
1077	}
1078
1079	if ttl > 0 {
1080		// Connection was successfully refreshed.
1081		return
1082	}
1083
1084	_ = c.close(DisconnectExpired)
1085}
1086
1087func (c *Client) expire() {
1088	c.mu.RLock()
1089	closed := c.status == statusClosed
1090	clientSideRefresh := c.clientSideRefresh
1091	exp := c.exp
1092	c.mu.RUnlock()
1093	if closed || exp == 0 {
1094		return
1095	}
1096	if !clientSideRefresh && c.eventHub.refreshHandler != nil {
1097		cb := func(reply RefreshReply, err error) {
1098			if err != nil {
1099				switch t := err.(type) {
1100				case *Disconnect:
1101					_ = c.close(t)
1102					return
1103				default:
1104					_ = c.close(DisconnectServerError)
1105					return
1106				}
1107			}
1108			if reply.Expired {
1109				_ = c.close(DisconnectExpired)
1110				return
1111			}
1112			if reply.ExpireAt > 0 {
1113				c.mu.Lock()
1114				c.exp = reply.ExpireAt
1115				if reply.Info != nil {
1116					c.info = reply.Info
1117				}
1118				c.mu.Unlock()
1119			}
1120			c.checkExpired()
1121		}
1122		c.eventHub.refreshHandler(RefreshEvent{}, cb)
1123	} else {
1124		c.checkExpired()
1125	}
1126}
1127
1128func (c *Client) handleConnect(params protocol.Raw, rw *replyWriter) error {
1129	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeConnect(params)
1130	if err != nil {
1131		return c.logDisconnectBadRequestWithError(err, "error decoding connect")
1132	}
1133	_, disconnect := c.connectCmd(cmd, rw)
1134	if disconnect != nil {
1135		return disconnect
1136	}
1137	c.triggerConnect()
1138	c.scheduleOnConnectTimers()
1139	return nil
1140}
1141
1142func (c *Client) triggerConnect() {
1143	c.connectMu.Lock()
1144	defer c.connectMu.Unlock()
1145	if c.status != statusConnecting {
1146		return
1147	}
1148	if c.node.clientEvents.connectHandler == nil {
1149		c.status = statusConnected
1150		return
1151	}
1152	c.node.clientEvents.connectHandler(c)
1153	c.status = statusConnected
1154}
1155
1156func (c *Client) scheduleOnConnectTimers() {
1157	// Make presence and refresh handlers always run after client connect event.
1158	c.mu.Lock()
1159	c.addPresenceUpdate()
1160	if c.exp > 0 {
1161		expireAfter := time.Duration(c.exp-time.Now().Unix()) * time.Second
1162		if c.clientSideRefresh {
1163			conf := c.node.config
1164			expireAfter += conf.ClientExpiredCloseDelay
1165		}
1166		c.addExpireUpdate(expireAfter)
1167	}
1168	c.mu.Unlock()
1169}
1170
1171func (c *Client) Refresh(opts ...RefreshOption) error {
1172	refreshOptions := &RefreshOptions{}
1173	for _, opt := range opts {
1174		opt(refreshOptions)
1175	}
1176	if refreshOptions.Expired {
1177		go func() { _ = c.close(DisconnectExpired) }()
1178		return nil
1179	}
1180
1181	expireAt := refreshOptions.ExpireAt
1182	info := refreshOptions.Info
1183
1184	res := &protocol.Refresh{
1185		Expires: expireAt > 0,
1186	}
1187
1188	ttl := expireAt - time.Now().Unix()
1189
1190	if ttl > 0 {
1191		res.Ttl = uint32(ttl)
1192	}
1193
1194	if expireAt > 0 {
1195		// connection check enabled
1196		if ttl > 0 {
1197			// connection refreshed, update client timestamp and set new expiration timeout
1198			c.mu.Lock()
1199			c.exp = expireAt
1200			if len(info) > 0 {
1201				c.info = info
1202			}
1203			duration := time.Duration(ttl)*time.Second + c.node.config.ClientExpiredCloseDelay
1204			c.addExpireUpdate(duration)
1205			c.mu.Unlock()
1206		} else {
1207			go func() { _ = c.close(DisconnectExpired) }()
1208			return nil
1209		}
1210	} else {
1211		c.mu.Lock()
1212		c.exp = 0
1213		c.mu.Unlock()
1214	}
1215
1216	pushBytes, err := protocol.EncodeRefreshPush(c.transport.Protocol().toProto(), res)
1217	if err != nil {
1218		return err
1219	}
1220	reply := prepared.NewReply(&protocol.Reply{
1221		Result: pushBytes,
1222	}, c.transport.Protocol().toProto())
1223
1224	return c.transportEnqueue(reply)
1225}
1226
1227func (c *Client) handleRefresh(params protocol.Raw, rw *replyWriter) error {
1228	if c.eventHub.refreshHandler == nil {
1229		return ErrorNotAvailable
1230	}
1231
1232	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeRefresh(params)
1233	if err != nil {
1234		return c.logDisconnectBadRequestWithError(err, "error decoding refresh")
1235	}
1236
1237	if cmd.Token == "" {
1238		return c.logDisconnectBadRequest("client token required to refresh")
1239	}
1240
1241	c.mu.RLock()
1242	clientSideRefresh := c.clientSideRefresh
1243	c.mu.RUnlock()
1244
1245	if !clientSideRefresh {
1246		// Client not supposed to send refresh command in case of server-side refresh mechanism.
1247		return c.logDisconnectBadRequest("server-side refresh expected")
1248	}
1249
1250	event := RefreshEvent{
1251		ClientSideRefresh: true,
1252		Token:             cmd.Token,
1253	}
1254
1255	cb := func(reply RefreshReply, err error) {
1256		defer rw.done()
1257
1258		if err != nil {
1259			c.writeDisconnectOrErrorFlush(rw, err)
1260			return
1261		}
1262
1263		if reply.Expired {
1264			c.Disconnect(DisconnectExpired)
1265			return
1266		}
1267
1268		expireAt := reply.ExpireAt
1269		info := reply.Info
1270
1271		res := &protocol.RefreshResult{
1272			Version: c.node.config.Version,
1273			Expires: expireAt > 0,
1274			Client:  c.uid,
1275		}
1276
1277		ttl := expireAt - time.Now().Unix()
1278
1279		if ttl > 0 {
1280			res.Ttl = uint32(ttl)
1281		}
1282
1283		if expireAt > 0 {
1284			// connection check enabled
1285			if ttl > 0 {
1286				// connection refreshed, update client timestamp and set new expiration timeout
1287				c.mu.Lock()
1288				c.exp = expireAt
1289				if len(info) > 0 {
1290					c.info = info
1291				}
1292				duration := time.Duration(ttl)*time.Second + c.node.config.ClientExpiredCloseDelay
1293				c.addExpireUpdate(duration)
1294				c.mu.Unlock()
1295			} else {
1296				c.writeError(rw, ErrorExpired)
1297				return
1298			}
1299		}
1300
1301		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeRefreshResult(res)
1302		if err != nil {
1303			c.logWriteInternalErrorFlush(rw, err, "error encoding refresh")
1304			return
1305		}
1306
1307		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1308	}
1309
1310	c.eventHub.refreshHandler(event, cb)
1311	return nil
1312}
1313
1314// onSubscribeError cleans up a channel from client channels if an error during subscribe happened.
1315// Channel kept in a map during subscribe request to check for duplicate subscription attempts.
1316func (c *Client) onSubscribeError(channel string) {
1317	c.mu.Lock()
1318	_, ok := c.channels[channel]
1319	delete(c.channels, channel)
1320	c.mu.Unlock()
1321	if ok {
1322		_ = c.node.removeSubscription(channel, c)
1323	}
1324}
1325
1326func (c *Client) handleSubscribe(params protocol.Raw, rw *replyWriter) error {
1327	if c.eventHub.subscribeHandler == nil {
1328		return ErrorNotAvailable
1329	}
1330
1331	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeSubscribe(params)
1332	if err != nil {
1333		return c.logDisconnectBadRequestWithError(err, "error decoding subscribe")
1334	}
1335
1336	replyError, disconnect := c.validateSubscribeRequest(cmd)
1337	if disconnect != nil || replyError != nil {
1338		if disconnect != nil {
1339			return disconnect
1340		}
1341		return replyError
1342	}
1343
1344	event := SubscribeEvent{
1345		Channel: cmd.Channel,
1346		Token:   cmd.Token,
1347	}
1348
1349	cb := func(reply SubscribeReply, err error) {
1350		defer rw.done()
1351
1352		if err != nil {
1353			c.onSubscribeError(cmd.Channel)
1354			c.writeDisconnectOrErrorFlush(rw, err)
1355			return
1356		}
1357
1358		ctx := c.subscribeCmd(cmd, reply, rw, false)
1359
1360		if ctx.disconnect != nil {
1361			c.onSubscribeError(cmd.Channel)
1362			c.Disconnect(ctx.disconnect)
1363			return
1364		}
1365		if ctx.err != nil {
1366			c.onSubscribeError(cmd.Channel)
1367			c.writeDisconnectOrErrorFlush(rw, ctx.err)
1368			return
1369		}
1370
1371		if channelHasFlag(ctx.channelContext.flags, flagJoinLeave) && ctx.clientInfo != nil {
1372			go func() { _ = c.node.publishJoin(cmd.Channel, ctx.clientInfo) }()
1373		}
1374	}
1375	c.eventHub.subscribeHandler(event, cb)
1376	return nil
1377}
1378
1379func (c *Client) getSubscribedChannelContext(channel string) (channelContext, bool) {
1380	c.mu.RLock()
1381	ctx, okChannel := c.channels[channel]
1382	c.mu.RUnlock()
1383	if !okChannel || !channelHasFlag(ctx.flags, flagSubscribed) {
1384		return channelContext{}, false
1385	}
1386	return ctx, true
1387}
1388
1389func (c *Client) handleSubRefresh(params protocol.Raw, rw *replyWriter) error {
1390	if c.eventHub.subRefreshHandler == nil {
1391		return ErrorNotAvailable
1392	}
1393
1394	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeSubRefresh(params)
1395	if err != nil {
1396		return c.logDisconnectBadRequestWithError(err, "error decoding sub refresh")
1397	}
1398
1399	channel := cmd.Channel
1400	if channel == "" {
1401		return c.logDisconnectBadRequest("channel required for sub refresh")
1402	}
1403
1404	ctx, okChannel := c.getSubscribedChannelContext(channel)
1405	if !okChannel {
1406		// Must be subscribed to refresh subscription.
1407		return ErrorPermissionDenied
1408	}
1409
1410	clientSideRefresh := channelHasFlag(ctx.flags, flagClientSideRefresh)
1411	if !clientSideRefresh {
1412		// Client not supposed to send sub refresh command in case of server-side
1413		// subscription refresh mechanism.
1414		return c.logDisconnectBadRequest("server-side sub refresh expected")
1415	}
1416
1417	if cmd.Token == "" {
1418		c.node.logger.log(newLogEntry(LogLevelInfo, "subscription refresh token required", map[string]interface{}{"client": c.uid, "user": c.UserID()}))
1419		return ErrorBadRequest
1420	}
1421
1422	event := SubRefreshEvent{
1423		ClientSideRefresh: true,
1424		Channel:           cmd.Channel,
1425		Token:             cmd.Token,
1426	}
1427
1428	cb := func(reply SubRefreshReply, err error) {
1429		defer rw.done()
1430
1431		if err != nil {
1432			c.writeDisconnectOrErrorFlush(rw, err)
1433			return
1434		}
1435
1436		res := &protocol.SubRefreshResult{}
1437
1438		if reply.ExpireAt > 0 {
1439			res.Expires = true
1440			now := time.Now().Unix()
1441			if reply.ExpireAt < now {
1442				c.writeError(rw, ErrorExpired)
1443				return
1444			}
1445			res.Ttl = uint32(reply.ExpireAt - now)
1446		}
1447
1448		c.mu.Lock()
1449		channelContext, okChan := c.channels[channel]
1450		if okChan && channelHasFlag(channelContext.flags, flagSubscribed) {
1451			channelContext.Info = reply.Info
1452			channelContext.expireAt = reply.ExpireAt
1453			c.channels[channel] = channelContext
1454		}
1455		c.mu.Unlock()
1456
1457		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeSubRefreshResult(res)
1458		if err != nil {
1459			c.logWriteInternalErrorFlush(rw, err, "error encoding sub refresh")
1460			return
1461		}
1462		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1463	}
1464
1465	c.eventHub.subRefreshHandler(event, cb)
1466	return nil
1467}
1468
1469func (c *Client) handleUnsubscribe(params protocol.Raw, rw *replyWriter) error {
1470	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeUnsubscribe(params)
1471	if err != nil {
1472		return c.logDisconnectBadRequestWithError(err, "error decoding unsubscribe")
1473	}
1474
1475	channel := cmd.Channel
1476	if channel == "" {
1477		return c.logDisconnectBadRequest("channel required for unsubscribe")
1478	}
1479
1480	if err := c.unsubscribe(channel); err != nil {
1481		return err
1482	}
1483
1484	replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeUnsubscribeResult(&protocol.UnsubscribeResult{})
1485	if err != nil {
1486		c.node.logger.log(newLogEntry(LogLevelError, "error encoding unsubscribe", map[string]interface{}{"error": err.Error()}))
1487		return DisconnectServerError
1488	}
1489
1490	_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1491	rw.done()
1492	return nil
1493}
1494
1495func (c *Client) handlePublish(params protocol.Raw, rw *replyWriter) error {
1496	if c.eventHub.publishHandler == nil {
1497		return ErrorNotAvailable
1498	}
1499
1500	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePublish(params)
1501	if err != nil {
1502		return c.logDisconnectBadRequestWithError(err, "error decoding publish")
1503	}
1504
1505	channel := cmd.Channel
1506	data := cmd.Data
1507
1508	if channel == "" || len(data) == 0 {
1509		return c.logDisconnectBadRequest("channel and data required for publish")
1510	}
1511
1512	c.mu.RLock()
1513	info := c.clientInfo(channel)
1514	c.mu.RUnlock()
1515
1516	event := PublishEvent{
1517		Channel:    channel,
1518		Data:       data,
1519		ClientInfo: info,
1520	}
1521
1522	cb := func(reply PublishReply, err error) {
1523		defer rw.done()
1524
1525		if err != nil {
1526			c.writeDisconnectOrErrorFlush(rw, err)
1527			return
1528		}
1529
1530		if reply.Result == nil {
1531			_, err := c.node.Publish(
1532				event.Channel, event.Data,
1533				WithHistory(reply.Options.HistorySize, reply.Options.HistoryTTL),
1534				WithClientInfo(reply.Options.ClientInfo),
1535			)
1536			if err != nil {
1537				c.logWriteInternalErrorFlush(rw, err, "error publish")
1538				return
1539			}
1540		}
1541
1542		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodePublishResult(&protocol.PublishResult{})
1543		if err != nil {
1544			c.logWriteInternalErrorFlush(rw, err, "error encoding publish")
1545			return
1546		}
1547		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1548	}
1549
1550	c.eventHub.publishHandler(event, cb)
1551	return nil
1552}
1553
1554func (c *Client) handlePresence(params protocol.Raw, rw *replyWriter) error {
1555	if c.eventHub.presenceHandler == nil {
1556		return ErrorNotAvailable
1557	}
1558
1559	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePresence(params)
1560	if err != nil {
1561		return c.logDisconnectBadRequestWithError(err, "error decoding presence")
1562	}
1563
1564	channel := cmd.Channel
1565	if channel == "" {
1566		return c.logDisconnectBadRequest("channel required for presence")
1567	}
1568
1569	event := PresenceEvent{
1570		Channel: channel,
1571	}
1572
1573	cb := func(reply PresenceReply, err error) {
1574		defer rw.done()
1575		if err != nil {
1576			c.writeDisconnectOrErrorFlush(rw, err)
1577			return
1578		}
1579
1580		var presence map[string]*ClientInfo
1581		if reply.Result == nil {
1582			result, err := c.node.Presence(event.Channel)
1583			if err != nil {
1584				c.logWriteInternalErrorFlush(rw, err, "error getting presence")
1585				return
1586			}
1587			presence = result.Presence
1588		} else {
1589			presence = reply.Result.Presence
1590		}
1591
1592		protoPresence := make(map[string]*protocol.ClientInfo, len(presence))
1593		for k, v := range presence {
1594			protoPresence[k] = infoToProto(v)
1595		}
1596
1597		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodePresenceResult(&protocol.PresenceResult{
1598			Presence: protoPresence,
1599		})
1600		if err != nil {
1601			c.logWriteInternalErrorFlush(rw, err, "error encoding presence")
1602			return
1603		}
1604		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1605	}
1606
1607	c.eventHub.presenceHandler(event, cb)
1608	return nil
1609}
1610
1611func (c *Client) handlePresenceStats(params protocol.Raw, rw *replyWriter) error {
1612	if c.eventHub.presenceStatsHandler == nil {
1613		return ErrorNotAvailable
1614	}
1615
1616	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePresenceStats(params)
1617	if err != nil {
1618		return c.logDisconnectBadRequestWithError(err, "error decoding presence stats")
1619	}
1620
1621	channel := cmd.Channel
1622	if channel == "" {
1623		return c.logDisconnectBadRequest("channel required for presence stats")
1624	}
1625
1626	event := PresenceStatsEvent{
1627		Channel: channel,
1628	}
1629
1630	cb := func(reply PresenceStatsReply, err error) {
1631		defer rw.done()
1632		if err != nil {
1633			c.writeDisconnectOrErrorFlush(rw, err)
1634			return
1635		}
1636
1637		var presenceStats PresenceStats
1638		if reply.Result == nil {
1639			result, err := c.node.PresenceStats(event.Channel)
1640			if err != nil {
1641				c.logWriteInternalErrorFlush(rw, err, "error getting presence stats")
1642				return
1643			}
1644			presenceStats = result.PresenceStats
1645		} else {
1646			presenceStats = reply.Result.PresenceStats
1647		}
1648
1649		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodePresenceStatsResult(&protocol.PresenceStatsResult{
1650			NumClients: uint32(presenceStats.NumClients),
1651			NumUsers:   uint32(presenceStats.NumUsers),
1652		})
1653		if err != nil {
1654			c.logWriteInternalErrorFlush(rw, err, "error encoding presence stats")
1655			return
1656		}
1657		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1658	}
1659
1660	c.eventHub.presenceStatsHandler(event, cb)
1661	return nil
1662}
1663
1664func (c *Client) handleHistory(params protocol.Raw, rw *replyWriter) error {
1665	if c.eventHub.historyHandler == nil {
1666		return ErrorNotAvailable
1667	}
1668
1669	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeHistory(params)
1670	if err != nil {
1671		return c.logDisconnectBadRequestWithError(err, "error decoding history")
1672	}
1673
1674	channel := cmd.Channel
1675	if channel == "" {
1676		return c.logDisconnectBadRequest("channel required for history")
1677	}
1678
1679	var filter HistoryFilter
1680	if cmd.Since != nil {
1681		filter.Since = &StreamPosition{
1682			Offset: cmd.Since.Offset,
1683			Epoch:  cmd.Since.Epoch,
1684		}
1685	}
1686	filter.Limit = int(cmd.Limit)
1687
1688	maxPublicationLimit := c.node.config.HistoryMaxPublicationLimit
1689	if maxPublicationLimit > 0 && (filter.Limit < 0 || filter.Limit > maxPublicationLimit) {
1690		filter.Limit = maxPublicationLimit
1691	}
1692
1693	filter.Reverse = cmd.Reverse
1694
1695	event := HistoryEvent{
1696		Channel: channel,
1697		Filter:  filter,
1698	}
1699
1700	cb := func(reply HistoryReply, err error) {
1701		defer rw.done()
1702		if err != nil {
1703			c.writeDisconnectOrErrorFlush(rw, err)
1704			return
1705		}
1706
1707		var pubs []*Publication
1708		var offset uint64
1709		var epoch string
1710		if reply.Result == nil {
1711			result, err := c.node.History(event.Channel, WithLimit(event.Filter.Limit), WithSince(event.Filter.Since), WithReverse(event.Filter.Reverse))
1712			if err != nil {
1713				c.logWriteInternalErrorFlush(rw, err, "error getting history")
1714				return
1715			}
1716			pubs = result.Publications
1717			offset = result.Offset
1718			epoch = result.Epoch
1719		} else {
1720			pubs = reply.Result.Publications
1721			offset = reply.Result.Offset
1722			epoch = reply.Result.Epoch
1723		}
1724
1725		protoPubs := make([]*protocol.Publication, 0, len(pubs))
1726		for _, pub := range pubs {
1727			protoPub := pubToProto(pub)
1728			protoPubs = append(protoPubs, protoPub)
1729		}
1730
1731		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeHistoryResult(&protocol.HistoryResult{
1732			Publications: protoPubs,
1733			Offset:       offset,
1734			Epoch:        epoch,
1735		})
1736		if err != nil {
1737			c.logWriteInternalErrorFlush(rw, err, "error encoding history")
1738			return
1739		}
1740		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1741	}
1742
1743	c.eventHub.historyHandler(event, cb)
1744	return nil
1745}
1746
1747func (c *Client) handlePing(params protocol.Raw, rw *replyWriter) error {
1748	_, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePing(params)
1749	if err != nil {
1750		return c.logDisconnectBadRequestWithError(err, "error decoding ping")
1751	}
1752	_ = writeReply(rw, &protocol.Reply{})
1753	defer rw.done()
1754	return nil
1755}
1756
1757func (c *Client) writeError(rw *replyWriter, error *Error) {
1758	_ = rw.write(&protocol.Reply{Error: error.toProto()})
1759}
1760
1761func (c *Client) writeDisconnectOrErrorFlush(rw *replyWriter, replyError error) {
1762	switch t := replyError.(type) {
1763	case *Disconnect:
1764		go func() { _ = c.close(t) }()
1765		return
1766	default:
1767		c.writeError(rw, toClientErr(replyError))
1768	}
1769}
1770
1771func writeReply(rw *replyWriter, reply *protocol.Reply) error {
1772	return rw.write(reply)
1773}
1774
1775func (c *Client) handleRPC(params protocol.Raw, rw *replyWriter) error {
1776	if c.eventHub.rpcHandler == nil {
1777		return ErrorNotAvailable
1778	}
1779	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeRPC(params)
1780	if err != nil {
1781		return c.logDisconnectBadRequestWithError(err, "error decoding rpc")
1782	}
1783
1784	event := RPCEvent{
1785		Method: cmd.Method,
1786		Data:   cmd.Data,
1787	}
1788
1789	cb := func(reply RPCReply, err error) {
1790		defer rw.done()
1791		if err != nil {
1792			c.writeDisconnectOrErrorFlush(rw, err)
1793			return
1794		}
1795		result := &protocol.RPCResult{
1796			Data: reply.Data,
1797		}
1798		var replyRes []byte
1799		replyRes, err = protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeRPCResult(result)
1800		if err != nil {
1801			c.logWriteInternalErrorFlush(rw, err, "error encoding rpc")
1802			return
1803		}
1804		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
1805	}
1806
1807	c.eventHub.rpcHandler(event, cb)
1808	return nil
1809}
1810
1811func (c *Client) handleSend(params protocol.Raw, rw *replyWriter) error {
1812	if c.eventHub.messageHandler == nil {
1813		// send handler is a bit special since it is only one way
1814		// request: client does not expect any reply.
1815		rw.done()
1816		return nil
1817	}
1818	cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeSend(params)
1819	if err != nil {
1820		return c.logDisconnectBadRequestWithError(err, "error decoding message")
1821	}
1822	defer rw.done()
1823	c.eventHub.messageHandler(MessageEvent{
1824		Data: cmd.Data,
1825	})
1826	return nil
1827}
1828
1829func (c *Client) unlockServerSideSubscriptions(subCtxMap map[string]subscribeContext) {
1830	for channel := range subCtxMap {
1831		c.pubSubSync.StopBuffering(channel)
1832	}
1833}
1834
1835// connectCmd handles connect command from client - client must send connect
1836// command immediately after establishing connection with server.
1837func (c *Client) connectCmd(cmd *protocol.ConnectRequest, rw *replyWriter) (*protocol.ConnectResult, error) {
1838	c.mu.RLock()
1839	authenticated := c.authenticated
1840	closed := c.status == statusClosed
1841	c.mu.RUnlock()
1842
1843	if closed {
1844		return nil, DisconnectNormal
1845	}
1846
1847	if authenticated {
1848		return nil, c.logDisconnectBadRequest("client already authenticated")
1849	}
1850
1851	config := c.node.config
1852	version := config.Version
1853	userConnectionLimit := config.UserConnectionLimit
1854	channelLimit := config.ClientChannelLimit
1855
1856	var (
1857		credentials       *Credentials
1858		authData          protocol.Raw
1859		subscriptions     map[string]SubscribeOptions
1860		clientSideRefresh bool
1861	)
1862
1863	if c.node.clientEvents.connectingHandler != nil {
1864		e := ConnectEvent{
1865			ClientID:  c.ID(),
1866			Data:      cmd.Data,
1867			Token:     cmd.Token,
1868			Name:      cmd.Name,
1869			Version:   cmd.Version,
1870			Transport: c.transport,
1871		}
1872		if len(cmd.Subs) > 0 {
1873			channels := make([]string, 0, len(cmd.Subs))
1874			for ch := range cmd.Subs {
1875				channels = append(channels, ch)
1876			}
1877			e.Channels = channels
1878		}
1879		reply, err := c.node.clientEvents.connectingHandler(c.ctx, e)
1880		if err != nil {
1881			return nil, err
1882		}
1883		if reply.Credentials != nil {
1884			credentials = reply.Credentials
1885		}
1886		if reply.Context != nil {
1887			c.mu.Lock()
1888			c.ctx = reply.Context
1889			c.mu.Unlock()
1890		}
1891		if reply.Data != nil {
1892			authData = reply.Data
1893		}
1894		clientSideRefresh = reply.ClientSideRefresh
1895		if len(reply.Subscriptions) > 0 {
1896			subscriptions = make(map[string]SubscribeOptions, len(reply.Subscriptions))
1897			for ch, opts := range reply.Subscriptions {
1898				if ch == "" {
1899					continue
1900				}
1901				subscriptions[ch] = opts
1902			}
1903		}
1904	}
1905
1906	if channelLimit > 0 && len(subscriptions) > channelLimit {
1907		return nil, DisconnectChannelLimit
1908	}
1909
1910	if credentials == nil {
1911		// Try to find Credentials in context.
1912		if cred, ok := GetCredentials(c.ctx); ok {
1913			credentials = cred
1914		}
1915	}
1916
1917	var (
1918		expires bool
1919		ttl     uint32
1920	)
1921
1922	c.mu.Lock()
1923	c.clientSideRefresh = clientSideRefresh
1924	c.mu.Unlock()
1925
1926	if credentials == nil {
1927		return nil, c.logDisconnectBadRequest("client credentials not found")
1928	}
1929
1930	c.mu.Lock()
1931	c.user = credentials.UserID
1932	c.info = credentials.Info
1933	c.exp = credentials.ExpireAt
1934
1935	user := c.user
1936	exp := c.exp
1937	closed = c.status == statusClosed
1938	c.mu.Unlock()
1939
1940	if closed {
1941		return nil, DisconnectNormal
1942	}
1943
1944	c.node.logger.log(newLogEntry(LogLevelDebug, "client authenticated", map[string]interface{}{"client": c.uid, "user": c.user}))
1945
1946	if userConnectionLimit > 0 && user != "" && len(c.node.hub.UserConnections(user)) >= userConnectionLimit {
1947		c.node.logger.log(newLogEntry(LogLevelInfo, "limit of connections for user reached", map[string]interface{}{"user": user, "client": c.uid, "limit": userConnectionLimit}))
1948		return nil, DisconnectConnectionLimit
1949	}
1950
1951	c.mu.RLock()
1952	if exp > 0 {
1953		expires = true
1954		now := time.Now().Unix()
1955		if exp < now {
1956			c.mu.RUnlock()
1957			c.node.logger.log(newLogEntry(LogLevelInfo, "connection expiration must be greater than now", map[string]interface{}{"client": c.uid, "user": c.UserID()}))
1958			return nil, ErrorExpired
1959		}
1960		ttl = uint32(exp - now)
1961	}
1962	c.mu.RUnlock()
1963
1964	res := &protocol.ConnectResult{
1965		Version: version,
1966		Expires: expires,
1967		Ttl:     ttl,
1968	}
1969
1970	// Client successfully connected.
1971	c.mu.Lock()
1972	c.authenticated = true
1973	c.mu.Unlock()
1974
1975	err := c.node.addClient(c)
1976	if err != nil {
1977		c.node.logger.log(newLogEntry(LogLevelError, "error adding client", map[string]interface{}{"client": c.uid, "error": err.Error()}))
1978		return nil, DisconnectServerError
1979	}
1980
1981	if !clientSideRefresh {
1982		// Server will do refresh itself.
1983		res.Expires = false
1984		res.Ttl = 0
1985	}
1986
1987	res.Client = c.uid
1988	if authData != nil {
1989		res.Data = authData
1990	}
1991
1992	var subCtxMap map[string]subscribeContext
1993	if len(subscriptions) > 0 {
1994		var subMu sync.Mutex
1995		subCtxMap = make(map[string]subscribeContext, len(subscriptions))
1996		subs := make(map[string]*protocol.SubscribeResult, len(subscriptions))
1997		var subDisconnect *Disconnect
1998		var subError *Error
1999		var wg sync.WaitGroup
2000
2001		wg.Add(len(subscriptions))
2002		for ch, opts := range subscriptions {
2003			go func(ch string, opts SubscribeOptions) {
2004				defer wg.Done()
2005				subCmd := &protocol.SubscribeRequest{
2006					Channel: ch,
2007				}
2008				if subReq, ok := cmd.Subs[ch]; ok {
2009					subCmd.Recover = subReq.Recover
2010					subCmd.Offset = subReq.Offset
2011					subCmd.Epoch = subReq.Epoch
2012				}
2013				subCtx := c.subscribeCmd(subCmd, SubscribeReply{Options: opts}, rw, true)
2014				subMu.Lock()
2015				subs[ch] = subCtx.result
2016				subCtxMap[ch] = subCtx
2017				if subCtx.disconnect != nil {
2018					subDisconnect = subCtx.disconnect
2019				}
2020				if subCtx.err != nil {
2021					subError = subCtx.err
2022				}
2023				subMu.Unlock()
2024			}(ch, opts)
2025		}
2026		wg.Wait()
2027
2028		if subDisconnect != nil || subError != nil {
2029			c.unlockServerSideSubscriptions(subCtxMap)
2030			for channel := range subCtxMap {
2031				c.onSubscribeError(channel)
2032			}
2033			if subDisconnect != nil {
2034				return nil, subDisconnect
2035			}
2036			return nil, subError
2037		}
2038		res.Subs = subs
2039	}
2040
2041	if c.transport.Unidirectional() {
2042		connectPushBytes, err := c.encodeConnectPush(res)
2043		if err != nil {
2044			c.unlockServerSideSubscriptions(subCtxMap)
2045			c.node.logger.log(newLogEntry(LogLevelError, "error encoding connect", map[string]interface{}{"error": err.Error()}))
2046			return nil, DisconnectServerError
2047		}
2048		_ = writeReply(rw, &protocol.Reply{Result: connectPushBytes})
2049		defer rw.done()
2050	} else {
2051		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeConnectResult(res)
2052		if err != nil {
2053			c.unlockServerSideSubscriptions(subCtxMap)
2054			c.node.logger.log(newLogEntry(LogLevelError, "error encoding connect", map[string]interface{}{"error": err.Error()}))
2055			return nil, DisconnectServerError
2056		}
2057		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
2058		defer rw.done()
2059	}
2060
2061	c.mu.Lock()
2062	for channel, subCtx := range subCtxMap {
2063		c.channels[channel] = subCtx.channelContext
2064	}
2065	c.mu.Unlock()
2066
2067	c.unlockServerSideSubscriptions(subCtxMap)
2068
2069	if len(subCtxMap) > 0 {
2070		for channel, subCtx := range subCtxMap {
2071			go func(channel string, subCtx subscribeContext) {
2072				if channelHasFlag(subCtx.channelContext.flags, flagJoinLeave) && subCtx.clientInfo != nil {
2073					_ = c.node.publishJoin(channel, subCtx.clientInfo)
2074				}
2075			}(channel, subCtx)
2076		}
2077	}
2078
2079	return res, nil
2080}
2081
2082// Subscribe client to a channel.
2083func (c *Client) Subscribe(channel string, opts ...SubscribeOption) error {
2084	if channel == "" {
2085		return fmt.Errorf("channel is empty")
2086	}
2087	channelLimit := c.node.config.ClientChannelLimit
2088	c.mu.RLock()
2089	numChannels := len(c.channels)
2090	c.mu.RUnlock()
2091	if channelLimit > 0 && numChannels >= channelLimit {
2092		go func() { _ = c.close(DisconnectChannelLimit) }()
2093		return nil
2094	}
2095
2096	subCmd := &protocol.SubscribeRequest{
2097		Channel: channel,
2098	}
2099	subscribeOpts := &SubscribeOptions{}
2100	for _, opt := range opts {
2101		opt(subscribeOpts)
2102	}
2103	if subscribeOpts.RecoverSince != nil {
2104		subCmd.Recover = true
2105		subCmd.Offset = subscribeOpts.RecoverSince.Offset
2106		subCmd.Epoch = subscribeOpts.RecoverSince.Epoch
2107	}
2108	subCtx := c.subscribeCmd(subCmd, SubscribeReply{
2109		Options: *subscribeOpts,
2110	}, nil, true)
2111	if subCtx.err != nil {
2112		c.onSubscribeError(subCmd.Channel)
2113		return subCtx.err
2114	}
2115	defer c.pubSubSync.StopBuffering(channel)
2116	c.mu.Lock()
2117	c.channels[channel] = subCtx.channelContext
2118	c.mu.Unlock()
2119	if hasFlag(c.transport.DisabledPushFlags(), PushFlagSubscribe) {
2120		return nil
2121	}
2122	sub := &protocol.Subscribe{
2123		Offset:      subCtx.result.GetOffset(),
2124		Epoch:       subCtx.result.GetEpoch(),
2125		Recoverable: subCtx.result.GetRecoverable(),
2126		Positioned:  subCtx.result.GetPositioned(),
2127		Data:        subCtx.result.Data,
2128	}
2129	pushBytes, err := protocol.EncodeSubscribePush(c.transport.Protocol().toProto(), channel, sub)
2130	if err != nil {
2131		return err
2132	}
2133	reply := prepared.NewReply(&protocol.Reply{
2134		Result: pushBytes,
2135	}, c.transport.Protocol().toProto())
2136	return c.transportEnqueue(reply)
2137}
2138
2139func (c *Client) validateSubscribeRequest(cmd *protocol.SubscribeRequest) (*Error, *Disconnect) {
2140	channel := cmd.Channel
2141	if channel == "" {
2142		c.node.logger.log(newLogEntry(LogLevelInfo, "channel required for subscribe", map[string]interface{}{"user": c.user, "client": c.uid}))
2143		return nil, DisconnectBadRequest
2144	}
2145
2146	config := c.node.config
2147	channelMaxLength := config.ChannelMaxLength
2148	channelLimit := config.ClientChannelLimit
2149
2150	if channelMaxLength > 0 && len(channel) > channelMaxLength {
2151		c.node.logger.log(newLogEntry(LogLevelInfo, "channel too long", map[string]interface{}{"max": channelMaxLength, "channel": channel, "user": c.user, "client": c.uid}))
2152		return ErrorBadRequest, nil
2153	}
2154
2155	c.mu.Lock()
2156	numChannels := len(c.channels)
2157	_, ok := c.channels[channel]
2158	if ok {
2159		c.mu.Unlock()
2160		c.node.logger.log(newLogEntry(LogLevelInfo, "client already subscribed on channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid}))
2161		return ErrorAlreadySubscribed, nil
2162	}
2163	if channelLimit > 0 && numChannels >= channelLimit {
2164		c.mu.Unlock()
2165		c.node.logger.log(newLogEntry(LogLevelInfo, "maximum limit of channels per client reached", map[string]interface{}{"limit": channelLimit, "user": c.user, "client": c.uid}))
2166		return ErrorLimitExceeded, nil
2167	}
2168	// Put channel to a map to track duplicate subscriptions. This channel should
2169	// be removed from a map upon an error during subscribe.
2170	c.channels[channel] = channelContext{}
2171	c.mu.Unlock()
2172
2173	return nil, nil
2174}
2175
2176func errorDisconnectContext(replyError *Error, disconnect *Disconnect) subscribeContext {
2177	ctx := subscribeContext{}
2178	if disconnect != nil {
2179		ctx.disconnect = disconnect
2180		return ctx
2181	}
2182	ctx.err = replyError
2183	return ctx
2184}
2185
2186type subscribeContext struct {
2187	result         *protocol.SubscribeResult
2188	clientInfo     *ClientInfo
2189	err            *Error
2190	disconnect     *Disconnect
2191	channelContext channelContext
2192}
2193
2194func isRecovered(historyResult HistoryResult, cmdOffset uint64, cmdEpoch string) ([]*protocol.Publication, bool) {
2195	latestOffset := historyResult.Offset
2196	latestEpoch := historyResult.Epoch
2197
2198	recoveredPubs := make([]*protocol.Publication, 0, len(historyResult.Publications))
2199	for _, pub := range historyResult.Publications {
2200		protoPub := pubToProto(pub)
2201		recoveredPubs = append(recoveredPubs, protoPub)
2202	}
2203
2204	nextOffset := cmdOffset + 1
2205	var recovered bool
2206	if len(recoveredPubs) == 0 {
2207		recovered = latestOffset == cmdOffset && (cmdEpoch == "" || latestEpoch == cmdEpoch)
2208	} else {
2209		recovered = recoveredPubs[0].Offset == nextOffset &&
2210			recoveredPubs[len(recoveredPubs)-1].Offset == latestOffset &&
2211			(cmdEpoch == "" || latestEpoch == cmdEpoch)
2212	}
2213
2214	return recoveredPubs, recovered
2215}
2216
2217// subscribeCmd handles subscribe command - clients send this when subscribe
2218// on channel, if channel if private then we must validate provided sign here before
2219// actually subscribe client on channel. Optionally we can send missed messages to
2220// client if it provided last message id seen in channel.
2221func (c *Client) subscribeCmd(cmd *protocol.SubscribeRequest, reply SubscribeReply, rw *replyWriter, serverSide bool) subscribeContext {
2222
2223	ctx := subscribeContext{}
2224	res := &protocol.SubscribeResult{}
2225
2226	if reply.Options.ExpireAt > 0 {
2227		ttl := reply.Options.ExpireAt - time.Now().Unix()
2228		if ttl <= 0 {
2229			c.node.logger.log(newLogEntry(LogLevelInfo, "subscription expiration must be greater than now", map[string]interface{}{"client": c.uid, "user": c.UserID()}))
2230			return errorDisconnectContext(ErrorExpired, nil)
2231		}
2232		if reply.ClientSideRefresh {
2233			res.Expires = true
2234			res.Ttl = uint32(ttl)
2235		}
2236	}
2237
2238	if reply.Options.Data != nil {
2239		res.Data = reply.Options.Data
2240	}
2241
2242	channel := cmd.Channel
2243
2244	info := &ClientInfo{
2245		ClientID: c.uid,
2246		UserID:   c.user,
2247		ConnInfo: c.info,
2248		ChanInfo: reply.Options.ChannelInfo,
2249	}
2250
2251	if reply.Options.Recover {
2252		// Start syncing recovery and PUB/SUB.
2253		// The important thing is to call StopBuffering for this channel
2254		// after response with Publications written to connection.
2255		c.pubSubSync.StartBuffering(channel)
2256	}
2257
2258	err := c.node.addSubscription(channel, c)
2259	if err != nil {
2260		c.node.logger.log(newLogEntry(LogLevelError, "error adding subscription", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2261		c.pubSubSync.StopBuffering(channel)
2262		if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal {
2263			return errorDisconnectContext(clientErr, nil)
2264		}
2265		ctx.disconnect = DisconnectServerError
2266		return ctx
2267	}
2268
2269	if reply.Options.Presence {
2270		err = c.node.addPresence(channel, c.uid, info)
2271		if err != nil {
2272			c.node.logger.log(newLogEntry(LogLevelError, "error adding presence", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2273			c.pubSubSync.StopBuffering(channel)
2274			ctx.disconnect = DisconnectServerError
2275			return ctx
2276		}
2277	}
2278
2279	var (
2280		latestOffset  uint64
2281		latestEpoch   string
2282		recoveredPubs []*protocol.Publication
2283	)
2284
2285	if reply.Options.Recover {
2286		res.Recoverable = true
2287		res.Positioned = true // recoverable subscriptions are automatically positioned.
2288		if cmd.Recover {
2289			cmdOffset := cmd.Offset
2290
2291			// Client provided subscribe request with recover flag on. Try to recover missed
2292			// publications automatically from history (we suppose here that history configured wisely).
2293			historyResult, err := c.node.recoverHistory(channel, StreamPosition{cmdOffset, cmd.Epoch})
2294			if err != nil {
2295				if errors.Is(err, ErrorUnrecoverablePosition) {
2296					// Result contains stream position in case of ErrorUnrecoverablePosition
2297					// during recovery.
2298					latestOffset = historyResult.Offset
2299					latestEpoch = historyResult.Epoch
2300					res.Recovered = false
2301					incRecover(res.Recovered)
2302				} else {
2303					c.node.logger.log(newLogEntry(LogLevelError, "error on recover", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2304					c.pubSubSync.StopBuffering(channel)
2305					if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal {
2306						return errorDisconnectContext(clientErr, nil)
2307					}
2308					ctx.disconnect = DisconnectServerError
2309					return ctx
2310				}
2311			} else {
2312				latestOffset = historyResult.Offset
2313				latestEpoch = historyResult.Epoch
2314				var recovered bool
2315				recoveredPubs, recovered = isRecovered(historyResult, cmdOffset, cmd.Epoch)
2316				res.Recovered = recovered
2317				incRecover(res.Recovered)
2318			}
2319		} else {
2320			streamTop, err := c.node.streamTop(channel)
2321			if err != nil {
2322				c.node.logger.log(newLogEntry(LogLevelError, "error getting recovery state for channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2323				c.pubSubSync.StopBuffering(channel)
2324				if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal {
2325					return errorDisconnectContext(clientErr, nil)
2326				}
2327				ctx.disconnect = DisconnectServerError
2328				return ctx
2329			}
2330			latestOffset = streamTop.Offset
2331			latestEpoch = streamTop.Epoch
2332		}
2333
2334		res.Epoch = latestEpoch
2335		res.Offset = latestOffset
2336
2337		bufferedPubs := c.pubSubSync.LockBufferAndReadBuffered(channel)
2338		var okMerge bool
2339		recoveredPubs, okMerge = recovery.MergePublications(recoveredPubs, bufferedPubs)
2340		if !okMerge {
2341			c.pubSubSync.StopBuffering(channel)
2342			ctx.disconnect = DisconnectInsufficientState
2343			return ctx
2344		}
2345	} else if reply.Options.Position {
2346		streamTop, err := c.node.streamTop(channel)
2347		if err != nil {
2348			c.node.logger.log(newLogEntry(LogLevelError, "error getting stream top for channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2349			if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal {
2350				return errorDisconnectContext(clientErr, nil)
2351			}
2352			ctx.disconnect = DisconnectServerError
2353			return ctx
2354		}
2355
2356		latestOffset = streamTop.Offset
2357		latestEpoch = streamTop.Epoch
2358
2359		res.Positioned = true
2360		res.Offset = streamTop.Offset
2361		res.Epoch = streamTop.Epoch
2362	}
2363
2364	if len(recoveredPubs) > 0 {
2365		lastPubOffset := recoveredPubs[len(recoveredPubs)-1].Offset
2366		if lastPubOffset > res.Offset {
2367			// There can be a case when recovery returned a limited set of publications
2368			// thus last publication offset will be smaller than history current offset.
2369			// In this case res.Recovered will be false. So we take a maximum here.
2370			latestOffset = recoveredPubs[len(recoveredPubs)-1].Offset
2371			res.Offset = latestOffset
2372		}
2373	}
2374
2375	res.Publications = recoveredPubs
2376
2377	if !serverSide {
2378		// Write subscription reply only if initiated by client.
2379		replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeSubscribeResult(res)
2380		if err != nil {
2381			c.node.logger.log(newLogEntry(LogLevelError, "error encoding subscribe", map[string]interface{}{"error": err.Error()}))
2382			if !serverSide {
2383				// Will be called later in case of server side sub.
2384				c.pubSubSync.StopBuffering(channel)
2385			}
2386			ctx.disconnect = DisconnectServerError
2387			return ctx
2388		}
2389		// Need to flush data from writer so subscription response is
2390		// sent before any subscription publication.
2391		_ = writeReply(rw, &protocol.Reply{Result: replyRes})
2392	}
2393
2394	var channelFlags uint8
2395	channelFlags |= flagSubscribed
2396	if serverSide {
2397		channelFlags |= flagServerSide
2398	}
2399	if reply.ClientSideRefresh {
2400		channelFlags |= flagClientSideRefresh
2401	}
2402	if reply.Options.Recover {
2403		channelFlags |= flagRecover
2404	}
2405	if reply.Options.Position {
2406		channelFlags |= flagPosition
2407	}
2408	if reply.Options.Presence {
2409		channelFlags |= flagPresence
2410	}
2411	if reply.Options.JoinLeave {
2412		channelFlags |= flagJoinLeave
2413	}
2414
2415	channelContext := channelContext{
2416		Info:     reply.Options.ChannelInfo,
2417		flags:    channelFlags,
2418		expireAt: reply.Options.ExpireAt,
2419		streamPosition: StreamPosition{
2420			Offset: latestOffset,
2421			Epoch:  latestEpoch,
2422		},
2423	}
2424	if reply.Options.Recover || reply.Options.Position {
2425		channelContext.positionCheckTime = time.Now().Unix()
2426	}
2427
2428	if !serverSide {
2429		// In case of server-side sub this will be done later by the caller.
2430		c.mu.Lock()
2431		c.channels[channel] = channelContext
2432		c.mu.Unlock()
2433		// Stop syncing recovery and PUB/SUB.
2434		// In case of server side subscription we will do this later.
2435		c.pubSubSync.StopBuffering(channel)
2436	}
2437
2438	if c.node.logger.enabled(LogLevelDebug) {
2439		c.node.logger.log(newLogEntry(LogLevelDebug, "client subscribed to channel", map[string]interface{}{"client": c.uid, "user": c.user, "channel": cmd.Channel}))
2440	}
2441
2442	ctx.result = res
2443	ctx.clientInfo = info
2444	ctx.channelContext = channelContext
2445	return ctx
2446}
2447
2448func (c *Client) writePublicationUpdatePosition(ch string, pub *protocol.Publication, reply *prepared.Reply, sp StreamPosition) error {
2449	c.mu.Lock()
2450	channelContext, ok := c.channels[ch]
2451	if !ok || !channelHasFlag(channelContext.flags, flagSubscribed) {
2452		c.mu.Unlock()
2453		return nil
2454	}
2455	if !channelHasFlag(channelContext.flags, flagRecover|flagPosition) {
2456		if hasFlag(c.transport.DisabledPushFlags(), PushFlagPublication) {
2457			c.mu.Unlock()
2458			return nil
2459		}
2460		c.mu.Unlock()
2461		return c.transportEnqueue(reply)
2462	}
2463	currentPositionOffset := channelContext.streamPosition.Offset
2464	nextExpectedOffset := currentPositionOffset + 1
2465	pubOffset := pub.Offset
2466	pubEpoch := sp.Epoch
2467	if pubEpoch != channelContext.streamPosition.Epoch {
2468		if c.node.logger.enabled(LogLevelDebug) {
2469			c.node.logger.log(newLogEntry(LogLevelDebug, "client insufficient state", map[string]interface{}{"channel": ch, "user": c.user, "client": c.uid, "epoch": pubEpoch, "expectedEpoch": channelContext.streamPosition.Epoch}))
2470		}
2471		// Oops: sth lost, let client reconnect to recover its state.
2472		go func() { _ = c.close(DisconnectInsufficientState) }()
2473		c.mu.Unlock()
2474		return nil
2475	}
2476	if pubOffset != nextExpectedOffset {
2477		if c.node.logger.enabled(LogLevelDebug) {
2478			c.node.logger.log(newLogEntry(LogLevelDebug, "client insufficient state", map[string]interface{}{"channel": ch, "user": c.user, "client": c.uid, "offset": pubOffset, "expectedOffset": nextExpectedOffset}))
2479		}
2480		// Oops: sth lost, let client reconnect to recover its state.
2481		go func() { _ = c.close(DisconnectInsufficientState) }()
2482		c.mu.Unlock()
2483		return nil
2484	}
2485	channelContext.positionCheckTime = time.Now().Unix()
2486	channelContext.positionCheckFailures = 0
2487	channelContext.streamPosition.Offset = pub.Offset
2488	c.channels[ch] = channelContext
2489	c.mu.Unlock()
2490	if hasFlag(c.transport.DisabledPushFlags(), PushFlagPublication) {
2491		return nil
2492	}
2493	return c.transportEnqueue(reply)
2494}
2495
2496func (c *Client) writePublication(ch string, pub *protocol.Publication, reply *prepared.Reply, sp StreamPosition) error {
2497	if pub.Offset == 0 {
2498		if hasFlag(c.transport.DisabledPushFlags(), PushFlagPublication) {
2499			return nil
2500		}
2501		return c.transportEnqueue(reply)
2502	}
2503	c.pubSubSync.SyncPublication(ch, pub, func() {
2504		_ = c.writePublicationUpdatePosition(ch, pub, reply, sp)
2505	})
2506	return nil
2507}
2508
2509func (c *Client) writeJoin(_ string, reply *prepared.Reply) error {
2510	if hasFlag(c.transport.DisabledPushFlags(), PushFlagJoin) {
2511		return nil
2512	}
2513	return c.transportEnqueue(reply)
2514}
2515
2516func (c *Client) writeLeave(_ string, reply *prepared.Reply) error {
2517	if hasFlag(c.transport.DisabledPushFlags(), PushFlagLeave) {
2518		return nil
2519	}
2520	return c.transportEnqueue(reply)
2521}
2522
2523// Lock must be held outside.
2524func (c *Client) unsubscribe(channel string) error {
2525	c.mu.RLock()
2526	info := c.clientInfo(channel)
2527	chCtx, ok := c.channels[channel]
2528	serverSide := channelHasFlag(chCtx.flags, flagServerSide)
2529	c.mu.RUnlock()
2530
2531	if ok {
2532		c.mu.Lock()
2533		delete(c.channels, channel)
2534		c.mu.Unlock()
2535
2536		if channelHasFlag(chCtx.flags, flagPresence) && channelHasFlag(chCtx.flags, flagSubscribed) {
2537			err := c.node.removePresence(channel, c.uid)
2538			if err != nil {
2539				c.node.logger.log(newLogEntry(LogLevelError, "error removing channel presence", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2540			}
2541		}
2542
2543		if channelHasFlag(chCtx.flags, flagJoinLeave) && channelHasFlag(chCtx.flags, flagSubscribed) {
2544			_ = c.node.publishLeave(channel, info)
2545		}
2546
2547		if err := c.node.removeSubscription(channel, c); err != nil {
2548			c.node.logger.log(newLogEntry(LogLevelError, "error removing subscription", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()}))
2549			return err
2550		}
2551
2552		if channelHasFlag(chCtx.flags, flagSubscribed) {
2553			if c.eventHub.unsubscribeHandler != nil {
2554				c.eventHub.unsubscribeHandler(UnsubscribeEvent{
2555					Channel:    channel,
2556					ServerSide: serverSide,
2557				})
2558			}
2559		}
2560	}
2561	if c.node.logger.enabled(LogLevelDebug) {
2562		c.node.logger.log(newLogEntry(LogLevelDebug, "client unsubscribed from channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid}))
2563	}
2564	return nil
2565}
2566
2567func (c *Client) logDisconnectBadRequest(message string) *Disconnect {
2568	c.node.logger.log(newLogEntry(LogLevelInfo, message, map[string]interface{}{"user": c.user, "client": c.uid}))
2569	return DisconnectBadRequest
2570}
2571
2572func (c *Client) logDisconnectBadRequestWithError(err error, message string) *Disconnect {
2573	c.node.logger.log(newLogEntry(LogLevelInfo, message, map[string]interface{}{"error": err.Error(), "user": c.user, "client": c.uid}))
2574	return DisconnectBadRequest
2575}
2576
2577func (c *Client) logWriteInternalErrorFlush(rw *replyWriter, err error, message string) {
2578	if clientErr, ok := err.(*Error); ok {
2579		c.writeError(rw, clientErr)
2580		return
2581	}
2582	c.node.logger.log(newLogEntry(LogLevelError, message, map[string]interface{}{"error": err.Error()}))
2583	c.writeError(rw, ErrorInternal)
2584}
2585
2586func toClientErr(err error) *Error {
2587	if clientErr, ok := err.(*Error); ok {
2588		return clientErr
2589	}
2590	return ErrorInternal
2591}
2592
2593func errLogLevel(err error) LogLevel {
2594	logLevel := LogLevelInfo
2595	if err != ErrorNotAvailable {
2596		logLevel = LogLevelError
2597	}
2598	return logLevel
2599}
2600