1package consul
2
3import (
4	"fmt"
5	"os"
6	"strings"
7	"testing"
8	"time"
9
10	"github.com/hashicorp/consul/agent/structs"
11	"github.com/hashicorp/consul/sdk/testutil/retry"
12	"github.com/hashicorp/consul/testrpc"
13	"github.com/hashicorp/go-uuid"
14	"github.com/hashicorp/net-rpc-msgpackrpc"
15)
16
17func generateUUID() (ret string) {
18	var err error
19	if ret, err = uuid.GenerateUUID(); err != nil {
20		panic(fmt.Sprintf("Unable to generate a UUID, %v", err))
21	}
22	return ret
23}
24
25func TestInitializeSessionTimers(t *testing.T) {
26	if testing.Short() {
27		t.Skip("too slow for testing.Short")
28	}
29
30	t.Parallel()
31	dir1, s1 := testServer(t)
32	defer os.RemoveAll(dir1)
33	defer s1.Shutdown()
34
35	testrpc.WaitForLeader(t, s1.RPC, "dc1")
36
37	state := s1.fsm.State()
38	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
39		t.Fatalf("err: %s", err)
40	}
41	session := &structs.Session{
42		ID:   generateUUID(),
43		Node: "foo",
44		TTL:  "10s",
45	}
46	if err := state.SessionCreate(100, session); err != nil {
47		t.Fatalf("err: %v", err)
48	}
49
50	// Reset the session timers
51	err := s1.initializeSessionTimers()
52	if err != nil {
53		t.Fatalf("err: %v", err)
54	}
55
56	// Check that we have a timer
57	if s1.sessionTimers.Get(session.ID) == nil {
58		t.Fatalf("missing session timer")
59	}
60}
61
62func TestResetSessionTimer_Fault(t *testing.T) {
63	if testing.Short() {
64		t.Skip("too slow for testing.Short")
65	}
66
67	t.Parallel()
68	dir1, s1 := testServer(t)
69	defer os.RemoveAll(dir1)
70	defer s1.Shutdown()
71
72	testrpc.WaitForLeader(t, s1.RPC, "dc1")
73
74	// Should not exist
75	err := s1.resetSessionTimer(generateUUID(), nil)
76	if err == nil || !strings.Contains(err.Error(), "not found") {
77		t.Fatalf("err: %v", err)
78	}
79
80	// Create a session
81	state := s1.fsm.State()
82	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
83		t.Fatalf("err: %s", err)
84	}
85	session := &structs.Session{
86		ID:   generateUUID(),
87		Node: "foo",
88		TTL:  "10s",
89	}
90	if err := state.SessionCreate(100, session); err != nil {
91		t.Fatalf("err: %v", err)
92	}
93
94	// Reset the session timer
95	err = s1.resetSessionTimer(session.ID, nil)
96	if err != nil {
97		t.Fatalf("err: %v", err)
98	}
99
100	// Check that we have a timer
101	if s1.sessionTimers.Get(session.ID) == nil {
102		t.Fatalf("missing session timer")
103	}
104}
105
106func TestResetSessionTimer_NoTTL(t *testing.T) {
107	if testing.Short() {
108		t.Skip("too slow for testing.Short")
109	}
110
111	t.Parallel()
112	dir1, s1 := testServer(t)
113	defer os.RemoveAll(dir1)
114	defer s1.Shutdown()
115
116	testrpc.WaitForLeader(t, s1.RPC, "dc1")
117
118	// Create a session
119	state := s1.fsm.State()
120	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
121		t.Fatalf("err: %s", err)
122	}
123	session := &structs.Session{
124		ID:   generateUUID(),
125		Node: "foo",
126		TTL:  "0000s",
127	}
128	if err := state.SessionCreate(100, session); err != nil {
129		t.Fatalf("err: %v", err)
130	}
131
132	// Reset the session timer
133	err := s1.resetSessionTimer(session.ID, session)
134	if err != nil {
135		t.Fatalf("err: %v", err)
136	}
137
138	// Check that we have a timer
139	if s1.sessionTimers.Get(session.ID) != nil {
140		t.Fatalf("should not have session timer")
141	}
142}
143
144func TestResetSessionTimer_InvalidTTL(t *testing.T) {
145	t.Parallel()
146	dir1, s1 := testServer(t)
147	defer os.RemoveAll(dir1)
148	defer s1.Shutdown()
149
150	// Create a session
151	session := &structs.Session{
152		ID:   generateUUID(),
153		Node: "foo",
154		TTL:  "foo",
155	}
156
157	// Reset the session timer
158	err := s1.resetSessionTimer(session.ID, session)
159	if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
160		t.Fatalf("err: %v", err)
161	}
162}
163
164func TestResetSessionTimerLocked(t *testing.T) {
165	if testing.Short() {
166		t.Skip("too slow for testing.Short")
167	}
168
169	t.Parallel()
170	dir1, s1 := testServer(t)
171	defer os.RemoveAll(dir1)
172	defer s1.Shutdown()
173
174	testrpc.WaitForLeader(t, s1.RPC, "dc1")
175
176	s1.createSessionTimer("foo", 5*time.Millisecond, nil)
177	if s1.sessionTimers.Get("foo") == nil {
178		t.Fatalf("missing timer")
179	}
180
181	retry.Run(t, func(r *retry.R) {
182		if s1.sessionTimers.Get("foo") != nil {
183			r.Fatal("timer should be gone")
184		}
185	})
186}
187
188func TestResetSessionTimerLocked_Renew(t *testing.T) {
189	if testing.Short() {
190		t.Skip("too slow for testing.Short")
191	}
192
193	dir1, s1 := testServer(t)
194	defer os.RemoveAll(dir1)
195	defer s1.Shutdown()
196
197	ttl := 100 * time.Millisecond
198
199	retry.Run(t, func(r *retry.R) {
200		// create the timer and make verify it was created
201		s1.createSessionTimer("foo", ttl, nil)
202		if s1.sessionTimers.Get("foo") == nil {
203			r.Fatalf("missing timer")
204		}
205
206		// wait until it is "expired" but still exists
207		// the session will exist until 2*ttl
208		time.Sleep(ttl)
209		if s1.sessionTimers.Get("foo") == nil {
210			r.Fatal("missing timer")
211		}
212	})
213
214	retry.Run(t, func(r *retry.R) {
215		// renew the session which will reset the TTL to 2*ttl
216		// since that is the current SessionTTLMultiplier
217		s1.createSessionTimer("foo", ttl, nil)
218		if s1.sessionTimers.Get("foo") == nil {
219			r.Fatal("missing timer")
220		}
221		renew := time.Now()
222
223		// Ensure invalidation happens after ttl
224		for {
225			// if timer still exists, sleep and continue
226			if s1.sessionTimers.Get("foo") != nil {
227				time.Sleep(time.Millisecond)
228				continue
229			}
230
231			// fail if timer gone before ttl passes
232			now := time.Now()
233			if now.Sub(renew) < ttl {
234				r.Fatalf("early invalidate")
235			}
236			break
237		}
238	})
239}
240
241func TestInvalidateSession(t *testing.T) {
242	if testing.Short() {
243		t.Skip("too slow for testing.Short")
244	}
245
246	t.Parallel()
247	dir1, s1 := testServer(t)
248	defer os.RemoveAll(dir1)
249	defer s1.Shutdown()
250
251	testrpc.WaitForLeader(t, s1.RPC, "dc1")
252
253	// Create a session
254	state := s1.fsm.State()
255	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
256		t.Fatalf("err: %s", err)
257	}
258
259	session := &structs.Session{
260		ID:   generateUUID(),
261		Node: "foo",
262		TTL:  "10s",
263	}
264	if err := state.SessionCreate(100, session); err != nil {
265		t.Fatalf("err: %v", err)
266	}
267
268	// This should cause a destroy
269	s1.invalidateSession(session.ID, nil)
270
271	// Check it is gone
272	_, sess, err := state.SessionGet(nil, session.ID, nil)
273	if err != nil {
274		t.Fatalf("err: %v", err)
275	}
276	if sess != nil {
277		t.Fatalf("should destroy session")
278	}
279}
280
281func TestClearSessionTimer(t *testing.T) {
282	t.Parallel()
283	dir1, s1 := testServer(t)
284	defer os.RemoveAll(dir1)
285	defer s1.Shutdown()
286
287	s1.createSessionTimer("foo", 5*time.Millisecond, nil)
288
289	err := s1.clearSessionTimer("foo")
290	if err != nil {
291		t.Fatalf("err: %v", err)
292	}
293
294	if s1.sessionTimers.Get("foo") != nil {
295		t.Fatalf("timer should be gone")
296	}
297}
298
299func TestClearAllSessionTimers(t *testing.T) {
300	t.Parallel()
301	dir1, s1 := testServer(t)
302	defer os.RemoveAll(dir1)
303	defer s1.Shutdown()
304
305	s1.createSessionTimer("foo", 10*time.Millisecond, nil)
306	s1.createSessionTimer("bar", 10*time.Millisecond, nil)
307	s1.createSessionTimer("baz", 10*time.Millisecond, nil)
308
309	s1.clearAllSessionTimers()
310
311	// sessionTimers is guarded by the lock
312	if s1.sessionTimers.Len() != 0 {
313		t.Fatalf("timers should be gone")
314	}
315}
316
317func TestServer_SessionTTL_Failover(t *testing.T) {
318	if testing.Short() {
319		t.Skip("too slow for testing.Short")
320	}
321
322	t.Parallel()
323	dir1, s1 := testServer(t)
324	defer os.RemoveAll(dir1)
325	defer s1.Shutdown()
326	testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
327
328	dir2, s2 := testServerDCBootstrap(t, "dc1", false)
329	defer os.RemoveAll(dir2)
330	defer s2.Shutdown()
331
332	dir3, s3 := testServerDCBootstrap(t, "dc1", false)
333	defer os.RemoveAll(dir3)
334	defer s3.Shutdown()
335	servers := []*Server{s1, s2, s3}
336
337	// Try to join
338	joinLAN(t, s2, s1)
339	joinLAN(t, s3, s1)
340	retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s1, 3)) })
341
342	// Find the leader
343	var leader *Server
344	for _, s := range servers {
345		// Check that s.sessionTimers is empty
346		if s.sessionTimers.Len() != 0 {
347			t.Fatalf("should have no sessionTimers")
348		}
349		// Find the leader too
350		if s.IsLeader() {
351			leader = s
352		}
353	}
354	if leader == nil {
355		t.Fatalf("Should have a leader")
356	}
357
358	codec := rpcClient(t, leader)
359	defer codec.Close()
360
361	// Register a node
362	node := structs.RegisterRequest{
363		Datacenter: s1.config.Datacenter,
364		Node:       "foo",
365		Address:    "127.0.0.1",
366	}
367	var out struct{}
368	if err := s1.RPC("Catalog.Register", &node, &out); err != nil {
369		t.Fatalf("err: %v", err)
370	}
371
372	// Create a TTL session
373	arg := structs.SessionRequest{
374		Datacenter: "dc1",
375		Op:         structs.SessionCreate,
376		Session: structs.Session{
377			Node: "foo",
378			TTL:  "10s",
379		},
380	}
381	var id1 string
382	if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &id1); err != nil {
383		t.Fatalf("err: %v", err)
384	}
385
386	// Check that sessionTimers has the session ID
387	if leader.sessionTimers.Get(id1) == nil {
388		t.Fatalf("missing session timer")
389	}
390
391	// Shutdown the leader!
392	leader.Shutdown()
393
394	// sessionTimers should be cleared on leader shutdown
395	if leader.sessionTimers.Len() != 0 {
396		t.Fatalf("session timers should be empty on the shutdown leader")
397	}
398	// Find the new leader
399	retry.Run(t, func(r *retry.R) {
400		leader = nil
401		for _, s := range servers {
402			if s.IsLeader() {
403				leader = s
404			}
405		}
406		if leader == nil {
407			r.Fatal("Should have a new leader")
408		}
409
410		// Ensure session timer is restored
411		if leader.sessionTimers.Get(id1) == nil {
412			r.Fatal("missing session timer")
413		}
414	})
415}
416