1package centrifuge 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "sync" 9 "time" 10 11 "github.com/centrifugal/centrifuge/internal/prepared" 12 "github.com/centrifugal/centrifuge/internal/queue" 13 "github.com/centrifugal/centrifuge/internal/recovery" 14 15 "github.com/centrifugal/protocol" 16 "github.com/google/uuid" 17) 18 19// clientEventHub allows to deal with client event handlers. 20// All its methods are not goroutine-safe and supposed to be called 21// once inside Node ConnectHandler. 22type clientEventHub struct { 23 aliveHandler AliveHandler 24 disconnectHandler DisconnectHandler 25 subscribeHandler SubscribeHandler 26 unsubscribeHandler UnsubscribeHandler 27 publishHandler PublishHandler 28 refreshHandler RefreshHandler 29 subRefreshHandler SubRefreshHandler 30 rpcHandler RPCHandler 31 messageHandler MessageHandler 32 presenceHandler PresenceHandler 33 presenceStatsHandler PresenceStatsHandler 34 historyHandler HistoryHandler 35} 36 37// OnAlive allows setting AliveHandler. 38// AliveHandler called periodically for active client connection. 39func (c *Client) OnAlive(h AliveHandler) { 40 c.eventHub.aliveHandler = h 41} 42 43// OnRefresh allows setting RefreshHandler. 44// RefreshHandler called when it's time to refresh expiring client connection. 45func (c *Client) OnRefresh(h RefreshHandler) { 46 c.eventHub.refreshHandler = h 47} 48 49// OnDisconnect allows setting DisconnectHandler. 50// DisconnectHandler called when client disconnected. 51func (c *Client) OnDisconnect(h DisconnectHandler) { 52 c.eventHub.disconnectHandler = h 53} 54 55// OnMessage allows setting MessageHandler. 56// MessageHandler called when client sent asynchronous message. 57func (c *Client) OnMessage(h MessageHandler) { 58 c.eventHub.messageHandler = h 59} 60 61// OnRPC allows setting RPCHandler. 62// RPCHandler will be executed on every incoming RPC call. 63func (c *Client) OnRPC(h RPCHandler) { 64 c.eventHub.rpcHandler = h 65} 66 67// OnSubRefresh allows setting SubRefreshHandler. 68// SubRefreshHandler called when it's time to refresh client subscription. 69func (c *Client) OnSubRefresh(h SubRefreshHandler) { 70 c.eventHub.subRefreshHandler = h 71} 72 73// OnSubscribe allows setting SubscribeHandler. 74// SubscribeHandler called when client subscribes on a channel. 75func (c *Client) OnSubscribe(h SubscribeHandler) { 76 c.eventHub.subscribeHandler = h 77} 78 79// OnUnsubscribe allows setting UnsubscribeHandler. 80// UnsubscribeHandler called when client unsubscribes from channel. 81func (c *Client) OnUnsubscribe(h UnsubscribeHandler) { 82 c.eventHub.unsubscribeHandler = h 83} 84 85// OnPublish allows setting PublishHandler. 86// PublishHandler called when client publishes message into channel. 87func (c *Client) OnPublish(h PublishHandler) { 88 c.eventHub.publishHandler = h 89} 90 91// OnPresence allows setting PresenceHandler. 92// PresenceHandler called when Presence request from client received. 93// At this moment you can only return a custom error or disconnect client. 94func (c *Client) OnPresence(h PresenceHandler) { 95 c.eventHub.presenceHandler = h 96} 97 98// OnPresenceStats allows settings PresenceStatsHandler. 99// PresenceStatsHandler called when Presence Stats request from client received. 100// At this moment you can only return a custom error or disconnect client. 101func (c *Client) OnPresenceStats(h PresenceStatsHandler) { 102 c.eventHub.presenceStatsHandler = h 103} 104 105// OnHistory allows settings HistoryHandler. 106// HistoryHandler called when History request from client received. 107// At this moment you can only return a custom error or disconnect client. 108func (c *Client) OnHistory(h HistoryHandler) { 109 c.eventHub.historyHandler = h 110} 111 112// We poll current position in channel from history storage periodically. 113// If client position is wrong maxCheckPositionFailures times in a row 114// then client will be disconnected with InsufficientState reason. Polling 115// not used in channels with high frequency updates since we can check position 116// comparing client offset with offset in incoming Publication. 117const maxCheckPositionFailures uint8 = 2 118 119// Note: up to 8 possible flags here. 120const ( 121 // flagSubscribed will be set upon successful Subscription to a channel. 122 // Until that moment channel exists in client Channels map only to track 123 // duplicate subscription requests. 124 flagSubscribed uint8 = 1 << iota 125 flagPresence 126 flagJoinLeave 127 flagPosition 128 flagRecover 129 flagServerSide 130 flagClientSideRefresh 131) 132 133// channelContext contains extra context for channel connection subscribed to. 134// Note: this struct is aligned to consume less memory. 135type channelContext struct { 136 Info []byte 137 expireAt int64 138 positionCheckTime int64 139 streamPosition StreamPosition 140 positionCheckFailures uint8 141 flags uint8 142} 143 144func channelHasFlag(flags, flag uint8) bool { 145 return flags&flag != 0 146} 147 148type timerOp uint8 149 150const ( 151 timerOpStale timerOp = 1 152 timerOpPresence timerOp = 2 153 timerOpExpire timerOp = 3 154) 155 156type status uint8 157 158const ( 159 statusConnecting status = 1 160 statusConnected status = 2 161 statusClosed status = 3 162) 163 164// ConnectRequest can be used in a unidirectional connection case to 165// pass initial connection information from a client-side. 166type ConnectRequest struct { 167 // Token is an optional token from a client. 168 Token string 169 // Data is an optional custom data from a client. 170 Data []byte 171 // Name of a client. 172 Name string 173 // Version of a client. 174 Version string 175 // Subs is a map with channel subscription state (for recovery on connect). 176 Subs map[string]SubscribeRequest 177} 178 179// SubscribeRequest contains state of subscription to a channel. 180type SubscribeRequest struct { 181 // Recover enables publication recovery for a channel. 182 Recover bool 183 // Epoch last seen by a client. 184 Epoch string 185 // Offset last seen by a client. 186 Offset uint64 187} 188 189func (r *ConnectRequest) toProto() *protocol.ConnectRequest { 190 if r == nil { 191 return nil 192 } 193 req := &protocol.ConnectRequest{ 194 Token: r.Token, 195 Data: r.Data, 196 Name: r.Name, 197 Version: r.Version, 198 } 199 if len(r.Subs) > 0 { 200 subs := make(map[string]*protocol.SubscribeRequest, len(r.Subs)) 201 for k, v := range r.Subs { 202 subs[k] = &protocol.SubscribeRequest{ 203 Recover: v.Recover, 204 Epoch: v.Epoch, 205 Offset: v.Offset, 206 } 207 } 208 req.Subs = subs 209 } 210 return req 211} 212 213// Client represents client connection to server. 214type Client struct { 215 mu sync.RWMutex 216 connectMu sync.Mutex // allows to sync connect with disconnect. 217 presenceMu sync.Mutex // allows to sync presence routine with client closing. 218 ctx context.Context 219 transport Transport 220 node *Node 221 exp int64 222 channels map[string]channelContext 223 messageWriter *writer 224 pubSubSync *recovery.PubSubSync 225 uid string 226 user string 227 info []byte 228 authenticated bool 229 clientSideRefresh bool 230 status status 231 timerOp timerOp 232 nextPresence int64 233 nextExpire int64 234 eventHub *clientEventHub 235 timer *time.Timer 236} 237 238// ClientCloseFunc must be called on Transport handler close to clean up Client. 239type ClientCloseFunc func() error 240 241// NewClient initializes new Client. 242func NewClient(ctx context.Context, n *Node, t Transport) (*Client, ClientCloseFunc, error) { 243 uuidObject, err := uuid.NewRandom() 244 if err != nil { 245 return nil, nil, err 246 } 247 248 client := &Client{ 249 ctx: ctx, 250 uid: uuidObject.String(), 251 node: n, 252 transport: t, 253 channels: make(map[string]channelContext), 254 pubSubSync: recovery.NewPubSubSync(), 255 status: statusConnecting, 256 eventHub: &clientEventHub{}, 257 } 258 259 messageWriterConf := writerConfig{ 260 MaxQueueSize: n.config.ClientQueueMaxSize, 261 WriteFn: func(item queue.Item) error { 262 if client.node.transportWriteHandler != nil { 263 pass := client.node.transportWriteHandler(client, TransportWriteEvent(item)) 264 if !pass { 265 return nil 266 } 267 } 268 if err := t.Write(item.Data); err != nil { 269 switch v := err.(type) { 270 case *Disconnect: 271 go func() { _ = client.close(v) }() 272 default: 273 go func() { _ = client.close(DisconnectWriteError) }() 274 } 275 return err 276 } 277 incTransportMessagesSent(t.Name()) 278 return nil 279 }, 280 WriteManyFn: func(items ...queue.Item) error { 281 messages := make([][]byte, 0, len(items)) 282 for i := 0; i < len(items); i++ { 283 if client.node.transportWriteHandler != nil { 284 pass := client.node.transportWriteHandler(client, TransportWriteEvent(items[i])) 285 if !pass { 286 continue 287 } 288 } 289 messages = append(messages, items[i].Data) 290 } 291 if err := t.WriteMany(messages...); err != nil { 292 switch v := err.(type) { 293 case *Disconnect: 294 go func() { _ = client.close(v) }() 295 default: 296 go func() { _ = client.close(DisconnectWriteError) }() 297 } 298 return err 299 } 300 addTransportMessagesSent(t.Name(), float64(len(items))) 301 return nil 302 }, 303 } 304 305 client.messageWriter = newWriter(messageWriterConf) 306 go client.messageWriter.run() 307 308 staleCloseDelay := n.config.ClientStaleCloseDelay 309 if staleCloseDelay > 0 && !client.authenticated { 310 client.mu.Lock() 311 client.timerOp = timerOpStale 312 client.timer = time.AfterFunc(staleCloseDelay, client.onTimerOp) 313 client.mu.Unlock() 314 } 315 return client, func() error { return client.close(nil) }, nil 316} 317 318func extractUnidirectionalDisconnect(err error) *Disconnect { 319 if err == nil { 320 return nil 321 } 322 var d *Disconnect 323 switch t := err.(type) { 324 case *Disconnect: 325 d = t 326 case *Error: 327 switch t.Code { 328 case ErrorExpired.Code: 329 d = DisconnectExpired 330 case ErrorTokenExpired.Code: 331 d = DisconnectExpired 332 default: 333 d = DisconnectServerError 334 } 335 default: 336 d = DisconnectServerError 337 } 338 return d 339} 340 341// Connect supposed to be called from unidirectional transport layer to pass 342// initial information about connection and thus initiate Node.OnConnecting 343// event. Bidirectional transport initiate connecting workflow automatically 344// since client passes Connect command upon successful connection establishment 345// with a server. 346func (c *Client) Connect(req ConnectRequest) { 347 err := c.unidirectionalConnect(req.toProto()) 348 if err != nil { 349 d := extractUnidirectionalDisconnect(err) 350 go func() { _ = c.close(d) }() 351 } 352} 353 354func (c *Client) encodeDisconnect(d *Disconnect) (*prepared.Reply, error) { 355 disconnect := &protocol.Disconnect{ 356 Code: d.Code, 357 Reason: d.Reason, 358 Reconnect: d.Reconnect, 359 } 360 pushBytes, err := protocol.EncodeDisconnectPush(c.transport.Protocol().toProto(), disconnect) 361 if err != nil { 362 return nil, err 363 } 364 return prepared.NewReply(&protocol.Reply{ 365 Result: pushBytes, 366 }, c.transport.Protocol().toProto()), nil 367} 368 369func (c *Client) encodeConnectPush(res *protocol.ConnectResult) ([]byte, error) { 370 p := &protocol.Connect{ 371 Version: res.GetVersion(), 372 Client: res.GetClient(), 373 Data: res.Data, 374 Subs: res.Subs, 375 Expires: res.Expires, 376 Ttl: res.Ttl, 377 } 378 return protocol.EncodeConnectPush(c.transport.Protocol().toProto(), p) 379} 380 381func hasFlag(flags, flag uint64) bool { 382 return flags&flag != 0 383} 384 385func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest) error { 386 write := func(rep *protocol.Reply) error { 387 if hasFlag(c.transport.DisabledPushFlags(), PushFlagConnect) { 388 return nil 389 } 390 c.trace("-->", rep.Result) 391 disconnect := c.messageWriter.enqueue(queue.Item{Data: rep.Result, IsPush: false}) 392 if disconnect != nil { 393 if c.node.logger.enabled(LogLevelDebug) { 394 c.node.logger.log(newLogEntry(LogLevelDebug, "disconnect after connect push", map[string]interface{}{"client": c.ID(), "user": c.UserID(), "reason": disconnect.Reason})) 395 } 396 go func() { _ = c.close(disconnect) }() 397 } 398 return disconnect 399 } 400 401 rw := &replyWriter{write, func() {}} 402 403 _, err := c.connectCmd(connectRequest, rw) 404 if err != nil { 405 return err 406 } 407 c.triggerConnect() 408 c.scheduleOnConnectTimers() 409 return nil 410} 411 412func (c *Client) onTimerOp() { 413 c.mu.Lock() 414 if c.status == statusClosed { 415 c.mu.Unlock() 416 return 417 } 418 timerOp := c.timerOp 419 c.mu.Unlock() 420 switch timerOp { 421 case timerOpStale: 422 c.closeUnauthenticated() 423 case timerOpPresence: 424 c.updatePresence() 425 case timerOpExpire: 426 c.expire() 427 } 428} 429 430// Lock must be held outside. 431func (c *Client) scheduleNextTimer() { 432 if c.status == statusClosed { 433 return 434 } 435 c.stopTimer() 436 var minEventTime int64 437 var nextTimerOp timerOp 438 var needTimer bool 439 if c.nextExpire > 0 { 440 nextTimerOp = timerOpExpire 441 minEventTime = c.nextExpire 442 needTimer = true 443 } 444 if c.nextPresence > 0 && (minEventTime == 0 || c.nextPresence < minEventTime) { 445 nextTimerOp = timerOpPresence 446 minEventTime = c.nextPresence 447 needTimer = true 448 } 449 if needTimer { 450 c.timerOp = nextTimerOp 451 afterDuration := time.Duration(minEventTime-time.Now().UnixNano()) * time.Nanosecond 452 c.timer = time.AfterFunc(afterDuration, c.onTimerOp) 453 } 454} 455 456// Lock must be held outside. 457func (c *Client) stopTimer() { 458 if c.timer != nil { 459 c.timer.Stop() 460 } 461} 462 463// Lock must be held outside. 464func (c *Client) addPresenceUpdate() { 465 config := c.node.config 466 presenceInterval := config.ClientPresenceUpdateInterval 467 c.nextPresence = time.Now().Add(presenceInterval).UnixNano() 468 c.scheduleNextTimer() 469} 470 471// Lock must be held outside. 472func (c *Client) addExpireUpdate(after time.Duration) { 473 c.nextExpire = time.Now().Add(after).UnixNano() 474 c.scheduleNextTimer() 475} 476 477// closeUnauthenticated closes connection if it's not authenticated yet. 478// At moment used to close client connections which have not sent valid 479// connect command in a reasonable time interval after established connection 480// with server. 481func (c *Client) closeUnauthenticated() { 482 c.mu.RLock() 483 authenticated := c.authenticated 484 closed := c.status == statusClosed 485 c.mu.RUnlock() 486 if !authenticated && !closed { 487 _ = c.close(DisconnectStale) 488 } 489} 490 491func (c *Client) transportEnqueue(reply *prepared.Reply) error { 492 var data []byte 493 if c.transport.Unidirectional() { 494 data = reply.Reply.Result 495 } else { 496 data = reply.Data() 497 } 498 c.trace("-->", data) 499 disconnect := c.messageWriter.enqueue(queue.Item{ 500 Data: data, 501 IsPush: reply.Reply.Id == 0, 502 }) 503 if disconnect != nil { 504 // close in goroutine to not block message broadcast. 505 go func() { _ = c.close(disconnect) }() 506 return io.EOF 507 } 508 return nil 509} 510 511// updateChannelPresence updates client presence info for channel so it 512// won't expire until client disconnect. 513func (c *Client) updateChannelPresence(ch string, chCtx channelContext) error { 514 if !channelHasFlag(chCtx.flags, flagPresence) { 515 return nil 516 } 517 return c.node.addPresence(ch, c.uid, &ClientInfo{ 518 ClientID: c.uid, 519 UserID: c.user, 520 ConnInfo: c.info, 521 ChanInfo: chCtx.Info, 522 }) 523} 524 525// Context returns client Context. This context will be canceled 526// as soon as client connection closes. 527func (c *Client) Context() context.Context { 528 return c.ctx 529} 530 531func (c *Client) checkSubscriptionExpiration(channel string, channelContext channelContext, delay time.Duration, resultCB func(bool)) { 532 now := c.node.nowTimeGetter().Unix() 533 expireAt := channelContext.expireAt 534 clientSideRefresh := channelHasFlag(channelContext.flags, flagClientSideRefresh) 535 if expireAt > 0 && now > expireAt+int64(delay.Seconds()) { 536 // Subscription expired. 537 if clientSideRefresh || c.eventHub.subRefreshHandler == nil { 538 // The only way subscription could be refreshed in this case is via 539 // SUB_REFRESH command sent from client but looks like that command 540 // with new refreshed token have not been received in configured window. 541 resultCB(false) 542 return 543 } 544 cb := func(reply SubRefreshReply, err error) { 545 if err != nil { 546 resultCB(false) 547 return 548 } 549 if reply.Expired || (reply.ExpireAt > 0 && reply.ExpireAt < now) { 550 resultCB(false) 551 return 552 } 553 c.mu.Lock() 554 if ctx, ok := c.channels[channel]; ok { 555 if len(reply.Info) > 0 { 556 ctx.Info = reply.Info 557 } 558 ctx.expireAt = reply.ExpireAt 559 c.channels[channel] = ctx 560 } 561 c.mu.Unlock() 562 resultCB(true) 563 } 564 // Give subscription a chance to be refreshed via SubRefreshHandler. 565 event := SubRefreshEvent{Channel: channel} 566 c.eventHub.subRefreshHandler(event, cb) 567 return 568 } 569 resultCB(true) 570} 571 572// updatePresence used for various periodic actions we need to do with client connections. 573func (c *Client) updatePresence() { 574 c.presenceMu.Lock() 575 defer c.presenceMu.Unlock() 576 config := c.node.config 577 c.mu.Lock() 578 if c.status == statusClosed { 579 c.mu.Unlock() 580 return 581 } 582 channels := make(map[string]channelContext, len(c.channels)) 583 for channel, channelContext := range c.channels { 584 if !channelHasFlag(channelContext.flags, flagSubscribed) { 585 continue 586 } 587 channels[channel] = channelContext 588 } 589 c.mu.Unlock() 590 if c.eventHub.aliveHandler != nil { 591 c.eventHub.aliveHandler() 592 } 593 for channel, channelContext := range channels { 594 c.checkSubscriptionExpiration(channel, channelContext, config.ClientExpiredSubCloseDelay, func(result bool) { 595 // Ideally we should deal with single expired subscription in this 596 // case - i.e. unsubscribe client from channel and give an advice 597 // to resubscribe. But there is scenario when browser goes online 598 // after computer was in sleeping mode which I have not managed to 599 // handle reliably on client side when unsubscribe with resubscribe 600 // flag was used. So I decided to stick with disconnect for now - 601 // it seems to work fine and drastically simplifies client code. 602 if !result { 603 go func() { _ = c.close(DisconnectSubExpired) }() 604 } 605 }) 606 607 checkDelay := config.ClientChannelPositionCheckDelay 608 if checkDelay > 0 && !c.checkPosition(checkDelay, channel, channelContext) { 609 go func() { _ = c.close(DisconnectInsufficientState) }() 610 // No need to proceed after close. 611 return 612 } 613 614 err := c.updateChannelPresence(channel, channelContext) 615 if err != nil { 616 c.node.logger.log(newLogEntry(LogLevelError, "error updating presence for channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 617 } 618 } 619 c.mu.Lock() 620 c.addPresenceUpdate() 621 c.mu.Unlock() 622} 623 624func (c *Client) checkPosition(checkDelay time.Duration, ch string, chCtx channelContext) bool { 625 if !channelHasFlag(chCtx.flags, flagRecover|flagPosition) { 626 return true 627 } 628 nowUnix := c.node.nowTimeGetter().Unix() 629 630 isInitialCheck := chCtx.positionCheckTime == 0 631 isTimeToCheck := nowUnix-chCtx.positionCheckTime > int64(checkDelay.Seconds()) 632 needCheckPosition := isInitialCheck || isTimeToCheck 633 634 if !needCheckPosition { 635 return true 636 } 637 position := chCtx.streamPosition 638 streamTop, err := c.node.streamTop(ch) 639 if err != nil { 640 return true 641 } 642 643 isValidPosition := streamTop.Offset == position.Offset && streamTop.Epoch == position.Epoch 644 keepConnection := true 645 c.mu.Lock() 646 if chContext, ok := c.channels[ch]; ok { 647 chContext.positionCheckTime = nowUnix 648 if !isValidPosition { 649 chContext.positionCheckFailures++ 650 keepConnection = chContext.positionCheckFailures < maxCheckPositionFailures 651 } else { 652 chContext.positionCheckFailures = 0 653 } 654 c.channels[ch] = chContext 655 } 656 c.mu.Unlock() 657 return keepConnection 658} 659 660// ID returns unique client connection id. 661func (c *Client) ID() string { 662 return c.uid 663} 664 665// UserID returns user id associated with client connection. 666func (c *Client) UserID() string { 667 return c.user 668} 669 670// Info returns connection info. 671func (c *Client) Info() []byte { 672 c.mu.Lock() 673 info := make([]byte, len(c.info)) 674 copy(info, c.info) 675 c.mu.Unlock() 676 return info 677} 678 679// Transport returns client connection transport information. 680func (c *Client) Transport() TransportInfo { 681 return c.transport 682} 683 684// Channels returns a slice of channels client connection currently subscribed to. 685func (c *Client) Channels() []string { 686 c.mu.RLock() 687 defer c.mu.RUnlock() 688 channels := make([]string, 0, len(c.channels)) 689 for ch, ctx := range c.channels { 690 if !channelHasFlag(ctx.flags, flagSubscribed) { 691 continue 692 } 693 channels = append(channels, ch) 694 } 695 return channels 696} 697 698// IsSubscribed returns true if client subscribed to a channel. 699func (c *Client) IsSubscribed(ch string) bool { 700 c.mu.RLock() 701 defer c.mu.RUnlock() 702 ctx, ok := c.channels[ch] 703 return ok && channelHasFlag(ctx.flags, flagSubscribed) 704} 705 706// Send data to client. This sends an asynchronous message – data will be 707// just written to connection. on client side this message can be handled 708// with Message handler. 709func (c *Client) Send(data []byte) error { 710 if hasFlag(c.transport.DisabledPushFlags(), PushFlagMessage) { 711 return nil 712 } 713 p := &protocol.Message{ 714 Data: data, 715 } 716 pushBytes, err := protocol.EncodeMessagePush(c.transport.Protocol().toProto(), p) 717 if err != nil { 718 return err 719 } 720 reply := prepared.NewReply(&protocol.Reply{ 721 Result: pushBytes, 722 }, c.transport.Protocol().toProto()) 723 return c.transportEnqueue(reply) 724} 725 726// Unsubscribe allows to unsubscribe client from channel. 727func (c *Client) Unsubscribe(ch string) error { 728 c.mu.RLock() 729 if c.status == statusClosed { 730 c.mu.RUnlock() 731 return nil 732 } 733 c.mu.RUnlock() 734 735 err := c.unsubscribe(ch) 736 if err != nil { 737 return err 738 } 739 return c.sendUnsubscribe(ch) 740} 741 742func (c *Client) sendUnsubscribe(ch string) error { 743 if hasFlag(c.transport.DisabledPushFlags(), PushFlagUnsubscribe) { 744 return nil 745 } 746 pushBytes, err := protocol.EncodeUnsubscribePush(c.transport.Protocol().toProto(), ch, &protocol.Unsubscribe{}) 747 if err != nil { 748 return err 749 } 750 reply := prepared.NewReply(&protocol.Reply{ 751 Result: pushBytes, 752 }, c.transport.Protocol().toProto()) 753 754 _ = c.transportEnqueue(reply) 755 return nil 756} 757 758// Disconnect client connection with specific disconnect code and reason. 759// This method internally creates a new goroutine at moment to do 760// closing stuff. An extra goroutine is required to solve disconnect 761// and alive callback ordering/sync problems. Will be a noop if client 762// already closed. As this method runs a separate goroutine client 763// connection will be closed eventually (i.e. not immediately). 764func (c *Client) Disconnect(disconnect *Disconnect) { 765 go func() { 766 _ = c.close(disconnect) 767 }() 768} 769 770func (c *Client) close(disconnect *Disconnect) error { 771 c.presenceMu.Lock() 772 defer c.presenceMu.Unlock() 773 c.connectMu.Lock() 774 defer c.connectMu.Unlock() 775 c.mu.Lock() 776 if c.status == statusClosed { 777 c.mu.Unlock() 778 return nil 779 } 780 prevStatus := c.status 781 c.status = statusClosed 782 783 c.stopTimer() 784 785 channels := make(map[string]channelContext, len(c.channels)) 786 for channel, channelContext := range c.channels { 787 channels[channel] = channelContext 788 } 789 c.mu.Unlock() 790 791 if len(channels) > 0 { 792 // Unsubscribe from all channels. 793 for channel := range channels { 794 err := c.unsubscribe(channel) 795 if err != nil { 796 c.node.logger.log(newLogEntry(LogLevelError, "error unsubscribing client from channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 797 } 798 } 799 } 800 801 c.mu.RLock() 802 authenticated := c.authenticated 803 c.mu.RUnlock() 804 805 if authenticated { 806 err := c.node.removeClient(c) 807 if err != nil { 808 c.node.logger.log(newLogEntry(LogLevelError, "error removing client", map[string]interface{}{"user": c.user, "client": c.uid, "error": err.Error()})) 809 } 810 } 811 812 if disconnect != nil && !hasFlag(c.transport.DisabledPushFlags(), PushFlagDisconnect) { 813 if reply, err := c.encodeDisconnect(disconnect); err == nil { 814 _ = c.transportEnqueue(reply) 815 } 816 } 817 818 // close writer and send messages remaining in writer queue if any. 819 _ = c.messageWriter.close() 820 821 _ = c.transport.Close(disconnect) 822 823 if disconnect != nil && disconnect.Reason != "" { 824 c.node.logger.log(newLogEntry(LogLevelDebug, "closing client connection", map[string]interface{}{"client": c.uid, "user": c.user, "reason": disconnect.Reason, "reconnect": disconnect.Reconnect})) 825 } 826 if disconnect != nil { 827 incServerDisconnect(disconnect.Code) 828 } 829 if c.eventHub.disconnectHandler != nil && prevStatus == statusConnected { 830 c.eventHub.disconnectHandler(DisconnectEvent{ 831 Disconnect: disconnect, 832 }) 833 } 834 return nil 835} 836 837func (c *Client) trace(msg string, data []byte) { 838 if !c.node.LogEnabled(LogLevelTrace) { 839 return 840 } 841 c.mu.RLock() 842 user := c.user 843 c.mu.RUnlock() 844 c.node.logger.log(newLogEntry(LogLevelTrace, msg, map[string]interface{}{"client": c.ID(), "user": user, "data": fmt.Sprintf("%#v", string(data))})) 845} 846 847// Lock must be held outside. 848func (c *Client) clientInfo(ch string) *ClientInfo { 849 var channelInfo protocol.Raw 850 channelContext, ok := c.channels[ch] 851 if ok && channelHasFlag(channelContext.flags, flagSubscribed) { 852 channelInfo = channelContext.Info 853 } 854 return &ClientInfo{ 855 ClientID: c.uid, 856 UserID: c.user, 857 ConnInfo: c.info, 858 ChanInfo: channelInfo, 859 } 860} 861 862// Handle raw data encoded with Centrifuge protocol. 863// Not goroutine-safe. Supposed to be called only from a transport connection reader. 864func (c *Client) Handle(data []byte) bool { 865 c.mu.Lock() 866 if c.status == statusClosed { 867 c.mu.Unlock() 868 return false 869 } 870 c.mu.Unlock() 871 872 if c.transport.Unidirectional() { 873 c.node.logger.log(newLogEntry(LogLevelInfo, "can't handle data for unidirectional client", map[string]interface{}{"client": c.ID(), "user": c.UserID()})) 874 go func() { _ = c.close(DisconnectBadRequest) }() 875 return false 876 } 877 878 if len(data) == 0 { 879 c.node.logger.log(newLogEntry(LogLevelInfo, "empty client request received", map[string]interface{}{"client": c.ID(), "user": c.UserID()})) 880 go func() { _ = c.close(DisconnectBadRequest) }() 881 return false 882 } 883 884 c.trace("<--", data) 885 886 protoType := c.transport.Protocol().toProto() 887 decoder := protocol.GetCommandDecoder(protoType, data) 888 defer protocol.PutCommandDecoder(protoType, decoder) 889 890 for { 891 cmd, err := decoder.Decode() 892 if err != nil && err != io.EOF { 893 c.node.logger.log(newLogEntry(LogLevelInfo, "error decoding command", map[string]interface{}{"data": string(data), "client": c.ID(), "user": c.UserID(), "error": err.Error()})) 894 go func() { _ = c.close(DisconnectBadRequest) }() 895 return false 896 } 897 if cmd != nil { 898 ok := c.handleCommand(cmd) 899 if !ok { 900 return false 901 } 902 } 903 if err == io.EOF { 904 break 905 } 906 } 907 return true 908} 909 910// handleCommand processes a single protocol.Command. 911func (c *Client) handleCommand(cmd *protocol.Command) bool { 912 if cmd.Method != protocol.Command_CONNECT && !c.authenticated { 913 // Client must send connect command to authenticate itself first. 914 c.node.logger.log(newLogEntry(LogLevelInfo, "client not authenticated to handle command", map[string]interface{}{"client": c.ID(), "user": c.UserID(), "command": fmt.Sprintf("%v", cmd)})) 915 go func() { _ = c.close(DisconnectBadRequest) }() 916 return false 917 } 918 919 if cmd.Id == 0 && cmd.Method != protocol.Command_SEND { 920 // Only send command from client can be sent without incremental ID. 921 c.node.logger.log(newLogEntry(LogLevelInfo, "command ID required for commands with reply expected", map[string]interface{}{"client": c.ID(), "user": c.UserID()})) 922 go func() { _ = c.close(DisconnectBadRequest) }() 923 return false 924 } 925 926 select { 927 case <-c.ctx.Done(): 928 return false 929 default: 930 } 931 932 disconnect := c.dispatchCommand(cmd) 933 934 select { 935 case <-c.ctx.Done(): 936 return false 937 default: 938 } 939 if disconnect != nil { 940 if disconnect != DisconnectNormal { 941 c.node.logger.log(newLogEntry(LogLevelInfo, "disconnect after handling command", map[string]interface{}{"command": fmt.Sprintf("%v", cmd), "client": c.ID(), "user": c.UserID(), "reason": disconnect.Reason})) 942 } 943 go func() { _ = c.close(disconnect) }() 944 return false 945 } 946 return true 947} 948 949type replyWriter struct { 950 write func(*protocol.Reply) error 951 done func() 952} 953 954// dispatchCommand dispatches Command into correct command handler. 955func (c *Client) dispatchCommand(cmd *protocol.Command) *Disconnect { 956 c.mu.Lock() 957 if c.status == statusClosed { 958 c.mu.Unlock() 959 return nil 960 } 961 c.mu.Unlock() 962 963 method := cmd.Method 964 params := cmd.Params 965 966 protoType := c.transport.Protocol().toProto() 967 replyEncoder := protocol.GetReplyEncoder(protoType) 968 969 var encodeErr error 970 971 started := time.Now() 972 973 write := func(rep *protocol.Reply) error { 974 rep.Id = cmd.Id 975 if rep.Error != nil { 976 if c.node.LogEnabled(LogLevelInfo) { 977 c.node.logger.log(newLogEntry(LogLevelInfo, "client command error", map[string]interface{}{"reply": fmt.Sprintf("%v", rep), "command": fmt.Sprintf("%v", cmd), "client": c.ID(), "user": c.UserID(), "error": rep.Error.Message, "code": rep.Error.Code})) 978 } 979 incReplyError(cmd.Method, rep.Error.Code) 980 } 981 982 var replyData []byte 983 replyData, encodeErr = replyEncoder.Encode(rep) 984 if encodeErr != nil { 985 c.node.logger.log(newLogEntry(LogLevelError, "error encoding reply", map[string]interface{}{"reply": fmt.Sprintf("%v", rep), "client": c.ID(), "user": c.UserID(), "error": encodeErr.Error()})) 986 return encodeErr 987 } 988 c.trace("-->", replyData) 989 disconnect := c.messageWriter.enqueue(queue.Item{Data: replyData, IsPush: false}) 990 if disconnect != nil { 991 if c.node.logger.enabled(LogLevelDebug) { 992 c.node.logger.log(newLogEntry(LogLevelDebug, "disconnect after sending reply", map[string]interface{}{"client": c.ID(), "user": c.UserID(), "reason": disconnect.Reason})) 993 } 994 go func() { _ = c.close(disconnect) }() 995 } 996 return disconnect 997 } 998 999 // done should be called after command fully processed. 1000 done := func() { 1001 observeCommandDuration(method, time.Since(started)) 1002 } 1003 1004 // The rule is as follows: if command handler returns an 1005 // error then we handle it here: write error into connection 1006 // or return disconnect further to caller and call rw.done() 1007 // in the end. 1008 // If handler returned nil error then we assume that all 1009 // rw operations will be executed inside handler itself. 1010 rw := &replyWriter{write, done} 1011 1012 var handleErr error 1013 1014 switch method { 1015 case protocol.Command_CONNECT: 1016 handleErr = c.handleConnect(params, rw) 1017 case protocol.Command_PING: 1018 handleErr = c.handlePing(params, rw) 1019 case protocol.Command_SUBSCRIBE: 1020 handleErr = c.handleSubscribe(params, rw) 1021 case protocol.Command_UNSUBSCRIBE: 1022 handleErr = c.handleUnsubscribe(params, rw) 1023 case protocol.Command_PUBLISH: 1024 handleErr = c.handlePublish(params, rw) 1025 case protocol.Command_PRESENCE: 1026 handleErr = c.handlePresence(params, rw) 1027 case protocol.Command_PRESENCE_STATS: 1028 handleErr = c.handlePresenceStats(params, rw) 1029 case protocol.Command_HISTORY: 1030 handleErr = c.handleHistory(params, rw) 1031 case protocol.Command_RPC: 1032 handleErr = c.handleRPC(params, rw) 1033 case protocol.Command_SEND: 1034 handleErr = c.handleSend(params, rw) 1035 case protocol.Command_REFRESH: 1036 handleErr = c.handleRefresh(params, rw) 1037 case protocol.Command_SUB_REFRESH: 1038 handleErr = c.handleSubRefresh(params, rw) 1039 default: 1040 handleErr = ErrorMethodNotFound 1041 } 1042 if encodeErr != nil { 1043 return DisconnectServerError 1044 } 1045 if handleErr != nil { 1046 defer rw.done() 1047 switch t := handleErr.(type) { 1048 case *Disconnect: 1049 return t 1050 default: 1051 c.writeError(rw, toClientErr(handleErr)) 1052 } 1053 } 1054 return nil 1055} 1056 1057func (c *Client) checkExpired() { 1058 c.mu.RLock() 1059 closed := c.status == statusClosed 1060 clientSideRefresh := c.clientSideRefresh 1061 exp := c.exp 1062 c.mu.RUnlock() 1063 if closed || exp == 0 { 1064 return 1065 } 1066 now := time.Now().Unix() 1067 ttl := exp - now 1068 1069 if !clientSideRefresh && c.eventHub.refreshHandler != nil { 1070 if ttl > 0 { 1071 c.mu.Lock() 1072 if c.status != statusClosed { 1073 c.addExpireUpdate(time.Duration(ttl) * time.Second) 1074 } 1075 c.mu.Unlock() 1076 } 1077 } 1078 1079 if ttl > 0 { 1080 // Connection was successfully refreshed. 1081 return 1082 } 1083 1084 _ = c.close(DisconnectExpired) 1085} 1086 1087func (c *Client) expire() { 1088 c.mu.RLock() 1089 closed := c.status == statusClosed 1090 clientSideRefresh := c.clientSideRefresh 1091 exp := c.exp 1092 c.mu.RUnlock() 1093 if closed || exp == 0 { 1094 return 1095 } 1096 if !clientSideRefresh && c.eventHub.refreshHandler != nil { 1097 cb := func(reply RefreshReply, err error) { 1098 if err != nil { 1099 switch t := err.(type) { 1100 case *Disconnect: 1101 _ = c.close(t) 1102 return 1103 default: 1104 _ = c.close(DisconnectServerError) 1105 return 1106 } 1107 } 1108 if reply.Expired { 1109 _ = c.close(DisconnectExpired) 1110 return 1111 } 1112 if reply.ExpireAt > 0 { 1113 c.mu.Lock() 1114 c.exp = reply.ExpireAt 1115 if reply.Info != nil { 1116 c.info = reply.Info 1117 } 1118 c.mu.Unlock() 1119 } 1120 c.checkExpired() 1121 } 1122 c.eventHub.refreshHandler(RefreshEvent{}, cb) 1123 } else { 1124 c.checkExpired() 1125 } 1126} 1127 1128func (c *Client) handleConnect(params protocol.Raw, rw *replyWriter) error { 1129 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeConnect(params) 1130 if err != nil { 1131 return c.logDisconnectBadRequestWithError(err, "error decoding connect") 1132 } 1133 _, disconnect := c.connectCmd(cmd, rw) 1134 if disconnect != nil { 1135 return disconnect 1136 } 1137 c.triggerConnect() 1138 c.scheduleOnConnectTimers() 1139 return nil 1140} 1141 1142func (c *Client) triggerConnect() { 1143 c.connectMu.Lock() 1144 defer c.connectMu.Unlock() 1145 if c.status != statusConnecting { 1146 return 1147 } 1148 if c.node.clientEvents.connectHandler == nil { 1149 c.status = statusConnected 1150 return 1151 } 1152 c.node.clientEvents.connectHandler(c) 1153 c.status = statusConnected 1154} 1155 1156func (c *Client) scheduleOnConnectTimers() { 1157 // Make presence and refresh handlers always run after client connect event. 1158 c.mu.Lock() 1159 c.addPresenceUpdate() 1160 if c.exp > 0 { 1161 expireAfter := time.Duration(c.exp-time.Now().Unix()) * time.Second 1162 if c.clientSideRefresh { 1163 conf := c.node.config 1164 expireAfter += conf.ClientExpiredCloseDelay 1165 } 1166 c.addExpireUpdate(expireAfter) 1167 } 1168 c.mu.Unlock() 1169} 1170 1171func (c *Client) Refresh(opts ...RefreshOption) error { 1172 refreshOptions := &RefreshOptions{} 1173 for _, opt := range opts { 1174 opt(refreshOptions) 1175 } 1176 if refreshOptions.Expired { 1177 go func() { _ = c.close(DisconnectExpired) }() 1178 return nil 1179 } 1180 1181 expireAt := refreshOptions.ExpireAt 1182 info := refreshOptions.Info 1183 1184 res := &protocol.Refresh{ 1185 Expires: expireAt > 0, 1186 } 1187 1188 ttl := expireAt - time.Now().Unix() 1189 1190 if ttl > 0 { 1191 res.Ttl = uint32(ttl) 1192 } 1193 1194 if expireAt > 0 { 1195 // connection check enabled 1196 if ttl > 0 { 1197 // connection refreshed, update client timestamp and set new expiration timeout 1198 c.mu.Lock() 1199 c.exp = expireAt 1200 if len(info) > 0 { 1201 c.info = info 1202 } 1203 duration := time.Duration(ttl)*time.Second + c.node.config.ClientExpiredCloseDelay 1204 c.addExpireUpdate(duration) 1205 c.mu.Unlock() 1206 } else { 1207 go func() { _ = c.close(DisconnectExpired) }() 1208 return nil 1209 } 1210 } else { 1211 c.mu.Lock() 1212 c.exp = 0 1213 c.mu.Unlock() 1214 } 1215 1216 pushBytes, err := protocol.EncodeRefreshPush(c.transport.Protocol().toProto(), res) 1217 if err != nil { 1218 return err 1219 } 1220 reply := prepared.NewReply(&protocol.Reply{ 1221 Result: pushBytes, 1222 }, c.transport.Protocol().toProto()) 1223 1224 return c.transportEnqueue(reply) 1225} 1226 1227func (c *Client) handleRefresh(params protocol.Raw, rw *replyWriter) error { 1228 if c.eventHub.refreshHandler == nil { 1229 return ErrorNotAvailable 1230 } 1231 1232 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeRefresh(params) 1233 if err != nil { 1234 return c.logDisconnectBadRequestWithError(err, "error decoding refresh") 1235 } 1236 1237 if cmd.Token == "" { 1238 return c.logDisconnectBadRequest("client token required to refresh") 1239 } 1240 1241 c.mu.RLock() 1242 clientSideRefresh := c.clientSideRefresh 1243 c.mu.RUnlock() 1244 1245 if !clientSideRefresh { 1246 // Client not supposed to send refresh command in case of server-side refresh mechanism. 1247 return c.logDisconnectBadRequest("server-side refresh expected") 1248 } 1249 1250 event := RefreshEvent{ 1251 ClientSideRefresh: true, 1252 Token: cmd.Token, 1253 } 1254 1255 cb := func(reply RefreshReply, err error) { 1256 defer rw.done() 1257 1258 if err != nil { 1259 c.writeDisconnectOrErrorFlush(rw, err) 1260 return 1261 } 1262 1263 if reply.Expired { 1264 c.Disconnect(DisconnectExpired) 1265 return 1266 } 1267 1268 expireAt := reply.ExpireAt 1269 info := reply.Info 1270 1271 res := &protocol.RefreshResult{ 1272 Version: c.node.config.Version, 1273 Expires: expireAt > 0, 1274 Client: c.uid, 1275 } 1276 1277 ttl := expireAt - time.Now().Unix() 1278 1279 if ttl > 0 { 1280 res.Ttl = uint32(ttl) 1281 } 1282 1283 if expireAt > 0 { 1284 // connection check enabled 1285 if ttl > 0 { 1286 // connection refreshed, update client timestamp and set new expiration timeout 1287 c.mu.Lock() 1288 c.exp = expireAt 1289 if len(info) > 0 { 1290 c.info = info 1291 } 1292 duration := time.Duration(ttl)*time.Second + c.node.config.ClientExpiredCloseDelay 1293 c.addExpireUpdate(duration) 1294 c.mu.Unlock() 1295 } else { 1296 c.writeError(rw, ErrorExpired) 1297 return 1298 } 1299 } 1300 1301 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeRefreshResult(res) 1302 if err != nil { 1303 c.logWriteInternalErrorFlush(rw, err, "error encoding refresh") 1304 return 1305 } 1306 1307 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1308 } 1309 1310 c.eventHub.refreshHandler(event, cb) 1311 return nil 1312} 1313 1314// onSubscribeError cleans up a channel from client channels if an error during subscribe happened. 1315// Channel kept in a map during subscribe request to check for duplicate subscription attempts. 1316func (c *Client) onSubscribeError(channel string) { 1317 c.mu.Lock() 1318 _, ok := c.channels[channel] 1319 delete(c.channels, channel) 1320 c.mu.Unlock() 1321 if ok { 1322 _ = c.node.removeSubscription(channel, c) 1323 } 1324} 1325 1326func (c *Client) handleSubscribe(params protocol.Raw, rw *replyWriter) error { 1327 if c.eventHub.subscribeHandler == nil { 1328 return ErrorNotAvailable 1329 } 1330 1331 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeSubscribe(params) 1332 if err != nil { 1333 return c.logDisconnectBadRequestWithError(err, "error decoding subscribe") 1334 } 1335 1336 replyError, disconnect := c.validateSubscribeRequest(cmd) 1337 if disconnect != nil || replyError != nil { 1338 if disconnect != nil { 1339 return disconnect 1340 } 1341 return replyError 1342 } 1343 1344 event := SubscribeEvent{ 1345 Channel: cmd.Channel, 1346 Token: cmd.Token, 1347 } 1348 1349 cb := func(reply SubscribeReply, err error) { 1350 defer rw.done() 1351 1352 if err != nil { 1353 c.onSubscribeError(cmd.Channel) 1354 c.writeDisconnectOrErrorFlush(rw, err) 1355 return 1356 } 1357 1358 ctx := c.subscribeCmd(cmd, reply, rw, false) 1359 1360 if ctx.disconnect != nil { 1361 c.onSubscribeError(cmd.Channel) 1362 c.Disconnect(ctx.disconnect) 1363 return 1364 } 1365 if ctx.err != nil { 1366 c.onSubscribeError(cmd.Channel) 1367 c.writeDisconnectOrErrorFlush(rw, ctx.err) 1368 return 1369 } 1370 1371 if channelHasFlag(ctx.channelContext.flags, flagJoinLeave) && ctx.clientInfo != nil { 1372 go func() { _ = c.node.publishJoin(cmd.Channel, ctx.clientInfo) }() 1373 } 1374 } 1375 c.eventHub.subscribeHandler(event, cb) 1376 return nil 1377} 1378 1379func (c *Client) getSubscribedChannelContext(channel string) (channelContext, bool) { 1380 c.mu.RLock() 1381 ctx, okChannel := c.channels[channel] 1382 c.mu.RUnlock() 1383 if !okChannel || !channelHasFlag(ctx.flags, flagSubscribed) { 1384 return channelContext{}, false 1385 } 1386 return ctx, true 1387} 1388 1389func (c *Client) handleSubRefresh(params protocol.Raw, rw *replyWriter) error { 1390 if c.eventHub.subRefreshHandler == nil { 1391 return ErrorNotAvailable 1392 } 1393 1394 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeSubRefresh(params) 1395 if err != nil { 1396 return c.logDisconnectBadRequestWithError(err, "error decoding sub refresh") 1397 } 1398 1399 channel := cmd.Channel 1400 if channel == "" { 1401 return c.logDisconnectBadRequest("channel required for sub refresh") 1402 } 1403 1404 ctx, okChannel := c.getSubscribedChannelContext(channel) 1405 if !okChannel { 1406 // Must be subscribed to refresh subscription. 1407 return ErrorPermissionDenied 1408 } 1409 1410 clientSideRefresh := channelHasFlag(ctx.flags, flagClientSideRefresh) 1411 if !clientSideRefresh { 1412 // Client not supposed to send sub refresh command in case of server-side 1413 // subscription refresh mechanism. 1414 return c.logDisconnectBadRequest("server-side sub refresh expected") 1415 } 1416 1417 if cmd.Token == "" { 1418 c.node.logger.log(newLogEntry(LogLevelInfo, "subscription refresh token required", map[string]interface{}{"client": c.uid, "user": c.UserID()})) 1419 return ErrorBadRequest 1420 } 1421 1422 event := SubRefreshEvent{ 1423 ClientSideRefresh: true, 1424 Channel: cmd.Channel, 1425 Token: cmd.Token, 1426 } 1427 1428 cb := func(reply SubRefreshReply, err error) { 1429 defer rw.done() 1430 1431 if err != nil { 1432 c.writeDisconnectOrErrorFlush(rw, err) 1433 return 1434 } 1435 1436 res := &protocol.SubRefreshResult{} 1437 1438 if reply.ExpireAt > 0 { 1439 res.Expires = true 1440 now := time.Now().Unix() 1441 if reply.ExpireAt < now { 1442 c.writeError(rw, ErrorExpired) 1443 return 1444 } 1445 res.Ttl = uint32(reply.ExpireAt - now) 1446 } 1447 1448 c.mu.Lock() 1449 channelContext, okChan := c.channels[channel] 1450 if okChan && channelHasFlag(channelContext.flags, flagSubscribed) { 1451 channelContext.Info = reply.Info 1452 channelContext.expireAt = reply.ExpireAt 1453 c.channels[channel] = channelContext 1454 } 1455 c.mu.Unlock() 1456 1457 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeSubRefreshResult(res) 1458 if err != nil { 1459 c.logWriteInternalErrorFlush(rw, err, "error encoding sub refresh") 1460 return 1461 } 1462 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1463 } 1464 1465 c.eventHub.subRefreshHandler(event, cb) 1466 return nil 1467} 1468 1469func (c *Client) handleUnsubscribe(params protocol.Raw, rw *replyWriter) error { 1470 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeUnsubscribe(params) 1471 if err != nil { 1472 return c.logDisconnectBadRequestWithError(err, "error decoding unsubscribe") 1473 } 1474 1475 channel := cmd.Channel 1476 if channel == "" { 1477 return c.logDisconnectBadRequest("channel required for unsubscribe") 1478 } 1479 1480 if err := c.unsubscribe(channel); err != nil { 1481 return err 1482 } 1483 1484 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeUnsubscribeResult(&protocol.UnsubscribeResult{}) 1485 if err != nil { 1486 c.node.logger.log(newLogEntry(LogLevelError, "error encoding unsubscribe", map[string]interface{}{"error": err.Error()})) 1487 return DisconnectServerError 1488 } 1489 1490 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1491 rw.done() 1492 return nil 1493} 1494 1495func (c *Client) handlePublish(params protocol.Raw, rw *replyWriter) error { 1496 if c.eventHub.publishHandler == nil { 1497 return ErrorNotAvailable 1498 } 1499 1500 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePublish(params) 1501 if err != nil { 1502 return c.logDisconnectBadRequestWithError(err, "error decoding publish") 1503 } 1504 1505 channel := cmd.Channel 1506 data := cmd.Data 1507 1508 if channel == "" || len(data) == 0 { 1509 return c.logDisconnectBadRequest("channel and data required for publish") 1510 } 1511 1512 c.mu.RLock() 1513 info := c.clientInfo(channel) 1514 c.mu.RUnlock() 1515 1516 event := PublishEvent{ 1517 Channel: channel, 1518 Data: data, 1519 ClientInfo: info, 1520 } 1521 1522 cb := func(reply PublishReply, err error) { 1523 defer rw.done() 1524 1525 if err != nil { 1526 c.writeDisconnectOrErrorFlush(rw, err) 1527 return 1528 } 1529 1530 if reply.Result == nil { 1531 _, err := c.node.Publish( 1532 event.Channel, event.Data, 1533 WithHistory(reply.Options.HistorySize, reply.Options.HistoryTTL), 1534 WithClientInfo(reply.Options.ClientInfo), 1535 ) 1536 if err != nil { 1537 c.logWriteInternalErrorFlush(rw, err, "error publish") 1538 return 1539 } 1540 } 1541 1542 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodePublishResult(&protocol.PublishResult{}) 1543 if err != nil { 1544 c.logWriteInternalErrorFlush(rw, err, "error encoding publish") 1545 return 1546 } 1547 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1548 } 1549 1550 c.eventHub.publishHandler(event, cb) 1551 return nil 1552} 1553 1554func (c *Client) handlePresence(params protocol.Raw, rw *replyWriter) error { 1555 if c.eventHub.presenceHandler == nil { 1556 return ErrorNotAvailable 1557 } 1558 1559 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePresence(params) 1560 if err != nil { 1561 return c.logDisconnectBadRequestWithError(err, "error decoding presence") 1562 } 1563 1564 channel := cmd.Channel 1565 if channel == "" { 1566 return c.logDisconnectBadRequest("channel required for presence") 1567 } 1568 1569 event := PresenceEvent{ 1570 Channel: channel, 1571 } 1572 1573 cb := func(reply PresenceReply, err error) { 1574 defer rw.done() 1575 if err != nil { 1576 c.writeDisconnectOrErrorFlush(rw, err) 1577 return 1578 } 1579 1580 var presence map[string]*ClientInfo 1581 if reply.Result == nil { 1582 result, err := c.node.Presence(event.Channel) 1583 if err != nil { 1584 c.logWriteInternalErrorFlush(rw, err, "error getting presence") 1585 return 1586 } 1587 presence = result.Presence 1588 } else { 1589 presence = reply.Result.Presence 1590 } 1591 1592 protoPresence := make(map[string]*protocol.ClientInfo, len(presence)) 1593 for k, v := range presence { 1594 protoPresence[k] = infoToProto(v) 1595 } 1596 1597 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodePresenceResult(&protocol.PresenceResult{ 1598 Presence: protoPresence, 1599 }) 1600 if err != nil { 1601 c.logWriteInternalErrorFlush(rw, err, "error encoding presence") 1602 return 1603 } 1604 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1605 } 1606 1607 c.eventHub.presenceHandler(event, cb) 1608 return nil 1609} 1610 1611func (c *Client) handlePresenceStats(params protocol.Raw, rw *replyWriter) error { 1612 if c.eventHub.presenceStatsHandler == nil { 1613 return ErrorNotAvailable 1614 } 1615 1616 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePresenceStats(params) 1617 if err != nil { 1618 return c.logDisconnectBadRequestWithError(err, "error decoding presence stats") 1619 } 1620 1621 channel := cmd.Channel 1622 if channel == "" { 1623 return c.logDisconnectBadRequest("channel required for presence stats") 1624 } 1625 1626 event := PresenceStatsEvent{ 1627 Channel: channel, 1628 } 1629 1630 cb := func(reply PresenceStatsReply, err error) { 1631 defer rw.done() 1632 if err != nil { 1633 c.writeDisconnectOrErrorFlush(rw, err) 1634 return 1635 } 1636 1637 var presenceStats PresenceStats 1638 if reply.Result == nil { 1639 result, err := c.node.PresenceStats(event.Channel) 1640 if err != nil { 1641 c.logWriteInternalErrorFlush(rw, err, "error getting presence stats") 1642 return 1643 } 1644 presenceStats = result.PresenceStats 1645 } else { 1646 presenceStats = reply.Result.PresenceStats 1647 } 1648 1649 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodePresenceStatsResult(&protocol.PresenceStatsResult{ 1650 NumClients: uint32(presenceStats.NumClients), 1651 NumUsers: uint32(presenceStats.NumUsers), 1652 }) 1653 if err != nil { 1654 c.logWriteInternalErrorFlush(rw, err, "error encoding presence stats") 1655 return 1656 } 1657 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1658 } 1659 1660 c.eventHub.presenceStatsHandler(event, cb) 1661 return nil 1662} 1663 1664func (c *Client) handleHistory(params protocol.Raw, rw *replyWriter) error { 1665 if c.eventHub.historyHandler == nil { 1666 return ErrorNotAvailable 1667 } 1668 1669 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeHistory(params) 1670 if err != nil { 1671 return c.logDisconnectBadRequestWithError(err, "error decoding history") 1672 } 1673 1674 channel := cmd.Channel 1675 if channel == "" { 1676 return c.logDisconnectBadRequest("channel required for history") 1677 } 1678 1679 var filter HistoryFilter 1680 if cmd.Since != nil { 1681 filter.Since = &StreamPosition{ 1682 Offset: cmd.Since.Offset, 1683 Epoch: cmd.Since.Epoch, 1684 } 1685 } 1686 filter.Limit = int(cmd.Limit) 1687 1688 maxPublicationLimit := c.node.config.HistoryMaxPublicationLimit 1689 if maxPublicationLimit > 0 && (filter.Limit < 0 || filter.Limit > maxPublicationLimit) { 1690 filter.Limit = maxPublicationLimit 1691 } 1692 1693 filter.Reverse = cmd.Reverse 1694 1695 event := HistoryEvent{ 1696 Channel: channel, 1697 Filter: filter, 1698 } 1699 1700 cb := func(reply HistoryReply, err error) { 1701 defer rw.done() 1702 if err != nil { 1703 c.writeDisconnectOrErrorFlush(rw, err) 1704 return 1705 } 1706 1707 var pubs []*Publication 1708 var offset uint64 1709 var epoch string 1710 if reply.Result == nil { 1711 result, err := c.node.History(event.Channel, WithLimit(event.Filter.Limit), WithSince(event.Filter.Since), WithReverse(event.Filter.Reverse)) 1712 if err != nil { 1713 c.logWriteInternalErrorFlush(rw, err, "error getting history") 1714 return 1715 } 1716 pubs = result.Publications 1717 offset = result.Offset 1718 epoch = result.Epoch 1719 } else { 1720 pubs = reply.Result.Publications 1721 offset = reply.Result.Offset 1722 epoch = reply.Result.Epoch 1723 } 1724 1725 protoPubs := make([]*protocol.Publication, 0, len(pubs)) 1726 for _, pub := range pubs { 1727 protoPub := pubToProto(pub) 1728 protoPubs = append(protoPubs, protoPub) 1729 } 1730 1731 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeHistoryResult(&protocol.HistoryResult{ 1732 Publications: protoPubs, 1733 Offset: offset, 1734 Epoch: epoch, 1735 }) 1736 if err != nil { 1737 c.logWriteInternalErrorFlush(rw, err, "error encoding history") 1738 return 1739 } 1740 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1741 } 1742 1743 c.eventHub.historyHandler(event, cb) 1744 return nil 1745} 1746 1747func (c *Client) handlePing(params protocol.Raw, rw *replyWriter) error { 1748 _, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodePing(params) 1749 if err != nil { 1750 return c.logDisconnectBadRequestWithError(err, "error decoding ping") 1751 } 1752 _ = writeReply(rw, &protocol.Reply{}) 1753 defer rw.done() 1754 return nil 1755} 1756 1757func (c *Client) writeError(rw *replyWriter, error *Error) { 1758 _ = rw.write(&protocol.Reply{Error: error.toProto()}) 1759} 1760 1761func (c *Client) writeDisconnectOrErrorFlush(rw *replyWriter, replyError error) { 1762 switch t := replyError.(type) { 1763 case *Disconnect: 1764 go func() { _ = c.close(t) }() 1765 return 1766 default: 1767 c.writeError(rw, toClientErr(replyError)) 1768 } 1769} 1770 1771func writeReply(rw *replyWriter, reply *protocol.Reply) error { 1772 return rw.write(reply) 1773} 1774 1775func (c *Client) handleRPC(params protocol.Raw, rw *replyWriter) error { 1776 if c.eventHub.rpcHandler == nil { 1777 return ErrorNotAvailable 1778 } 1779 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeRPC(params) 1780 if err != nil { 1781 return c.logDisconnectBadRequestWithError(err, "error decoding rpc") 1782 } 1783 1784 event := RPCEvent{ 1785 Method: cmd.Method, 1786 Data: cmd.Data, 1787 } 1788 1789 cb := func(reply RPCReply, err error) { 1790 defer rw.done() 1791 if err != nil { 1792 c.writeDisconnectOrErrorFlush(rw, err) 1793 return 1794 } 1795 result := &protocol.RPCResult{ 1796 Data: reply.Data, 1797 } 1798 var replyRes []byte 1799 replyRes, err = protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeRPCResult(result) 1800 if err != nil { 1801 c.logWriteInternalErrorFlush(rw, err, "error encoding rpc") 1802 return 1803 } 1804 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 1805 } 1806 1807 c.eventHub.rpcHandler(event, cb) 1808 return nil 1809} 1810 1811func (c *Client) handleSend(params protocol.Raw, rw *replyWriter) error { 1812 if c.eventHub.messageHandler == nil { 1813 // send handler is a bit special since it is only one way 1814 // request: client does not expect any reply. 1815 rw.done() 1816 return nil 1817 } 1818 cmd, err := protocol.GetParamsDecoder(c.transport.Protocol().toProto()).DecodeSend(params) 1819 if err != nil { 1820 return c.logDisconnectBadRequestWithError(err, "error decoding message") 1821 } 1822 defer rw.done() 1823 c.eventHub.messageHandler(MessageEvent{ 1824 Data: cmd.Data, 1825 }) 1826 return nil 1827} 1828 1829func (c *Client) unlockServerSideSubscriptions(subCtxMap map[string]subscribeContext) { 1830 for channel := range subCtxMap { 1831 c.pubSubSync.StopBuffering(channel) 1832 } 1833} 1834 1835// connectCmd handles connect command from client - client must send connect 1836// command immediately after establishing connection with server. 1837func (c *Client) connectCmd(cmd *protocol.ConnectRequest, rw *replyWriter) (*protocol.ConnectResult, error) { 1838 c.mu.RLock() 1839 authenticated := c.authenticated 1840 closed := c.status == statusClosed 1841 c.mu.RUnlock() 1842 1843 if closed { 1844 return nil, DisconnectNormal 1845 } 1846 1847 if authenticated { 1848 return nil, c.logDisconnectBadRequest("client already authenticated") 1849 } 1850 1851 config := c.node.config 1852 version := config.Version 1853 userConnectionLimit := config.UserConnectionLimit 1854 channelLimit := config.ClientChannelLimit 1855 1856 var ( 1857 credentials *Credentials 1858 authData protocol.Raw 1859 subscriptions map[string]SubscribeOptions 1860 clientSideRefresh bool 1861 ) 1862 1863 if c.node.clientEvents.connectingHandler != nil { 1864 e := ConnectEvent{ 1865 ClientID: c.ID(), 1866 Data: cmd.Data, 1867 Token: cmd.Token, 1868 Name: cmd.Name, 1869 Version: cmd.Version, 1870 Transport: c.transport, 1871 } 1872 if len(cmd.Subs) > 0 { 1873 channels := make([]string, 0, len(cmd.Subs)) 1874 for ch := range cmd.Subs { 1875 channels = append(channels, ch) 1876 } 1877 e.Channels = channels 1878 } 1879 reply, err := c.node.clientEvents.connectingHandler(c.ctx, e) 1880 if err != nil { 1881 return nil, err 1882 } 1883 if reply.Credentials != nil { 1884 credentials = reply.Credentials 1885 } 1886 if reply.Context != nil { 1887 c.mu.Lock() 1888 c.ctx = reply.Context 1889 c.mu.Unlock() 1890 } 1891 if reply.Data != nil { 1892 authData = reply.Data 1893 } 1894 clientSideRefresh = reply.ClientSideRefresh 1895 if len(reply.Subscriptions) > 0 { 1896 subscriptions = make(map[string]SubscribeOptions, len(reply.Subscriptions)) 1897 for ch, opts := range reply.Subscriptions { 1898 if ch == "" { 1899 continue 1900 } 1901 subscriptions[ch] = opts 1902 } 1903 } 1904 } 1905 1906 if channelLimit > 0 && len(subscriptions) > channelLimit { 1907 return nil, DisconnectChannelLimit 1908 } 1909 1910 if credentials == nil { 1911 // Try to find Credentials in context. 1912 if cred, ok := GetCredentials(c.ctx); ok { 1913 credentials = cred 1914 } 1915 } 1916 1917 var ( 1918 expires bool 1919 ttl uint32 1920 ) 1921 1922 c.mu.Lock() 1923 c.clientSideRefresh = clientSideRefresh 1924 c.mu.Unlock() 1925 1926 if credentials == nil { 1927 return nil, c.logDisconnectBadRequest("client credentials not found") 1928 } 1929 1930 c.mu.Lock() 1931 c.user = credentials.UserID 1932 c.info = credentials.Info 1933 c.exp = credentials.ExpireAt 1934 1935 user := c.user 1936 exp := c.exp 1937 closed = c.status == statusClosed 1938 c.mu.Unlock() 1939 1940 if closed { 1941 return nil, DisconnectNormal 1942 } 1943 1944 c.node.logger.log(newLogEntry(LogLevelDebug, "client authenticated", map[string]interface{}{"client": c.uid, "user": c.user})) 1945 1946 if userConnectionLimit > 0 && user != "" && len(c.node.hub.UserConnections(user)) >= userConnectionLimit { 1947 c.node.logger.log(newLogEntry(LogLevelInfo, "limit of connections for user reached", map[string]interface{}{"user": user, "client": c.uid, "limit": userConnectionLimit})) 1948 return nil, DisconnectConnectionLimit 1949 } 1950 1951 c.mu.RLock() 1952 if exp > 0 { 1953 expires = true 1954 now := time.Now().Unix() 1955 if exp < now { 1956 c.mu.RUnlock() 1957 c.node.logger.log(newLogEntry(LogLevelInfo, "connection expiration must be greater than now", map[string]interface{}{"client": c.uid, "user": c.UserID()})) 1958 return nil, ErrorExpired 1959 } 1960 ttl = uint32(exp - now) 1961 } 1962 c.mu.RUnlock() 1963 1964 res := &protocol.ConnectResult{ 1965 Version: version, 1966 Expires: expires, 1967 Ttl: ttl, 1968 } 1969 1970 // Client successfully connected. 1971 c.mu.Lock() 1972 c.authenticated = true 1973 c.mu.Unlock() 1974 1975 err := c.node.addClient(c) 1976 if err != nil { 1977 c.node.logger.log(newLogEntry(LogLevelError, "error adding client", map[string]interface{}{"client": c.uid, "error": err.Error()})) 1978 return nil, DisconnectServerError 1979 } 1980 1981 if !clientSideRefresh { 1982 // Server will do refresh itself. 1983 res.Expires = false 1984 res.Ttl = 0 1985 } 1986 1987 res.Client = c.uid 1988 if authData != nil { 1989 res.Data = authData 1990 } 1991 1992 var subCtxMap map[string]subscribeContext 1993 if len(subscriptions) > 0 { 1994 var subMu sync.Mutex 1995 subCtxMap = make(map[string]subscribeContext, len(subscriptions)) 1996 subs := make(map[string]*protocol.SubscribeResult, len(subscriptions)) 1997 var subDisconnect *Disconnect 1998 var subError *Error 1999 var wg sync.WaitGroup 2000 2001 wg.Add(len(subscriptions)) 2002 for ch, opts := range subscriptions { 2003 go func(ch string, opts SubscribeOptions) { 2004 defer wg.Done() 2005 subCmd := &protocol.SubscribeRequest{ 2006 Channel: ch, 2007 } 2008 if subReq, ok := cmd.Subs[ch]; ok { 2009 subCmd.Recover = subReq.Recover 2010 subCmd.Offset = subReq.Offset 2011 subCmd.Epoch = subReq.Epoch 2012 } 2013 subCtx := c.subscribeCmd(subCmd, SubscribeReply{Options: opts}, rw, true) 2014 subMu.Lock() 2015 subs[ch] = subCtx.result 2016 subCtxMap[ch] = subCtx 2017 if subCtx.disconnect != nil { 2018 subDisconnect = subCtx.disconnect 2019 } 2020 if subCtx.err != nil { 2021 subError = subCtx.err 2022 } 2023 subMu.Unlock() 2024 }(ch, opts) 2025 } 2026 wg.Wait() 2027 2028 if subDisconnect != nil || subError != nil { 2029 c.unlockServerSideSubscriptions(subCtxMap) 2030 for channel := range subCtxMap { 2031 c.onSubscribeError(channel) 2032 } 2033 if subDisconnect != nil { 2034 return nil, subDisconnect 2035 } 2036 return nil, subError 2037 } 2038 res.Subs = subs 2039 } 2040 2041 if c.transport.Unidirectional() { 2042 connectPushBytes, err := c.encodeConnectPush(res) 2043 if err != nil { 2044 c.unlockServerSideSubscriptions(subCtxMap) 2045 c.node.logger.log(newLogEntry(LogLevelError, "error encoding connect", map[string]interface{}{"error": err.Error()})) 2046 return nil, DisconnectServerError 2047 } 2048 _ = writeReply(rw, &protocol.Reply{Result: connectPushBytes}) 2049 defer rw.done() 2050 } else { 2051 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeConnectResult(res) 2052 if err != nil { 2053 c.unlockServerSideSubscriptions(subCtxMap) 2054 c.node.logger.log(newLogEntry(LogLevelError, "error encoding connect", map[string]interface{}{"error": err.Error()})) 2055 return nil, DisconnectServerError 2056 } 2057 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 2058 defer rw.done() 2059 } 2060 2061 c.mu.Lock() 2062 for channel, subCtx := range subCtxMap { 2063 c.channels[channel] = subCtx.channelContext 2064 } 2065 c.mu.Unlock() 2066 2067 c.unlockServerSideSubscriptions(subCtxMap) 2068 2069 if len(subCtxMap) > 0 { 2070 for channel, subCtx := range subCtxMap { 2071 go func(channel string, subCtx subscribeContext) { 2072 if channelHasFlag(subCtx.channelContext.flags, flagJoinLeave) && subCtx.clientInfo != nil { 2073 _ = c.node.publishJoin(channel, subCtx.clientInfo) 2074 } 2075 }(channel, subCtx) 2076 } 2077 } 2078 2079 return res, nil 2080} 2081 2082// Subscribe client to a channel. 2083func (c *Client) Subscribe(channel string, opts ...SubscribeOption) error { 2084 if channel == "" { 2085 return fmt.Errorf("channel is empty") 2086 } 2087 channelLimit := c.node.config.ClientChannelLimit 2088 c.mu.RLock() 2089 numChannels := len(c.channels) 2090 c.mu.RUnlock() 2091 if channelLimit > 0 && numChannels >= channelLimit { 2092 go func() { _ = c.close(DisconnectChannelLimit) }() 2093 return nil 2094 } 2095 2096 subCmd := &protocol.SubscribeRequest{ 2097 Channel: channel, 2098 } 2099 subscribeOpts := &SubscribeOptions{} 2100 for _, opt := range opts { 2101 opt(subscribeOpts) 2102 } 2103 if subscribeOpts.RecoverSince != nil { 2104 subCmd.Recover = true 2105 subCmd.Offset = subscribeOpts.RecoverSince.Offset 2106 subCmd.Epoch = subscribeOpts.RecoverSince.Epoch 2107 } 2108 subCtx := c.subscribeCmd(subCmd, SubscribeReply{ 2109 Options: *subscribeOpts, 2110 }, nil, true) 2111 if subCtx.err != nil { 2112 c.onSubscribeError(subCmd.Channel) 2113 return subCtx.err 2114 } 2115 defer c.pubSubSync.StopBuffering(channel) 2116 c.mu.Lock() 2117 c.channels[channel] = subCtx.channelContext 2118 c.mu.Unlock() 2119 if hasFlag(c.transport.DisabledPushFlags(), PushFlagSubscribe) { 2120 return nil 2121 } 2122 sub := &protocol.Subscribe{ 2123 Offset: subCtx.result.GetOffset(), 2124 Epoch: subCtx.result.GetEpoch(), 2125 Recoverable: subCtx.result.GetRecoverable(), 2126 Positioned: subCtx.result.GetPositioned(), 2127 Data: subCtx.result.Data, 2128 } 2129 pushBytes, err := protocol.EncodeSubscribePush(c.transport.Protocol().toProto(), channel, sub) 2130 if err != nil { 2131 return err 2132 } 2133 reply := prepared.NewReply(&protocol.Reply{ 2134 Result: pushBytes, 2135 }, c.transport.Protocol().toProto()) 2136 return c.transportEnqueue(reply) 2137} 2138 2139func (c *Client) validateSubscribeRequest(cmd *protocol.SubscribeRequest) (*Error, *Disconnect) { 2140 channel := cmd.Channel 2141 if channel == "" { 2142 c.node.logger.log(newLogEntry(LogLevelInfo, "channel required for subscribe", map[string]interface{}{"user": c.user, "client": c.uid})) 2143 return nil, DisconnectBadRequest 2144 } 2145 2146 config := c.node.config 2147 channelMaxLength := config.ChannelMaxLength 2148 channelLimit := config.ClientChannelLimit 2149 2150 if channelMaxLength > 0 && len(channel) > channelMaxLength { 2151 c.node.logger.log(newLogEntry(LogLevelInfo, "channel too long", map[string]interface{}{"max": channelMaxLength, "channel": channel, "user": c.user, "client": c.uid})) 2152 return ErrorBadRequest, nil 2153 } 2154 2155 c.mu.Lock() 2156 numChannels := len(c.channels) 2157 _, ok := c.channels[channel] 2158 if ok { 2159 c.mu.Unlock() 2160 c.node.logger.log(newLogEntry(LogLevelInfo, "client already subscribed on channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid})) 2161 return ErrorAlreadySubscribed, nil 2162 } 2163 if channelLimit > 0 && numChannels >= channelLimit { 2164 c.mu.Unlock() 2165 c.node.logger.log(newLogEntry(LogLevelInfo, "maximum limit of channels per client reached", map[string]interface{}{"limit": channelLimit, "user": c.user, "client": c.uid})) 2166 return ErrorLimitExceeded, nil 2167 } 2168 // Put channel to a map to track duplicate subscriptions. This channel should 2169 // be removed from a map upon an error during subscribe. 2170 c.channels[channel] = channelContext{} 2171 c.mu.Unlock() 2172 2173 return nil, nil 2174} 2175 2176func errorDisconnectContext(replyError *Error, disconnect *Disconnect) subscribeContext { 2177 ctx := subscribeContext{} 2178 if disconnect != nil { 2179 ctx.disconnect = disconnect 2180 return ctx 2181 } 2182 ctx.err = replyError 2183 return ctx 2184} 2185 2186type subscribeContext struct { 2187 result *protocol.SubscribeResult 2188 clientInfo *ClientInfo 2189 err *Error 2190 disconnect *Disconnect 2191 channelContext channelContext 2192} 2193 2194func isRecovered(historyResult HistoryResult, cmdOffset uint64, cmdEpoch string) ([]*protocol.Publication, bool) { 2195 latestOffset := historyResult.Offset 2196 latestEpoch := historyResult.Epoch 2197 2198 recoveredPubs := make([]*protocol.Publication, 0, len(historyResult.Publications)) 2199 for _, pub := range historyResult.Publications { 2200 protoPub := pubToProto(pub) 2201 recoveredPubs = append(recoveredPubs, protoPub) 2202 } 2203 2204 nextOffset := cmdOffset + 1 2205 var recovered bool 2206 if len(recoveredPubs) == 0 { 2207 recovered = latestOffset == cmdOffset && (cmdEpoch == "" || latestEpoch == cmdEpoch) 2208 } else { 2209 recovered = recoveredPubs[0].Offset == nextOffset && 2210 recoveredPubs[len(recoveredPubs)-1].Offset == latestOffset && 2211 (cmdEpoch == "" || latestEpoch == cmdEpoch) 2212 } 2213 2214 return recoveredPubs, recovered 2215} 2216 2217// subscribeCmd handles subscribe command - clients send this when subscribe 2218// on channel, if channel if private then we must validate provided sign here before 2219// actually subscribe client on channel. Optionally we can send missed messages to 2220// client if it provided last message id seen in channel. 2221func (c *Client) subscribeCmd(cmd *protocol.SubscribeRequest, reply SubscribeReply, rw *replyWriter, serverSide bool) subscribeContext { 2222 2223 ctx := subscribeContext{} 2224 res := &protocol.SubscribeResult{} 2225 2226 if reply.Options.ExpireAt > 0 { 2227 ttl := reply.Options.ExpireAt - time.Now().Unix() 2228 if ttl <= 0 { 2229 c.node.logger.log(newLogEntry(LogLevelInfo, "subscription expiration must be greater than now", map[string]interface{}{"client": c.uid, "user": c.UserID()})) 2230 return errorDisconnectContext(ErrorExpired, nil) 2231 } 2232 if reply.ClientSideRefresh { 2233 res.Expires = true 2234 res.Ttl = uint32(ttl) 2235 } 2236 } 2237 2238 if reply.Options.Data != nil { 2239 res.Data = reply.Options.Data 2240 } 2241 2242 channel := cmd.Channel 2243 2244 info := &ClientInfo{ 2245 ClientID: c.uid, 2246 UserID: c.user, 2247 ConnInfo: c.info, 2248 ChanInfo: reply.Options.ChannelInfo, 2249 } 2250 2251 if reply.Options.Recover { 2252 // Start syncing recovery and PUB/SUB. 2253 // The important thing is to call StopBuffering for this channel 2254 // after response with Publications written to connection. 2255 c.pubSubSync.StartBuffering(channel) 2256 } 2257 2258 err := c.node.addSubscription(channel, c) 2259 if err != nil { 2260 c.node.logger.log(newLogEntry(LogLevelError, "error adding subscription", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2261 c.pubSubSync.StopBuffering(channel) 2262 if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal { 2263 return errorDisconnectContext(clientErr, nil) 2264 } 2265 ctx.disconnect = DisconnectServerError 2266 return ctx 2267 } 2268 2269 if reply.Options.Presence { 2270 err = c.node.addPresence(channel, c.uid, info) 2271 if err != nil { 2272 c.node.logger.log(newLogEntry(LogLevelError, "error adding presence", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2273 c.pubSubSync.StopBuffering(channel) 2274 ctx.disconnect = DisconnectServerError 2275 return ctx 2276 } 2277 } 2278 2279 var ( 2280 latestOffset uint64 2281 latestEpoch string 2282 recoveredPubs []*protocol.Publication 2283 ) 2284 2285 if reply.Options.Recover { 2286 res.Recoverable = true 2287 res.Positioned = true // recoverable subscriptions are automatically positioned. 2288 if cmd.Recover { 2289 cmdOffset := cmd.Offset 2290 2291 // Client provided subscribe request with recover flag on. Try to recover missed 2292 // publications automatically from history (we suppose here that history configured wisely). 2293 historyResult, err := c.node.recoverHistory(channel, StreamPosition{cmdOffset, cmd.Epoch}) 2294 if err != nil { 2295 if errors.Is(err, ErrorUnrecoverablePosition) { 2296 // Result contains stream position in case of ErrorUnrecoverablePosition 2297 // during recovery. 2298 latestOffset = historyResult.Offset 2299 latestEpoch = historyResult.Epoch 2300 res.Recovered = false 2301 incRecover(res.Recovered) 2302 } else { 2303 c.node.logger.log(newLogEntry(LogLevelError, "error on recover", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2304 c.pubSubSync.StopBuffering(channel) 2305 if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal { 2306 return errorDisconnectContext(clientErr, nil) 2307 } 2308 ctx.disconnect = DisconnectServerError 2309 return ctx 2310 } 2311 } else { 2312 latestOffset = historyResult.Offset 2313 latestEpoch = historyResult.Epoch 2314 var recovered bool 2315 recoveredPubs, recovered = isRecovered(historyResult, cmdOffset, cmd.Epoch) 2316 res.Recovered = recovered 2317 incRecover(res.Recovered) 2318 } 2319 } else { 2320 streamTop, err := c.node.streamTop(channel) 2321 if err != nil { 2322 c.node.logger.log(newLogEntry(LogLevelError, "error getting recovery state for channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2323 c.pubSubSync.StopBuffering(channel) 2324 if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal { 2325 return errorDisconnectContext(clientErr, nil) 2326 } 2327 ctx.disconnect = DisconnectServerError 2328 return ctx 2329 } 2330 latestOffset = streamTop.Offset 2331 latestEpoch = streamTop.Epoch 2332 } 2333 2334 res.Epoch = latestEpoch 2335 res.Offset = latestOffset 2336 2337 bufferedPubs := c.pubSubSync.LockBufferAndReadBuffered(channel) 2338 var okMerge bool 2339 recoveredPubs, okMerge = recovery.MergePublications(recoveredPubs, bufferedPubs) 2340 if !okMerge { 2341 c.pubSubSync.StopBuffering(channel) 2342 ctx.disconnect = DisconnectInsufficientState 2343 return ctx 2344 } 2345 } else if reply.Options.Position { 2346 streamTop, err := c.node.streamTop(channel) 2347 if err != nil { 2348 c.node.logger.log(newLogEntry(LogLevelError, "error getting stream top for channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2349 if clientErr, ok := err.(*Error); ok && clientErr != ErrorInternal { 2350 return errorDisconnectContext(clientErr, nil) 2351 } 2352 ctx.disconnect = DisconnectServerError 2353 return ctx 2354 } 2355 2356 latestOffset = streamTop.Offset 2357 latestEpoch = streamTop.Epoch 2358 2359 res.Positioned = true 2360 res.Offset = streamTop.Offset 2361 res.Epoch = streamTop.Epoch 2362 } 2363 2364 if len(recoveredPubs) > 0 { 2365 lastPubOffset := recoveredPubs[len(recoveredPubs)-1].Offset 2366 if lastPubOffset > res.Offset { 2367 // There can be a case when recovery returned a limited set of publications 2368 // thus last publication offset will be smaller than history current offset. 2369 // In this case res.Recovered will be false. So we take a maximum here. 2370 latestOffset = recoveredPubs[len(recoveredPubs)-1].Offset 2371 res.Offset = latestOffset 2372 } 2373 } 2374 2375 res.Publications = recoveredPubs 2376 2377 if !serverSide { 2378 // Write subscription reply only if initiated by client. 2379 replyRes, err := protocol.GetResultEncoder(c.transport.Protocol().toProto()).EncodeSubscribeResult(res) 2380 if err != nil { 2381 c.node.logger.log(newLogEntry(LogLevelError, "error encoding subscribe", map[string]interface{}{"error": err.Error()})) 2382 if !serverSide { 2383 // Will be called later in case of server side sub. 2384 c.pubSubSync.StopBuffering(channel) 2385 } 2386 ctx.disconnect = DisconnectServerError 2387 return ctx 2388 } 2389 // Need to flush data from writer so subscription response is 2390 // sent before any subscription publication. 2391 _ = writeReply(rw, &protocol.Reply{Result: replyRes}) 2392 } 2393 2394 var channelFlags uint8 2395 channelFlags |= flagSubscribed 2396 if serverSide { 2397 channelFlags |= flagServerSide 2398 } 2399 if reply.ClientSideRefresh { 2400 channelFlags |= flagClientSideRefresh 2401 } 2402 if reply.Options.Recover { 2403 channelFlags |= flagRecover 2404 } 2405 if reply.Options.Position { 2406 channelFlags |= flagPosition 2407 } 2408 if reply.Options.Presence { 2409 channelFlags |= flagPresence 2410 } 2411 if reply.Options.JoinLeave { 2412 channelFlags |= flagJoinLeave 2413 } 2414 2415 channelContext := channelContext{ 2416 Info: reply.Options.ChannelInfo, 2417 flags: channelFlags, 2418 expireAt: reply.Options.ExpireAt, 2419 streamPosition: StreamPosition{ 2420 Offset: latestOffset, 2421 Epoch: latestEpoch, 2422 }, 2423 } 2424 if reply.Options.Recover || reply.Options.Position { 2425 channelContext.positionCheckTime = time.Now().Unix() 2426 } 2427 2428 if !serverSide { 2429 // In case of server-side sub this will be done later by the caller. 2430 c.mu.Lock() 2431 c.channels[channel] = channelContext 2432 c.mu.Unlock() 2433 // Stop syncing recovery and PUB/SUB. 2434 // In case of server side subscription we will do this later. 2435 c.pubSubSync.StopBuffering(channel) 2436 } 2437 2438 if c.node.logger.enabled(LogLevelDebug) { 2439 c.node.logger.log(newLogEntry(LogLevelDebug, "client subscribed to channel", map[string]interface{}{"client": c.uid, "user": c.user, "channel": cmd.Channel})) 2440 } 2441 2442 ctx.result = res 2443 ctx.clientInfo = info 2444 ctx.channelContext = channelContext 2445 return ctx 2446} 2447 2448func (c *Client) writePublicationUpdatePosition(ch string, pub *protocol.Publication, reply *prepared.Reply, sp StreamPosition) error { 2449 c.mu.Lock() 2450 channelContext, ok := c.channels[ch] 2451 if !ok || !channelHasFlag(channelContext.flags, flagSubscribed) { 2452 c.mu.Unlock() 2453 return nil 2454 } 2455 if !channelHasFlag(channelContext.flags, flagRecover|flagPosition) { 2456 if hasFlag(c.transport.DisabledPushFlags(), PushFlagPublication) { 2457 c.mu.Unlock() 2458 return nil 2459 } 2460 c.mu.Unlock() 2461 return c.transportEnqueue(reply) 2462 } 2463 currentPositionOffset := channelContext.streamPosition.Offset 2464 nextExpectedOffset := currentPositionOffset + 1 2465 pubOffset := pub.Offset 2466 pubEpoch := sp.Epoch 2467 if pubEpoch != channelContext.streamPosition.Epoch { 2468 if c.node.logger.enabled(LogLevelDebug) { 2469 c.node.logger.log(newLogEntry(LogLevelDebug, "client insufficient state", map[string]interface{}{"channel": ch, "user": c.user, "client": c.uid, "epoch": pubEpoch, "expectedEpoch": channelContext.streamPosition.Epoch})) 2470 } 2471 // Oops: sth lost, let client reconnect to recover its state. 2472 go func() { _ = c.close(DisconnectInsufficientState) }() 2473 c.mu.Unlock() 2474 return nil 2475 } 2476 if pubOffset != nextExpectedOffset { 2477 if c.node.logger.enabled(LogLevelDebug) { 2478 c.node.logger.log(newLogEntry(LogLevelDebug, "client insufficient state", map[string]interface{}{"channel": ch, "user": c.user, "client": c.uid, "offset": pubOffset, "expectedOffset": nextExpectedOffset})) 2479 } 2480 // Oops: sth lost, let client reconnect to recover its state. 2481 go func() { _ = c.close(DisconnectInsufficientState) }() 2482 c.mu.Unlock() 2483 return nil 2484 } 2485 channelContext.positionCheckTime = time.Now().Unix() 2486 channelContext.positionCheckFailures = 0 2487 channelContext.streamPosition.Offset = pub.Offset 2488 c.channels[ch] = channelContext 2489 c.mu.Unlock() 2490 if hasFlag(c.transport.DisabledPushFlags(), PushFlagPublication) { 2491 return nil 2492 } 2493 return c.transportEnqueue(reply) 2494} 2495 2496func (c *Client) writePublication(ch string, pub *protocol.Publication, reply *prepared.Reply, sp StreamPosition) error { 2497 if pub.Offset == 0 { 2498 if hasFlag(c.transport.DisabledPushFlags(), PushFlagPublication) { 2499 return nil 2500 } 2501 return c.transportEnqueue(reply) 2502 } 2503 c.pubSubSync.SyncPublication(ch, pub, func() { 2504 _ = c.writePublicationUpdatePosition(ch, pub, reply, sp) 2505 }) 2506 return nil 2507} 2508 2509func (c *Client) writeJoin(_ string, reply *prepared.Reply) error { 2510 if hasFlag(c.transport.DisabledPushFlags(), PushFlagJoin) { 2511 return nil 2512 } 2513 return c.transportEnqueue(reply) 2514} 2515 2516func (c *Client) writeLeave(_ string, reply *prepared.Reply) error { 2517 if hasFlag(c.transport.DisabledPushFlags(), PushFlagLeave) { 2518 return nil 2519 } 2520 return c.transportEnqueue(reply) 2521} 2522 2523// Lock must be held outside. 2524func (c *Client) unsubscribe(channel string) error { 2525 c.mu.RLock() 2526 info := c.clientInfo(channel) 2527 chCtx, ok := c.channels[channel] 2528 serverSide := channelHasFlag(chCtx.flags, flagServerSide) 2529 c.mu.RUnlock() 2530 2531 if ok { 2532 c.mu.Lock() 2533 delete(c.channels, channel) 2534 c.mu.Unlock() 2535 2536 if channelHasFlag(chCtx.flags, flagPresence) && channelHasFlag(chCtx.flags, flagSubscribed) { 2537 err := c.node.removePresence(channel, c.uid) 2538 if err != nil { 2539 c.node.logger.log(newLogEntry(LogLevelError, "error removing channel presence", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2540 } 2541 } 2542 2543 if channelHasFlag(chCtx.flags, flagJoinLeave) && channelHasFlag(chCtx.flags, flagSubscribed) { 2544 _ = c.node.publishLeave(channel, info) 2545 } 2546 2547 if err := c.node.removeSubscription(channel, c); err != nil { 2548 c.node.logger.log(newLogEntry(LogLevelError, "error removing subscription", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "error": err.Error()})) 2549 return err 2550 } 2551 2552 if channelHasFlag(chCtx.flags, flagSubscribed) { 2553 if c.eventHub.unsubscribeHandler != nil { 2554 c.eventHub.unsubscribeHandler(UnsubscribeEvent{ 2555 Channel: channel, 2556 ServerSide: serverSide, 2557 }) 2558 } 2559 } 2560 } 2561 if c.node.logger.enabled(LogLevelDebug) { 2562 c.node.logger.log(newLogEntry(LogLevelDebug, "client unsubscribed from channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid})) 2563 } 2564 return nil 2565} 2566 2567func (c *Client) logDisconnectBadRequest(message string) *Disconnect { 2568 c.node.logger.log(newLogEntry(LogLevelInfo, message, map[string]interface{}{"user": c.user, "client": c.uid})) 2569 return DisconnectBadRequest 2570} 2571 2572func (c *Client) logDisconnectBadRequestWithError(err error, message string) *Disconnect { 2573 c.node.logger.log(newLogEntry(LogLevelInfo, message, map[string]interface{}{"error": err.Error(), "user": c.user, "client": c.uid})) 2574 return DisconnectBadRequest 2575} 2576 2577func (c *Client) logWriteInternalErrorFlush(rw *replyWriter, err error, message string) { 2578 if clientErr, ok := err.(*Error); ok { 2579 c.writeError(rw, clientErr) 2580 return 2581 } 2582 c.node.logger.log(newLogEntry(LogLevelError, message, map[string]interface{}{"error": err.Error()})) 2583 c.writeError(rw, ErrorInternal) 2584} 2585 2586func toClientErr(err error) *Error { 2587 if clientErr, ok := err.(*Error); ok { 2588 return clientErr 2589 } 2590 return ErrorInternal 2591} 2592 2593func errLogLevel(err error) LogLevel { 2594 logLevel := LogLevelInfo 2595 if err != ErrorNotAvailable { 2596 logLevel = LogLevelError 2597 } 2598 return logLevel 2599} 2600