1package consul
2
3import (
4	"errors"
5	"testing"
6	"time"
7
8	"github.com/hashicorp/consul/agent/consul/state"
9	"github.com/hashicorp/consul/agent/structs"
10	"github.com/hashicorp/consul/api"
11	"github.com/hashicorp/consul/sdk/testutil"
12	memdb "github.com/hashicorp/go-memdb"
13	"github.com/stretchr/testify/assert"
14	"github.com/stretchr/testify/require"
15)
16
17func TestGatewayLocator(t *testing.T) {
18	state, err := state.NewStateStore(nil)
19	require.NoError(t, err)
20
21	dc1 := &structs.FederationState{
22		Datacenter: "dc1",
23		MeshGateways: []structs.CheckServiceNode{
24			newTestMeshGatewayNode(
25				"dc1", "gateway1", "1.2.3.4", 5555, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
26			),
27			newTestMeshGatewayNode(
28				"dc1", "gateway2", "4.3.2.1", 9999, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
29			),
30		},
31		UpdatedAt: time.Now().UTC(),
32	}
33	dc2 := &structs.FederationState{
34		Datacenter: "dc2",
35		MeshGateways: []structs.CheckServiceNode{
36			newTestMeshGatewayNode(
37				"dc2", "gateway1", "5.6.7.8", 5555, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
38			),
39			newTestMeshGatewayNode(
40				"dc2", "gateway2", "8.7.6.5", 9999, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
41			),
42		},
43		UpdatedAt: time.Now().UTC(),
44	}
45
46	t.Run("primary - no data", func(t *testing.T) {
47		logger := testutil.Logger(t)
48		tsd := &testServerDelegate{State: state, isLeader: true}
49		g := NewGatewayLocator(
50			logger,
51			tsd,
52			"dc1",
53			"dc1",
54		)
55
56		idx, err := g.runOnce(0)
57		require.NoError(t, err)
58		assert.False(t, g.DialPrimaryThroughLocalGateway())
59		assert.Equal(t, uint64(1), idx)
60		assert.Len(t, tsd.Calls, 1)
61		assert.Equal(t, []string(nil), g.listGateways(false))
62		assert.Equal(t, []string(nil), g.listGateways(true))
63	})
64
65	t.Run("secondary - no data", func(t *testing.T) {
66		logger := testutil.Logger(t)
67		tsd := &testServerDelegate{State: state, isLeader: true}
68		g := NewGatewayLocator(
69			logger,
70			tsd,
71			"dc2",
72			"dc1",
73		)
74
75		idx, err := g.runOnce(0)
76		require.NoError(t, err)
77		assert.False(t, g.DialPrimaryThroughLocalGateway())
78		assert.Equal(t, uint64(1), idx)
79		assert.Len(t, tsd.Calls, 1)
80		assert.Equal(t, []string(nil), g.listGateways(false))
81		assert.Equal(t, []string(nil), g.listGateways(true))
82	})
83
84	t.Run("secondary - just fallback", func(t *testing.T) {
85		logger := testutil.Logger(t)
86		tsd := &testServerDelegate{State: state, isLeader: true}
87		g := NewGatewayLocator(
88			logger,
89			tsd,
90			"dc2",
91			"dc1",
92		)
93		g.RefreshPrimaryGatewayFallbackAddresses([]string{
94			"7.7.7.7:7777",
95			"8.8.8.8:8888",
96		})
97
98		idx, err := g.runOnce(0)
99		require.NoError(t, err)
100		assert.False(t, g.DialPrimaryThroughLocalGateway())
101		assert.Equal(t, uint64(1), idx)
102		assert.Len(t, tsd.Calls, 1)
103		assert.Equal(t, []string(nil), g.listGateways(false))
104		assert.Equal(t, []string{
105			"7.7.7.7:7777",
106			"8.8.8.8:8888",
107		}, g.listGateways(true))
108	})
109
110	// Insert data for the dcs
111	require.NoError(t, state.FederationStateSet(1, dc1))
112	require.NoError(t, state.FederationStateSet(2, dc2))
113
114	t.Run("primary - with data", func(t *testing.T) {
115		logger := testutil.Logger(t)
116		tsd := &testServerDelegate{State: state, isLeader: true}
117		g := NewGatewayLocator(
118			logger,
119			tsd,
120			"dc1",
121			"dc1",
122		)
123
124		idx, err := g.runOnce(0)
125		require.NoError(t, err)
126		assert.False(t, g.DialPrimaryThroughLocalGateway())
127		assert.Equal(t, uint64(2), idx)
128		assert.Len(t, tsd.Calls, 1)
129		assert.Equal(t, []string{
130			"1.2.3.4:5555",
131			"4.3.2.1:9999",
132		}, g.listGateways(false))
133		assert.Equal(t, []string{
134			"1.2.3.4:5555",
135			"4.3.2.1:9999",
136		}, g.listGateways(true))
137	})
138
139	t.Run("secondary - with data", func(t *testing.T) {
140		logger := testutil.Logger(t)
141		tsd := &testServerDelegate{State: state, isLeader: true}
142		g := NewGatewayLocator(
143			logger,
144			tsd,
145			"dc2",
146			"dc1",
147		)
148
149		idx, err := g.runOnce(0)
150		require.NoError(t, err)
151		assert.False(t, g.DialPrimaryThroughLocalGateway())
152		assert.Equal(t, uint64(2), idx)
153		assert.Len(t, tsd.Calls, 1)
154		assert.Equal(t, []string{
155			"5.6.7.8:5555",
156			"8.7.6.5:9999",
157		}, g.listGateways(false))
158		assert.Equal(t, []string{
159			"1.2.3.4:5555",
160			"4.3.2.1:9999",
161		}, g.listGateways(true))
162	})
163
164	t.Run("secondary - with data and fallback - no repl", func(t *testing.T) {
165		logger := testutil.Logger(t)
166		tsd := &testServerDelegate{State: state, isLeader: true}
167		g := NewGatewayLocator(
168			logger,
169			tsd,
170			"dc2",
171			"dc1",
172		)
173
174		g.RefreshPrimaryGatewayFallbackAddresses([]string{
175			"7.7.7.7:7777",
176			"8.8.8.8:8888",
177		})
178
179		idx, err := g.runOnce(0)
180		require.NoError(t, err)
181		assert.False(t, g.DialPrimaryThroughLocalGateway())
182		assert.Equal(t, uint64(2), idx)
183		assert.Len(t, tsd.Calls, 1)
184		assert.Equal(t, []string{
185			"5.6.7.8:5555",
186			"8.7.6.5:9999",
187		}, g.listGateways(false))
188		assert.Equal(t, []string{
189			"1.2.3.4:5555",
190			"4.3.2.1:9999",
191			"7.7.7.7:7777",
192			"8.8.8.8:8888",
193		}, g.listGateways(true))
194	})
195
196	t.Run("secondary - with data and fallback - repl ok", func(t *testing.T) {
197		logger := testutil.Logger(t)
198		tsd := &testServerDelegate{State: state, isLeader: true}
199		g := NewGatewayLocator(
200			logger,
201			tsd,
202			"dc2",
203			"dc1",
204		)
205
206		g.RefreshPrimaryGatewayFallbackAddresses([]string{
207			"7.7.7.7:7777",
208			"8.8.8.8:8888",
209		})
210
211		g.SetLastFederationStateReplicationError(nil)
212
213		idx, err := g.runOnce(0)
214		require.NoError(t, err)
215		assert.True(t, g.DialPrimaryThroughLocalGateway())
216		assert.Equal(t, uint64(2), idx)
217		assert.Len(t, tsd.Calls, 1)
218		assert.Equal(t, []string{
219			"5.6.7.8:5555",
220			"8.7.6.5:9999",
221		}, g.listGateways(false))
222		assert.Equal(t, []string{
223			"5.6.7.8:5555",
224			"8.7.6.5:9999",
225		}, g.listGateways(true))
226	})
227
228	t.Run("secondary - with data and fallback - repl ok then failed 2 times", func(t *testing.T) {
229		logger := testutil.Logger(t)
230		tsd := &testServerDelegate{State: state, isLeader: true}
231		g := NewGatewayLocator(
232			logger,
233			tsd,
234			"dc2",
235			"dc1",
236		)
237
238		g.RefreshPrimaryGatewayFallbackAddresses([]string{
239			"7.7.7.7:7777",
240			"8.8.8.8:8888",
241		})
242
243		g.SetLastFederationStateReplicationError(nil)
244		g.SetLastFederationStateReplicationError(errors.New("fake"))
245		g.SetLastFederationStateReplicationError(errors.New("fake"))
246
247		idx, err := g.runOnce(0)
248		require.NoError(t, err)
249		assert.True(t, g.DialPrimaryThroughLocalGateway())
250		assert.Equal(t, uint64(2), idx)
251		assert.Len(t, tsd.Calls, 1)
252		assert.Equal(t, []string{
253			"5.6.7.8:5555",
254			"8.7.6.5:9999",
255		}, g.listGateways(false))
256		assert.Equal(t, []string{
257			"5.6.7.8:5555",
258			"8.7.6.5:9999",
259		}, g.listGateways(true))
260	})
261
262	t.Run("secondary - with data and fallback - repl ok then failed 3 times", func(t *testing.T) {
263		logger := testutil.Logger(t)
264		tsd := &testServerDelegate{State: state, isLeader: true}
265		g := NewGatewayLocator(
266			logger,
267			tsd,
268			"dc2",
269			"dc1",
270		)
271
272		g.RefreshPrimaryGatewayFallbackAddresses([]string{
273			"7.7.7.7:7777",
274			"8.8.8.8:8888",
275		})
276
277		g.SetLastFederationStateReplicationError(nil)
278		g.SetLastFederationStateReplicationError(errors.New("fake"))
279		g.SetLastFederationStateReplicationError(errors.New("fake"))
280		g.SetLastFederationStateReplicationError(errors.New("fake"))
281
282		idx, err := g.runOnce(0)
283		require.NoError(t, err)
284		assert.False(t, g.DialPrimaryThroughLocalGateway())
285		assert.Equal(t, uint64(2), idx)
286		assert.Len(t, tsd.Calls, 1)
287		assert.Equal(t, []string{
288			"5.6.7.8:5555",
289			"8.7.6.5:9999",
290		}, g.listGateways(false))
291		assert.Equal(t, []string{
292			"1.2.3.4:5555",
293			"4.3.2.1:9999",
294			"7.7.7.7:7777",
295			"8.8.8.8:8888",
296		}, g.listGateways(true))
297	})
298
299	t.Run("secondary - with data and fallback - repl ok then failed 3 times then ok again", func(t *testing.T) {
300		logger := testutil.Logger(t)
301		tsd := &testServerDelegate{State: state, isLeader: true}
302		g := NewGatewayLocator(
303			logger,
304			tsd,
305			"dc2",
306			"dc1",
307		)
308
309		g.RefreshPrimaryGatewayFallbackAddresses([]string{
310			"7.7.7.7:7777",
311			"8.8.8.8:8888",
312		})
313
314		g.SetLastFederationStateReplicationError(nil)
315		g.SetLastFederationStateReplicationError(errors.New("fake"))
316		g.SetLastFederationStateReplicationError(errors.New("fake"))
317		g.SetLastFederationStateReplicationError(errors.New("fake"))
318		g.SetLastFederationStateReplicationError(nil)
319
320		idx, err := g.runOnce(0)
321		require.NoError(t, err)
322		assert.True(t, g.DialPrimaryThroughLocalGateway())
323		assert.Equal(t, uint64(2), idx)
324		assert.Len(t, tsd.Calls, 1)
325		assert.Equal(t, []string{
326			"5.6.7.8:5555",
327			"8.7.6.5:9999",
328		}, g.listGateways(false))
329		assert.Equal(t, []string{
330			"5.6.7.8:5555",
331			"8.7.6.5:9999",
332		}, g.listGateways(true))
333	})
334}
335
336type testServerDelegate struct {
337	State *state.Store
338
339	Calls []uint64
340
341	isLeader    bool
342	lastContact time.Time
343}
344
345// This is just enough to exercise the logic.
346func (d *testServerDelegate) blockingQuery(
347	queryOpts structs.QueryOptionsCompat,
348	queryMeta structs.QueryMetaCompat,
349	fn queryFn,
350) error {
351	minQueryIndex := queryOpts.GetMinQueryIndex()
352
353	d.Calls = append(d.Calls, minQueryIndex)
354
355	var ws memdb.WatchSet
356
357	err := fn(ws, d.State)
358	if err == nil && queryMeta.GetIndex() < 1 {
359		queryMeta.SetIndex(1)
360	}
361
362	return err
363}
364
365func newFakeStateStore() (*state.Store, error) {
366	return state.NewStateStore(nil)
367}
368
369func (d *testServerDelegate) IsLeader() bool {
370	return d.isLeader
371}
372
373func (d *testServerDelegate) LeaderLastContact() time.Time {
374	return d.lastContact
375}
376