1package centrifuge 2 3import ( 4 "context" 5 "fmt" 6 "io" 7 "strconv" 8 "strings" 9 "sync" 10 "testing" 11 "time" 12 13 "github.com/centrifugal/centrifuge/internal/prepared" 14 15 "github.com/centrifugal/protocol" 16 "github.com/stretchr/testify/require" 17) 18 19type testTransport struct { 20 mu sync.Mutex 21 sink chan []byte 22 closed bool 23 closeCh chan struct{} 24 disconnect *Disconnect 25 protoType ProtocolType 26 cancelFn func() 27 unidirectional bool 28 writeErr error 29} 30 31func newTestTransport(cancelFn func()) *testTransport { 32 return &testTransport{ 33 cancelFn: cancelFn, 34 protoType: ProtocolTypeJSON, 35 closeCh: make(chan struct{}), 36 unidirectional: false, 37 } 38} 39 40func (t *testTransport) setProtocolType(pType ProtocolType) { 41 t.protoType = pType 42} 43 44func (t *testTransport) setUnidirectional(uni bool) { 45 t.unidirectional = uni 46} 47 48func (t *testTransport) setSink(sink chan []byte) { 49 t.sink = sink 50} 51 52func (t *testTransport) Write(message []byte) error { 53 if t.writeErr != nil { 54 return t.writeErr 55 } 56 t.mu.Lock() 57 defer t.mu.Unlock() 58 if t.closed { 59 return io.EOF 60 } 61 if t.sink != nil { 62 t.sink <- message 63 } 64 return nil 65} 66 67func (t *testTransport) WriteMany(messages ...[]byte) error { 68 if t.writeErr != nil { 69 return t.writeErr 70 } 71 t.mu.Lock() 72 defer t.mu.Unlock() 73 if t.closed { 74 return io.EOF 75 } 76 for _, buf := range messages { 77 if t.sink != nil { 78 t.sink <- buf 79 } 80 } 81 return nil 82} 83 84func (t *testTransport) Name() string { 85 return transportWebsocket 86} 87 88func (t *testTransport) Protocol() ProtocolType { 89 return t.protoType 90} 91 92func (t *testTransport) Unidirectional() bool { 93 return t.unidirectional 94} 95 96func (t *testTransport) DisabledPushFlags() uint64 { 97 if t.Unidirectional() { 98 return 0 99 } 100 return PushFlagDisconnect 101} 102 103func (t *testTransport) Close(disconnect *Disconnect) error { 104 t.mu.Lock() 105 defer t.mu.Unlock() 106 if t.closed { 107 return nil 108 } 109 t.disconnect = disconnect 110 t.closed = true 111 t.cancelFn() 112 close(t.closeCh) 113 return nil 114} 115 116func TestHub(t *testing.T) { 117 h := newHub() 118 c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {})) 119 require.NoError(t, err) 120 c.user = "test" 121 err = h.remove(c) 122 require.NoError(t, err) 123 err = h.add(c) 124 require.NoError(t, err) 125 conns := h.UserConnections("test") 126 require.Equal(t, 1, len(conns)) 127 require.Equal(t, 1, h.NumClients()) 128 require.Equal(t, 1, h.NumUsers()) 129 130 validUID := c.uid 131 c.uid = "invalid" 132 err = h.remove(c) 133 require.NoError(t, err) 134 require.Len(t, h.UserConnections("test"), 1) 135 136 c.uid = validUID 137 err = h.remove(c) 138 require.NoError(t, err) 139 require.Len(t, h.UserConnections("test"), 0) 140} 141 142func TestHubUnsubscribe(t *testing.T) { 143 n := defaultTestNode() 144 defer func() { _ = n.Shutdown(context.Background()) }() 145 146 client := newTestSubscribedClient(t, n, "42", "test_channel") 147 transport := client.transport.(*testTransport) 148 transport.sink = make(chan []byte, 100) 149 150 // Unsubscribe not existed user. 151 err := n.hub.unsubscribe("1", "test_channel", "") 152 require.NoError(t, err) 153 154 // Unsubscribe subscribed user. 155 err = n.hub.unsubscribe("42", "test_channel", "") 156 require.NoError(t, err) 157 select { 158 case data := <-transport.sink: 159 require.Equal(t, "{\"result\":{\"type\":3,\"channel\":\"test_channel\",\"data\":{}}}", string(data)) 160 case <-time.After(2 * time.Second): 161 t.Fatal("no data in sink") 162 } 163 require.Zero(t, n.hub.NumSubscribers("test_channel")) 164} 165 166func TestHubDisconnect(t *testing.T) { 167 n := defaultNodeNoHandlers() 168 defer func() { _ = n.Shutdown(context.Background()) }() 169 170 n.OnConnect(func(client *Client) { 171 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 172 cb(SubscribeReply{}, nil) 173 }) 174 }) 175 176 client := newTestSubscribedClient(t, n, "42", "test_channel") 177 clientWithReconnect := newTestSubscribedClient(t, n, "24", "test_channel_reconnect") 178 require.Len(t, n.hub.UserConnections("42"), 1) 179 require.Len(t, n.hub.UserConnections("24"), 1) 180 require.Equal(t, 1, n.hub.NumSubscribers("test_channel")) 181 require.Equal(t, 1, n.hub.NumSubscribers("test_channel_reconnect")) 182 183 wg := sync.WaitGroup{} 184 wg.Add(2) 185 186 client.eventHub.disconnectHandler = func(e DisconnectEvent) { 187 defer wg.Done() 188 require.False(t, e.Disconnect.Reconnect) 189 } 190 191 clientWithReconnect.eventHub.disconnectHandler = func(e DisconnectEvent) { 192 defer wg.Done() 193 require.True(t, e.Disconnect.Reconnect) 194 } 195 196 // Disconnect not existed user. 197 err := n.hub.disconnect("1", DisconnectForceNoReconnect, "", nil) 198 require.NoError(t, err) 199 200 // Disconnect subscribed user. 201 err = n.hub.disconnect("42", DisconnectForceNoReconnect, "", nil) 202 require.NoError(t, err) 203 select { 204 case <-client.transport.(*testTransport).closeCh: 205 case <-time.After(2 * time.Second): 206 t.Fatal("no data in sink") 207 } 208 require.Len(t, n.hub.UserConnections("42"), 0) 209 require.Equal(t, 0, n.hub.NumSubscribers("test_channel")) 210 211 // Disconnect subscribed user with reconnect. 212 err = n.hub.disconnect("24", DisconnectForceReconnect, "", nil) 213 require.NoError(t, err) 214 select { 215 case <-clientWithReconnect.transport.(*testTransport).closeCh: 216 case <-time.After(2 * time.Second): 217 t.Fatal("no data in sink") 218 } 219 require.Len(t, n.hub.UserConnections("24"), 0) 220 require.Equal(t, 0, n.hub.NumSubscribers("test_channel_reconnect")) 221 222 wg.Wait() 223 224 require.Len(t, n.hub.UserConnections("24"), 0) 225 require.Len(t, n.hub.UserConnections("42"), 0) 226 require.Equal(t, 0, n.hub.NumSubscribers("test_channel")) 227 require.Equal(t, 0, n.hub.NumSubscribers("test_channel_reconnect")) 228} 229 230func TestHubDisconnect_ClientWhitelist(t *testing.T) { 231 n := defaultNodeNoHandlers() 232 defer func() { _ = n.Shutdown(context.Background()) }() 233 234 n.OnConnect(func(client *Client) { 235 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 236 cb(SubscribeReply{}, nil) 237 }) 238 }) 239 240 client := newTestSubscribedClient(t, n, "12", "test_channel") 241 clientToKeep := newTestSubscribedClient(t, n, "12", "test_channel") 242 243 require.Len(t, n.hub.UserConnections("12"), 2) 244 require.Equal(t, 2, n.hub.NumSubscribers("test_channel")) 245 246 shouldBeClosed := make(chan struct{}) 247 shouldNotBeClosed := make(chan struct{}) 248 249 client.eventHub.disconnectHandler = func(e DisconnectEvent) { 250 close(shouldBeClosed) 251 } 252 253 clientToKeep.eventHub.disconnectHandler = func(e DisconnectEvent) { 254 close(shouldNotBeClosed) 255 } 256 257 whitelist := []string{clientToKeep.ID()} 258 259 // Disconnect not existed user. 260 err := n.hub.disconnect("12", DisconnectConnectionLimit, "", whitelist) 261 require.NoError(t, err) 262 263 select { 264 case <-shouldBeClosed: 265 select { 266 case <-shouldNotBeClosed: 267 require.Fail(t, "client should not be disconnected") 268 case <-time.After(time.Second): 269 require.Len(t, n.hub.UserConnections("12"), 1) 270 require.Equal(t, 1, n.hub.NumSubscribers("test_channel")) 271 } 272 case <-time.After(time.Second): 273 require.Fail(t, "timeout waiting for channel close") 274 } 275} 276 277func TestHubDisconnect_ClientID(t *testing.T) { 278 n := defaultNodeNoHandlers() 279 defer func() { _ = n.Shutdown(context.Background()) }() 280 281 n.OnConnect(func(client *Client) { 282 client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { 283 cb(SubscribeReply{}, nil) 284 }) 285 }) 286 287 client := newTestSubscribedClient(t, n, "12", "test_channel") 288 clientToKeep := newTestSubscribedClient(t, n, "12", "test_channel") 289 290 require.Len(t, n.hub.UserConnections("12"), 2) 291 require.Equal(t, 2, n.hub.NumSubscribers("test_channel")) 292 293 shouldBeClosed := make(chan struct{}) 294 shouldNotBeClosed := make(chan struct{}) 295 296 client.eventHub.disconnectHandler = func(e DisconnectEvent) { 297 close(shouldBeClosed) 298 } 299 300 clientToKeep.eventHub.disconnectHandler = func(e DisconnectEvent) { 301 close(shouldNotBeClosed) 302 } 303 304 clientToDisconnect := client.ID() 305 306 // Disconnect not existed user. 307 err := n.hub.disconnect("12", DisconnectConnectionLimit, clientToDisconnect, nil) 308 require.NoError(t, err) 309 310 select { 311 case <-shouldBeClosed: 312 select { 313 case <-shouldNotBeClosed: 314 require.Fail(t, "client should not be disconnected") 315 case <-time.After(time.Second): 316 require.Len(t, n.hub.UserConnections("12"), 1) 317 require.Equal(t, 1, n.hub.NumSubscribers("test_channel")) 318 } 319 case <-time.After(time.Second): 320 require.Fail(t, "timeout waiting for channel close") 321 } 322} 323 324func TestHubBroadcastPublication(t *testing.T) { 325 tcs := []struct { 326 name string 327 protocolType ProtocolType 328 }{ 329 {name: "JSON", protocolType: ProtocolTypeJSON}, 330 {name: "Protobuf", protocolType: ProtocolTypeProtobuf}, 331 } 332 333 for _, tc := range tcs { 334 t.Run(tc.name, func(t *testing.T) { 335 n := defaultTestNode() 336 defer func() { _ = n.Shutdown(context.Background()) }() 337 338 client := newTestSubscribedClient(t, n, "42", "test_channel") 339 transport := client.transport.(*testTransport) 340 transport.sink = make(chan []byte, 100) 341 transport.protoType = tc.protocolType 342 343 // Broadcast to not existed channel. 344 err := n.hub.BroadcastPublication( 345 "not_test_channel", 346 &Publication{Data: []byte(`{"data": "broadcast_data"}`)}, 347 StreamPosition{}, 348 ) 349 require.NoError(t, err) 350 351 // Broadcast to existed channel. 352 err = n.hub.BroadcastPublication( 353 "test_channel", 354 &Publication{Data: []byte(`{"data": "broadcast_data"}`)}, 355 StreamPosition{}, 356 ) 357 require.NoError(t, err) 358 select { 359 case data := <-transport.sink: 360 require.True(t, strings.Contains(string(data), "broadcast_data")) 361 case <-time.After(2 * time.Second): 362 t.Fatal("no data in sink") 363 } 364 }) 365 } 366} 367 368func TestHubBroadcastJoin(t *testing.T) { 369 tcs := []struct { 370 name string 371 protocolType ProtocolType 372 }{ 373 {name: "JSON", protocolType: ProtocolTypeJSON}, 374 {name: "Protobuf", protocolType: ProtocolTypeProtobuf}, 375 } 376 377 for _, tc := range tcs { 378 t.Run(tc.name, func(t *testing.T) { 379 n := defaultTestNode() 380 defer func() { _ = n.Shutdown(context.Background()) }() 381 382 client := newTestSubscribedClient(t, n, "42", "test_channel") 383 transport := client.transport.(*testTransport) 384 transport.sink = make(chan []byte, 100) 385 transport.protoType = tc.protocolType 386 387 // Broadcast to not existed channel. 388 err := n.hub.broadcastJoin("not_test_channel", &ClientInfo{ClientID: "broadcast_client"}) 389 require.NoError(t, err) 390 391 // Broadcast to existed channel. 392 err = n.hub.broadcastJoin("test_channel", &ClientInfo{ClientID: "broadcast_client"}) 393 require.NoError(t, err) 394 select { 395 case data := <-transport.sink: 396 require.True(t, strings.Contains(string(data), "broadcast_client")) 397 case <-time.After(2 * time.Second): 398 t.Fatal("no data in sink") 399 } 400 }) 401 } 402} 403 404func TestHubBroadcastLeave(t *testing.T) { 405 tcs := []struct { 406 name string 407 protocolType ProtocolType 408 }{ 409 {name: "JSON", protocolType: ProtocolTypeJSON}, 410 {name: "Protobuf", protocolType: ProtocolTypeProtobuf}, 411 } 412 413 for _, tc := range tcs { 414 t.Run(tc.name, func(t *testing.T) { 415 n := defaultTestNode() 416 defer func() { _ = n.Shutdown(context.Background()) }() 417 418 client := newTestSubscribedClient(t, n, "42", "test_channel") 419 transport := client.transport.(*testTransport) 420 transport.sink = make(chan []byte, 100) 421 transport.protoType = tc.protocolType 422 423 // Broadcast to not existed channel. 424 err := n.hub.broadcastLeave("not_test_channel", &ClientInfo{ClientID: "broadcast_client"}) 425 require.NoError(t, err) 426 427 // Broadcast to existed channel. 428 err = n.hub.broadcastLeave("test_channel", &ClientInfo{ClientID: "broadcast_client"}) 429 require.NoError(t, err) 430 select { 431 case data := <-transport.sink: 432 require.Contains(t, string(data), "broadcast_client") 433 case <-time.After(2 * time.Second): 434 t.Fatal("no data in sink") 435 } 436 }) 437 } 438} 439 440func TestHubShutdown(t *testing.T) { 441 h := newHub() 442 err := h.shutdown(context.Background()) 443 require.NoError(t, err) 444 h = newHub() 445 c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {})) 446 require.NoError(t, err) 447 _ = h.add(c) 448 449 err = h.shutdown(context.Background()) 450 require.NoError(t, err) 451 452 ctxCanceled, cancel := context.WithCancel(context.Background()) 453 cancel() 454 err = h.shutdown(ctxCanceled) 455 require.EqualError(t, err, "context canceled") 456} 457 458func TestHubSubscriptions(t *testing.T) { 459 h := newHub() 460 c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {})) 461 require.NoError(t, err) 462 463 _, _ = h.addSub("test1", c) 464 _, _ = h.addSub("test2", c) 465 require.Equal(t, 2, h.NumChannels()) 466 require.Contains(t, h.Channels(), "test1") 467 require.Contains(t, h.Channels(), "test2") 468 require.NotZero(t, h.NumSubscribers("test1")) 469 require.NotZero(t, h.NumSubscribers("test2")) 470 471 // Not exited sub. 472 removed, err := h.removeSub("not_existed", c) 473 require.NoError(t, err) 474 require.True(t, removed) 475 476 // Exited sub with invalid uid. 477 validUID := c.uid 478 c.uid = "invalid" 479 removed, err = h.removeSub("test1", c) 480 require.NoError(t, err) 481 require.True(t, removed) 482 c.uid = validUID 483 484 // Exited sub. 485 removed, err = h.removeSub("test1", c) 486 require.NoError(t, err) 487 require.True(t, removed) 488 489 // Exited sub. 490 removed, err = h.removeSub("test2", c) 491 require.NoError(t, err) 492 require.True(t, removed) 493 494 require.Equal(t, h.NumChannels(), 0) 495 require.Zero(t, h.NumSubscribers("test1")) 496 require.Zero(t, h.NumSubscribers("test2")) 497} 498 499func TestPreparedReply(t *testing.T) { 500 reply := protocol.Reply{} 501 preparedReply := prepared.NewReply(&reply, protocol.TypeJSON) 502 data := preparedReply.Data() 503 require.NotNil(t, data) 504} 505 506func TestUserConnections(t *testing.T) { 507 h := newHub() 508 c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {})) 509 require.NoError(t, err) 510 _ = h.add(c) 511 512 connections := h.UserConnections(c.UserID()) 513 require.Equal(t, h.connShards[index(c.UserID(), numHubShards)].conns, connections) 514} 515 516func TestHubSharding(t *testing.T) { 517 numUsers := numHubShards * 10 518 numChannels := numHubShards * 10 519 520 channels := make([]string, 0, numChannels) 521 for i := 0; i < numChannels; i++ { 522 channels = append(channels, "ch"+strconv.Itoa(i)) 523 } 524 525 n := defaultTestNode() 526 defer func() { _ = n.Shutdown(context.Background()) }() 527 528 for j := 0; j < 2; j++ { // two connections from the same user. 529 for i := 0; i < numUsers; i++ { 530 c, err := newClient(context.Background(), n, newTestTransport(func() {})) 531 require.NoError(t, err) 532 c.user = strconv.Itoa(i) 533 require.NoError(t, err) 534 _ = n.hub.add(c) 535 for _, ch := range channels { 536 _, _ = n.hub.addSub(ch, c) 537 } 538 } 539 } 540 541 for i := range n.hub.connShards { 542 require.NotZero(t, n.hub.connShards[i].NumClients()) 543 require.NotZero(t, n.hub.connShards[i].NumUsers()) 544 } 545 for i := range n.hub.subShards { 546 require.True(t, len(n.hub.subShards[i].subs) > 0) 547 } 548 549 require.Equal(t, numUsers, n.Hub().NumUsers()) 550 require.Equal(t, 2*numUsers, n.Hub().NumClients()) 551 require.Equal(t, numChannels, n.Hub().NumChannels()) 552} 553 554// This benchmark allows to estimate the benefit from Hub sharding. 555// As we have a broadcasting goroutine here it's not very useful to look at 556// total allocations here - it's better to look at operation time. 557func BenchmarkHub_Contention(b *testing.B) { 558 numClients := 100 559 numChannels := 128 560 561 n := defaultTestNodeBenchmark(b) 562 563 var clients []*Client 564 var channels []string 565 566 for i := 0; i < numChannels; i++ { 567 channels = append(channels, "ch"+strconv.Itoa(i)) 568 } 569 570 for i := 0; i < numClients; i++ { 571 c, err := newClient(context.Background(), n, newTestTransport(func() {})) 572 require.NoError(b, err) 573 _ = n.hub.add(c) 574 clients = append(clients, c) 575 for _, ch := range channels { 576 _, _ = n.hub.addSub(ch, c) 577 } 578 } 579 580 pub := &Publication{ 581 Data: []byte(`{"input": "test"}`), 582 } 583 streamPosition := StreamPosition{} 584 585 b.ResetTimer() 586 b.RunParallel(func(pb *testing.PB) { 587 i := 0 588 for pb.Next() { 589 i++ 590 var wg sync.WaitGroup 591 wg.Add(1) 592 go func() { 593 defer wg.Done() 594 _ = n.hub.BroadcastPublication(channels[(i+numChannels/2)%numChannels], pub, streamPosition) 595 }() 596 _, _ = n.hub.addSub(channels[i%numChannels], clients[i%numClients]) 597 wg.Wait() 598 } 599 }) 600} 601 602var broadcastBenches = []struct { 603 NumSubscribers int 604}{ 605 {1000}, 606 {10000}, 607 {100000}, 608} 609 610// BenchmarkHub_MassiveBroadcast allows estimating time to broadcast 611// a single message to many subscribers inside one channel. 612func BenchmarkHub_MassiveBroadcast(b *testing.B) { 613 pub := &Publication{Data: []byte(`{"input": "test"}`)} 614 streamPosition := StreamPosition{} 615 616 for _, tt := range broadcastBenches { 617 numSubscribers := tt.NumSubscribers 618 b.Run(fmt.Sprintf("%d", numSubscribers), func(b *testing.B) { 619 b.ReportAllocs() 620 n := defaultTestNodeBenchmark(b) 621 622 numChannels := 64 623 channels := make([]string, 0, numChannels) 624 625 for i := 0; i < numChannels; i++ { 626 channels = append(channels, "broadcast"+strconv.Itoa(i)) 627 } 628 629 sink := make(chan []byte, 1024) 630 631 for i := 0; i < numSubscribers; i++ { 632 t := newTestTransport(func() {}) 633 t.setSink(sink) 634 c, err := newClient(context.Background(), n, t) 635 require.NoError(b, err) 636 _ = n.hub.add(c) 637 for _, ch := range channels { 638 _, _ = n.hub.addSub(ch, c) 639 } 640 } 641 642 b.ResetTimer() 643 for i := 0; i < b.N; i++ { 644 var wg sync.WaitGroup 645 wg.Add(1) 646 go func() { 647 defer wg.Done() 648 j := 0 649 for { 650 <-sink 651 j++ 652 if j == numSubscribers { 653 break 654 } 655 } 656 }() 657 _ = n.hub.BroadcastPublication(channels[i%numChannels], pub, streamPosition) 658 wg.Wait() 659 } 660 }) 661 } 662} 663