1package consul
2
3import (
4	"fmt"
5	"os"
6	"strings"
7	"testing"
8	"time"
9
10	"github.com/hashicorp/consul/agent/structs"
11	"github.com/hashicorp/consul/testrpc"
12	"github.com/hashicorp/consul/testutil/retry"
13	"github.com/hashicorp/go-uuid"
14	"github.com/hashicorp/net-rpc-msgpackrpc"
15)
16
17func generateUUID() (ret string) {
18	var err error
19	if ret, err = uuid.GenerateUUID(); err != nil {
20		panic(fmt.Sprintf("Unable to generate a UUID, %v", err))
21	}
22	return ret
23}
24
25func TestInitializeSessionTimers(t *testing.T) {
26	t.Parallel()
27	dir1, s1 := testServer(t)
28	defer os.RemoveAll(dir1)
29	defer s1.Shutdown()
30
31	testrpc.WaitForLeader(t, s1.RPC, "dc1")
32
33	state := s1.fsm.State()
34	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
35		t.Fatalf("err: %s", err)
36	}
37	session := &structs.Session{
38		ID:   generateUUID(),
39		Node: "foo",
40		TTL:  "10s",
41	}
42	if err := state.SessionCreate(100, session); err != nil {
43		t.Fatalf("err: %v", err)
44	}
45
46	// Reset the session timers
47	err := s1.initializeSessionTimers()
48	if err != nil {
49		t.Fatalf("err: %v", err)
50	}
51
52	// Check that we have a timer
53	if s1.sessionTimers.Get(session.ID) == nil {
54		t.Fatalf("missing session timer")
55	}
56}
57
58func TestResetSessionTimer_Fault(t *testing.T) {
59	t.Parallel()
60	dir1, s1 := testServer(t)
61	defer os.RemoveAll(dir1)
62	defer s1.Shutdown()
63
64	testrpc.WaitForLeader(t, s1.RPC, "dc1")
65
66	// Should not exist
67	err := s1.resetSessionTimer(generateUUID(), nil)
68	if err == nil || !strings.Contains(err.Error(), "not found") {
69		t.Fatalf("err: %v", err)
70	}
71
72	// Create a session
73	state := s1.fsm.State()
74	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
75		t.Fatalf("err: %s", err)
76	}
77	session := &structs.Session{
78		ID:   generateUUID(),
79		Node: "foo",
80		TTL:  "10s",
81	}
82	if err := state.SessionCreate(100, session); err != nil {
83		t.Fatalf("err: %v", err)
84	}
85
86	// Reset the session timer
87	err = s1.resetSessionTimer(session.ID, nil)
88	if err != nil {
89		t.Fatalf("err: %v", err)
90	}
91
92	// Check that we have a timer
93	if s1.sessionTimers.Get(session.ID) == nil {
94		t.Fatalf("missing session timer")
95	}
96}
97
98func TestResetSessionTimer_NoTTL(t *testing.T) {
99	t.Parallel()
100	dir1, s1 := testServer(t)
101	defer os.RemoveAll(dir1)
102	defer s1.Shutdown()
103
104	testrpc.WaitForLeader(t, s1.RPC, "dc1")
105
106	// Create a session
107	state := s1.fsm.State()
108	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
109		t.Fatalf("err: %s", err)
110	}
111	session := &structs.Session{
112		ID:   generateUUID(),
113		Node: "foo",
114		TTL:  "0000s",
115	}
116	if err := state.SessionCreate(100, session); err != nil {
117		t.Fatalf("err: %v", err)
118	}
119
120	// Reset the session timer
121	err := s1.resetSessionTimer(session.ID, session)
122	if err != nil {
123		t.Fatalf("err: %v", err)
124	}
125
126	// Check that we have a timer
127	if s1.sessionTimers.Get(session.ID) != nil {
128		t.Fatalf("should not have session timer")
129	}
130}
131
132func TestResetSessionTimer_InvalidTTL(t *testing.T) {
133	t.Parallel()
134	dir1, s1 := testServer(t)
135	defer os.RemoveAll(dir1)
136	defer s1.Shutdown()
137
138	// Create a session
139	session := &structs.Session{
140		ID:   generateUUID(),
141		Node: "foo",
142		TTL:  "foo",
143	}
144
145	// Reset the session timer
146	err := s1.resetSessionTimer(session.ID, session)
147	if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
148		t.Fatalf("err: %v", err)
149	}
150}
151
152func TestResetSessionTimerLocked(t *testing.T) {
153	t.Parallel()
154	dir1, s1 := testServer(t)
155	defer os.RemoveAll(dir1)
156	defer s1.Shutdown()
157
158	testrpc.WaitForLeader(t, s1.RPC, "dc1")
159
160	s1.createSessionTimer("foo", 5*time.Millisecond)
161	if s1.sessionTimers.Get("foo") == nil {
162		t.Fatalf("missing timer")
163	}
164
165	time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier)
166	if s1.sessionTimers.Get("foo") != nil {
167		t.Fatalf("timer should be gone")
168	}
169}
170
171func TestResetSessionTimerLocked_Renew(t *testing.T) {
172	t.Parallel()
173	dir1, s1 := testServer(t)
174	defer os.RemoveAll(dir1)
175	defer s1.Shutdown()
176
177	ttl := 100 * time.Millisecond
178
179	// create the timer
180	s1.createSessionTimer("foo", ttl)
181	if s1.sessionTimers.Get("foo") == nil {
182		t.Fatalf("missing timer")
183	}
184
185	// wait until it is "expired" but at this point
186	// the session still exists.
187	time.Sleep(ttl)
188	if s1.sessionTimers.Get("foo") == nil {
189		t.Fatal("missing timer")
190	}
191
192	// renew the session which will reset the TTL to 2*ttl
193	// since that is the current SessionTTLMultiplier
194	s1.createSessionTimer("foo", ttl)
195
196	// Watch for invalidation
197	renew := time.Now()
198	deadline := renew.Add(2 * structs.SessionTTLMultiplier * ttl)
199	for {
200		now := time.Now()
201		if now.After(deadline) {
202			t.Fatal("should have expired by now")
203		}
204
205		// timer still exists
206		if s1.sessionTimers.Get("foo") != nil {
207			time.Sleep(time.Millisecond)
208			continue
209		}
210
211		// timer gone
212		if now.Sub(renew) < ttl {
213			t.Fatalf("early invalidate")
214		}
215		break
216	}
217}
218
219func TestInvalidateSession(t *testing.T) {
220	t.Parallel()
221	dir1, s1 := testServer(t)
222	defer os.RemoveAll(dir1)
223	defer s1.Shutdown()
224
225	testrpc.WaitForLeader(t, s1.RPC, "dc1")
226
227	// Create a session
228	state := s1.fsm.State()
229	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
230		t.Fatalf("err: %s", err)
231	}
232	session := &structs.Session{
233		ID:   generateUUID(),
234		Node: "foo",
235		TTL:  "10s",
236	}
237	if err := state.SessionCreate(100, session); err != nil {
238		t.Fatalf("err: %v", err)
239	}
240
241	// This should cause a destroy
242	s1.invalidateSession(session.ID)
243
244	// Check it is gone
245	_, sess, err := state.SessionGet(nil, session.ID)
246	if err != nil {
247		t.Fatalf("err: %v", err)
248	}
249	if sess != nil {
250		t.Fatalf("should destroy session")
251	}
252}
253
254func TestClearSessionTimer(t *testing.T) {
255	t.Parallel()
256	dir1, s1 := testServer(t)
257	defer os.RemoveAll(dir1)
258	defer s1.Shutdown()
259
260	s1.createSessionTimer("foo", 5*time.Millisecond)
261
262	err := s1.clearSessionTimer("foo")
263	if err != nil {
264		t.Fatalf("err: %v", err)
265	}
266
267	if s1.sessionTimers.Get("foo") != nil {
268		t.Fatalf("timer should be gone")
269	}
270}
271
272func TestClearAllSessionTimers(t *testing.T) {
273	t.Parallel()
274	dir1, s1 := testServer(t)
275	defer os.RemoveAll(dir1)
276	defer s1.Shutdown()
277
278	s1.createSessionTimer("foo", 10*time.Millisecond)
279	s1.createSessionTimer("bar", 10*time.Millisecond)
280	s1.createSessionTimer("baz", 10*time.Millisecond)
281
282	err := s1.clearAllSessionTimers()
283	if err != nil {
284		t.Fatalf("err: %v", err)
285	}
286
287	// sessionTimers is guarded by the lock
288	if s1.sessionTimers.Len() != 0 {
289		t.Fatalf("timers should be gone")
290	}
291}
292
293func TestServer_SessionTTL_Failover(t *testing.T) {
294	t.Parallel()
295	dir1, s1 := testServer(t)
296	defer os.RemoveAll(dir1)
297	defer s1.Shutdown()
298	testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
299
300	dir2, s2 := testServerDCBootstrap(t, "dc1", false)
301	defer os.RemoveAll(dir2)
302	defer s2.Shutdown()
303
304	dir3, s3 := testServerDCBootstrap(t, "dc1", false)
305	defer os.RemoveAll(dir3)
306	defer s3.Shutdown()
307	servers := []*Server{s1, s2, s3}
308
309	// Try to join
310	joinLAN(t, s2, s1)
311	joinLAN(t, s3, s1)
312	retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s1, 3)) })
313
314	// Find the leader
315	var leader *Server
316	for _, s := range servers {
317		// Check that s.sessionTimers is empty
318		if s.sessionTimers.Len() != 0 {
319			t.Fatalf("should have no sessionTimers")
320		}
321		// Find the leader too
322		if s.IsLeader() {
323			leader = s
324		}
325	}
326	if leader == nil {
327		t.Fatalf("Should have a leader")
328	}
329
330	codec := rpcClient(t, leader)
331	defer codec.Close()
332
333	// Register a node
334	node := structs.RegisterRequest{
335		Datacenter: s1.config.Datacenter,
336		Node:       "foo",
337		Address:    "127.0.0.1",
338	}
339	var out struct{}
340	if err := s1.RPC("Catalog.Register", &node, &out); err != nil {
341		t.Fatalf("err: %v", err)
342	}
343
344	// Create a TTL session
345	arg := structs.SessionRequest{
346		Datacenter: "dc1",
347		Op:         structs.SessionCreate,
348		Session: structs.Session{
349			Node: "foo",
350			TTL:  "10s",
351		},
352	}
353	var id1 string
354	if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &id1); err != nil {
355		t.Fatalf("err: %v", err)
356	}
357
358	// Check that sessionTimers has the session ID
359	if leader.sessionTimers.Get(id1) == nil {
360		t.Fatalf("missing session timer")
361	}
362
363	// Shutdown the leader!
364	leader.Shutdown()
365
366	// sessionTimers should be cleared on leader shutdown
367	if leader.sessionTimers.Len() != 0 {
368		t.Fatalf("session timers should be empty on the shutdown leader")
369	}
370	// Find the new leader
371	retry.Run(t, func(r *retry.R) {
372		leader = nil
373		for _, s := range servers {
374			if s.IsLeader() {
375				leader = s
376			}
377		}
378		if leader == nil {
379			r.Fatal("Should have a new leader")
380		}
381
382		// Ensure session timer is restored
383		if leader.sessionTimers.Get(id1) == nil {
384			r.Fatal("missing session timer")
385		}
386	})
387}
388