1package vault
2
3import (
4	"context"
5	"sync"
6	"testing"
7	"time"
8
9	log "github.com/hashicorp/go-hclog"
10	uuid "github.com/hashicorp/go-uuid"
11
12	"github.com/hashicorp/vault/helper/namespace"
13	"github.com/hashicorp/vault/sdk/helper/logging"
14)
15
16// mockRollback returns a mock rollback manager
17func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
18	backend := new(NoopBackend)
19	mounts := new(MountTable)
20	router := NewRouter()
21	core, _, _ := TestCoreUnsealed(t)
22
23	_, barrier, _ := mockBarrier(t)
24	view := NewBarrierView(barrier, "logical/")
25
26	mounts.Entries = []*MountEntry{
27		&MountEntry{
28			Path:        "foo",
29			NamespaceID: namespace.RootNamespaceID,
30			namespace:   namespace.RootNamespace,
31		},
32	}
33	meUUID, err := uuid.GenerateUUID()
34	if err != nil {
35		t.Fatal(err)
36	}
37
38	if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID, Accessor: "noopaccessor", NamespaceID: namespace.RootNamespaceID, namespace: namespace.RootNamespace}, view); err != nil {
39		t.Fatalf("err: %s", err)
40	}
41
42	mountsFunc := func() []*MountEntry {
43		return mounts.Entries
44	}
45
46	logger := logging.NewVaultLogger(log.Trace)
47
48	rb := NewRollbackManager(context.Background(), logger, mountsFunc, router, core)
49	rb.period = 10 * time.Millisecond
50	return rb, backend
51}
52
53func TestRollbackManager(t *testing.T) {
54	m, backend := mockRollback(t)
55	if len(backend.Paths) > 0 {
56		t.Fatalf("bad: %#v", backend)
57	}
58
59	m.Start()
60	time.Sleep(50 * time.Millisecond)
61	m.Stop()
62
63	count := len(backend.Paths)
64	if count == 0 {
65		t.Fatalf("bad: %#v", backend)
66	}
67	if backend.Paths[0] != "" {
68		t.Fatalf("bad: %#v", backend)
69	}
70
71	time.Sleep(50 * time.Millisecond)
72
73	if count != len(backend.Paths) {
74		t.Fatalf("should stop requests: %#v", backend)
75	}
76}
77
78func TestRollbackManager_Join(t *testing.T) {
79	m, backend := mockRollback(t)
80	if len(backend.Paths) > 0 {
81		t.Fatalf("bad: %#v", backend)
82	}
83
84	m.Start()
85	defer m.Stop()
86
87	wg := &sync.WaitGroup{}
88	wg.Add(3)
89
90	errCh := make(chan error, 3)
91	go func() {
92		defer wg.Done()
93		err := m.Rollback(namespace.RootContext(nil), "foo")
94		if err != nil {
95			errCh <- err
96		}
97	}()
98
99	go func() {
100		defer wg.Done()
101		err := m.Rollback(namespace.RootContext(nil), "foo")
102		if err != nil {
103			errCh <- err
104		}
105	}()
106
107	go func() {
108		defer wg.Done()
109		err := m.Rollback(namespace.RootContext(nil), "foo")
110		if err != nil {
111			errCh <- err
112		}
113	}()
114	wg.Wait()
115	close(errCh)
116	err := <-errCh
117	if err != nil {
118		t.Fatalf("Error on rollback:%v", err)
119	}
120}
121