1package consul 2 3import ( 4 "fmt" 5 "os" 6 "strings" 7 "testing" 8 "time" 9 10 "github.com/hashicorp/consul/agent/structs" 11 "github.com/hashicorp/consul/sdk/testutil/retry" 12 "github.com/hashicorp/consul/testrpc" 13 "github.com/hashicorp/go-uuid" 14 "github.com/hashicorp/net-rpc-msgpackrpc" 15) 16 17func generateUUID() (ret string) { 18 var err error 19 if ret, err = uuid.GenerateUUID(); err != nil { 20 panic(fmt.Sprintf("Unable to generate a UUID, %v", err)) 21 } 22 return ret 23} 24 25func TestInitializeSessionTimers(t *testing.T) { 26 if testing.Short() { 27 t.Skip("too slow for testing.Short") 28 } 29 30 t.Parallel() 31 dir1, s1 := testServer(t) 32 defer os.RemoveAll(dir1) 33 defer s1.Shutdown() 34 35 testrpc.WaitForLeader(t, s1.RPC, "dc1") 36 37 state := s1.fsm.State() 38 if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { 39 t.Fatalf("err: %s", err) 40 } 41 session := &structs.Session{ 42 ID: generateUUID(), 43 Node: "foo", 44 TTL: "10s", 45 } 46 if err := state.SessionCreate(100, session); err != nil { 47 t.Fatalf("err: %v", err) 48 } 49 50 // Reset the session timers 51 err := s1.initializeSessionTimers() 52 if err != nil { 53 t.Fatalf("err: %v", err) 54 } 55 56 // Check that we have a timer 57 if s1.sessionTimers.Get(session.ID) == nil { 58 t.Fatalf("missing session timer") 59 } 60} 61 62func TestResetSessionTimer_Fault(t *testing.T) { 63 if testing.Short() { 64 t.Skip("too slow for testing.Short") 65 } 66 67 t.Parallel() 68 dir1, s1 := testServer(t) 69 defer os.RemoveAll(dir1) 70 defer s1.Shutdown() 71 72 testrpc.WaitForLeader(t, s1.RPC, "dc1") 73 74 // Should not exist 75 err := s1.resetSessionTimer(generateUUID(), nil) 76 if err == nil || !strings.Contains(err.Error(), "not found") { 77 t.Fatalf("err: %v", err) 78 } 79 80 // Create a session 81 state := s1.fsm.State() 82 if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { 83 t.Fatalf("err: %s", err) 84 } 85 session := &structs.Session{ 86 ID: generateUUID(), 87 Node: "foo", 88 TTL: "10s", 89 } 90 if err := state.SessionCreate(100, session); err != nil { 91 t.Fatalf("err: %v", err) 92 } 93 94 // Reset the session timer 95 err = s1.resetSessionTimer(session.ID, nil) 96 if err != nil { 97 t.Fatalf("err: %v", err) 98 } 99 100 // Check that we have a timer 101 if s1.sessionTimers.Get(session.ID) == nil { 102 t.Fatalf("missing session timer") 103 } 104} 105 106func TestResetSessionTimer_NoTTL(t *testing.T) { 107 if testing.Short() { 108 t.Skip("too slow for testing.Short") 109 } 110 111 t.Parallel() 112 dir1, s1 := testServer(t) 113 defer os.RemoveAll(dir1) 114 defer s1.Shutdown() 115 116 testrpc.WaitForLeader(t, s1.RPC, "dc1") 117 118 // Create a session 119 state := s1.fsm.State() 120 if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { 121 t.Fatalf("err: %s", err) 122 } 123 session := &structs.Session{ 124 ID: generateUUID(), 125 Node: "foo", 126 TTL: "0000s", 127 } 128 if err := state.SessionCreate(100, session); err != nil { 129 t.Fatalf("err: %v", err) 130 } 131 132 // Reset the session timer 133 err := s1.resetSessionTimer(session.ID, session) 134 if err != nil { 135 t.Fatalf("err: %v", err) 136 } 137 138 // Check that we have a timer 139 if s1.sessionTimers.Get(session.ID) != nil { 140 t.Fatalf("should not have session timer") 141 } 142} 143 144func TestResetSessionTimer_InvalidTTL(t *testing.T) { 145 t.Parallel() 146 dir1, s1 := testServer(t) 147 defer os.RemoveAll(dir1) 148 defer s1.Shutdown() 149 150 // Create a session 151 session := &structs.Session{ 152 ID: generateUUID(), 153 Node: "foo", 154 TTL: "foo", 155 } 156 157 // Reset the session timer 158 err := s1.resetSessionTimer(session.ID, session) 159 if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") { 160 t.Fatalf("err: %v", err) 161 } 162} 163 164func TestResetSessionTimerLocked(t *testing.T) { 165 if testing.Short() { 166 t.Skip("too slow for testing.Short") 167 } 168 169 t.Parallel() 170 dir1, s1 := testServer(t) 171 defer os.RemoveAll(dir1) 172 defer s1.Shutdown() 173 174 testrpc.WaitForLeader(t, s1.RPC, "dc1") 175 176 s1.createSessionTimer("foo", 5*time.Millisecond, nil) 177 if s1.sessionTimers.Get("foo") == nil { 178 t.Fatalf("missing timer") 179 } 180 181 retry.Run(t, func(r *retry.R) { 182 if s1.sessionTimers.Get("foo") != nil { 183 r.Fatal("timer should be gone") 184 } 185 }) 186} 187 188func TestResetSessionTimerLocked_Renew(t *testing.T) { 189 if testing.Short() { 190 t.Skip("too slow for testing.Short") 191 } 192 193 dir1, s1 := testServer(t) 194 defer os.RemoveAll(dir1) 195 defer s1.Shutdown() 196 197 ttl := 100 * time.Millisecond 198 199 retry.Run(t, func(r *retry.R) { 200 // create the timer and make verify it was created 201 s1.createSessionTimer("foo", ttl, nil) 202 if s1.sessionTimers.Get("foo") == nil { 203 r.Fatalf("missing timer") 204 } 205 206 // wait until it is "expired" but still exists 207 // the session will exist until 2*ttl 208 time.Sleep(ttl) 209 if s1.sessionTimers.Get("foo") == nil { 210 r.Fatal("missing timer") 211 } 212 }) 213 214 retry.Run(t, func(r *retry.R) { 215 // renew the session which will reset the TTL to 2*ttl 216 // since that is the current SessionTTLMultiplier 217 s1.createSessionTimer("foo", ttl, nil) 218 if s1.sessionTimers.Get("foo") == nil { 219 r.Fatal("missing timer") 220 } 221 renew := time.Now() 222 223 // Ensure invalidation happens after ttl 224 for { 225 // if timer still exists, sleep and continue 226 if s1.sessionTimers.Get("foo") != nil { 227 time.Sleep(time.Millisecond) 228 continue 229 } 230 231 // fail if timer gone before ttl passes 232 now := time.Now() 233 if now.Sub(renew) < ttl { 234 r.Fatalf("early invalidate") 235 } 236 break 237 } 238 }) 239} 240 241func TestInvalidateSession(t *testing.T) { 242 if testing.Short() { 243 t.Skip("too slow for testing.Short") 244 } 245 246 t.Parallel() 247 dir1, s1 := testServer(t) 248 defer os.RemoveAll(dir1) 249 defer s1.Shutdown() 250 251 testrpc.WaitForLeader(t, s1.RPC, "dc1") 252 253 // Create a session 254 state := s1.fsm.State() 255 if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { 256 t.Fatalf("err: %s", err) 257 } 258 259 session := &structs.Session{ 260 ID: generateUUID(), 261 Node: "foo", 262 TTL: "10s", 263 } 264 if err := state.SessionCreate(100, session); err != nil { 265 t.Fatalf("err: %v", err) 266 } 267 268 // This should cause a destroy 269 s1.invalidateSession(session.ID, nil) 270 271 // Check it is gone 272 _, sess, err := state.SessionGet(nil, session.ID, nil) 273 if err != nil { 274 t.Fatalf("err: %v", err) 275 } 276 if sess != nil { 277 t.Fatalf("should destroy session") 278 } 279} 280 281func TestClearSessionTimer(t *testing.T) { 282 t.Parallel() 283 dir1, s1 := testServer(t) 284 defer os.RemoveAll(dir1) 285 defer s1.Shutdown() 286 287 s1.createSessionTimer("foo", 5*time.Millisecond, nil) 288 289 err := s1.clearSessionTimer("foo") 290 if err != nil { 291 t.Fatalf("err: %v", err) 292 } 293 294 if s1.sessionTimers.Get("foo") != nil { 295 t.Fatalf("timer should be gone") 296 } 297} 298 299func TestClearAllSessionTimers(t *testing.T) { 300 t.Parallel() 301 dir1, s1 := testServer(t) 302 defer os.RemoveAll(dir1) 303 defer s1.Shutdown() 304 305 s1.createSessionTimer("foo", 10*time.Millisecond, nil) 306 s1.createSessionTimer("bar", 10*time.Millisecond, nil) 307 s1.createSessionTimer("baz", 10*time.Millisecond, nil) 308 309 s1.clearAllSessionTimers() 310 311 // sessionTimers is guarded by the lock 312 if s1.sessionTimers.Len() != 0 { 313 t.Fatalf("timers should be gone") 314 } 315} 316 317func TestServer_SessionTTL_Failover(t *testing.T) { 318 if testing.Short() { 319 t.Skip("too slow for testing.Short") 320 } 321 322 t.Parallel() 323 dir1, s1 := testServer(t) 324 defer os.RemoveAll(dir1) 325 defer s1.Shutdown() 326 testrpc.WaitForTestAgent(t, s1.RPC, "dc1") 327 328 dir2, s2 := testServerDCBootstrap(t, "dc1", false) 329 defer os.RemoveAll(dir2) 330 defer s2.Shutdown() 331 332 dir3, s3 := testServerDCBootstrap(t, "dc1", false) 333 defer os.RemoveAll(dir3) 334 defer s3.Shutdown() 335 servers := []*Server{s1, s2, s3} 336 337 // Try to join 338 joinLAN(t, s2, s1) 339 joinLAN(t, s3, s1) 340 retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s1, 3)) }) 341 342 // Find the leader 343 var leader *Server 344 for _, s := range servers { 345 // Check that s.sessionTimers is empty 346 if s.sessionTimers.Len() != 0 { 347 t.Fatalf("should have no sessionTimers") 348 } 349 // Find the leader too 350 if s.IsLeader() { 351 leader = s 352 } 353 } 354 if leader == nil { 355 t.Fatalf("Should have a leader") 356 } 357 358 codec := rpcClient(t, leader) 359 defer codec.Close() 360 361 // Register a node 362 node := structs.RegisterRequest{ 363 Datacenter: s1.config.Datacenter, 364 Node: "foo", 365 Address: "127.0.0.1", 366 } 367 var out struct{} 368 if err := s1.RPC("Catalog.Register", &node, &out); err != nil { 369 t.Fatalf("err: %v", err) 370 } 371 372 // Create a TTL session 373 arg := structs.SessionRequest{ 374 Datacenter: "dc1", 375 Op: structs.SessionCreate, 376 Session: structs.Session{ 377 Node: "foo", 378 TTL: "10s", 379 }, 380 } 381 var id1 string 382 if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &id1); err != nil { 383 t.Fatalf("err: %v", err) 384 } 385 386 // Check that sessionTimers has the session ID 387 if leader.sessionTimers.Get(id1) == nil { 388 t.Fatalf("missing session timer") 389 } 390 391 // Shutdown the leader! 392 leader.Shutdown() 393 394 // sessionTimers should be cleared on leader shutdown 395 if leader.sessionTimers.Len() != 0 { 396 t.Fatalf("session timers should be empty on the shutdown leader") 397 } 398 // Find the new leader 399 retry.Run(t, func(r *retry.R) { 400 leader = nil 401 for _, s := range servers { 402 if s.IsLeader() { 403 leader = s 404 } 405 } 406 if leader == nil { 407 r.Fatal("Should have a new leader") 408 } 409 410 // Ensure session timer is restored 411 if leader.sessionTimers.Get(id1) == nil { 412 r.Fatal("missing session timer") 413 } 414 }) 415} 416