1package consul
2
3import (
4	"fmt"
5	"os"
6	"strings"
7	"testing"
8	"time"
9
10	"github.com/hashicorp/consul/agent/structs"
11	"github.com/hashicorp/consul/sdk/testutil/retry"
12	"github.com/hashicorp/consul/testrpc"
13	"github.com/hashicorp/go-uuid"
14	"github.com/hashicorp/net-rpc-msgpackrpc"
15)
16
17func generateUUID() (ret string) {
18	var err error
19	if ret, err = uuid.GenerateUUID(); err != nil {
20		panic(fmt.Sprintf("Unable to generate a UUID, %v", err))
21	}
22	return ret
23}
24
25func TestInitializeSessionTimers(t *testing.T) {
26	t.Parallel()
27	dir1, s1 := testServer(t)
28	defer os.RemoveAll(dir1)
29	defer s1.Shutdown()
30
31	testrpc.WaitForLeader(t, s1.RPC, "dc1")
32
33	state := s1.fsm.State()
34	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
35		t.Fatalf("err: %s", err)
36	}
37	session := &structs.Session{
38		ID:   generateUUID(),
39		Node: "foo",
40		TTL:  "10s",
41	}
42	if err := state.SessionCreate(100, session); err != nil {
43		t.Fatalf("err: %v", err)
44	}
45
46	// Reset the session timers
47	err := s1.initializeSessionTimers()
48	if err != nil {
49		t.Fatalf("err: %v", err)
50	}
51
52	// Check that we have a timer
53	if s1.sessionTimers.Get(session.ID) == nil {
54		t.Fatalf("missing session timer")
55	}
56}
57
58func TestResetSessionTimer_Fault(t *testing.T) {
59	t.Parallel()
60	dir1, s1 := testServer(t)
61	defer os.RemoveAll(dir1)
62	defer s1.Shutdown()
63
64	testrpc.WaitForLeader(t, s1.RPC, "dc1")
65
66	// Should not exist
67	err := s1.resetSessionTimer(generateUUID(), nil)
68	if err == nil || !strings.Contains(err.Error(), "not found") {
69		t.Fatalf("err: %v", err)
70	}
71
72	// Create a session
73	state := s1.fsm.State()
74	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
75		t.Fatalf("err: %s", err)
76	}
77	session := &structs.Session{
78		ID:   generateUUID(),
79		Node: "foo",
80		TTL:  "10s",
81	}
82	if err := state.SessionCreate(100, session); err != nil {
83		t.Fatalf("err: %v", err)
84	}
85
86	// Reset the session timer
87	err = s1.resetSessionTimer(session.ID, nil)
88	if err != nil {
89		t.Fatalf("err: %v", err)
90	}
91
92	// Check that we have a timer
93	if s1.sessionTimers.Get(session.ID) == nil {
94		t.Fatalf("missing session timer")
95	}
96}
97
98func TestResetSessionTimer_NoTTL(t *testing.T) {
99	t.Parallel()
100	dir1, s1 := testServer(t)
101	defer os.RemoveAll(dir1)
102	defer s1.Shutdown()
103
104	testrpc.WaitForLeader(t, s1.RPC, "dc1")
105
106	// Create a session
107	state := s1.fsm.State()
108	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
109		t.Fatalf("err: %s", err)
110	}
111	session := &structs.Session{
112		ID:   generateUUID(),
113		Node: "foo",
114		TTL:  "0000s",
115	}
116	if err := state.SessionCreate(100, session); err != nil {
117		t.Fatalf("err: %v", err)
118	}
119
120	// Reset the session timer
121	err := s1.resetSessionTimer(session.ID, session)
122	if err != nil {
123		t.Fatalf("err: %v", err)
124	}
125
126	// Check that we have a timer
127	if s1.sessionTimers.Get(session.ID) != nil {
128		t.Fatalf("should not have session timer")
129	}
130}
131
132func TestResetSessionTimer_InvalidTTL(t *testing.T) {
133	t.Parallel()
134	dir1, s1 := testServer(t)
135	defer os.RemoveAll(dir1)
136	defer s1.Shutdown()
137
138	// Create a session
139	session := &structs.Session{
140		ID:   generateUUID(),
141		Node: "foo",
142		TTL:  "foo",
143	}
144
145	// Reset the session timer
146	err := s1.resetSessionTimer(session.ID, session)
147	if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
148		t.Fatalf("err: %v", err)
149	}
150}
151
152func TestResetSessionTimerLocked(t *testing.T) {
153	t.Parallel()
154	dir1, s1 := testServer(t)
155	defer os.RemoveAll(dir1)
156	defer s1.Shutdown()
157
158	testrpc.WaitForLeader(t, s1.RPC, "dc1")
159
160	s1.createSessionTimer("foo", 5*time.Millisecond, nil)
161	if s1.sessionTimers.Get("foo") == nil {
162		t.Fatalf("missing timer")
163	}
164
165	retry.Run(t, func(r *retry.R) {
166		if s1.sessionTimers.Get("foo") != nil {
167			r.Fatal("timer should be gone")
168		}
169	})
170}
171
172func TestResetSessionTimerLocked_Renew(t *testing.T) {
173	dir1, s1 := testServer(t)
174	defer os.RemoveAll(dir1)
175	defer s1.Shutdown()
176
177	ttl := 100 * time.Millisecond
178
179	retry.Run(t, func(r *retry.R) {
180		// create the timer and make verify it was created
181		s1.createSessionTimer("foo", ttl, nil)
182		if s1.sessionTimers.Get("foo") == nil {
183			r.Fatalf("missing timer")
184		}
185
186		// wait until it is "expired" but still exists
187		// the session will exist until 2*ttl
188		time.Sleep(ttl)
189		if s1.sessionTimers.Get("foo") == nil {
190			r.Fatal("missing timer")
191		}
192	})
193
194	retry.Run(t, func(r *retry.R) {
195		// renew the session which will reset the TTL to 2*ttl
196		// since that is the current SessionTTLMultiplier
197		s1.createSessionTimer("foo", ttl, nil)
198		if s1.sessionTimers.Get("foo") == nil {
199			r.Fatal("missing timer")
200		}
201		renew := time.Now()
202
203		// Ensure invalidation happens after ttl
204		for {
205			// if timer still exists, sleep and continue
206			if s1.sessionTimers.Get("foo") != nil {
207				time.Sleep(time.Millisecond)
208				continue
209			}
210
211			// fail if timer gone before ttl passes
212			now := time.Now()
213			if now.Sub(renew) < ttl {
214				r.Fatalf("early invalidate")
215			}
216			break
217		}
218	})
219}
220
221func TestInvalidateSession(t *testing.T) {
222	t.Parallel()
223	dir1, s1 := testServer(t)
224	defer os.RemoveAll(dir1)
225	defer s1.Shutdown()
226
227	testrpc.WaitForLeader(t, s1.RPC, "dc1")
228
229	// Create a session
230	state := s1.fsm.State()
231	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
232		t.Fatalf("err: %s", err)
233	}
234
235	session := &structs.Session{
236		ID:   generateUUID(),
237		Node: "foo",
238		TTL:  "10s",
239	}
240	if err := state.SessionCreate(100, session); err != nil {
241		t.Fatalf("err: %v", err)
242	}
243
244	// This should cause a destroy
245	s1.invalidateSession(session.ID, nil)
246
247	// Check it is gone
248	_, sess, err := state.SessionGet(nil, session.ID, nil)
249	if err != nil {
250		t.Fatalf("err: %v", err)
251	}
252	if sess != nil {
253		t.Fatalf("should destroy session")
254	}
255}
256
257func TestClearSessionTimer(t *testing.T) {
258	t.Parallel()
259	dir1, s1 := testServer(t)
260	defer os.RemoveAll(dir1)
261	defer s1.Shutdown()
262
263	s1.createSessionTimer("foo", 5*time.Millisecond, nil)
264
265	err := s1.clearSessionTimer("foo")
266	if err != nil {
267		t.Fatalf("err: %v", err)
268	}
269
270	if s1.sessionTimers.Get("foo") != nil {
271		t.Fatalf("timer should be gone")
272	}
273}
274
275func TestClearAllSessionTimers(t *testing.T) {
276	t.Parallel()
277	dir1, s1 := testServer(t)
278	defer os.RemoveAll(dir1)
279	defer s1.Shutdown()
280
281	s1.createSessionTimer("foo", 10*time.Millisecond, nil)
282	s1.createSessionTimer("bar", 10*time.Millisecond, nil)
283	s1.createSessionTimer("baz", 10*time.Millisecond, nil)
284
285	s1.clearAllSessionTimers()
286
287	// sessionTimers is guarded by the lock
288	if s1.sessionTimers.Len() != 0 {
289		t.Fatalf("timers should be gone")
290	}
291}
292
293func TestServer_SessionTTL_Failover(t *testing.T) {
294	t.Parallel()
295	dir1, s1 := testServer(t)
296	defer os.RemoveAll(dir1)
297	defer s1.Shutdown()
298	testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
299
300	dir2, s2 := testServerDCBootstrap(t, "dc1", false)
301	defer os.RemoveAll(dir2)
302	defer s2.Shutdown()
303
304	dir3, s3 := testServerDCBootstrap(t, "dc1", false)
305	defer os.RemoveAll(dir3)
306	defer s3.Shutdown()
307	servers := []*Server{s1, s2, s3}
308
309	// Try to join
310	joinLAN(t, s2, s1)
311	joinLAN(t, s3, s1)
312	retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s1, 3)) })
313
314	// Find the leader
315	var leader *Server
316	for _, s := range servers {
317		// Check that s.sessionTimers is empty
318		if s.sessionTimers.Len() != 0 {
319			t.Fatalf("should have no sessionTimers")
320		}
321		// Find the leader too
322		if s.IsLeader() {
323			leader = s
324		}
325	}
326	if leader == nil {
327		t.Fatalf("Should have a leader")
328	}
329
330	codec := rpcClient(t, leader)
331	defer codec.Close()
332
333	// Register a node
334	node := structs.RegisterRequest{
335		Datacenter: s1.config.Datacenter,
336		Node:       "foo",
337		Address:    "127.0.0.1",
338	}
339	var out struct{}
340	if err := s1.RPC("Catalog.Register", &node, &out); err != nil {
341		t.Fatalf("err: %v", err)
342	}
343
344	// Create a TTL session
345	arg := structs.SessionRequest{
346		Datacenter: "dc1",
347		Op:         structs.SessionCreate,
348		Session: structs.Session{
349			Node: "foo",
350			TTL:  "10s",
351		},
352	}
353	var id1 string
354	if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &id1); err != nil {
355		t.Fatalf("err: %v", err)
356	}
357
358	// Check that sessionTimers has the session ID
359	if leader.sessionTimers.Get(id1) == nil {
360		t.Fatalf("missing session timer")
361	}
362
363	// Shutdown the leader!
364	leader.Shutdown()
365
366	// sessionTimers should be cleared on leader shutdown
367	if leader.sessionTimers.Len() != 0 {
368		t.Fatalf("session timers should be empty on the shutdown leader")
369	}
370	// Find the new leader
371	retry.Run(t, func(r *retry.R) {
372		leader = nil
373		for _, s := range servers {
374			if s.IsLeader() {
375				leader = s
376			}
377		}
378		if leader == nil {
379			r.Fatal("Should have a new leader")
380		}
381
382		// Ensure session timer is restored
383		if leader.sessionTimers.Get(id1) == nil {
384			r.Fatal("missing session timer")
385		}
386	})
387}
388