1package cachetype
2
3import (
4	"fmt"
5	"testing"
6	"time"
7
8	"github.com/hashicorp/consul/agent/cache"
9	"github.com/hashicorp/consul/agent/checks"
10	"github.com/hashicorp/consul/agent/local"
11	"github.com/hashicorp/consul/agent/structs"
12	"github.com/hashicorp/consul/agent/token"
13	"github.com/hashicorp/consul/types"
14	"github.com/hashicorp/go-memdb"
15	"github.com/stretchr/testify/require"
16)
17
18func TestServiceHTTPChecks_Fetch(t *testing.T) {
19	chkTypes := []*structs.CheckType{
20		{
21			CheckID:       "http-check",
22			HTTP:          "localhost:8080/health",
23			Interval:      5 * time.Second,
24			OutputMaxSize: checks.DefaultBufSize,
25		},
26		{
27			CheckID:  "grpc-check",
28			GRPC:     "localhost:9090/v1.Health",
29			Interval: 5 * time.Second,
30		},
31		{
32			CheckID: "ttl-check",
33			TTL:     10 * time.Second,
34		},
35	}
36
37	svcState := local.ServiceState{
38		Service: &structs.NodeService{
39			ID: "web",
40		},
41	}
42
43	// Create mockAgent and cache type
44	a := newMockAgent()
45	a.LocalState().SetServiceState(&svcState)
46	typ := ServiceHTTPChecks{Agent: a}
47
48	// Adding TTL check should not yield result from Fetch since TTL checks aren't tracked
49	if err := a.AddCheck(*chkTypes[2]); err != nil {
50		t.Fatalf("failed to add check: %v", err)
51	}
52	result, err := typ.Fetch(
53		cache.FetchOptions{},
54		&ServiceHTTPChecksRequest{ServiceID: svcState.Service.ID, MaxQueryTime: 100 * time.Millisecond},
55	)
56	if err != nil {
57		t.Fatalf("failed to fetch: %v", err)
58	}
59	got, ok := result.Value.([]structs.CheckType)
60	if !ok {
61		t.Fatalf("fetched value of wrong type, got %T, want []structs.CheckType", result.Value)
62	}
63	require.Empty(t, got)
64
65	// Adding HTTP check should yield check in Fetch
66	if err := a.AddCheck(*chkTypes[0]); err != nil {
67		t.Fatalf("failed to add check: %v", err)
68	}
69	result, err = typ.Fetch(
70		cache.FetchOptions{},
71		&ServiceHTTPChecksRequest{ServiceID: svcState.Service.ID},
72	)
73	if err != nil {
74		t.Fatalf("failed to fetch: %v", err)
75	}
76	if result.Index != 1 {
77		t.Fatalf("expected index of 1 after first cache hit, got %d", result.Index)
78	}
79	got, ok = result.Value.([]structs.CheckType)
80	if !ok {
81		t.Fatalf("fetched value of wrong type, got %T, want []structs.CheckType", result.Value)
82	}
83	want := chkTypes[0:1]
84	for i, c := range want {
85		require.Equal(t, *c, got[i])
86	}
87
88	// Adding GRPC check should yield both checks in Fetch
89	if err := a.AddCheck(*chkTypes[1]); err != nil {
90		t.Fatalf("failed to add check: %v", err)
91	}
92	result2, err := typ.Fetch(
93		cache.FetchOptions{LastResult: &result},
94		&ServiceHTTPChecksRequest{ServiceID: svcState.Service.ID},
95	)
96	if err != nil {
97		t.Fatalf("failed to fetch: %v", err)
98	}
99	if result2.Index != 2 {
100		t.Fatalf("expected index of 2 after second request, got %d", result2.Index)
101	}
102
103	got, ok = result2.Value.([]structs.CheckType)
104	if !ok {
105		t.Fatalf("fetched value of wrong type, got %T, want []structs.CheckType", got)
106	}
107	want = chkTypes[0:2]
108	for i, c := range want {
109		require.Equal(t, *c, got[i])
110	}
111
112	// Removing GRPC check should yield HTTP check in Fetch
113	if err := a.RemoveCheck(chkTypes[1].CheckID); err != nil {
114		t.Fatalf("failed to add check: %v", err)
115	}
116	result3, err := typ.Fetch(
117		cache.FetchOptions{LastResult: &result2},
118		&ServiceHTTPChecksRequest{ServiceID: svcState.Service.ID},
119	)
120	if err != nil {
121		t.Fatalf("failed to fetch: %v", err)
122	}
123	if result3.Index != 3 {
124		t.Fatalf("expected index of 3 after third request, got %d", result3.Index)
125	}
126
127	got, ok = result3.Value.([]structs.CheckType)
128	if !ok {
129		t.Fatalf("fetched value of wrong type, got %T, want []structs.CheckType", got)
130	}
131	want = chkTypes[0:1]
132	for i, c := range want {
133		require.Equal(t, *c, got[i])
134	}
135
136	// Fetching again should yield no change in result nor index
137	result4, err := typ.Fetch(
138		cache.FetchOptions{LastResult: &result3},
139		&ServiceHTTPChecksRequest{ServiceID: svcState.Service.ID, MaxQueryTime: 100 * time.Millisecond},
140	)
141	if err != nil {
142		t.Fatalf("failed to fetch: %v", err)
143	}
144	if result4.Index != 3 {
145		t.Fatalf("expected index of 3 after fetch timeout, got %d", result4.Index)
146	}
147
148	got, ok = result4.Value.([]structs.CheckType)
149	if !ok {
150		t.Fatalf("fetched value of wrong type, got %T, want []structs.CheckType", got)
151	}
152	want = chkTypes[0:1]
153	for i, c := range want {
154		require.Equal(t, *c, got[i])
155	}
156}
157
158func TestServiceHTTPChecks_badReqType(t *testing.T) {
159	a := newMockAgent()
160	typ := ServiceHTTPChecks{Agent: a}
161
162	// Fetch
163	_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
164		t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
165	require.Error(t, err)
166	require.Contains(t, err.Error(), "wrong request type")
167}
168
169type mockAgent struct {
170	state  *local.State
171	checks []structs.CheckType
172}
173
174func newMockAgent() *mockAgent {
175	m := mockAgent{
176		state:  local.NewState(local.Config{NodeID: "host"}, nil, new(token.Store)),
177		checks: make([]structs.CheckType, 0),
178	}
179	m.state.TriggerSyncChanges = func() {}
180	return &m
181}
182
183func (m *mockAgent) ServiceHTTPBasedChecks(id structs.ServiceID) []structs.CheckType {
184	return m.checks
185}
186
187func (m *mockAgent) LocalState() *local.State {
188	return m.state
189}
190
191func (m *mockAgent) LocalBlockingQuery(alwaysBlock bool, hash string, wait time.Duration,
192	fn func(ws memdb.WatchSet) (string, interface{}, error)) (string, interface{}, error) {
193
194	hash, err := hashChecks(m.checks)
195	if err != nil {
196		return "", nil, fmt.Errorf("failed to hash checks: %+v", m.checks)
197	}
198	return hash, m.checks, nil
199}
200
201func (m *mockAgent) AddCheck(check structs.CheckType) error {
202	if check.IsHTTP() || check.IsGRPC() {
203		m.checks = append(m.checks, check)
204	}
205	return nil
206}
207
208func (m *mockAgent) RemoveCheck(id types.CheckID) error {
209	new := make([]structs.CheckType, 0)
210	for _, c := range m.checks {
211		if c.CheckID != id {
212			new = append(new, c)
213		}
214	}
215	m.checks = new
216	return nil
217}
218