1package vault 2 3import ( 4 "context" 5 "sync" 6 "testing" 7 "time" 8 9 log "github.com/hashicorp/go-hclog" 10 uuid "github.com/hashicorp/go-uuid" 11 12 "github.com/hashicorp/vault/helper/namespace" 13 "github.com/hashicorp/vault/sdk/helper/logging" 14) 15 16// mockRollback returns a mock rollback manager 17func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) { 18 backend := new(NoopBackend) 19 mounts := new(MountTable) 20 router := NewRouter() 21 core, _, _ := TestCoreUnsealed(t) 22 23 _, barrier, _ := mockBarrier(t) 24 view := NewBarrierView(barrier, "logical/") 25 26 mounts.Entries = []*MountEntry{ 27 &MountEntry{ 28 Path: "foo", 29 NamespaceID: namespace.RootNamespaceID, 30 namespace: namespace.RootNamespace, 31 }, 32 } 33 meUUID, err := uuid.GenerateUUID() 34 if err != nil { 35 t.Fatal(err) 36 } 37 38 if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID, Accessor: "noopaccessor", NamespaceID: namespace.RootNamespaceID, namespace: namespace.RootNamespace}, view); err != nil { 39 t.Fatalf("err: %s", err) 40 } 41 42 mountsFunc := func() []*MountEntry { 43 return mounts.Entries 44 } 45 46 logger := logging.NewVaultLogger(log.Trace) 47 48 rb := NewRollbackManager(context.Background(), logger, mountsFunc, router, core) 49 rb.period = 10 * time.Millisecond 50 return rb, backend 51} 52 53func TestRollbackManager(t *testing.T) { 54 m, backend := mockRollback(t) 55 if len(backend.Paths) > 0 { 56 t.Fatalf("bad: %#v", backend) 57 } 58 59 m.Start() 60 time.Sleep(50 * time.Millisecond) 61 m.Stop() 62 63 count := len(backend.Paths) 64 if count == 0 { 65 t.Fatalf("bad: %#v", backend) 66 } 67 if backend.Paths[0] != "" { 68 t.Fatalf("bad: %#v", backend) 69 } 70 71 time.Sleep(50 * time.Millisecond) 72 73 if count != len(backend.Paths) { 74 t.Fatalf("should stop requests: %#v", backend) 75 } 76} 77 78func TestRollbackManager_Join(t *testing.T) { 79 m, backend := mockRollback(t) 80 if len(backend.Paths) > 0 { 81 t.Fatalf("bad: %#v", backend) 82 } 83 84 m.Start() 85 defer m.Stop() 86 87 wg := &sync.WaitGroup{} 88 wg.Add(3) 89 90 errCh := make(chan error, 3) 91 go func() { 92 defer wg.Done() 93 err := m.Rollback(namespace.RootContext(nil), "foo") 94 if err != nil { 95 errCh <- err 96 } 97 }() 98 99 go func() { 100 defer wg.Done() 101 err := m.Rollback(namespace.RootContext(nil), "foo") 102 if err != nil { 103 errCh <- err 104 } 105 }() 106 107 go func() { 108 defer wg.Done() 109 err := m.Rollback(namespace.RootContext(nil), "foo") 110 if err != nil { 111 errCh <- err 112 } 113 }() 114 wg.Wait() 115 close(errCh) 116 err := <-errCh 117 if err != nil { 118 t.Fatalf("Error on rollback:%v", err) 119 } 120} 121