1package centrifuge 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "strings" 10 "sync" 11 "testing" 12 "time" 13 14 "github.com/centrifugal/protocol" 15 "github.com/stretchr/testify/require" 16) 17 18func newClient(ctx context.Context, n *Node, t Transport) (*Client, error) { 19 c, _, err := NewClient(ctx, n, t) 20 if err != nil { 21 return nil, err 22 } 23 return c, nil 24} 25 26func newTestConnectedClient(t *testing.T, n *Node, userID string) *Client { 27 client := newTestClient(t, n, userID) 28 connectClient(t, client) 29 require.True(t, len(n.hub.UserConnections(userID)) > 0) 30 return client 31} 32 33func newTestSubscribedClient(t *testing.T, n *Node, userID, chanID string) *Client { 34 client := newTestConnectedClient(t, n, userID) 35 subscribeClient(t, client, chanID) 36 require.True(t, n.hub.NumSubscribers(chanID) > 0) 37 require.Contains(t, client.channels, chanID) 38 return client 39} 40 41func TestConnectRequestToProto(t *testing.T) { 42 r := ConnectRequest{ 43 Token: "token", 44 Subs: map[string]SubscribeRequest{ 45 "test": { 46 Recover: true, 47 Offset: 1, 48 Epoch: "epoch", 49 }, 50 }} 51 protoReq := r.toProto() 52 require.Equal(t, "token", protoReq.GetToken()) 53 require.Equal(t, uint64(1), protoReq.Subs["test"].Offset) 54 require.Equal(t, "epoch", protoReq.Subs["test"].Epoch) 55 require.True(t, protoReq.Subs["test"].Recover) 56} 57 58func TestSetCredentials(t *testing.T) { 59 ctx := context.Background() 60 newCtx := SetCredentials(ctx, &Credentials{}) 61 val := newCtx.Value(credentialsContextKey).(*Credentials) 62 require.NotNil(t, val) 63} 64 65func TestNewClient(t *testing.T) { 66 node := defaultTestNode() 67 transport := newTestTransport(func() {}) 68 client, err := newClient(context.Background(), node, transport) 69 require.NoError(t, err) 70 require.NotNil(t, client) 71} 72 73func TestClientInitialState(t *testing.T) { 74 node := defaultTestNode() 75 defer func() { _ = node.Shutdown(context.Background()) }() 76 transport := newTestTransport(func() {}) 77 client, _ := newClient(context.Background(), node, transport) 78 require.Equal(t, client.uid, client.ID()) 79 require.NotNil(t, "", client.user) 80 require.Equal(t, 0, len(client.Channels())) 81 require.Equal(t, ProtocolTypeJSON, client.Transport().Protocol()) 82 require.Equal(t, "websocket", client.Transport().Name()) 83 require.True(t, client.status == statusConnecting) 84 require.False(t, client.authenticated) 85} 86 87func TestClientClosedState(t *testing.T) { 88 node := defaultTestNode() 89 defer func() { _ = node.Shutdown(context.Background()) }() 90 transport := newTestTransport(func() {}) 91 client, _ := newClient(context.Background(), node, transport) 92 err := client.close(nil) 93 require.NoError(t, err) 94 require.True(t, client.status == statusClosed) 95} 96 97func TestClientTimer(t *testing.T) { 98 node := defaultTestNode() 99 node.config.ClientStaleCloseDelay = 25 * time.Second 100 defer func() { _ = node.Shutdown(context.Background()) }() 101 transport := newTestTransport(func() {}) 102 client, _ := newClient(context.Background(), node, transport) 103 require.NotNil(t, client.timer) 104 node.config.ClientStaleCloseDelay = 0 105 client, _ = newClient(context.Background(), node, transport) 106 require.Nil(t, client.timer) 107} 108 109func TestClientOnTimerOpClosedClient(t *testing.T) { 110 node := defaultTestNode() 111 defer func() { _ = node.Shutdown(context.Background()) }() 112 client := newTestClient(t, node, "42") 113 err := client.close(DisconnectForceNoReconnect) 114 require.NoError(t, err) 115 client.onTimerOp() 116 require.False(t, client.timer.Stop()) 117} 118 119func TestClientUnsubscribeClosedClient(t *testing.T) { 120 node := defaultTestNode() 121 defer func() { _ = node.Shutdown(context.Background()) }() 122 client := newTestClient(t, node, "42") 123 connectClient(t, client) 124 subscribeClient(t, client, "test") 125 err := client.close(DisconnectForceNoReconnect) 126 require.NoError(t, err) 127 err = client.Unsubscribe("test") 128 require.NoError(t, err) 129} 130 131func TestClientTimerSchedule(t *testing.T) { 132 node := defaultTestNode() 133 defer func() { _ = node.Shutdown(context.Background()) }() 134 transport := newTestTransport(func() {}) 135 client, _ := newClient(context.Background(), node, transport) 136 client.mu.Lock() 137 defer client.mu.Unlock() 138 client.nextExpire = time.Now().Add(5 * time.Second).UnixNano() 139 client.nextPresence = time.Now().Add(10 * time.Second).UnixNano() 140 client.scheduleNextTimer() 141 require.NotNil(t, client.timer) 142 require.Equal(t, timerOpExpire, client.timerOp) 143 client.nextPresence = time.Now().Add(time.Second).UnixNano() 144 client.scheduleNextTimer() 145 require.NotNil(t, client.timer) 146 require.Equal(t, timerOpPresence, client.timerOp) 147} 148 149func TestClientConnectNoCredentialsNoToken(t *testing.T) { 150 node := defaultTestNode() 151 defer func() { _ = node.Shutdown(context.Background()) }() 152 transport := newTestTransport(func() {}) 153 client, _ := newClient(context.Background(), node, transport) 154 rwWrapper := testReplyWriterWrapper() 155 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 156 require.Equal(t, DisconnectBadRequest, err) 157} 158 159func TestClientConnectContextCredentials(t *testing.T) { 160 node := defaultTestNode() 161 defer func() { _ = node.Shutdown(context.Background()) }() 162 163 transport := newTestTransport(func() {}) 164 ctx := context.Background() 165 newCtx := SetCredentials(ctx, &Credentials{ 166 UserID: "42", 167 ExpireAt: time.Now().Unix() + 60, 168 }) 169 client, _ := newClient(newCtx, node, transport) 170 171 rwWrapper := testReplyWriterWrapper() 172 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 173 require.NoError(t, err) 174 result := extractConnectReply(rwWrapper.replies, client.Transport().Protocol()) 175 require.Equal(t, false, result.Expires) 176 require.Equal(t, uint32(0), result.Ttl) 177 require.True(t, client.authenticated) 178 require.Equal(t, "42", client.UserID()) 179} 180 181func TestClientRefreshHandlerClosingExpiredClient(t *testing.T) { 182 node := defaultTestNode() 183 defer func() { _ = node.Shutdown(context.Background()) }() 184 185 node.OnConnect(func(client *Client) { 186 client.OnRefresh(func(_ RefreshEvent, callback RefreshCallback) { 187 callback(RefreshReply{ 188 Expired: true, 189 }, nil) 190 }) 191 }) 192 193 transport := newTestTransport(func() {}) 194 ctx := context.Background() 195 newCtx := SetCredentials(ctx, &Credentials{ 196 UserID: "42", 197 ExpireAt: time.Now().Unix() + 60, 198 }) 199 client, _ := newClient(newCtx, node, transport) 200 201 rwWrapper := testReplyWriterWrapper() 202 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 203 require.NoError(t, err) 204 client.triggerConnect() 205 client.expire() 206 require.True(t, client.status == statusClosed) 207} 208 209func TestClientRefreshHandlerProlongsClientSession(t *testing.T) { 210 node := defaultTestNode() 211 defer func() { _ = node.Shutdown(context.Background()) }() 212 213 transport := newTestTransport(func() {}) 214 ctx := context.Background() 215 newCtx := SetCredentials(ctx, &Credentials{ 216 UserID: "42", 217 ExpireAt: time.Now().Unix() + 60, 218 }) 219 client, _ := newClient(newCtx, node, transport) 220 221 expireAt := time.Now().Unix() + 60 222 223 node.OnConnect(func(client *Client) { 224 client.OnRefresh(func(_ RefreshEvent, cb RefreshCallback) { 225 cb(RefreshReply{ 226 ExpireAt: expireAt, 227 }, nil) 228 }) 229 }) 230 231 rwWrapper := testReplyWriterWrapper() 232 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 233 require.NoError(t, err) 234 client.expire() 235 require.False(t, client.status == statusClosed) 236 require.Equal(t, expireAt, client.exp) 237} 238 239func TestClientConnectWithExpiredContextCredentials(t *testing.T) { 240 node := defaultTestNode() 241 defer func() { _ = node.Shutdown(context.Background()) }() 242 243 transport := newTestTransport(func() {}) 244 ctx := context.Background() 245 newCtx := SetCredentials(ctx, &Credentials{ 246 UserID: "42", 247 ExpireAt: time.Now().Unix() - 60, 248 }) 249 client, _ := newClient(newCtx, node, transport) 250 251 node.OnConnect(func(client *Client) { 252 client.OnRefresh(func(_ RefreshEvent, cb RefreshCallback) { 253 cb(RefreshReply{}, nil) 254 }) 255 }) 256 257 rwWrapper := testReplyWriterWrapper() 258 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 259 require.Equal(t, ErrorExpired, err) 260} 261 262func connectClient(t testing.TB, client *Client) *protocol.ConnectResult { 263 rwWrapper := testReplyWriterWrapper() 264 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 265 require.NoError(t, err) 266 require.Nil(t, rwWrapper.replies[0].Error) 267 require.True(t, client.authenticated) 268 result := extractConnectReply(rwWrapper.replies, client.Transport().Protocol()) 269 require.Equal(t, client.uid, result.Client) 270 client.triggerConnect() 271 client.scheduleOnConnectTimers() 272 return result 273} 274 275func extractSubscribeResult(replies []*protocol.Reply, protoType ProtocolType) *protocol.SubscribeResult { 276 var res protocol.SubscribeResult 277 if protoType == ProtocolTypeJSON { 278 err := json.Unmarshal(replies[0].Result, &res) 279 if err != nil { 280 panic(err) 281 } 282 } else { 283 err := res.UnmarshalVT(replies[0].Result) 284 if err != nil { 285 panic(err) 286 } 287 } 288 return &res 289} 290 291func extractConnectReply(replies []*protocol.Reply, protoType ProtocolType) *protocol.ConnectResult { 292 var res protocol.ConnectResult 293 if protoType == ProtocolTypeJSON { 294 err := json.Unmarshal(replies[0].Result, &res) 295 if err != nil { 296 panic(err) 297 } 298 } else { 299 err := res.UnmarshalVT(replies[0].Result) 300 if err != nil { 301 panic(err) 302 } 303 } 304 return &res 305} 306 307func subscribeClient(t testing.TB, client *Client, ch string) *protocol.SubscribeResult { 308 rwWrapper := testReplyWriterWrapper() 309 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 310 Channel: ch, 311 }), rwWrapper.rw) 312 require.NoError(t, err) 313 require.Nil(t, rwWrapper.replies[0].Error) 314 return extractSubscribeResult(rwWrapper.replies, client.Transport().Protocol()) 315} 316 317func TestClientSubscribe(t *testing.T) { 318 node := defaultNodeNoHandlers() 319 defer func() { _ = node.Shutdown(context.Background()) }() 320 321 node.OnConnect(func(client *Client) { 322 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 323 cb(SubscribeReply{ 324 Options: SubscribeOptions{ 325 JoinLeave: true, 326 Presence: true, 327 Position: true, 328 Recover: true, 329 ChannelInfo: []byte("{}"), 330 ExpireAt: time.Now().Unix() + 3600, 331 Data: []byte("{}"), 332 }, 333 }, nil) 334 }) 335 }) 336 337 client := newTestClient(t, node, "42") 338 connectClient(t, client) 339 340 require.Equal(t, 0, len(client.Channels())) 341 342 rwWrapper := testReplyWriterWrapper() 343 344 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 345 Channel: "test1", 346 }), rwWrapper.rw) 347 require.NoError(t, err) 348 require.Equal(t, 1, len(rwWrapper.replies)) 349 require.Nil(t, rwWrapper.replies[0].Error) 350 res := extractSubscribeResult(rwWrapper.replies, client.Transport().Protocol()) 351 require.Empty(t, res.Offset) 352 require.False(t, res.Recovered) 353 require.Empty(t, res.Publications) 354 require.Equal(t, 1, len(client.Channels())) 355 356 rwWrapper = testReplyWriterWrapper() 357 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 358 Channel: "test2", 359 }), rwWrapper.rw) 360 require.NoError(t, err) 361 require.Equal(t, 2, len(client.Channels())) 362 require.Equal(t, 1, node.Hub().NumClients()) 363 require.Equal(t, 2, node.Hub().NumChannels()) 364 365 rwWrapper = testReplyWriterWrapper() 366 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 367 Channel: "test2", 368 }), rwWrapper.rw) 369 require.Equal(t, ErrorAlreadySubscribed, err) 370} 371 372func TestClientSubscribeBrokerErrorOnSubscribe(t *testing.T) { 373 broker := NewTestBroker() 374 broker.errorOnSubscribe = true 375 node := nodeWithBroker(broker) 376 defer func() { _ = node.Shutdown(context.Background()) }() 377 378 done := make(chan struct{}) 379 380 node.OnConnect(func(client *Client) { 381 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 382 callback(SubscribeReply{}, nil) 383 }) 384 client.OnDisconnect(func(event DisconnectEvent) { 385 require.Equal(t, DisconnectServerError, event.Disconnect) 386 close(done) 387 }) 388 }) 389 390 client := newTestClient(t, node, "42") 391 connectClient(t, client) 392 393 rwWrapper := testReplyWriterWrapper() 394 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 395 Channel: "test1", 396 }), rwWrapper.rw) 397 require.NoError(t, err) 398 399 select { 400 case <-time.After(time.Second): 401 require.Fail(t, "timeout waiting for channel close") 402 case <-done: 403 } 404} 405 406func TestClientSubscribeBrokerErrorOnStreamTop(t *testing.T) { 407 broker := NewTestBroker() 408 broker.errorOnHistory = true 409 node := nodeWithBroker(broker) 410 defer func() { _ = node.Shutdown(context.Background()) }() 411 412 done := make(chan struct{}) 413 414 node.OnConnect(func(client *Client) { 415 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 416 callback(SubscribeReply{ 417 Options: SubscribeOptions{Recover: true}, 418 }, nil) 419 }) 420 client.OnDisconnect(func(event DisconnectEvent) { 421 require.Equal(t, DisconnectServerError, event.Disconnect) 422 close(done) 423 }) 424 }) 425 426 client := newTestClient(t, node, "42") 427 connectClient(t, client) 428 429 rwWrapper := testReplyWriterWrapper() 430 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 431 Channel: "test1", 432 }), rwWrapper.rw) 433 require.NoError(t, err) 434 435 select { 436 case <-time.After(time.Second): 437 require.Fail(t, "timeout waiting for channel close") 438 case <-done: 439 } 440} 441 442func TestClientSubscribeUnrecoverablePosition(t *testing.T) { 443 broker := NewTestBroker() 444 node := nodeWithBroker(broker) 445 defer func() { _ = node.Shutdown(context.Background()) }() 446 447 node.OnConnect(func(client *Client) { 448 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 449 callback(SubscribeReply{ 450 Options: SubscribeOptions{Recover: true}, 451 }, nil) 452 }) 453 }) 454 455 client := newTestClient(t, node, "42") 456 connectClient(t, client) 457 458 rwWrapper := testReplyWriterWrapper() 459 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 460 Channel: "test1", 461 Recover: true, 462 Epoch: "xxx", 463 }), rwWrapper.rw) 464 require.NoError(t, err) 465 require.Equal(t, 1, len(rwWrapper.replies)) 466 require.Nil(t, rwWrapper.replies[0].Error) 467 res := extractSubscribeResult(rwWrapper.replies, client.Transport().Protocol()) 468 require.Empty(t, res.Offset) 469 require.Empty(t, res.Epoch) 470 require.False(t, res.Recovered) 471 require.Empty(t, res.Publications) 472} 473 474func TestClientSubscribePositionedError(t *testing.T) { 475 broker := NewTestBroker() 476 broker.errorOnHistory = true 477 node := nodeWithBroker(broker) 478 defer func() { _ = node.Shutdown(context.Background()) }() 479 480 done := make(chan struct{}) 481 482 node.OnConnect(func(client *Client) { 483 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 484 callback(SubscribeReply{ 485 Options: SubscribeOptions{Position: true}, 486 }, nil) 487 }) 488 client.OnDisconnect(func(event DisconnectEvent) { 489 require.Equal(t, DisconnectServerError, event.Disconnect) 490 close(done) 491 }) 492 }) 493 494 client := newTestClient(t, node, "42") 495 connectClient(t, client) 496 497 rwWrapper := testReplyWriterWrapper() 498 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 499 Channel: "test1", 500 }), rwWrapper.rw) 501 require.NoError(t, err) 502 503 select { 504 case <-time.After(time.Second): 505 require.Fail(t, "timeout waiting for channel close") 506 case <-done: 507 } 508} 509 510func TestClientSubscribePositioned(t *testing.T) { 511 node := nodeWithTestBroker() 512 defer func() { _ = node.Shutdown(context.Background()) }() 513 514 node.OnConnect(func(client *Client) { 515 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 516 callback(SubscribeReply{ 517 Options: SubscribeOptions{Position: true}, 518 }, nil) 519 }) 520 }) 521 522 client := newTestClient(t, node, "42") 523 connectClient(t, client) 524 525 rwWrapper := testReplyWriterWrapper() 526 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 527 Channel: "test1", 528 }), rwWrapper.rw) 529 require.NoError(t, err) 530 531 var result protocol.SubscribeResult 532 require.Nil(t, rwWrapper.replies[0].Error) 533 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 534 require.NoError(t, err) 535 require.True(t, result.Positioned) 536} 537 538func TestClientSubscribeBrokerErrorOnRecoverHistory(t *testing.T) { 539 broker := NewTestBroker() 540 broker.errorOnHistory = true 541 node := nodeWithBroker(broker) 542 defer func() { _ = node.Shutdown(context.Background()) }() 543 544 done := make(chan struct{}) 545 546 node.OnConnect(func(client *Client) { 547 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 548 callback(SubscribeReply{Options: SubscribeOptions{Recover: true}}, nil) 549 }) 550 client.OnDisconnect(func(event DisconnectEvent) { 551 require.Equal(t, DisconnectServerError, event.Disconnect) 552 close(done) 553 }) 554 }) 555 556 client := newTestClient(t, node, "42") 557 connectClient(t, client) 558 559 rwWrapper := testReplyWriterWrapper() 560 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 561 Channel: "test1", 562 Recover: true, 563 }), rwWrapper.rw) 564 require.NoError(t, err) 565 566 select { 567 case <-time.After(time.Second): 568 require.Fail(t, "timeout waiting for channel close") 569 case <-done: 570 } 571} 572 573func testUnexpectedOffsetEpoch(t *testing.T, offset uint64, epoch string) { 574 broker := NewTestBroker() 575 node := nodeWithBroker(broker) 576 defer func() { _ = node.Shutdown(context.Background()) }() 577 578 done := make(chan struct{}) 579 580 node.OnConnect(func(client *Client) { 581 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 582 callback(SubscribeReply{Options: SubscribeOptions{Recover: true}}, nil) 583 }) 584 client.OnDisconnect(func(event DisconnectEvent) { 585 require.Equal(t, DisconnectInsufficientState, event.Disconnect) 586 close(done) 587 }) 588 }) 589 590 client := newTestClient(t, node, "42") 591 connectClient(t, client) 592 593 rwWrapper := testReplyWriterWrapper() 594 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 595 Channel: "test", 596 Recover: true, 597 }), rwWrapper.rw) 598 require.NoError(t, err) 599 600 err = node.handlePublication("test", &Publication{ 601 Offset: offset, 602 }, StreamPosition{offset, epoch}) 603 require.NoError(t, err) 604 605 select { 606 case <-time.After(time.Second): 607 require.Fail(t, "timeout waiting for channel close") 608 case <-done: 609 } 610} 611 612func TestClientUnexpectedOffsetEpoch(t *testing.T) { 613 tests := []struct { 614 Name string 615 Offset uint64 616 Epoch string 617 }{ 618 {"wrong_offset", 2, ""}, 619 {"wrong_epoch", 1, "xyz"}, 620 } 621 622 for _, tt := range tests { 623 t.Run(tt.Name, func(t *testing.T) { 624 testUnexpectedOffsetEpoch(t, tt.Offset, tt.Epoch) 625 }) 626 } 627} 628 629func TestClientSubscribeValidateErrors(t *testing.T) { 630 node := defaultTestNode() 631 node.config.ClientChannelLimit = 1 632 node.config.ChannelMaxLength = 10 633 defer func() { _ = node.Shutdown(context.Background()) }() 634 635 node.OnConnect(func(client *Client) { 636 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 637 callback(SubscribeReply{}, nil) 638 }) 639 }) 640 641 client := newTestClient(t, node, "42") 642 connectClient(t, client) 643 644 rwWrapper := testReplyWriterWrapper() 645 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 646 Channel: "test2_very_long_channel_name", 647 }), rwWrapper.rw) 648 require.Equal(t, ErrorBadRequest, err) 649 650 subscribeClient(t, client, "test1") 651 652 rwWrapper = testReplyWriterWrapper() 653 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 654 Channel: "test2", 655 }), rwWrapper.rw) 656 require.Equal(t, ErrorLimitExceeded, err) 657 658 rwWrapper = testReplyWriterWrapper() 659 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 660 Channel: "", 661 }), rwWrapper.rw) 662 require.Equal(t, DisconnectBadRequest, err) 663} 664 665func TestClientSubscribeReceivePublication(t *testing.T) { 666 node := defaultTestNode() 667 defer func() { _ = node.Shutdown(context.Background()) }() 668 transport := newTestTransport(func() {}) 669 transport.sink = make(chan []byte, 100) 670 ctx := context.Background() 671 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 672 client, _ := newClient(newCtx, node, transport) 673 674 connectClient(t, client) 675 676 rwWrapper := testReplyWriterWrapper() 677 678 subCtx := client.subscribeCmd(&protocol.SubscribeRequest{ 679 Channel: "test", 680 }, SubscribeReply{}, rwWrapper.rw, false) 681 require.Nil(t, subCtx.disconnect) 682 require.Nil(t, rwWrapper.replies[0].Error) 683 684 done := make(chan struct{}) 685 go func() { 686 for data := range transport.sink { 687 if strings.Contains(string(data), "test message") { 688 close(done) 689 } 690 } 691 }() 692 693 _, err := node.Publish("test", []byte(`{"text": "test message"}`)) 694 require.NoError(t, err) 695 696 select { 697 case <-time.After(time.Second): 698 require.Fail(t, "timeout receiving publication") 699 case <-done: 700 } 701} 702 703func TestClientSubscribeReceivePublicationWithOffset(t *testing.T) { 704 node := defaultTestNode() 705 defer func() { _ = node.Shutdown(context.Background()) }() 706 707 transport := newTestTransport(func() {}) 708 transport.sink = make(chan []byte, 100) 709 ctx := context.Background() 710 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 711 client, _ := newClient(newCtx, node, transport) 712 713 connectClient(t, client) 714 715 rwWrapper := testReplyWriterWrapper() 716 717 subCtx := client.subscribeCmd(&protocol.SubscribeRequest{ 718 Channel: "test", 719 }, SubscribeReply{}, rwWrapper.rw, false) 720 require.Nil(t, subCtx.disconnect) 721 require.Nil(t, rwWrapper.replies[0].Error) 722 723 done := make(chan struct{}) 724 go func() { 725 var offset uint64 = 1 726 for data := range transport.sink { 727 if strings.Contains(string(data), "test message") { 728 dec := json.NewDecoder(strings.NewReader(string(data))) 729 for { 730 var push struct { 731 Result struct { 732 Channel string 733 Data struct { 734 Offset uint64 735 } 736 } 737 } 738 err := dec.Decode(&push) 739 if err == io.EOF { 740 break 741 } 742 require.NoError(t, err) 743 if push.Result.Data.Offset != offset { 744 require.Fail(t, fmt.Sprintf("wrong offset: %d != %d", push.Result.Data.Offset, offset)) 745 } 746 offset++ 747 if offset > 3 { 748 close(done) 749 } 750 } 751 } 752 } 753 }() 754 755 // Send 3 publications, expect client to receive them with 756 // incremental sequence numbers. 757 _, err := node.Publish("test", []byte(`{"text": "test message 1"}`), WithHistory(10, time.Minute)) 758 require.NoError(t, err) 759 _, err = node.Publish("test", []byte(`{"text": "test message 2"}`), WithHistory(10, time.Minute)) 760 require.NoError(t, err) 761 _, err = node.Publish("test", []byte(`{"text": "test message 3"}`), WithHistory(10, time.Minute)) 762 require.NoError(t, err) 763 764 select { 765 case <-time.After(time.Second): 766 require.Fail(t, "timeout receiving publications") 767 case <-done: 768 } 769} 770 771func TestUserConnectionLimit(t *testing.T) { 772 node := defaultTestNode() 773 node.config.UserConnectionLimit = 1 774 defer func() { _ = node.Shutdown(context.Background()) }() 775 776 transport := newTestTransport(func() {}) 777 ctx := context.Background() 778 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 779 780 client, _ := newClient(newCtx, node, transport) 781 connectClient(t, client) 782 783 rwWrapper := testReplyWriterWrapper() 784 anotherClient, _ := newClient(newCtx, node, transport) 785 _, err := anotherClient.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 786 require.Equal(t, DisconnectConnectionLimit, err) 787} 788 789type testContextKey int 790 791var keyTest testContextKey = 1 792 793func TestConnectingReply(t *testing.T) { 794 node := defaultTestNode() 795 defer func() { _ = node.Shutdown(context.Background()) }() 796 797 node.OnConnecting(func(ctx context.Context, e ConnectEvent) (ConnectReply, error) { 798 newCtx := context.WithValue(ctx, keyTest, "val") 799 return ConnectReply{ 800 Context: newCtx, 801 Data: []byte("{}"), 802 Credentials: &Credentials{ 803 UserID: "12", 804 }, 805 }, nil 806 }) 807 808 done := make(chan struct{}) 809 810 node.OnConnect(func(c *Client) { 811 v, ok := c.Context().Value(keyTest).(string) 812 require.True(t, ok) 813 require.Equal(t, "val", v) 814 require.Equal(t, "12", c.UserID()) 815 close(done) 816 }) 817 818 client := newTestClient(t, node, "42") 819 connectClient(t, client) 820 821 select { 822 case <-time.After(time.Second): 823 require.Fail(t, "timeout waiting for channel close") 824 case <-done: 825 } 826} 827 828func TestServerSideSubscriptions(t *testing.T) { 829 testCases := []struct { 830 Name string 831 Unidirectional bool 832 }{ 833 {"bidi", false}, 834 {"uni", true}, 835 } 836 837 for _, tt := range testCases { 838 t.Run(tt.Name, func(t *testing.T) { 839 node := defaultTestNode() 840 defer func() { _ = node.Shutdown(context.Background()) }() 841 842 node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) { 843 return ConnectReply{ 844 Subscriptions: map[string]SubscribeOptions{ 845 "server-side-1": {}, 846 "$server-side-2": {}, 847 }, 848 }, nil 849 }) 850 transport := newTestTransport(func() {}) 851 transport.setUnidirectional(tt.Unidirectional) 852 transport.sink = make(chan []byte, 100) 853 ctx := context.Background() 854 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 855 client, _ := newClient(newCtx, node, transport) 856 if !tt.Unidirectional { 857 connectClient(t, client) 858 } else { 859 client.Connect(ConnectRequest{ 860 Subs: map[string]SubscribeRequest{ 861 "server-side-1": { 862 Recover: true, 863 Epoch: "test", 864 Offset: 0, 865 }, 866 }, 867 }) 868 } 869 870 _ = client.Subscribe("server-side-3") 871 _, err := node.Publish("server-side-1", []byte(`{"text": "test message 1"}`)) 872 require.NoError(t, err) 873 _, err = node.Publish("$server-side-2", []byte(`{"text": "test message 2"}`)) 874 require.NoError(t, err) 875 _, err = node.Publish("server-side-3", []byte(`{"text": "test message 3"}`)) 876 require.NoError(t, err) 877 878 done := make(chan struct{}) 879 go func() { 880 var i int 881 for data := range transport.sink { 882 if strings.Contains(string(data), "test message 1") { 883 i++ 884 } 885 if strings.Contains(string(data), "test message 2") { 886 i++ 887 } 888 if strings.Contains(string(data), "test message 3") { 889 i++ 890 } 891 if i == 3 { 892 close(done) 893 return 894 } 895 } 896 }() 897 898 select { 899 case <-time.After(time.Second): 900 require.Fail(t, "timeout receiving publication") 901 case <-done: 902 } 903 }) 904 } 905} 906 907func TestClientRefresh(t *testing.T) { 908 node := defaultTestNode() 909 defer func() { _ = node.Shutdown(context.Background()) }() 910 911 node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) { 912 return ConnectReply{}, nil 913 }) 914 transport := newTestTransport(func() {}) 915 transport.setUnidirectional(true) 916 transport.sink = make(chan []byte, 100) 917 ctx := context.Background() 918 919 expireAt1 := time.Now().Unix() + 100 920 expireAt2 := time.Now().Unix() + 200 921 newCtx := SetCredentials(ctx, &Credentials{UserID: "42", ExpireAt: expireAt1}) 922 client, _ := newClient(newCtx, node, transport) 923 924 client.Connect(ConnectRequest{}) 925 926 require.Equal(t, expireAt1, client.exp) 927 err := client.Refresh() 928 require.NoError(t, err) 929 require.Zero(t, client.exp) 930 931 done := make(chan struct{}) 932 go func() { 933 for data := range transport.sink { 934 if strings.Contains(string(data), `"type":8`) { 935 close(done) 936 return 937 } 938 } 939 }() 940 941 select { 942 case <-time.After(time.Second): 943 require.Fail(t, "timeout waiting for done channel closed") 944 case <-done: 945 } 946 947 done = make(chan struct{}) 948 go func() { 949 for data := range transport.sink { 950 if strings.Contains(string(data), `"type":8`) && strings.Contains(string(data), `"ttl":`) && strings.Contains(string(data), `"expires":true`) { 951 close(done) 952 return 953 } 954 } 955 }() 956 957 err = client.Refresh(WithRefreshExpireAt(expireAt2), WithRefreshInfo([]byte("info"))) 958 require.NoError(t, err) 959 require.Equal(t, expireAt2, client.exp) 960 require.Equal(t, []byte("info"), client.info) 961 962 select { 963 case <-time.After(time.Second): 964 require.Fail(t, "timeout waiting for done channel closed") 965 case <-done: 966 } 967 968 done = make(chan struct{}) 969 go func() { 970 for data := range transport.sink { 971 if strings.Contains(string(data), `"type":8`) && strings.Contains(string(data), `"ttl":`) && strings.Contains(string(data), `"expires":true`) { 972 close(done) 973 return 974 } 975 } 976 }() 977 978 err = node.Refresh("42", WithRefreshExpireAt(expireAt2), WithRefreshInfo([]byte("info"))) 979 require.NoError(t, err) 980 require.Equal(t, expireAt2, client.exp) 981 require.Equal(t, []byte("info"), client.info) 982 983 select { 984 case <-time.After(time.Second): 985 require.Fail(t, "timeout waiting for done channel closed") 986 case <-done: 987 } 988 989 done = make(chan struct{}) 990 go func() { 991 for data := range transport.sink { 992 if strings.Contains(string(data), `"code":3005`) { 993 // DisconnectExpired sent. 994 close(done) 995 return 996 } 997 } 998 }() 999 1000 err = client.Refresh(WithRefreshExpired(true)) 1001 require.NoError(t, err) 1002 1003 select { 1004 case <-time.After(time.Second): 1005 require.Fail(t, "timeout waiting for done channel closed") 1006 case <-done: 1007 } 1008} 1009 1010func TestClientRefreshExpireAtInThePast(t *testing.T) { 1011 node := defaultTestNode() 1012 defer func() { _ = node.Shutdown(context.Background()) }() 1013 1014 node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) { 1015 return ConnectReply{}, nil 1016 }) 1017 transport := newTestTransport(func() {}) 1018 transport.setUnidirectional(true) 1019 transport.sink = make(chan []byte, 100) 1020 ctx := context.Background() 1021 1022 expireAt1 := time.Now().Unix() + 100 1023 expireAt2 := time.Now().Unix() - 200 1024 newCtx := SetCredentials(ctx, &Credentials{UserID: "42", ExpireAt: expireAt1}) 1025 client, _ := newClient(newCtx, node, transport) 1026 1027 client.Connect(ConnectRequest{}) 1028 1029 require.Equal(t, expireAt1, client.exp) 1030 1031 done := make(chan struct{}) 1032 go func() { 1033 for data := range transport.sink { 1034 if strings.Contains(string(data), `"code":3005`) { 1035 close(done) 1036 return 1037 } 1038 } 1039 }() 1040 1041 err := client.Refresh(WithRefreshExpireAt(expireAt2)) 1042 require.NoError(t, err) 1043 1044 select { 1045 case <-time.After(time.Second): 1046 require.Fail(t, "timeout waiting for done channel closed") 1047 case <-done: 1048 } 1049} 1050 1051func TestClient_IsSubscribed(t *testing.T) { 1052 node := defaultTestNode() 1053 defer func() { _ = node.Shutdown(context.Background()) }() 1054 1055 transport := newTestTransport(func() {}) 1056 ctx := context.Background() 1057 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 1058 1059 client, _ := newClient(newCtx, node, transport) 1060 connectClient(t, client) 1061 1062 require.False(t, client.IsSubscribed("test")) 1063 _ = subscribeClient(t, client, "test") 1064 require.True(t, client.IsSubscribed("test")) 1065} 1066 1067func TestClientSubscribeLast(t *testing.T) { 1068 node := defaultNodeNoHandlers() 1069 defer func() { _ = node.Shutdown(context.Background()) }() 1070 1071 node.OnConnect(func(client *Client) { 1072 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 1073 cb(SubscribeReply{Options: SubscribeOptions{Recover: true}}, nil) 1074 }) 1075 }) 1076 1077 client := newTestClient(t, node, "42") 1078 connectClient(t, client) 1079 1080 result := subscribeClient(t, client, "test") 1081 require.Equal(t, uint64(0), result.Offset) 1082 1083 for i := 0; i < 10; i++ { 1084 _, _ = node.Publish("test", []byte("{}"), WithHistory(10, time.Minute)) 1085 } 1086 1087 client = newTestClient(t, node, "42") 1088 connectClient(t, client) 1089 result = subscribeClient(t, client, "test") 1090 require.Equal(t, uint64(10), result.Offset, fmt.Sprintf("expected: 10, got %d", result.Offset)) 1091} 1092 1093func newTestClient(t testing.TB, node *Node, userID string) *Client { 1094 ctx, cancelFn := context.WithCancel(context.Background()) 1095 transport := newTestTransport(cancelFn) 1096 newCtx := SetCredentials(ctx, &Credentials{UserID: userID}) 1097 client, err := newClient(newCtx, node, transport) 1098 require.NoError(t, err) 1099 return client 1100} 1101 1102func getJSONEncodedParams(t testing.TB, request interface{}) []byte { 1103 paramsEncoder := protocol.NewJSONParamsEncoder() 1104 params, err := paramsEncoder.Encode(request) 1105 require.NoError(t, err) 1106 return params 1107} 1108 1109func TestClientUnsubscribeClientSide(t *testing.T) { 1110 node := defaultNodeNoHandlers() 1111 defer func() { _ = node.Shutdown(context.Background()) }() 1112 1113 client := newTestClient(t, node, "42") 1114 1115 unsubscribed := make(chan struct{}) 1116 1117 node.OnConnect(func(client *Client) { 1118 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 1119 callback(SubscribeReply{}, nil) 1120 }) 1121 client.OnUnsubscribe(func(_ UnsubscribeEvent) { 1122 close(unsubscribed) 1123 }) 1124 }) 1125 1126 connectClient(t, client) 1127 subscribeClient(t, client, "test") 1128 1129 rwWrapper := testReplyWriterWrapper() 1130 params := getJSONEncodedParams(t, &protocol.UnsubscribeRequest{Channel: ""}) 1131 err := client.handleUnsubscribe(params, rwWrapper.rw) 1132 require.Equal(t, DisconnectBadRequest, err) 1133 1134 rwWrapper = testReplyWriterWrapper() 1135 params = getJSONEncodedParams(t, &protocol.UnsubscribeRequest{Channel: "test"}) 1136 err = client.handleUnsubscribe(params, rwWrapper.rw) 1137 require.NoError(t, err) 1138 1139 require.Equal(t, 0, len(client.Channels())) 1140 require.Equal(t, 1, node.Hub().NumClients()) 1141 require.Equal(t, 0, node.Hub().NumChannels()) 1142 1143 select { 1144 case <-unsubscribed: 1145 case <-time.After(time.Second): 1146 t.Fatal("unsubscribe handler not called") 1147 } 1148} 1149 1150func TestClientUnsubscribeServerSide(t *testing.T) { 1151 node := defaultTestNode() 1152 defer func() { _ = node.Shutdown(context.Background()) }() 1153 client := newTestClient(t, node, "42") 1154 1155 unsubscribed := make(chan struct{}) 1156 1157 node.OnConnect(func(client *Client) { 1158 client.OnSubscribe(func(event SubscribeEvent, callback SubscribeCallback) { 1159 callback(SubscribeReply{}, nil) 1160 }) 1161 client.OnUnsubscribe(func(_ UnsubscribeEvent) { 1162 close(unsubscribed) 1163 }) 1164 }) 1165 1166 connectClient(t, client) 1167 subscribeClient(t, client, "test") 1168 require.Equal(t, 1, len(client.Channels())) 1169 1170 err := client.Unsubscribe("test") 1171 require.NoError(t, err) 1172 require.Equal(t, 0, len(client.Channels())) 1173 require.Equal(t, 1, node.Hub().NumClients()) 1174 require.Equal(t, 0, node.Hub().NumChannels()) 1175 1176 select { 1177 case <-unsubscribed: 1178 case <-time.After(time.Second): 1179 t.Fatal("unsubscribe handler not called") 1180 } 1181} 1182 1183func TestClientAliveHandler(t *testing.T) { 1184 node := defaultNodeNoHandlers() 1185 defer func() { _ = node.Shutdown(context.Background()) }() 1186 1187 node.config.ClientPresenceUpdateInterval = time.Millisecond 1188 1189 transport := newTestTransport(func() {}) 1190 ctx := context.Background() 1191 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 1192 client, _ := newClient(newCtx, node, transport) 1193 1194 done := make(chan struct{}) 1195 closed := false 1196 disconnected := make(chan struct{}) 1197 numCalls := 0 1198 1199 node.OnConnect(func(client *Client) { 1200 client.OnAlive(func() { 1201 numCalls++ 1202 if numCalls >= 50 && !closed { 1203 close(done) 1204 closed = true 1205 client.Disconnect(DisconnectForceNoReconnect) 1206 } 1207 }) 1208 1209 client.OnDisconnect(func(_ DisconnectEvent) { 1210 close(disconnected) 1211 }) 1212 1213 }) 1214 1215 connectClient(t, client) 1216 client.triggerConnect() 1217 client.scheduleOnConnectTimers() 1218 1219 select { 1220 case <-done: 1221 case <-time.After(time.Second): 1222 t.Fatal("alive handler not called") 1223 } 1224 select { 1225 case <-disconnected: 1226 case <-time.After(time.Second): 1227 t.Fatal("disconnect handler not called") 1228 } 1229} 1230 1231type sliceReplyWriter struct { 1232 replies []*protocol.Reply 1233 rw *replyWriter 1234} 1235 1236func testReplyWriterWrapper() *sliceReplyWriter { 1237 replies := make([]*protocol.Reply, 0) 1238 wrapper := &sliceReplyWriter{ 1239 replies: replies, 1240 } 1241 wrapper.rw = &replyWriter{ 1242 write: func(rep *protocol.Reply) error { 1243 wrapper.replies = append(wrapper.replies, rep) 1244 return nil 1245 }, 1246 done: func() {}, 1247 } 1248 return wrapper 1249} 1250 1251func TestClientRefreshNotAvailable(t *testing.T) { 1252 node := defaultNodeNoHandlers() 1253 defer func() { _ = node.Shutdown(context.Background()) }() 1254 client := newTestClient(t, node, "42") 1255 connectClient(t, client) 1256 rwWrapper := testReplyWriterWrapper() 1257 1258 cmd := &protocol.RefreshRequest{} 1259 params := getJSONEncodedParams(t, cmd) 1260 1261 err := client.handleRefresh(params, rwWrapper.rw) 1262 require.Equal(t, ErrorNotAvailable, err) 1263} 1264 1265func TestClientRefreshEmptyToken(t *testing.T) { 1266 node := defaultNodeNoHandlers() 1267 defer func() { _ = node.Shutdown(context.Background()) }() 1268 1269 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 1270 return ConnectReply{ClientSideRefresh: true}, nil 1271 }) 1272 1273 node.OnConnect(func(client *Client) { 1274 client.OnRefresh(func(event RefreshEvent, callback RefreshCallback) { 1275 callback(RefreshReply{ 1276 ExpireAt: time.Now().Unix() + 300, 1277 }, nil) 1278 }) 1279 }) 1280 1281 client := newTestClient(t, node, "42") 1282 connectClient(t, client) 1283 rwWrapper := testReplyWriterWrapper() 1284 1285 cmd := &protocol.RefreshRequest{Token: ""} 1286 params := getJSONEncodedParams(t, cmd) 1287 1288 err := client.handleRefresh(params, rwWrapper.rw) 1289 require.Equal(t, DisconnectBadRequest, err) 1290} 1291 1292func TestClientRefreshUnexpected(t *testing.T) { 1293 node := defaultNodeNoHandlers() 1294 defer func() { _ = node.Shutdown(context.Background()) }() 1295 1296 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 1297 return ConnectReply{ClientSideRefresh: false}, nil // we do not want client side refresh here. 1298 }) 1299 1300 node.OnConnect(func(client *Client) { 1301 client.OnRefresh(func(event RefreshEvent, callback RefreshCallback) { 1302 callback(RefreshReply{ 1303 ExpireAt: time.Now().Unix() + 300, 1304 }, nil) 1305 }) 1306 }) 1307 1308 client := newTestClient(t, node, "42") 1309 connectClient(t, client) 1310 rwWrapper := testReplyWriterWrapper() 1311 1312 cmd := &protocol.RefreshRequest{Token: "xxx"} 1313 params := getJSONEncodedParams(t, cmd) 1314 1315 err := client.handleRefresh(params, rwWrapper.rw) 1316 require.Equal(t, DisconnectBadRequest, err) 1317} 1318 1319func TestClientPublishNotAvailable(t *testing.T) { 1320 node := defaultNodeNoHandlers() 1321 defer func() { _ = node.Shutdown(context.Background()) }() 1322 client := newTestClient(t, node, "42") 1323 connectClient(t, client) 1324 rwWrapper := testReplyWriterWrapper() 1325 1326 cmd := &protocol.PublishRequest{ 1327 Channel: "test", 1328 Data: []byte(`{}`), 1329 } 1330 params := getJSONEncodedParams(t, cmd) 1331 1332 err := client.handlePublish(params, rwWrapper.rw) 1333 require.Equal(t, ErrorNotAvailable, err) 1334} 1335 1336type testBrokerEventHandler struct { 1337 // Publication must register callback func to handle Publications received. 1338 HandlePublicationFunc func(ch string, pub *Publication, sp StreamPosition) error 1339 // Join must register callback func to handle Join messages received. 1340 HandleJoinFunc func(ch string, info *ClientInfo) error 1341 // Leave must register callback func to handle Leave messages received. 1342 HandleLeaveFunc func(ch string, info *ClientInfo) error 1343 // Control must register callback func to handle Control data received. 1344 HandleControlFunc func([]byte) error 1345} 1346 1347func (b *testBrokerEventHandler) HandlePublication(ch string, pub *Publication, sp StreamPosition) error { 1348 if b.HandlePublicationFunc != nil { 1349 return b.HandlePublicationFunc(ch, pub, sp) 1350 } 1351 return nil 1352} 1353 1354func (b *testBrokerEventHandler) HandleJoin(ch string, info *ClientInfo) error { 1355 if b.HandleJoinFunc != nil { 1356 return b.HandleJoinFunc(ch, info) 1357 } 1358 return nil 1359} 1360 1361func (b *testBrokerEventHandler) HandleLeave(ch string, info *ClientInfo) error { 1362 if b.HandleLeaveFunc != nil { 1363 return b.HandleLeaveFunc(ch, info) 1364 } 1365 return nil 1366} 1367 1368func (b *testBrokerEventHandler) HandleControl(data []byte) error { 1369 if b.HandleControlFunc != nil { 1370 return b.HandleControlFunc(data) 1371 } 1372 return nil 1373} 1374 1375type testClientMessage struct { 1376 Input string `json:"input"` 1377 Timestamp int64 `json:"timestamp"` 1378} 1379 1380func TestClientPublishHandler(t *testing.T) { 1381 node := defaultNodeNoHandlers() 1382 defer func() { _ = node.Shutdown(context.Background()) }() 1383 1384 node.OnConnect(func(client *Client) { 1385 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 1386 cb(SubscribeReply{}, nil) 1387 }) 1388 }) 1389 1390 client := newTestClient(t, node, "42") 1391 connectClient(t, client) 1392 1393 node.broker.(*MemoryBroker).eventHandler = &testBrokerEventHandler{ 1394 HandlePublicationFunc: func(ch string, pub *Publication, sp StreamPosition) error { 1395 var msg testClientMessage 1396 err := json.Unmarshal(pub.Data, &msg) 1397 require.NoError(t, err) 1398 if msg.Input == "with timestamp" { 1399 require.True(t, msg.Timestamp > 0) 1400 } else { 1401 require.Zero(t, msg.Timestamp) 1402 } 1403 return nil 1404 }, 1405 } 1406 1407 subscribeClient(t, client, "test") 1408 1409 client.eventHub.publishHandler = func(e PublishEvent, cb PublishCallback) { 1410 var msg testClientMessage 1411 err := json.Unmarshal(e.Data, &msg) 1412 require.NoError(t, err) 1413 if msg.Input == "with disconnect" { 1414 cb(PublishReply{}, DisconnectBadRequest) 1415 return 1416 } 1417 if msg.Input == "with error" { 1418 cb(PublishReply{}, ErrorBadRequest) 1419 return 1420 } 1421 if msg.Input == "with timestamp" { 1422 msg.Timestamp = time.Now().Unix() 1423 data, _ := json.Marshal(msg) 1424 res, err := node.Publish(e.Channel, data) 1425 require.NoError(t, err) 1426 cb(PublishReply{ 1427 Result: &res, 1428 }, nil) 1429 return 1430 } 1431 cb(PublishReply{}, nil) 1432 } 1433 1434 rwWrapper := testReplyWriterWrapper() 1435 err := client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 1436 Channel: "test", 1437 Data: []byte(`{"input": "no time"}`), 1438 }), rwWrapper.rw) 1439 require.NoError(t, err) 1440 require.Nil(t, rwWrapper.replies[0].Error) 1441 1442 rwWrapper = testReplyWriterWrapper() 1443 err = client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 1444 Channel: "test", 1445 Data: []byte(`{"input": "with timestamp"}`), 1446 }), rwWrapper.rw) 1447 require.NoError(t, err) 1448 require.Nil(t, rwWrapper.replies[0].Error) 1449 1450 rwWrapper = testReplyWriterWrapper() 1451 err = client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 1452 Channel: "test", 1453 Data: []byte(`{"input": "with error"}`), 1454 }), rwWrapper.rw) 1455 require.NoError(t, err) 1456 require.Equal(t, ErrorBadRequest.toProto(), rwWrapper.replies[0].Error) 1457 1458 rwWrapper = testReplyWriterWrapper() 1459 err = client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 1460 Channel: "test", 1461 Data: []byte(`{"input": "with disconnect"}`), 1462 }), rwWrapper.rw) 1463 require.NoError(t, err) 1464 select { 1465 case <-client.Context().Done(): 1466 case <-time.After(time.Second): 1467 require.Fail(t, "client not closed") 1468 } 1469} 1470 1471func TestClientPublishError(t *testing.T) { 1472 broker := NewTestBroker() 1473 broker.errorOnPublish = true 1474 node := nodeWithBroker(broker) 1475 defer func() { _ = node.Shutdown(context.Background()) }() 1476 1477 node.OnConnect(func(client *Client) { 1478 client.OnPublish(func(event PublishEvent, cb PublishCallback) { 1479 require.Equal(t, "test", event.Channel) 1480 require.NotNil(t, event.ClientInfo) 1481 cb(PublishReply{}, nil) 1482 }) 1483 }) 1484 1485 client := newTestClient(t, node, "42") 1486 connectClient(t, client) 1487 1488 rwWrapper := testReplyWriterWrapper() 1489 err := client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 1490 Channel: "test", 1491 Data: []byte(`{"input": "no time"}`), 1492 }), rwWrapper.rw) 1493 require.NoError(t, err) 1494 require.Equal(t, ErrorInternal.toProto(), rwWrapper.replies[0].Error) 1495} 1496 1497func TestClientPing(t *testing.T) { 1498 node := defaultTestNode() 1499 defer func() { _ = node.Shutdown(context.Background()) }() 1500 client := newTestClient(t, node, "42") 1501 1502 connectClient(t, client) 1503 1504 rwWrapper := testReplyWriterWrapper() 1505 err := client.handlePing(getJSONEncodedParams(t, &protocol.PingRequest{}), rwWrapper.rw) 1506 require.NoError(t, err) 1507 require.Nil(t, rwWrapper.replies[0].Error) 1508 require.Empty(t, rwWrapper.replies[0].Result) 1509} 1510 1511func TestClientPresence(t *testing.T) { 1512 node := defaultNodeNoHandlers() 1513 defer func() { _ = node.Shutdown(context.Background()) }() 1514 1515 client := newTestClient(t, node, "42") 1516 1517 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 1518 cb(SubscribeReply{Options: SubscribeOptions{Presence: true}}, nil) 1519 }) 1520 1521 client.OnPresence(func(e PresenceEvent, cb PresenceCallback) { 1522 cb(PresenceReply{}, nil) 1523 }) 1524 client.OnPresenceStats(func(e PresenceStatsEvent, cb PresenceStatsCallback) { 1525 cb(PresenceStatsReply{}, nil) 1526 }) 1527 1528 connectClient(t, client) 1529 subscribeClient(t, client, "test") 1530 1531 rwWrapper := testReplyWriterWrapper() 1532 err := client.handlePresence(getJSONEncodedParams(t, &protocol.PresenceRequest{ 1533 Channel: "", 1534 }), rwWrapper.rw) 1535 require.Equal(t, DisconnectBadRequest, err) 1536 1537 rwWrapper = testReplyWriterWrapper() 1538 err = client.handlePresence(getJSONEncodedParams(t, &protocol.PresenceRequest{ 1539 Channel: "test", 1540 }), rwWrapper.rw) 1541 require.NoError(t, err) 1542 require.Len(t, rwWrapper.replies, 1) 1543 require.Nil(t, rwWrapper.replies[0].Error) 1544 var result protocol.PresenceResult 1545 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 1546 require.NoError(t, err) 1547 require.Equal(t, 1, len(result.Presence)) 1548 1549 rwWrapper = testReplyWriterWrapper() 1550 err = client.handlePresenceStats(getJSONEncodedParams(t, &protocol.PresenceStatsRequest{ 1551 Channel: "", 1552 }), rwWrapper.rw) 1553 require.Equal(t, DisconnectBadRequest, err) 1554 1555 rwWrapper = testReplyWriterWrapper() 1556 err = client.handlePresenceStats(getJSONEncodedParams(t, &protocol.PresenceStatsRequest{ 1557 Channel: "test", 1558 }), rwWrapper.rw) 1559 require.NoError(t, err) 1560 require.Len(t, rwWrapper.replies, 1) 1561 require.Nil(t, rwWrapper.replies[0].Error) 1562} 1563 1564func TestClientPresenceTakeover(t *testing.T) { 1565 node := defaultNodeNoHandlers() 1566 defer func() { _ = node.Shutdown(context.Background()) }() 1567 1568 client := newTestClient(t, node, "42") 1569 1570 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 1571 cb(SubscribeReply{Options: SubscribeOptions{Presence: true}}, nil) 1572 }) 1573 1574 client.OnPresence(func(e PresenceEvent, cb PresenceCallback) { 1575 res, err := node.Presence(e.Channel) 1576 require.NoError(t, err) 1577 cb(PresenceReply{ 1578 Result: &res, 1579 }, nil) 1580 }) 1581 client.OnPresenceStats(func(e PresenceStatsEvent, cb PresenceStatsCallback) { 1582 res, err := node.PresenceStats(e.Channel) 1583 require.NoError(t, err) 1584 cb(PresenceStatsReply{ 1585 Result: &res, 1586 }, nil) 1587 }) 1588 1589 connectClient(t, client) 1590 subscribeClient(t, client, "test") 1591 1592 rwWrapper := testReplyWriterWrapper() 1593 err := client.handlePresence(getJSONEncodedParams(t, &protocol.PresenceRequest{ 1594 Channel: "test", 1595 }), rwWrapper.rw) 1596 require.NoError(t, err) 1597 require.Len(t, rwWrapper.replies, 1) 1598 require.Nil(t, rwWrapper.replies[0].Error) 1599 var result protocol.PresenceResult 1600 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 1601 require.NoError(t, err) 1602 require.Equal(t, 1, len(result.Presence)) 1603 1604 rwWrapper = testReplyWriterWrapper() 1605 err = client.handlePresenceStats(getJSONEncodedParams(t, &protocol.PresenceStatsRequest{ 1606 Channel: "test", 1607 }), rwWrapper.rw) 1608 require.NoError(t, err) 1609 require.Len(t, rwWrapper.replies, 1) 1610 require.Nil(t, rwWrapper.replies[0].Error) 1611} 1612 1613func TestClientPresenceError(t *testing.T) { 1614 presenceManager := NewTestPresenceManager() 1615 presenceManager.errorOnPresence = true 1616 node := nodeWithPresenceManager(presenceManager) 1617 defer func() { _ = node.Shutdown(context.Background()) }() 1618 1619 node.OnConnect(func(client *Client) { 1620 client.OnPresence(func(event PresenceEvent, cb PresenceCallback) { 1621 require.Equal(t, "test", event.Channel) 1622 cb(PresenceReply{}, nil) 1623 }) 1624 }) 1625 1626 client := newTestClient(t, node, "42") 1627 connectClient(t, client) 1628 1629 rwWrapper := testReplyWriterWrapper() 1630 err := client.handlePresence(getJSONEncodedParams(t, &protocol.PresenceRequest{ 1631 Channel: "test", 1632 }), rwWrapper.rw) 1633 require.NoError(t, err) 1634 require.Equal(t, ErrorInternal.toProto(), rwWrapper.replies[0].Error) 1635} 1636 1637func TestClientPresenceNotAvailable(t *testing.T) { 1638 node := defaultTestNode() 1639 defer func() { _ = node.Shutdown(context.Background()) }() 1640 1641 client := newTestClient(t, node, "42") 1642 1643 connectClient(t, client) 1644 subscribeClient(t, client, "test") 1645 1646 rwWrapper := testReplyWriterWrapper() 1647 err := client.handlePresence(getJSONEncodedParams(t, &protocol.PresenceRequest{ 1648 Channel: "test", 1649 }), rwWrapper.rw) 1650 require.Equal(t, ErrorNotAvailable, err) 1651} 1652 1653func TestClientSubscribeNotAvailable(t *testing.T) { 1654 node := defaultNodeNoHandlers() 1655 defer func() { _ = node.Shutdown(context.Background()) }() 1656 1657 client := newTestClient(t, node, "42") 1658 1659 connectClient(t, client) 1660 1661 rwWrapper := testReplyWriterWrapper() 1662 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 1663 Channel: "test", 1664 }), rwWrapper.rw) 1665 require.Equal(t, ErrorNotAvailable, err) 1666} 1667 1668func TestClientPresenceStatsNotAvailable(t *testing.T) { 1669 node := defaultTestNode() 1670 defer func() { _ = node.Shutdown(context.Background()) }() 1671 1672 client := newTestClient(t, node, "42") 1673 1674 connectClient(t, client) 1675 subscribeClient(t, client, "test") 1676 1677 rwWrapper := testReplyWriterWrapper() 1678 err := client.handlePresenceStats(getJSONEncodedParams(t, &protocol.PresenceStatsRequest{ 1679 Channel: "test", 1680 }), rwWrapper.rw) 1681 require.Equal(t, ErrorNotAvailable, err) 1682} 1683 1684func TestClientPresenceStatsError(t *testing.T) { 1685 presenceManager := NewTestPresenceManager() 1686 presenceManager.errorOnPresenceStats = true 1687 node := nodeWithPresenceManager(presenceManager) 1688 defer func() { _ = node.Shutdown(context.Background()) }() 1689 1690 node.OnConnect(func(client *Client) { 1691 client.OnPresenceStats(func(event PresenceStatsEvent, cb PresenceStatsCallback) { 1692 require.Equal(t, "test", event.Channel) 1693 cb(PresenceStatsReply{}, nil) 1694 }) 1695 }) 1696 1697 client := newTestClient(t, node, "42") 1698 connectClient(t, client) 1699 1700 rwWrapper := testReplyWriterWrapper() 1701 err := client.handlePresenceStats(getJSONEncodedParams(t, &protocol.PresenceStatsRequest{ 1702 Channel: "test", 1703 }), rwWrapper.rw) 1704 require.NoError(t, err) 1705 require.Equal(t, ErrorInternal.toProto(), rwWrapper.replies[0].Error) 1706} 1707 1708func TestClientHistoryNoFilter(t *testing.T) { 1709 node := defaultTestNode() 1710 defer func() { _ = node.Shutdown(context.Background()) }() 1711 1712 client := newTestClient(t, node, "42") 1713 1714 client.OnHistory(func(e HistoryEvent, cb HistoryCallback) { 1715 require.Nil(t, e.Filter.Since) 1716 require.Equal(t, 0, e.Filter.Limit) 1717 cb(HistoryReply{}, nil) 1718 }) 1719 1720 for i := 0; i < 10; i++ { 1721 _, _ = node.Publish("test", []byte(`{}`), WithHistory(10, time.Minute)) 1722 } 1723 1724 connectClient(t, client) 1725 subscribeClient(t, client, "test") 1726 1727 rwWrapper := testReplyWriterWrapper() 1728 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1729 Channel: "", 1730 }), rwWrapper.rw) 1731 require.Equal(t, DisconnectBadRequest, err) 1732 1733 rwWrapper = testReplyWriterWrapper() 1734 err = client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1735 Channel: "test", 1736 }), rwWrapper.rw) 1737 require.NoError(t, err) 1738 require.Len(t, rwWrapper.replies, 1) 1739 require.Nil(t, rwWrapper.replies[0].Error) 1740 var result protocol.HistoryResult 1741 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 1742 require.NoError(t, err) 1743 require.Equal(t, 0, len(result.Publications)) 1744 require.Equal(t, uint64(10), result.Offset) 1745 require.NotZero(t, result.Epoch) 1746} 1747 1748func TestClientHistoryWithLimit(t *testing.T) { 1749 node := defaultTestNode() 1750 defer func() { _ = node.Shutdown(context.Background()) }() 1751 1752 client := newTestClient(t, node, "42") 1753 1754 client.OnHistory(func(e HistoryEvent, cb HistoryCallback) { 1755 require.Nil(t, e.Filter.Since) 1756 require.Equal(t, 3, e.Filter.Limit) 1757 cb(HistoryReply{}, nil) 1758 }) 1759 1760 for i := 0; i < 10; i++ { 1761 _, _ = node.Publish("test", []byte(`{}`), WithHistory(10, time.Minute)) 1762 } 1763 1764 connectClient(t, client) 1765 subscribeClient(t, client, "test") 1766 1767 rwWrapper := testReplyWriterWrapper() 1768 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1769 Channel: "test", 1770 Limit: 3, 1771 }), rwWrapper.rw) 1772 require.NoError(t, err) 1773 require.Len(t, rwWrapper.replies, 1) 1774 require.Nil(t, rwWrapper.replies[0].Error) 1775 var result protocol.HistoryResult 1776 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 1777 require.NoError(t, err) 1778 require.Equal(t, 3, len(result.Publications)) 1779 require.Equal(t, uint64(10), result.Offset) 1780 require.NotZero(t, result.Epoch) 1781} 1782 1783func TestClientHistoryWithSinceAndLimit(t *testing.T) { 1784 node := defaultTestNode() 1785 defer func() { _ = node.Shutdown(context.Background()) }() 1786 1787 client := newTestClient(t, node, "42") 1788 1789 client.OnHistory(func(e HistoryEvent, cb HistoryCallback) { 1790 require.NotNil(t, e.Filter.Since) 1791 require.Equal(t, 2, e.Filter.Limit) 1792 cb(HistoryReply{}, nil) 1793 }) 1794 1795 var pubRes PublishResult 1796 for i := 0; i < 10; i++ { 1797 pubRes, _ = node.Publish("test", []byte(`{}`), WithHistory(10, time.Minute)) 1798 } 1799 1800 connectClient(t, client) 1801 subscribeClient(t, client, "test") 1802 1803 rwWrapper := testReplyWriterWrapper() 1804 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1805 Channel: "test", 1806 Limit: 2, 1807 Since: &protocol.StreamPosition{ 1808 Offset: 2, 1809 Epoch: pubRes.Epoch, 1810 }, 1811 }), rwWrapper.rw) 1812 require.NoError(t, err) 1813 require.Len(t, rwWrapper.replies, 1) 1814 require.Nil(t, rwWrapper.replies[0].Error) 1815 var result protocol.HistoryResult 1816 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 1817 require.NoError(t, err) 1818 require.Equal(t, 2, len(result.Publications)) 1819 require.Equal(t, uint64(4), result.Publications[1].Offset) 1820 require.Equal(t, uint64(10), result.Offset) 1821 require.NotZero(t, result.Epoch) 1822} 1823 1824func TestClientHistoryTakeover(t *testing.T) { 1825 node := defaultTestNode() 1826 node.config.HistoryMaxPublicationLimit = 2 1827 defer func() { _ = node.Shutdown(context.Background()) }() 1828 1829 client := newTestClient(t, node, "42") 1830 1831 client.OnHistory(func(e HistoryEvent, cb HistoryCallback) { 1832 require.Nil(t, e.Filter.Since) 1833 require.Equal(t, 2, e.Filter.Limit) 1834 // Change limit here, so 3 publications returned. 1835 res, err := node.History(e.Channel, WithLimit(e.Filter.Limit+1), WithSince(e.Filter.Since)) 1836 require.NoError(t, err) 1837 cb(HistoryReply{ 1838 Result: &res, 1839 }, nil) 1840 }) 1841 1842 for i := 0; i < 10; i++ { 1843 _, _ = node.Publish("test", []byte(`{}`), WithHistory(10, time.Minute)) 1844 } 1845 1846 connectClient(t, client) 1847 subscribeClient(t, client, "test") 1848 1849 rwWrapper := testReplyWriterWrapper() 1850 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1851 Channel: "test", 1852 Limit: 3, 1853 }), rwWrapper.rw) 1854 require.NoError(t, err) 1855 require.Len(t, rwWrapper.replies, 1) 1856 require.Nil(t, rwWrapper.replies[0].Error) 1857 var result protocol.HistoryResult 1858 err = json.Unmarshal(rwWrapper.replies[0].Result, &result) 1859 require.NoError(t, err) 1860 require.Equal(t, 3, len(result.Publications)) 1861 require.Equal(t, uint64(10), result.Offset) 1862 require.NotZero(t, result.Epoch) 1863} 1864 1865func TestClientHistoryUnrecoverablePositionEpoch(t *testing.T) { 1866 node := defaultTestNode() 1867 defer func() { _ = node.Shutdown(context.Background()) }() 1868 1869 client := newTestClient(t, node, "42") 1870 1871 client.OnHistory(func(e HistoryEvent, cb HistoryCallback) { 1872 require.NotNil(t, e.Filter.Since) 1873 require.Equal(t, 2, e.Filter.Limit) 1874 result, err := node.History(e.Channel, WithLimit(e.Filter.Limit), WithSince(e.Filter.Since), WithReverse(e.Filter.Reverse)) 1875 if err != nil { 1876 cb(HistoryReply{}, err) 1877 return 1878 } 1879 cb(HistoryReply{Result: &result}, nil) 1880 }) 1881 1882 for i := 0; i < 10; i++ { 1883 _, _ = node.Publish("test", []byte(`{}`), WithHistory(10, time.Minute)) 1884 } 1885 1886 connectClient(t, client) 1887 subscribeClient(t, client, "test") 1888 1889 rwWrapper := testReplyWriterWrapper() 1890 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1891 Channel: "test", 1892 Limit: 2, 1893 Since: &protocol.StreamPosition{ 1894 Offset: 2, 1895 Epoch: "wrong_one", 1896 }, 1897 }), rwWrapper.rw) 1898 require.NoError(t, err) 1899 require.Equal(t, ErrorUnrecoverablePosition.toProto(), rwWrapper.replies[0].Error) 1900} 1901 1902func TestClientHistoryBrokerError(t *testing.T) { 1903 broker := NewTestBroker() 1904 broker.errorOnHistory = true 1905 node := nodeWithBroker(broker) 1906 defer func() { _ = node.Shutdown(context.Background()) }() 1907 1908 node.OnConnect(func(client *Client) { 1909 client.OnHistory(func(event HistoryEvent, cb HistoryCallback) { 1910 require.Equal(t, "test", event.Channel) 1911 cb(HistoryReply{}, nil) 1912 }) 1913 }) 1914 1915 client := newTestClient(t, node, "42") 1916 connectClient(t, client) 1917 1918 rwWrapper := testReplyWriterWrapper() 1919 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1920 Channel: "test", 1921 }), rwWrapper.rw) 1922 require.NoError(t, err) 1923 require.Equal(t, ErrorInternal.toProto(), rwWrapper.replies[0].Error) 1924} 1925 1926func TestClientHistoryNotAvailable(t *testing.T) { 1927 node := defaultTestNode() 1928 defer func() { _ = node.Shutdown(context.Background()) }() 1929 1930 client := newTestClient(t, node, "42") 1931 1932 connectClient(t, client) 1933 subscribeClient(t, client, "test") 1934 1935 rwWrapper := testReplyWriterWrapper() 1936 err := client.handleHistory(getJSONEncodedParams(t, &protocol.HistoryRequest{ 1937 Channel: "test", 1938 }), rwWrapper.rw) 1939 require.Equal(t, ErrorNotAvailable, err) 1940} 1941 1942func TestClientCloseUnauthenticated(t *testing.T) { 1943 node := defaultTestNode() 1944 defer func() { _ = node.Shutdown(context.Background()) }() 1945 1946 node.config.ClientStaleCloseDelay = time.Millisecond 1947 1948 client := newTestClient(t, node, "42") 1949 select { 1950 case <-client.Context().Done(): 1951 case <-time.After(time.Second): 1952 require.Fail(t, "client not closed") 1953 } 1954 client.mu.Lock() 1955 require.True(t, client.status == statusClosed) 1956 client.mu.Unlock() 1957} 1958 1959func TestClientHandleUnidirectional(t *testing.T) { 1960 node := defaultTestNode() 1961 defer func() { _ = node.Shutdown(context.Background()) }() 1962 1963 client := newTestClient(t, node, "42") 1964 client.transport.(*testTransport).unidirectional = true 1965 proceed := client.Handle([]byte("test")) 1966 require.False(t, proceed) 1967 select { 1968 case <-client.Context().Done(): 1969 case <-time.After(time.Second): 1970 require.Fail(t, "client not closed") 1971 } 1972} 1973 1974func TestExtractUnidirectionalDisconnect(t *testing.T) { 1975 d := extractUnidirectionalDisconnect(nil) 1976 require.Nil(t, d) 1977 d = extractUnidirectionalDisconnect(errors.New("test")) 1978 require.Equal(t, DisconnectServerError, d) 1979 d = extractUnidirectionalDisconnect(ErrorLimitExceeded) 1980 require.Equal(t, DisconnectServerError, d) 1981 d = extractUnidirectionalDisconnect(DisconnectChannelLimit) 1982 require.Equal(t, DisconnectChannelLimit, d) 1983 d = extractUnidirectionalDisconnect(DisconnectServerError) 1984 require.Equal(t, DisconnectServerError, d) 1985 d = extractUnidirectionalDisconnect(ErrorExpired) 1986 require.Equal(t, DisconnectExpired, d) 1987} 1988 1989func TestClientHandleEmptyData(t *testing.T) { 1990 node := defaultTestNode() 1991 defer func() { _ = node.Shutdown(context.Background()) }() 1992 1993 client := newTestClient(t, node, "42") 1994 proceed := client.Handle(nil) 1995 require.False(t, proceed) 1996 select { 1997 case <-client.Context().Done(): 1998 case <-time.After(time.Second): 1999 require.Fail(t, "client not closed") 2000 } 2001 proceed = client.Handle([]byte("test")) 2002 require.False(t, proceed) 2003 disconnect := client.dispatchCommand(&protocol.Command{}) 2004 require.Nil(t, disconnect) 2005} 2006 2007func TestClientHandleBrokenData(t *testing.T) { 2008 node := defaultTestNode() 2009 defer func() { _ = node.Shutdown(context.Background()) }() 2010 2011 client := newTestClient(t, node, "42") 2012 proceed := client.Handle([]byte(`nd3487yt734y38&**&**`)) 2013 require.False(t, proceed) 2014 select { 2015 case <-client.Context().Done(): 2016 case <-time.After(time.Second): 2017 require.Fail(t, "client not closed") 2018 } 2019} 2020 2021func TestClientHandleCommandNotAuthenticated(t *testing.T) { 2022 node := defaultTestNode() 2023 defer func() { _ = node.Shutdown(context.Background()) }() 2024 2025 client := newTestClient(t, node, "42") 2026 params := getJSONEncodedParams(t, &protocol.SubscribeRequest{ 2027 Channel: "test", 2028 }) 2029 cmd := &protocol.Command{Id: 1, Method: protocol.Command_SUBSCRIBE, Params: params} 2030 data, err := json.Marshal(cmd) 2031 require.NoError(t, err) 2032 proceed := client.Handle(data) 2033 require.False(t, proceed) 2034 select { 2035 case <-client.Context().Done(): 2036 case <-time.After(time.Second): 2037 require.Fail(t, "client not closed") 2038 } 2039} 2040 2041func TestClientHandleUnknownMethod(t *testing.T) { 2042 node := defaultTestNode() 2043 defer func() { _ = node.Shutdown(context.Background()) }() 2044 2045 client := newTestClient(t, node, "42") 2046 params := getJSONEncodedParams(t, &protocol.SubscribeRequest{ 2047 Channel: "test", 2048 }) 2049 cmd := &protocol.Command{Id: 1, Method: 10000, Params: params} 2050 disconnect := client.dispatchCommand(cmd) 2051 require.Nil(t, disconnect) 2052} 2053 2054func TestClientHandleCommandWithBrokenParams(t *testing.T) { 2055 node := defaultTestNode() 2056 defer func() { _ = node.Shutdown(context.Background()) }() 2057 2058 var counterMu sync.Mutex 2059 var numDisconnectCalls int 2060 var numConnectCalls int 2061 var wg sync.WaitGroup 2062 wg.Add(11) 2063 2064 node.OnConnect(func(client *Client) { 2065 counterMu.Lock() 2066 numConnectCalls++ 2067 counterMu.Unlock() 2068 2069 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 2070 cb(SubscribeReply{}, nil) 2071 }) 2072 2073 client.OnRPC(func(e RPCEvent, cb RPCCallback) { 2074 cb(RPCReply{Data: []byte(`{"year": "2020"}`)}, nil) 2075 }) 2076 2077 client.OnMessage(func(event MessageEvent) {}) 2078 2079 client.OnHistory(func(e HistoryEvent, cb HistoryCallback) { 2080 cb(HistoryReply{}, nil) 2081 }) 2082 2083 client.OnPresence(func(e PresenceEvent, cb PresenceCallback) { 2084 cb(PresenceReply{}, nil) 2085 }) 2086 2087 client.OnPresenceStats(func(e PresenceStatsEvent, cb PresenceStatsCallback) { 2088 cb(PresenceStatsReply{}, nil) 2089 }) 2090 2091 client.OnRefresh(func(e RefreshEvent, cb RefreshCallback) { 2092 cb(RefreshReply{}, nil) 2093 }) 2094 2095 client.OnSubRefresh(func(e SubRefreshEvent, cb SubRefreshCallback) { 2096 cb(SubRefreshReply{}, nil) 2097 }) 2098 2099 client.OnPublish(func(e PublishEvent, cb PublishCallback) { 2100 cb(PublishReply{}, nil) 2101 }) 2102 2103 client.OnDisconnect(func(event DisconnectEvent) { 2104 counterMu.Lock() 2105 numDisconnectCalls++ 2106 counterMu.Unlock() 2107 require.Equal(t, DisconnectBadRequest, event.Disconnect) 2108 wg.Done() 2109 }) 2110 }) 2111 2112 client := newTestClient(t, node, "42") 2113 2114 data, err := json.Marshal(&protocol.Command{ 2115 Id: 1, Method: protocol.Command_CONNECT, Params: []byte("[]"), 2116 }) 2117 require.NoError(t, err) 2118 proceed := client.Handle(data) 2119 require.False(t, proceed) 2120 select { 2121 case <-client.Context().Done(): 2122 case <-time.After(time.Second): 2123 require.Fail(t, "client not closed") 2124 } 2125 // Check no connect and no disconnect event called. 2126 counterMu.Lock() 2127 require.Equal(t, 0, numDisconnectCalls) 2128 require.Equal(t, 0, numConnectCalls) 2129 counterMu.Unlock() 2130 2131 // Now check other methods. 2132 methods := []protocol.Command_MethodType{ 2133 protocol.Command_SUBSCRIBE, 2134 protocol.Command_PING, 2135 protocol.Command_PUBLISH, 2136 protocol.Command_UNSUBSCRIBE, 2137 protocol.Command_PRESENCE, 2138 protocol.Command_PRESENCE_STATS, 2139 protocol.Command_HISTORY, 2140 protocol.Command_REFRESH, 2141 protocol.Command_RPC, 2142 protocol.Command_SEND, 2143 protocol.Command_SUB_REFRESH, 2144 } 2145 2146 for _, method := range methods { 2147 client = newTestClient(t, node, "42") 2148 connectClient(t, client) 2149 data, err := json.Marshal(&protocol.Command{ 2150 Id: 1, Method: method, Params: []byte("[]"), 2151 }) 2152 require.NoError(t, err) 2153 proceed := client.Handle(data) 2154 require.False(t, proceed) 2155 select { 2156 case <-client.Context().Done(): 2157 case <-time.After(time.Second): 2158 require.Fail(t, "client not closed") 2159 } 2160 } 2161 2162 done := make(chan struct{}) 2163 go func() { 2164 wg.Wait() 2165 close(done) 2166 }() 2167 2168 select { 2169 case <-done: 2170 case <-time.After(5 * time.Second): 2171 require.Fail(t, "timeout waiting wait group done") 2172 } 2173} 2174 2175func TestClientOnAlive(t *testing.T) { 2176 node := defaultTestNode() 2177 node.config.ClientPresenceUpdateInterval = time.Second 2178 defer func() { _ = node.Shutdown(context.Background()) }() 2179 2180 done := make(chan struct{}) 2181 var closeOnce sync.Once 2182 2183 node.OnConnect(func(client *Client) { 2184 client.OnAlive(func() { 2185 closeOnce.Do(func() { 2186 close(done) 2187 }) 2188 }) 2189 }) 2190 2191 client := newTestClient(t, node, "42") 2192 connectClient(t, client) 2193 2194 select { 2195 case <-done: 2196 case <-time.After(5 * time.Second): 2197 require.Fail(t, "timeout waiting done close") 2198 } 2199} 2200 2201func TestClientHandleCommandWithoutID(t *testing.T) { 2202 node := defaultTestNode() 2203 defer func() { _ = node.Shutdown(context.Background()) }() 2204 2205 client := newTestClient(t, node, "42") 2206 params := getJSONEncodedParams(t, &protocol.ConnectRequest{}) 2207 cmd := &protocol.Command{Method: protocol.Command_CONNECT, Params: params} 2208 data, err := json.Marshal(cmd) 2209 require.NoError(t, err) 2210 proceed := client.Handle(data) 2211 require.False(t, proceed) 2212 select { 2213 case <-client.Context().Done(): 2214 case <-time.After(time.Second): 2215 require.Fail(t, "client not closed") 2216 } 2217} 2218 2219func TestErrorDisconnectContext(t *testing.T) { 2220 ctx := errorDisconnectContext(nil, DisconnectForceReconnect) 2221 require.Nil(t, ctx.err) 2222 require.Equal(t, DisconnectForceReconnect, ctx.disconnect) 2223 ctx = errorDisconnectContext(ErrorLimitExceeded, nil) 2224 require.Nil(t, ctx.disconnect) 2225 require.Equal(t, ErrorLimitExceeded, ctx.err) 2226} 2227 2228func TestToClientError(t *testing.T) { 2229 require.Equal(t, ErrorInternal, toClientErr(errors.New("boom"))) 2230 require.Equal(t, ErrorLimitExceeded, toClientErr(ErrorLimitExceeded)) 2231} 2232 2233func TestClientAlreadyAuthenticated(t *testing.T) { 2234 node := defaultTestNode() 2235 defer func() { _ = node.Shutdown(context.Background()) }() 2236 2237 client := newTestClient(t, node, "42") 2238 connectClient(t, client) 2239 2240 params := getJSONEncodedParams(t, &protocol.ConnectRequest{}) 2241 cmd := &protocol.Command{Id: 2, Method: protocol.Command_CONNECT, Params: params} 2242 data, err := json.Marshal(cmd) 2243 require.NoError(t, err) 2244 proceed := client.Handle(data) 2245 require.False(t, proceed) 2246 select { 2247 case <-client.Context().Done(): 2248 case <-time.After(time.Second): 2249 require.Fail(t, "client not closed") 2250 } 2251} 2252 2253func TestClientCloseExpired(t *testing.T) { 2254 node := defaultTestNode() 2255 defer func() { _ = node.Shutdown(context.Background()) }() 2256 2257 ctx, cancelFn := context.WithCancel(context.Background()) 2258 transport := newTestTransport(cancelFn) 2259 newCtx := SetCredentials(ctx, &Credentials{UserID: "42", ExpireAt: time.Now().Unix() + 2}) 2260 client, _ := newClient(newCtx, node, transport) 2261 connectClient(t, client) 2262 client.scheduleOnConnectTimers() 2263 client.mu.RLock() 2264 require.False(t, client.status == statusClosed) 2265 client.mu.RUnlock() 2266 select { 2267 case <-client.Context().Done(): 2268 case <-time.After(5 * time.Second): 2269 require.Fail(t, "client not closed") 2270 } 2271 client.mu.RLock() 2272 defer client.mu.RUnlock() 2273 require.True(t, client.status == statusClosed) 2274} 2275 2276func TestClientInfo(t *testing.T) { 2277 node := defaultTestNode() 2278 defer func() { _ = node.Shutdown(context.Background()) }() 2279 2280 ctx, cancelFn := context.WithCancel(context.Background()) 2281 transport := newTestTransport(cancelFn) 2282 newCtx := SetCredentials(ctx, &Credentials{UserID: "42", Info: []byte("info")}) 2283 client, _ := newClient(newCtx, node, transport) 2284 connectClient(t, client) 2285 require.Equal(t, []byte("info"), client.Info()) 2286} 2287 2288func TestClientConnectExpiredError(t *testing.T) { 2289 node := defaultTestNode() 2290 defer func() { _ = node.Shutdown(context.Background()) }() 2291 2292 transport := newTestTransport(func() {}) 2293 ctx := context.Background() 2294 newCtx := SetCredentials(ctx, &Credentials{UserID: "42", ExpireAt: time.Now().Unix() - 2}) 2295 client, _ := newClient(newCtx, node, transport) 2296 2297 rwWrapper := testReplyWriterWrapper() 2298 _, err := client.connectCmd(&protocol.ConnectRequest{}, rwWrapper.rw) 2299 require.Equal(t, ErrorExpired, err) 2300 require.False(t, client.authenticated) 2301} 2302 2303func TestClientPresenceUpdate(t *testing.T) { 2304 node := defaultNodeNoHandlers() 2305 defer func() { _ = node.Shutdown(context.Background()) }() 2306 2307 node.OnConnect(func(client *Client) { 2308 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 2309 cb(SubscribeReply{ 2310 Options: SubscribeOptions{Presence: true}, 2311 }, nil) 2312 }) 2313 }) 2314 2315 client := newTestClient(t, node, "42") 2316 2317 connectClient(t, client) 2318 subscribeClient(t, client, "test") 2319 2320 client.mu.RLock() 2321 chCtx, ok := client.channels["test"] 2322 client.mu.RUnlock() 2323 require.True(t, ok) 2324 2325 err := client.updateChannelPresence("test", chCtx) 2326 require.NoError(t, err) 2327} 2328 2329func TestClientSubExpired(t *testing.T) { 2330 node := defaultNodeNoHandlers() 2331 defer func() { _ = node.Shutdown(context.Background()) }() 2332 2333 node.config.ClientExpiredSubCloseDelay = 0 2334 node.config.ClientPresenceUpdateInterval = 10 * time.Millisecond 2335 2336 doneCh := make(chan struct{}) 2337 2338 node.OnConnect(func(client *Client) { 2339 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 2340 cb(SubscribeReply{ 2341 Options: SubscribeOptions{ 2342 ExpireAt: time.Now().Unix() + 1, 2343 Presence: true, 2344 }, 2345 }, nil) 2346 }) 2347 2348 client.OnDisconnect(func(event DisconnectEvent) { 2349 if event.Disconnect == DisconnectSubExpired { 2350 close(doneCh) 2351 } 2352 }) 2353 }) 2354 2355 client := newTestClient(t, node, "42") 2356 connectClient(t, client) 2357 subscribeClient(t, client, "test") 2358 2359 select { 2360 case <-doneCh: 2361 case <-time.After(5 * time.Second): 2362 t.Fatal("timeout waiting for disconnect due to expired subscription") 2363 } 2364} 2365 2366func TestClientSend(t *testing.T) { 2367 node := defaultTestNode() 2368 defer func() { _ = node.Shutdown(context.Background()) }() 2369 2370 client := newTestClient(t, node, "42") 2371 2372 connectClient(t, client) 2373 2374 err := client.Send([]byte(`{}`)) 2375 require.NoError(t, err) 2376 2377 err = client.close(nil) 2378 require.NoError(t, err) 2379 2380 err = client.Send([]byte(`{}`)) 2381 require.Error(t, err) 2382 require.Equal(t, io.EOF, err) 2383} 2384 2385func TestClientClose(t *testing.T) { 2386 node := defaultTestNode() 2387 defer func() { _ = node.Shutdown(context.Background()) }() 2388 2389 client := newTestClient(t, node, "42") 2390 connectClient(t, client) 2391 2392 err := client.close(DisconnectShutdown) 2393 require.NoError(t, err) 2394 require.True(t, client.transport.(*testTransport).closed) 2395 require.Equal(t, DisconnectShutdown, client.transport.(*testTransport).disconnect) 2396} 2397 2398func TestClientHandleRPCNotAvailable(t *testing.T) { 2399 node := defaultTestNode() 2400 defer func() { _ = node.Shutdown(context.Background()) }() 2401 2402 client := newTestClient(t, node, "42") 2403 connectClient(t, client) 2404 2405 rwWrapper := testReplyWriterWrapper() 2406 2407 err := client.handleRPC(getJSONEncodedParams(t, &protocol.RPCRequest{ 2408 Method: "xxx", 2409 }), rwWrapper.rw) 2410 require.Equal(t, ErrorNotAvailable, err) 2411} 2412 2413func TestClientHandleRPC(t *testing.T) { 2414 node := defaultTestNode() 2415 defer func() { _ = node.Shutdown(context.Background()) }() 2416 2417 client := newTestClient(t, node, "42") 2418 2419 var rpcHandlerCalled bool 2420 2421 node.OnConnect(func(client *Client) { 2422 client.OnRPC(func(event RPCEvent, cb RPCCallback) { 2423 rpcHandlerCalled = true 2424 expectedData := []byte("{}") 2425 require.Equal(t, expectedData, event.Data) 2426 cb(RPCReply{}, nil) 2427 }) 2428 }) 2429 2430 connectClient(t, client) 2431 2432 rwWrapper := testReplyWriterWrapper() 2433 2434 err := client.handleRPC(getJSONEncodedParams(t, &protocol.RPCRequest{ 2435 Data: []byte("{}"), 2436 }), rwWrapper.rw) 2437 require.NoError(t, err) 2438 require.Nil(t, rwWrapper.replies[0].Error) 2439 require.True(t, rpcHandlerCalled) 2440} 2441 2442func TestClientHandleSendNoHandlerSet(t *testing.T) { 2443 node := defaultTestNode() 2444 defer func() { _ = node.Shutdown(context.Background()) }() 2445 client := newTestClient(t, node, "42") 2446 connectClient(t, client) 2447 2448 rwWrapper := testReplyWriterWrapper() 2449 err := client.handleSend(getJSONEncodedParams(t, &protocol.SendRequest{ 2450 Data: []byte(`{"data":"hello"}`), 2451 }), rwWrapper.rw) 2452 require.NoError(t, err) 2453} 2454 2455func TestClientHandleSend(t *testing.T) { 2456 node := defaultTestNode() 2457 defer func() { _ = node.Shutdown(context.Background()) }() 2458 2459 client := newTestClient(t, node, "42") 2460 2461 var messageHandlerCalled bool 2462 2463 client.OnMessage(func(event MessageEvent) { 2464 messageHandlerCalled = true 2465 expectedData := []byte(`{"data":"hello"}`) 2466 require.Equal(t, expectedData, event.Data) 2467 }) 2468 connectClient(t, client) 2469 2470 rwWrapper := testReplyWriterWrapper() 2471 2472 err := client.handleSend(getJSONEncodedParams(t, &protocol.SendRequest{ 2473 Data: []byte(`{"data":"hello"}`), 2474 }), rwWrapper.rw) 2475 require.NoError(t, err) 2476 require.True(t, messageHandlerCalled) 2477} 2478 2479func TestClientHandlePublishNotAllowed(t *testing.T) { 2480 node := defaultNodeNoHandlers() 2481 defer func() { _ = node.Shutdown(context.Background()) }() 2482 2483 client := newTestClient(t, node, "42") 2484 2485 node.OnConnect(func(client *Client) { 2486 client.OnPublish(func(_ PublishEvent, cb PublishCallback) { 2487 cb(PublishReply{}, ErrorPermissionDenied) 2488 }) 2489 }) 2490 2491 connectClient(t, client) 2492 2493 rwWrapper := testReplyWriterWrapper() 2494 2495 err := client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 2496 Data: []byte(`{"hello": 1}`), 2497 Channel: "test", 2498 }), rwWrapper.rw) 2499 require.NoError(t, err) 2500 require.Equal(t, ErrorPermissionDenied.toProto(), rwWrapper.replies[0].Error) 2501} 2502 2503func TestClientHandlePublish(t *testing.T) { 2504 node := defaultNodeNoHandlers() 2505 defer func() { _ = node.Shutdown(context.Background()) }() 2506 2507 client := newTestClient(t, node, "42") 2508 2509 node.OnConnect(func(client *Client) { 2510 client.OnPublish(func(event PublishEvent, cb PublishCallback) { 2511 expectedData := []byte(`{"hello":1}`) 2512 require.Equal(t, expectedData, event.Data) 2513 require.Equal(t, "test", event.Channel) 2514 cb(PublishReply{}, nil) 2515 }) 2516 }) 2517 2518 connectClient(t, client) 2519 2520 rwWrapper := testReplyWriterWrapper() 2521 err := client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 2522 Data: []byte(`{"hello":1}`), 2523 Channel: "", 2524 }), rwWrapper.rw) 2525 require.Equal(t, DisconnectBadRequest, err) 2526 2527 rwWrapper = testReplyWriterWrapper() 2528 err = client.handlePublish(getJSONEncodedParams(t, &protocol.PublishRequest{ 2529 Data: []byte(`{"hello":1}`), 2530 Channel: "test", 2531 }), rwWrapper.rw) 2532 require.NoError(t, err) 2533 require.Nil(t, rwWrapper.replies[0].Error) 2534} 2535 2536func TestClientSideRefresh(t *testing.T) { 2537 node := defaultNodeNoHandlers() 2538 defer func() { _ = node.Shutdown(context.Background()) }() 2539 2540 transport := newTestTransport(func() {}) 2541 ctx := context.Background() 2542 newCtx := SetCredentials(ctx, &Credentials{ 2543 UserID: "42", 2544 ExpireAt: time.Now().Unix() + 60, 2545 }) 2546 client, _ := newClient(newCtx, node, transport) 2547 2548 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 2549 return ConnectReply{ 2550 ClientSideRefresh: true, 2551 }, nil 2552 }) 2553 2554 expireAt := time.Now().Unix() + 60 2555 2556 node.OnConnect(func(client *Client) { 2557 client.OnRefresh(func(e RefreshEvent, cb RefreshCallback) { 2558 require.Equal(t, "test", e.Token) 2559 cb(RefreshReply{ 2560 ExpireAt: expireAt, 2561 }, nil) 2562 }) 2563 }) 2564 2565 connectClient(t, client) 2566 2567 rwWrapper := testReplyWriterWrapper() 2568 2569 err := client.handleRefresh(getJSONEncodedParams(t, &protocol.RefreshRequest{ 2570 Token: "test", 2571 }), rwWrapper.rw) 2572 require.NoError(t, err) 2573 require.Nil(t, rwWrapper.replies[0].Error) 2574} 2575 2576func TestServerSideRefresh(t *testing.T) { 2577 node := defaultNodeNoHandlers() 2578 defer func() { _ = node.Shutdown(context.Background()) }() 2579 2580 ctx, cancelFn := context.WithCancel(context.Background()) 2581 transport := newTestTransport(cancelFn) 2582 2583 startExpireAt := time.Now().Unix() + 1 2584 newCtx := SetCredentials(ctx, &Credentials{ 2585 UserID: "42", 2586 ExpireAt: startExpireAt, 2587 }) 2588 client, _ := newClient(newCtx, node, transport) 2589 2590 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 2591 return ConnectReply{ 2592 ClientSideRefresh: false, 2593 }, nil 2594 }) 2595 2596 expireAt := time.Now().Unix() + 60 2597 2598 done := make(chan struct{}) 2599 2600 node.OnConnect(func(client *Client) { 2601 client.OnRefresh(func(e RefreshEvent, cb RefreshCallback) { 2602 require.Equal(t, "", e.Token) 2603 require.False(t, e.ClientSideRefresh) 2604 cb(RefreshReply{ 2605 ExpireAt: expireAt, 2606 Info: []byte("{}"), 2607 }, nil) 2608 close(done) 2609 }) 2610 }) 2611 2612 connectClient(t, client) 2613 2614 select { 2615 case <-time.After(5 * time.Second): 2616 require.Fail(t, "timeout waiting for work done") 2617 case <-done: 2618 } 2619 2620 require.True(t, client.nextExpire > startExpireAt) 2621 require.Equal(t, client.info, []byte("{}")) 2622} 2623 2624func TestServerSideRefreshDisconnect(t *testing.T) { 2625 node := defaultNodeNoHandlers() 2626 defer func() { _ = node.Shutdown(context.Background()) }() 2627 2628 ctx, cancelFn := context.WithCancel(context.Background()) 2629 transport := newTestTransport(cancelFn) 2630 2631 startExpireAt := time.Now().Unix() + 1 2632 newCtx := SetCredentials(ctx, &Credentials{ 2633 UserID: "42", 2634 ExpireAt: startExpireAt, 2635 }) 2636 client, _ := newClient(newCtx, node, transport) 2637 2638 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 2639 return ConnectReply{ 2640 ClientSideRefresh: false, 2641 }, nil 2642 }) 2643 2644 done := make(chan struct{}) 2645 disconnected := make(chan struct{}) 2646 2647 node.OnConnect(func(client *Client) { 2648 client.OnRefresh(func(e RefreshEvent, cb RefreshCallback) { 2649 require.Equal(t, "", e.Token) 2650 require.False(t, e.ClientSideRefresh) 2651 cb(RefreshReply{}, DisconnectExpired) 2652 close(done) 2653 }) 2654 client.OnDisconnect(func(event DisconnectEvent) { 2655 require.Equal(t, DisconnectExpired, event.Disconnect) 2656 close(disconnected) 2657 }) 2658 }) 2659 2660 connectClient(t, client) 2661 2662 select { 2663 case <-time.After(5 * time.Second): 2664 require.Fail(t, "timeout waiting for work done") 2665 case <-done: 2666 } 2667 2668 select { 2669 case <-time.After(5 * time.Second): 2670 require.Fail(t, "timeout waiting for client close") 2671 case <-disconnected: 2672 } 2673} 2674 2675func TestServerSideRefreshCustomError(t *testing.T) { 2676 node := defaultNodeNoHandlers() 2677 defer func() { _ = node.Shutdown(context.Background()) }() 2678 2679 ctx, cancelFn := context.WithCancel(context.Background()) 2680 transport := newTestTransport(cancelFn) 2681 2682 startExpireAt := time.Now().Unix() + 1 2683 newCtx := SetCredentials(ctx, &Credentials{ 2684 UserID: "42", 2685 ExpireAt: startExpireAt, 2686 }) 2687 client, _ := newClient(newCtx, node, transport) 2688 2689 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 2690 return ConnectReply{ 2691 ClientSideRefresh: false, 2692 }, nil 2693 }) 2694 2695 done := make(chan struct{}) 2696 disconnected := make(chan struct{}) 2697 2698 node.OnConnect(func(client *Client) { 2699 client.OnRefresh(func(e RefreshEvent, cb RefreshCallback) { 2700 require.Equal(t, "", e.Token) 2701 require.False(t, e.ClientSideRefresh) 2702 cb(RefreshReply{}, errors.New("boom")) 2703 close(done) 2704 }) 2705 client.OnDisconnect(func(event DisconnectEvent) { 2706 require.Equal(t, DisconnectServerError, event.Disconnect) 2707 close(disconnected) 2708 }) 2709 }) 2710 2711 connectClient(t, client) 2712 2713 select { 2714 case <-time.After(5 * time.Second): 2715 require.Fail(t, "timeout waiting for work done") 2716 case <-done: 2717 } 2718 2719 select { 2720 case <-time.After(5 * time.Second): 2721 require.Fail(t, "timeout waiting for client close") 2722 case <-disconnected: 2723 } 2724} 2725 2726func TestClientSideSubRefresh(t *testing.T) { 2727 node := defaultNodeNoHandlers() 2728 defer func() { _ = node.Shutdown(context.Background()) }() 2729 2730 transport := newTestTransport(func() {}) 2731 ctx := context.Background() 2732 newCtx := SetCredentials(ctx, &Credentials{ 2733 UserID: "42", 2734 ExpireAt: time.Now().Unix() + 60, 2735 }) 2736 client, _ := newClient(newCtx, node, transport) 2737 2738 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 2739 return ConnectReply{ 2740 ClientSideRefresh: true, 2741 }, nil 2742 }) 2743 2744 expireAt := time.Now().Unix() + 60 2745 2746 node.OnConnect(func(client *Client) { 2747 client.OnSubscribe(func(_ SubscribeEvent, cb SubscribeCallback) { 2748 cb(SubscribeReply{ 2749 Options: SubscribeOptions{ 2750 ExpireAt: time.Now().Unix() + 10, 2751 }, 2752 ClientSideRefresh: true, 2753 }, nil) 2754 }) 2755 }) 2756 2757 connectClient(t, client) 2758 subscribeClient(t, client, "test") 2759 2760 rwWrapper := testReplyWriterWrapper() 2761 err := client.handleSubRefresh(getJSONEncodedParams(t, &protocol.SubRefreshRequest{ 2762 Channel: "test", 2763 Token: "test_token", 2764 }), rwWrapper.rw) 2765 require.Equal(t, ErrorNotAvailable, err) 2766 2767 client.OnSubRefresh(func(e SubRefreshEvent, cb SubRefreshCallback) { 2768 require.Equal(t, "test_token", e.Token) 2769 cb(SubRefreshReply{ 2770 ExpireAt: expireAt, 2771 }, nil) 2772 }) 2773 2774 rwWrapper = testReplyWriterWrapper() 2775 err = client.handleSubRefresh(getJSONEncodedParams(t, &protocol.SubRefreshRequest{ 2776 Channel: "test", 2777 Token: "", 2778 }), rwWrapper.rw) 2779 require.Equal(t, ErrorBadRequest, err) 2780 2781 rwWrapper = testReplyWriterWrapper() 2782 err = client.handleSubRefresh(getJSONEncodedParams(t, &protocol.SubRefreshRequest{ 2783 Channel: "test1", 2784 Token: "test_token", 2785 }), rwWrapper.rw) 2786 require.Equal(t, ErrorPermissionDenied, err) 2787 2788 rwWrapper = testReplyWriterWrapper() 2789 err = client.handleSubRefresh(getJSONEncodedParams(t, &protocol.SubRefreshRequest{ 2790 Channel: "", 2791 Token: "test_token", 2792 }), rwWrapper.rw) 2793 require.Equal(t, DisconnectBadRequest, err) 2794 2795 rwWrapper = testReplyWriterWrapper() 2796 err = client.handleSubRefresh(getJSONEncodedParams(t, &protocol.SubRefreshRequest{ 2797 Channel: "test", 2798 Token: "test_token", 2799 }), rwWrapper.rw) 2800 require.NoError(t, err) 2801 require.Nil(t, rwWrapper.replies[0].Error) 2802} 2803 2804func TestClientSideSubRefreshUnexpected(t *testing.T) { 2805 node := defaultNodeNoHandlers() 2806 defer func() { _ = node.Shutdown(context.Background()) }() 2807 2808 transport := newTestTransport(func() {}) 2809 ctx := context.Background() 2810 newCtx := SetCredentials(ctx, &Credentials{ 2811 UserID: "42", 2812 ExpireAt: time.Now().Unix() + 60, 2813 }) 2814 client, _ := newClient(newCtx, node, transport) 2815 2816 node.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { 2817 return ConnectReply{ 2818 ClientSideRefresh: true, 2819 }, nil 2820 }) 2821 2822 expireAt := time.Now().Unix() + 60 2823 2824 node.OnConnect(func(client *Client) { 2825 client.OnSubscribe(func(_ SubscribeEvent, cb SubscribeCallback) { 2826 cb(SubscribeReply{ 2827 ClientSideRefresh: false, 2828 }, nil) 2829 }) 2830 2831 client.OnSubRefresh(func(e SubRefreshEvent, cb SubRefreshCallback) { 2832 require.Equal(t, "test_token", e.Token) 2833 cb(SubRefreshReply{ 2834 ExpireAt: expireAt, 2835 }, nil) 2836 }) 2837 }) 2838 2839 connectClient(t, client) 2840 subscribeClient(t, client, "test") 2841 2842 rwWrapper := testReplyWriterWrapper() 2843 err := client.handleSubRefresh(getJSONEncodedParams(t, &protocol.SubRefreshRequest{ 2844 Channel: "test", 2845 Token: "test_token", 2846 }), rwWrapper.rw) 2847 require.Equal(t, DisconnectBadRequest, err) 2848} 2849 2850func TestCloseNoRace(t *testing.T) { 2851 node := defaultTestNode() 2852 defer func() { _ = node.Shutdown(context.Background()) }() 2853 2854 done := make(chan struct{}) 2855 2856 node.OnConnect(func(client *Client) { 2857 client.Disconnect(DisconnectForceNoReconnect) 2858 time.Sleep(time.Second) 2859 client.OnDisconnect(func(_ DisconnectEvent) { 2860 close(done) 2861 }) 2862 }) 2863 2864 client := newTestClient(t, node, "42") 2865 connectClient(t, client) 2866 2867 select { 2868 case <-time.After(time.Second): 2869 require.Fail(t, "timeout waiting for work done") 2870 case <-done: 2871 } 2872} 2873 2874func TestClientCheckSubscriptionExpiration(t *testing.T) { 2875 node := defaultTestNode() 2876 defer func() { _ = node.Shutdown(context.Background()) }() 2877 2878 client := newTestClient(t, node, "42") 2879 2880 var nowTime time.Time 2881 node.mu.Lock() 2882 node.nowTimeGetter = func() time.Time { 2883 return nowTime 2884 } 2885 node.mu.Unlock() 2886 2887 chanCtx := channelContext{expireAt: 100} 2888 2889 // not expired. 2890 nowTime = time.Unix(100, 0) 2891 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2892 require.True(t, b) 2893 }) 2894 2895 // simple refresh unavailable. 2896 nowTime = time.Unix(200, 0) 2897 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2898 require.False(t, b) 2899 }) 2900 2901 // refreshed but expired. 2902 client.eventHub.subRefreshHandler = func(event SubRefreshEvent, cb SubRefreshCallback) { 2903 require.Equal(t, "channel", event.Channel) 2904 cb(SubRefreshReply{Expired: true}, nil) 2905 } 2906 nowTime = time.Unix(200, 0) 2907 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2908 require.False(t, b) 2909 }) 2910 2911 // refreshed but not really. 2912 client.eventHub.subRefreshHandler = func(event SubRefreshEvent, cb SubRefreshCallback) { 2913 require.Equal(t, "channel", event.Channel) 2914 cb(SubRefreshReply{ExpireAt: 150}, nil) 2915 } 2916 nowTime = time.Unix(200, 0) 2917 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2918 require.False(t, b) 2919 }) 2920 2921 // refreshed but unknown channel. 2922 client.eventHub.subRefreshHandler = func(event SubRefreshEvent, cb SubRefreshCallback) { 2923 require.Equal(t, "channel", event.Channel) 2924 cb(SubRefreshReply{ 2925 ExpireAt: 250, 2926 Info: []byte("info"), 2927 }, nil) 2928 } 2929 nowTime = time.Unix(200, 0) 2930 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2931 require.True(t, b) 2932 }) 2933 require.NotContains(t, client.channels, "channel") 2934 2935 // refreshed. 2936 client.channels["channel"] = channelContext{} 2937 client.eventHub.subRefreshHandler = func(event SubRefreshEvent, cb SubRefreshCallback) { 2938 require.Equal(t, "channel", event.Channel) 2939 cb(SubRefreshReply{ 2940 ExpireAt: 250, 2941 Info: []byte("info"), 2942 }, nil) 2943 } 2944 nowTime = time.Unix(200, 0) 2945 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2946 require.True(t, b) 2947 }) 2948 require.Contains(t, client.channels, "channel") 2949 require.EqualValues(t, 250, client.channels["channel"].expireAt) 2950 require.Equal(t, []byte("info"), client.channels["channel"].Info) 2951 2952 // Error from handler. 2953 client.eventHub.subRefreshHandler = func(event SubRefreshEvent, cb SubRefreshCallback) { 2954 cb(SubRefreshReply{}, DisconnectExpired) 2955 } 2956 nowTime = time.Unix(200, 0) 2957 client.checkSubscriptionExpiration("channel", chanCtx, 50*time.Second, func(b bool) { 2958 require.False(t, b) 2959 }) 2960} 2961 2962func TestClientCheckPosition(t *testing.T) { 2963 node := defaultTestNode() 2964 defer func() { _ = node.Shutdown(context.Background()) }() 2965 2966 client := newTestClient(t, node, "42") 2967 2968 node.mu.Lock() 2969 node.nowTimeGetter = func() time.Time { 2970 return time.Unix(200, 0) 2971 } 2972 node.mu.Unlock() 2973 2974 // no recover. 2975 got := client.checkPosition(300*time.Second, "channel", channelContext{}) 2976 require.True(t, got) 2977 2978 // not initial, not time to check. 2979 got = client.checkPosition(300*time.Second, "channel", channelContext{positionCheckTime: 50, flags: flagRecover}) 2980 require.True(t, got) 2981 2982 // invalid position. 2983 client.channels["channel"] = channelContext{positionCheckFailures: 2, flags: flagRecover} 2984 got = client.checkPosition(50*time.Second, "channel", channelContext{ 2985 positionCheckTime: 50, flags: flagRecover, 2986 }) 2987 require.False(t, got) 2988 require.Contains(t, client.channels, "channel") 2989 require.EqualValues(t, 3, client.channels["channel"].positionCheckFailures) 2990 require.EqualValues(t, 200, client.channels["channel"].positionCheckTime) 2991 2992 // valid position resets positionCheckFailures. 2993 require.NotZero(t, client.channels["channel"].positionCheckFailures) 2994 sp, _ := node.streamTop("channel") 2995 got = client.checkPosition(50*time.Second, "channel", channelContext{ 2996 positionCheckTime: 50, flags: flagRecover, streamPosition: sp, 2997 }) 2998 require.True(t, got) 2999 require.Zero(t, client.channels["channel"].positionCheckFailures) 3000} 3001 3002func TestErrLogLevel(t *testing.T) { 3003 require.Equal(t, LogLevelInfo, errLogLevel(ErrorNotAvailable)) 3004 require.Equal(t, LogLevelError, errLogLevel(errors.New("boom"))) 3005} 3006 3007func TestClientTransportWriteError(t *testing.T) { 3008 testCases := []struct { 3009 Name string 3010 Error error 3011 ExpectedDisconnect *Disconnect 3012 }{ 3013 {"disconnect", DisconnectSlow, DisconnectSlow}, 3014 {"other", errors.New("boom"), DisconnectWriteError}, 3015 } 3016 3017 for _, tt := range testCases { 3018 t.Run(tt.Name, func(t *testing.T) { 3019 node := defaultTestNode() 3020 defer func() { _ = node.Shutdown(context.Background()) }() 3021 transport := newTestTransport(func() {}) 3022 transport.sink = make(chan []byte, 100) 3023 transport.writeErr = tt.Error 3024 3025 done := make(chan *Disconnect) 3026 3027 node.OnConnect(func(client *Client) { 3028 client.OnDisconnect(func(event DisconnectEvent) { 3029 done <- event.Disconnect 3030 }) 3031 }) 3032 3033 ctx := context.Background() 3034 newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) 3035 client, _ := newClient(newCtx, node, transport) 3036 3037 connectClient(t, client) 3038 3039 rwWrapper := testReplyWriterWrapper() 3040 3041 subCtx := client.subscribeCmd(&protocol.SubscribeRequest{ 3042 Channel: "test", 3043 }, SubscribeReply{}, rwWrapper.rw, false) 3044 require.Nil(t, subCtx.disconnect) 3045 require.Nil(t, rwWrapper.replies[0].Error) 3046 3047 _, err := node.Publish("test", []byte(`{"text": "test message"}`)) 3048 require.NoError(t, err) 3049 3050 select { 3051 case <-time.After(time.Second): 3052 require.Fail(t, "client not closed") 3053 case d := <-done: 3054 require.Equal(t, tt.ExpectedDisconnect, d) 3055 } 3056 }) 3057 } 3058} 3059 3060func TestFlagExists(t *testing.T) { 3061 flags := PushFlagDisconnect 3062 require.True(t, hasFlag(flags, PushFlagDisconnect)) 3063} 3064 3065func TestFlagNotExists(t *testing.T) { 3066 var flags uint64 3067 require.False(t, hasFlag(flags, PushFlagDisconnect)) 3068} 3069 3070func TestConcurrentSameChannelSubscribe(t *testing.T) { 3071 node := defaultNodeNoHandlers() 3072 defer func() { _ = node.Shutdown(context.Background()) }() 3073 3074 var wg sync.WaitGroup 3075 concurrency := 10 3076 wg.Add(concurrency) 3077 3078 onSubscribe := make(chan struct{}) 3079 3080 node.OnConnect(func(client *Client) { 3081 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3082 go func() { 3083 cb(SubscribeReply{ 3084 Options: SubscribeOptions{ 3085 Recover: true, 3086 }, 3087 }, nil) 3088 close(onSubscribe) 3089 }() 3090 }) 3091 }) 3092 3093 client := newTestClient(t, node, "42") 3094 connectClient(t, client) 3095 3096 var subscribeErrors []string 3097 var mu sync.Mutex 3098 3099 for i := 0; i < concurrency; i++ { 3100 go func() { 3101 defer wg.Done() 3102 rwWrapper := testReplyWriterWrapper() 3103 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3104 Channel: "test1", 3105 Recover: true, 3106 Offset: 0, 3107 }), rwWrapper.rw) 3108 mu.Lock() 3109 defer mu.Unlock() 3110 if err != nil { 3111 subscribeErrors = append(subscribeErrors, err.Error()) 3112 } else { 3113 subscribeErrors = append(subscribeErrors, "nil") 3114 } 3115 }() 3116 } 3117 3118 wg.Wait() 3119 3120 <-onSubscribe 3121 3122 var n int 3123 for _, e := range subscribeErrors { 3124 if e == "105: already subscribed" { 3125 n++ 3126 } 3127 } 3128 require.Equal(t, concurrency-1, n) 3129} 3130 3131type slowHistoryBroker struct { 3132 startPublishingCh chan struct{} 3133 stopPublishingCh chan struct{} 3134 *MemoryBroker 3135 err error 3136} 3137 3138func (b *slowHistoryBroker) setError(err error) { 3139 b.err = err 3140} 3141 3142func (b *slowHistoryBroker) History(ch string, filter HistoryFilter) ([]*Publication, StreamPosition, error) { 3143 close(b.startPublishingCh) 3144 res, sp, err := b.MemoryBroker.History(ch, filter) 3145 <-b.stopPublishingCh 3146 if b.err != nil { 3147 return nil, StreamPosition{}, b.err 3148 } 3149 return res, sp, err 3150} 3151 3152func TestSubscribeWithBufferedPublications(t *testing.T) { 3153 c := DefaultConfig 3154 c.LogLevel = LogLevelTrace 3155 c.LogHandler = func(entry LogEntry) {} 3156 node, err := New(c) 3157 if err != nil { 3158 panic(err) 3159 } 3160 startPublishingCh := make(chan struct{}) 3161 stopPublishingCh := make(chan struct{}) 3162 broker, err := NewMemoryBroker(node, MemoryBrokerConfig{}) 3163 require.NoError(t, err) 3164 node.SetBroker(&slowHistoryBroker{startPublishingCh: startPublishingCh, stopPublishingCh: stopPublishingCh, MemoryBroker: broker}) 3165 err = node.Run() 3166 require.NoError(t, err) 3167 defer func() { _ = node.Shutdown(context.Background()) }() 3168 3169 node.OnConnect(func(client *Client) { 3170 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3171 cb(SubscribeReply{ 3172 Options: SubscribeOptions{ 3173 Recover: true, 3174 }, 3175 }, nil) 3176 }) 3177 }) 3178 3179 client := newTestClient(t, node, "42") 3180 connectClient(t, client) 3181 3182 rwWrapper := testReplyWriterWrapper() 3183 go func() { 3184 <-startPublishingCh 3185 for i := 0; i < 5; i++ { 3186 _, err := node.Publish("test1", []byte(`{}`), WithHistory(100, 60*time.Second)) 3187 require.NoError(t, err) 3188 } 3189 close(stopPublishingCh) 3190 }() 3191 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3192 Channel: "test1", 3193 Recover: true, 3194 Offset: 0, 3195 }), rwWrapper.rw) 3196 require.NoError(t, err) 3197 require.Equal(t, 1, len(rwWrapper.replies)) 3198 require.Nil(t, rwWrapper.replies[0].Error) 3199 res := extractSubscribeResult(rwWrapper.replies, client.Transport().Protocol()) 3200 require.Equal(t, uint64(5), res.Offset) 3201 require.True(t, res.Recovered) 3202 require.Len(t, res.Publications, 5) 3203 require.Equal(t, 1, len(client.Channels())) 3204} 3205 3206func TestClientChannelsWhileSubscribing(t *testing.T) { 3207 node := defaultNodeNoHandlers() 3208 defer func() { _ = node.Shutdown(context.Background()) }() 3209 3210 waitCh := make(chan struct{}) 3211 doneCh := make(chan struct{}) 3212 3213 node.OnConnect(func(client *Client) { 3214 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3215 go func() { 3216 <-waitCh 3217 cb(SubscribeReply{ 3218 Options: SubscribeOptions{}, 3219 }, nil) 3220 close(doneCh) 3221 }() 3222 }) 3223 }) 3224 3225 client := newTestClient(t, node, "42") 3226 connectClient(t, client) 3227 3228 rwWrapper := testReplyWriterWrapper() 3229 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3230 Channel: "test1", 3231 }), rwWrapper.rw) 3232 require.NoError(t, err) 3233 require.Equal(t, 0, len(client.Channels())) 3234 require.False(t, client.IsSubscribed("test1")) 3235 close(waitCh) 3236 <-doneCh 3237 require.Equal(t, 1, len(client.Channels())) 3238} 3239 3240func TestClientChannelsCleanupOnSubscribeError(t *testing.T) { 3241 node := defaultNodeNoHandlers() 3242 defer func() { _ = node.Shutdown(context.Background()) }() 3243 3244 node.OnConnect(func(client *Client) { 3245 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3246 cb(SubscribeReply{}, ErrorInternal) 3247 }) 3248 }) 3249 3250 client := newTestClient(t, node, "42") 3251 connectClient(t, client) 3252 3253 rwWrapper := testReplyWriterWrapper() 3254 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3255 Channel: "test1", 3256 }), rwWrapper.rw) 3257 require.NoError(t, err) 3258 require.Len(t, client.channels, 0) 3259} 3260 3261func TestClientChannelsCleanupOnSubscribeDisconnect(t *testing.T) { 3262 node := defaultNodeNoHandlers() 3263 defer func() { _ = node.Shutdown(context.Background()) }() 3264 3265 node.OnConnect(func(client *Client) { 3266 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3267 cb(SubscribeReply{}, DisconnectChannelLimit) 3268 }) 3269 }) 3270 3271 client := newTestClient(t, node, "42") 3272 connectClient(t, client) 3273 3274 rwWrapper := testReplyWriterWrapper() 3275 err := client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3276 Channel: "test1", 3277 }), rwWrapper.rw) 3278 require.NoError(t, err) 3279 require.Len(t, client.channels, 0) 3280} 3281 3282func TestClientSubscribingChannelsCleanupOnClientClose(t *testing.T) { 3283 c := DefaultConfig 3284 c.LogLevel = LogLevelTrace 3285 c.LogHandler = func(entry LogEntry) {} 3286 node, err := New(c) 3287 if err != nil { 3288 panic(err) 3289 } 3290 startPublishingCh := make(chan struct{}) 3291 stopPublishingCh := make(chan struct{}) 3292 disconnectedCh := make(chan struct{}) 3293 broker, err := NewMemoryBroker(node, MemoryBrokerConfig{}) 3294 require.NoError(t, err) 3295 node.SetBroker(&slowHistoryBroker{startPublishingCh: startPublishingCh, stopPublishingCh: stopPublishingCh, MemoryBroker: broker}) 3296 err = node.Run() 3297 require.NoError(t, err) 3298 defer func() { _ = node.Shutdown(context.Background()) }() 3299 3300 node.OnConnect(func(client *Client) { 3301 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3302 go func() { 3303 cb(SubscribeReply{ 3304 Options: SubscribeOptions{ 3305 Recover: true, 3306 }, 3307 }, nil) 3308 }() 3309 }) 3310 3311 client.OnDisconnect(func(event DisconnectEvent) { 3312 close(disconnectedCh) 3313 }) 3314 }) 3315 3316 client := newTestClient(t, node, "42") 3317 connectClient(t, client) 3318 3319 rwWrapper := testReplyWriterWrapper() 3320 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3321 Channel: "test1", 3322 }), rwWrapper.rw) 3323 require.NoError(t, err) 3324 3325 <-startPublishingCh 3326 close(stopPublishingCh) 3327 client.Disconnect(DisconnectNormal) 3328 <-disconnectedCh 3329 require.Len(t, node.Hub().Channels(), 0, node.Hub().Channels()) 3330} 3331 3332func TestClientSubscribingChannelsCleanupOnHistoryError(t *testing.T) { 3333 c := DefaultConfig 3334 c.LogLevel = LogLevelTrace 3335 c.LogHandler = func(entry LogEntry) {} 3336 node, err := New(c) 3337 if err != nil { 3338 panic(err) 3339 } 3340 startPublishingCh := make(chan struct{}) 3341 stopPublishingCh := make(chan struct{}) 3342 broker, err := NewMemoryBroker(node, MemoryBrokerConfig{}) 3343 require.NoError(t, err) 3344 slowBroker := &slowHistoryBroker{startPublishingCh: startPublishingCh, stopPublishingCh: stopPublishingCh, MemoryBroker: broker} 3345 slowBroker.setError(ErrorNotAvailable) 3346 node.SetBroker(slowBroker) 3347 err = node.Run() 3348 require.NoError(t, err) 3349 defer func() { _ = node.Shutdown(context.Background()) }() 3350 3351 node.OnConnect(func(client *Client) { 3352 client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { 3353 cb(SubscribeReply{ 3354 Options: SubscribeOptions{ 3355 Recover: true, 3356 }, 3357 }, nil) 3358 }) 3359 }) 3360 3361 client := newTestClient(t, node, "42") 3362 connectClient(t, client) 3363 3364 close(stopPublishingCh) 3365 3366 rwWrapper := testReplyWriterWrapper() 3367 err = client.handleSubscribe(getJSONEncodedParams(t, &protocol.SubscribeRequest{ 3368 Channel: "test1", 3369 }), rwWrapper.rw) 3370 require.NoError(t, err) 3371 require.Len(t, node.Hub().Channels(), 0, node.Hub().Channels()) 3372} 3373