1package consul
2
3import (
4	"net/rpc"
5	"os"
6	"testing"
7
8	"github.com/hashicorp/consul/testrpc"
9	"github.com/stretchr/testify/require"
10)
11
12type testClusterConfig struct {
13	Datacenter string
14	Servers    int
15	Clients    int
16	ServerConf func(*Config)
17	ClientConf func(*Config)
18
19	ServerWait func(*testing.T, *Server)
20	ClientWait func(*testing.T, *Client)
21}
22
23type testCluster struct {
24	Servers      []*Server
25	ServerCodecs []rpc.ClientCodec
26	Clients      []*Client
27}
28
29func newTestCluster(t *testing.T, conf *testClusterConfig) *testCluster {
30	t.Helper()
31
32	require.NotNil(t, conf)
33	cluster := testCluster{}
34
35	// create the servers
36	for i := 0; i < conf.Servers; i++ {
37		dir, srv := testServerWithConfig(t, func(c *Config) {
38			if conf.Datacenter != "" {
39				c.Datacenter = conf.Datacenter
40			}
41			c.Bootstrap = false
42			c.BootstrapExpect = conf.Servers
43
44			if conf.ServerConf != nil {
45				conf.ServerConf(c)
46			}
47		})
48		t.Cleanup(func() { os.RemoveAll(dir) })
49		t.Cleanup(func() { srv.Shutdown() })
50
51		cluster.Servers = append(cluster.Servers, srv)
52
53		codec := rpcClient(t, srv)
54
55		cluster.ServerCodecs = append(cluster.ServerCodecs, codec)
56		t.Cleanup(func() { codec.Close() })
57
58		if i > 0 {
59			joinLAN(t, srv, cluster.Servers[0])
60		}
61	}
62
63	waitForLeaderEstablishment(t, cluster.Servers...)
64	if conf.ServerWait != nil {
65		for _, srv := range cluster.Servers {
66			conf.ServerWait(t, srv)
67		}
68	}
69
70	// create the clients
71	for i := 0; i < conf.Clients; i++ {
72		dir, client := testClientWithConfig(t, func(c *Config) {
73			if conf.Datacenter != "" {
74				c.Datacenter = conf.Datacenter
75			}
76			if conf.ClientConf != nil {
77				conf.ClientConf(c)
78			}
79		})
80
81		t.Cleanup(func() { os.RemoveAll(dir) })
82		t.Cleanup(func() { client.Shutdown() })
83
84		if len(cluster.Servers) > 0 {
85			joinLAN(t, client, cluster.Servers[0])
86		}
87
88		cluster.Clients = append(cluster.Clients, client)
89	}
90
91	for _, client := range cluster.Clients {
92		if conf.ClientWait != nil {
93			conf.ClientWait(t, client)
94		} else {
95			testrpc.WaitForTestAgent(t, client.RPC, client.config.Datacenter)
96		}
97	}
98
99	return &cluster
100}
101