1/*
2Copyright 2018 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package connrotation
18
19import (
20	"context"
21	"net"
22	"sync"
23	"sync/atomic"
24	"testing"
25	"time"
26)
27
28func TestCloseAll(t *testing.T) {
29	closed := make(chan struct{}, 50)
30	dialFn := func(ctx context.Context, network, address string) (net.Conn, error) {
31		return closeOnlyConn{onClose: func() { closed <- struct{}{} }}, nil
32	}
33	dialer := NewDialer(dialFn)
34
35	const numConns = 10
36
37	// Outer loop to ensure Dialer is re-usable after CloseAll.
38	for i := 0; i < 5; i++ {
39		for j := 0; j < numConns; j++ {
40			if _, err := dialer.Dial("", ""); err != nil {
41				t.Fatal(err)
42			}
43		}
44		dialer.CloseAll()
45		deadline := time.After(time.Second)
46		for j := 0; j < numConns; j++ {
47			select {
48			case <-closed:
49			case <-deadline:
50				t.Fatalf("iteration %d: 1s after CloseAll only %d/%d connections closed", i, j, numConns)
51			}
52		}
53	}
54}
55
56// TestCloseAllRace ensures CloseAll works with connections being simultaneously dialed
57func TestCloseAllRace(t *testing.T) {
58	conns := int64(0)
59	dialer := NewDialer(func(ctx context.Context, network, address string) (net.Conn, error) {
60		return closeOnlyConn{onClose: func() { atomic.AddInt64(&conns, -1) }}, nil
61	})
62
63	const raceCount = 5000
64	begin := &sync.WaitGroup{}
65	begin.Add(1)
66
67	wg := &sync.WaitGroup{}
68
69	// Close all as fast as we can
70	wg.Add(1)
71	go func() {
72		begin.Wait()
73		defer wg.Done()
74		for i := 0; i < raceCount; i++ {
75			dialer.CloseAll()
76		}
77	}()
78
79	// Dial as fast as we can
80	wg.Add(1)
81	go func() {
82		begin.Wait()
83		defer wg.Done()
84		for i := 0; i < raceCount; i++ {
85			if _, err := dialer.Dial("", ""); err != nil {
86				t.Error(err)
87				return
88			}
89			atomic.AddInt64(&conns, 1)
90		}
91	}()
92
93	// Trigger both goroutines as close to the same time as possible
94	begin.Done()
95
96	// Wait for goroutines
97	wg.Wait()
98
99	// Ensure CloseAll ran after all dials
100	dialer.CloseAll()
101
102	// Expect all connections to close within 5 seconds
103	for start := time.Now(); time.Now().Sub(start) < 5*time.Second; time.Sleep(10 * time.Millisecond) {
104		// Ensure all connections were closed
105		if c := atomic.LoadInt64(&conns); c == 0 {
106			break
107		} else {
108			t.Logf("got %d open connections, want 0, will retry", c)
109		}
110	}
111	// Ensure all connections were closed
112	if c := atomic.LoadInt64(&conns); c != 0 {
113		t.Fatalf("got %d open connections, want 0", c)
114	}
115}
116
117type closeOnlyConn struct {
118	net.Conn
119	onClose func()
120}
121
122func (c closeOnlyConn) Close() error {
123	go c.onClose()
124	return nil
125}
126