1package zk
2
3import (
4	"fmt"
5	"io"
6	"io/ioutil"
7	"math/rand"
8	"os"
9	"path/filepath"
10	"strings"
11	"testing"
12	"time"
13)
14
15const (
16	_testConfigName   = "zoo.cfg"
17	_testMyIDFileName = "myid"
18)
19
20func init() {
21	rand.Seed(time.Now().UnixNano())
22}
23
24type TestServer struct {
25	Port   int
26	Path   string
27	Srv    *server
28	Config ServerConfigServer
29}
30
31type TestCluster struct {
32	Path    string
33	Config  ServerConfig
34	Servers []TestServer
35}
36
37func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCluster, error) {
38	if testing.Short() {
39		t.Skip("ZK cluster tests skipped in short case.")
40	}
41	tmpPath, err := ioutil.TempDir("", "gozk")
42	requireNoError(t, err, "failed to create tmp dir for test server setup")
43
44	success := false
45	startPort := int(rand.Int31n(6000) + 10000)
46	cluster := &TestCluster{Path: tmpPath}
47
48	defer func() {
49		if !success {
50			cluster.Stop()
51		}
52	}()
53
54	for serverN := 0; serverN < size; serverN++ {
55		srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN+1))
56		if err := os.Mkdir(srvPath, 0700); err != nil {
57			requireNoError(t, err, "failed to make server path")
58		}
59
60		port := startPort + serverN*3
61		cfg := ServerConfig{
62			ClientPort: port,
63			DataDir:    srvPath,
64		}
65
66		for i := 0; i < size; i++ {
67			serverNConfig := ServerConfigServer{
68				ID:                 i + 1,
69				Host:               "127.0.0.1",
70				PeerPort:           startPort + i*3 + 1,
71				LeaderElectionPort: startPort + i*3 + 2,
72			}
73
74			cfg.Servers = append(cfg.Servers, serverNConfig)
75		}
76
77		cfgPath := filepath.Join(srvPath, _testConfigName)
78		fi, err := os.Create(cfgPath)
79		requireNoError(t, err)
80
81		requireNoError(t, cfg.Marshall(fi))
82		fi.Close()
83
84		fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName))
85		requireNoError(t, err)
86
87		_, err = fmt.Fprintf(fi, "%d\n", serverN+1)
88		fi.Close()
89		requireNoError(t, err)
90
91		srv, err := NewIntegrationTestServer(t, cfgPath, stdout, stderr)
92		requireNoError(t, err)
93
94		if err := srv.Start(); err != nil {
95			return nil, err
96		}
97
98		cluster.Servers = append(cluster.Servers, TestServer{
99			Path:   srvPath,
100			Port:   cfg.ClientPort,
101			Srv:    srv,
102			Config: cfg.Servers[serverN],
103		})
104		cluster.Config = cfg
105	}
106
107	if err := cluster.waitForStart(30, time.Second); err != nil {
108		return nil, err
109	}
110
111	success = true
112
113	return cluster, nil
114}
115
116func (tc *TestCluster) Connect(idx int) (*Conn, <-chan Event, error) {
117	return Connect([]string{fmt.Sprintf("127.0.0.1:%d", tc.Servers[idx].Port)}, time.Second*15)
118}
119
120func (tc *TestCluster) ConnectAll() (*Conn, <-chan Event, error) {
121	return tc.ConnectAllTimeout(time.Second * 15)
122}
123
124func (tc *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) {
125	return tc.ConnectWithOptions(sessionTimeout)
126}
127
128func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
129	hosts := make([]string, len(tc.Servers))
130	for i, srv := range tc.Servers {
131		hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port)
132	}
133	zk, ch, err := Connect(hosts, sessionTimeout, options...)
134	return zk, ch, err
135}
136
137func (tc *TestCluster) Stop() error {
138	for _, srv := range tc.Servers {
139		srv.Srv.Stop()
140	}
141	defer os.RemoveAll(tc.Path)
142	return tc.waitForStop(5, time.Second)
143}
144
145// waitForStart blocks until the cluster is up
146func (tc *TestCluster) waitForStart(maxRetry int, interval time.Duration) error {
147	// verify that the servers are up with SRVR
148	serverAddrs := make([]string, len(tc.Servers))
149	for i, s := range tc.Servers {
150		serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port)
151	}
152
153	for i := 0; i < maxRetry; i++ {
154		_, ok := FLWSrvr(serverAddrs, time.Second)
155		if ok {
156			return nil
157		}
158		time.Sleep(interval)
159	}
160
161	return fmt.Errorf("unable to verify health of servers")
162}
163
164// waitForStop blocks until the cluster is down
165func (tc *TestCluster) waitForStop(maxRetry int, interval time.Duration) error {
166	// verify that the servers are up with RUOK
167	serverAddrs := make([]string, len(tc.Servers))
168	for i, s := range tc.Servers {
169		serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port)
170	}
171
172	var success bool
173	for i := 0; i < maxRetry && !success; i++ {
174		success = true
175		for _, ok := range FLWRuok(serverAddrs, time.Second) {
176			if ok {
177				success = false
178			}
179		}
180		if !success {
181			time.Sleep(interval)
182		}
183	}
184	if !success {
185		return fmt.Errorf("unable to verify servers are down")
186	}
187	return nil
188}
189
190func (tc *TestCluster) StartServer(server string) {
191	for _, s := range tc.Servers {
192		if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) {
193			s.Srv.Start()
194			return
195		}
196	}
197	panic(fmt.Sprintf("unknown server: %s", server))
198}
199
200func (tc *TestCluster) StopServer(server string) {
201	for _, s := range tc.Servers {
202		if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) {
203			s.Srv.Stop()
204			return
205		}
206	}
207	panic(fmt.Sprintf("unknown server: %s", server))
208}
209
210func (tc *TestCluster) StartAllServers() error {
211	for _, s := range tc.Servers {
212		if err := s.Srv.Start(); err != nil {
213			return fmt.Errorf("failed to start server listening on port `%d` : %+v", s.Port, err)
214		}
215	}
216
217	if err := tc.waitForStart(10, time.Second*2); err != nil {
218		return fmt.Errorf("failed to wait to startup zk servers: %v", err)
219	}
220
221	return nil
222}
223
224func (tc *TestCluster) StopAllServers() error {
225	var err error
226	for _, s := range tc.Servers {
227		if err := s.Srv.Stop(); err != nil {
228			err = fmt.Errorf("failed to stop server listening on port `%d` : %v", s.Port, err)
229		}
230	}
231	if err != nil {
232		return err
233	}
234
235	if err := tc.waitForStop(5, time.Second); err != nil {
236		return fmt.Errorf("failed to wait to startup zk servers: %v", err)
237	}
238
239	return nil
240}
241
242func requireNoError(t *testing.T, err error, msgAndArgs ...interface{}) {
243	if err != nil {
244		t.Logf("received unexpected error: %v", err)
245		t.Fatal(msgAndArgs...)
246	}
247}
248