1package consul
2
3import (
4	"errors"
5	"sync/atomic"
6	"testing"
7	"time"
8
9	memdb "github.com/hashicorp/go-memdb"
10	"github.com/stretchr/testify/assert"
11	"github.com/stretchr/testify/require"
12
13	"github.com/hashicorp/consul/agent/consul/state"
14	"github.com/hashicorp/consul/agent/structs"
15	"github.com/hashicorp/consul/api"
16	"github.com/hashicorp/consul/sdk/testutil"
17)
18
19func TestGatewayLocator(t *testing.T) {
20	state := state.NewStateStore(nil)
21
22	serverRoles := []string{"leader", "follower"}
23	now := time.Now().UTC()
24
25	dc1 := &structs.FederationState{
26		Datacenter: "dc1",
27		MeshGateways: []structs.CheckServiceNode{
28			newTestMeshGatewayNode(
29				"dc1", "gateway1", "1.2.3.4", 5555, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
30			),
31			newTestMeshGatewayNode(
32				"dc1", "gateway2", "4.3.2.1", 9999, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
33			),
34		},
35		UpdatedAt: time.Now().UTC(),
36	}
37	dc2 := &structs.FederationState{
38		Datacenter: "dc2",
39		MeshGateways: []structs.CheckServiceNode{
40			newTestMeshGatewayNode(
41				"dc2", "gateway1", "5.6.7.8", 5555, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
42			),
43			newTestMeshGatewayNode(
44				"dc2", "gateway2", "8.7.6.5", 9999, map[string]string{structs.MetaWANFederationKey: "1"}, api.HealthPassing,
45			),
46		},
47		UpdatedAt: time.Now().UTC(),
48	}
49
50	t.Run("primary - no data", func(t *testing.T) {
51		for _, role := range serverRoles {
52			t.Run(role, func(t *testing.T) {
53				isLeader := role == "leader"
54
55				logger := testutil.Logger(t)
56				tsd := &testServerDelegate{State: state, isLeader: isLeader}
57				if !isLeader {
58					tsd.lastContact = now
59				}
60				g := NewGatewayLocator(
61					logger,
62					tsd,
63					"dc1",
64					"dc1",
65				)
66				g.SetUseReplicationSignal(isLeader)
67
68				t.Run("before first run", func(t *testing.T) {
69					assert.False(t, g.DialPrimaryThroughLocalGateway()) // not important
70					assert.Len(t, tsd.Calls, 0)
71					assert.Equal(t, []string(nil), g.listGateways(false))
72					assert.Equal(t, []string(nil), g.listGateways(true))
73					assert.False(t, tsd.datacenterSupportsFederationStates())
74				})
75
76				idx, err := g.runOnce(0)
77				require.NoError(t, err)
78				assert.Equal(t, uint64(1), idx)
79
80				t.Run("after first run", func(t *testing.T) {
81					assert.False(t, g.DialPrimaryThroughLocalGateway()) // not important
82					assert.Len(t, tsd.Calls, 1)
83					assert.Equal(t, []string(nil), g.listGateways(false))
84					assert.Equal(t, []string(nil), g.listGateways(true))
85					assert.False(t, tsd.datacenterSupportsFederationStates()) // no results, so we don't flip the bit yet
86				})
87			})
88		}
89	})
90
91	t.Run("secondary - no data", func(t *testing.T) {
92		for _, role := range serverRoles {
93			t.Run(role, func(t *testing.T) {
94				isLeader := role == "leader"
95
96				logger := testutil.Logger(t)
97				tsd := &testServerDelegate{State: state, isLeader: isLeader}
98				if !isLeader {
99					tsd.lastContact = now
100				}
101				g := NewGatewayLocator(
102					logger,
103					tsd,
104					"dc2",
105					"dc1",
106				)
107				g.SetUseReplicationSignal(isLeader)
108
109				t.Run("before first run", func(t *testing.T) {
110					assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
111					assert.Len(t, tsd.Calls, 0)
112					assert.Equal(t, []string(nil), g.listGateways(false))
113					assert.Equal(t, []string(nil), g.listGateways(true))
114					assert.False(t, tsd.datacenterSupportsFederationStates())
115				})
116
117				idx, err := g.runOnce(0)
118				require.NoError(t, err)
119				assert.Equal(t, uint64(1), idx)
120
121				t.Run("after first run", func(t *testing.T) {
122					assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
123					assert.Len(t, tsd.Calls, 1)
124					assert.Equal(t, []string(nil), g.listGateways(false))
125					assert.Equal(t, []string(nil), g.listGateways(true))
126					assert.False(t, tsd.datacenterSupportsFederationStates()) // no results, so we don't flip the bit yet
127				})
128			})
129		}
130	})
131
132	t.Run("secondary - just fallback", func(t *testing.T) {
133		for _, role := range serverRoles {
134			t.Run(role, func(t *testing.T) {
135				isLeader := role == "leader"
136
137				logger := testutil.Logger(t)
138				tsd := &testServerDelegate{State: state, isLeader: isLeader}
139				if !isLeader {
140					tsd.lastContact = now
141				}
142				g := NewGatewayLocator(
143					logger,
144					tsd,
145					"dc2",
146					"dc1",
147				)
148				g.SetUseReplicationSignal(isLeader)
149				g.RefreshPrimaryGatewayFallbackAddresses([]string{
150					"7.7.7.7:7777",
151					"8.8.8.8:8888",
152				})
153
154				t.Run("before first run", func(t *testing.T) {
155					assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
156					assert.Len(t, tsd.Calls, 0)
157					assert.Equal(t, []string(nil), g.listGateways(false))
158					assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
159					assert.False(t, tsd.datacenterSupportsFederationStates())
160				})
161
162				idx, err := g.runOnce(0)
163				require.NoError(t, err)
164				assert.Equal(t, uint64(1), idx)
165
166				t.Run("after first run", func(t *testing.T) {
167					assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
168					assert.Len(t, tsd.Calls, 1)
169					assert.Equal(t, []string(nil), g.listGateways(false))
170					assert.Equal(t, []string{
171						"7.7.7.7:7777",
172						"8.8.8.8:8888",
173					}, g.listGateways(true))
174					assert.False(t, tsd.datacenterSupportsFederationStates()) // no results, so we don't flip the bit yet
175				})
176			})
177		}
178	})
179
180	// Insert data for the dcs
181	require.NoError(t, state.FederationStateSet(1, dc1))
182	require.NoError(t, state.FederationStateSet(2, dc2))
183
184	t.Run("primary - with data", func(t *testing.T) {
185		for _, role := range serverRoles {
186			t.Run(role, func(t *testing.T) {
187				isLeader := role == "leader"
188
189				logger := testutil.Logger(t)
190				tsd := &testServerDelegate{State: state, isLeader: isLeader}
191				if !isLeader {
192					tsd.lastContact = now
193				}
194				g := NewGatewayLocator(
195					logger,
196					tsd,
197					"dc1",
198					"dc1",
199				)
200				g.SetUseReplicationSignal(isLeader)
201
202				t.Run("before first run", func(t *testing.T) {
203					assert.False(t, g.DialPrimaryThroughLocalGateway()) // not important
204					assert.Len(t, tsd.Calls, 0)
205					assert.Equal(t, []string(nil), g.listGateways(false))
206					assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
207					assert.False(t, tsd.datacenterSupportsFederationStates())
208				})
209
210				idx, err := g.runOnce(0)
211				require.NoError(t, err)
212				assert.Equal(t, uint64(2), idx)
213
214				t.Run("after first run", func(t *testing.T) {
215					assert.False(t, g.DialPrimaryThroughLocalGateway()) // not important
216					assert.Len(t, tsd.Calls, 1)
217					assert.Equal(t, []string{
218						"1.2.3.4:5555",
219						"4.3.2.1:9999",
220					}, g.listGateways(false))
221					assert.Equal(t, []string{
222						"1.2.3.4:5555",
223						"4.3.2.1:9999",
224					}, g.listGateways(true))
225					assert.True(t, tsd.datacenterSupportsFederationStates()) // have results, so we flip the bit
226				})
227			})
228		}
229	})
230
231	t.Run("secondary - with data", func(t *testing.T) {
232		for _, role := range serverRoles {
233			t.Run(role, func(t *testing.T) {
234				isLeader := role == "leader"
235
236				logger := testutil.Logger(t)
237				tsd := &testServerDelegate{State: state, isLeader: isLeader}
238				if !isLeader {
239					tsd.lastContact = now
240				}
241				g := NewGatewayLocator(
242					logger,
243					tsd,
244					"dc2",
245					"dc1",
246				)
247				g.SetUseReplicationSignal(isLeader)
248
249				t.Run("before first run", func(t *testing.T) {
250					assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
251					assert.Len(t, tsd.Calls, 0)
252					assert.Equal(t, []string(nil), g.listGateways(false))
253					assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
254					assert.False(t, tsd.datacenterSupportsFederationStates())
255				})
256
257				idx, err := g.runOnce(0)
258				require.NoError(t, err)
259				assert.Equal(t, uint64(2), idx)
260
261				t.Run("after first run", func(t *testing.T) {
262					assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
263					assert.Len(t, tsd.Calls, 1)
264					assert.Equal(t, []string{
265						"5.6.7.8:5555",
266						"8.7.6.5:9999",
267					}, g.listGateways(false))
268					assert.Equal(t, []string{
269						"5.6.7.8:5555",
270						"8.7.6.5:9999",
271					}, g.listGateways(true))
272					assert.True(t, tsd.datacenterSupportsFederationStates()) // have results, so we flip the bit
273				})
274
275			})
276		}
277	})
278
279	t.Run("secondary - with data and fallback - repl ok", func(t *testing.T) {
280		// Only run for the leader.
281		logger := testutil.Logger(t)
282		tsd := &testServerDelegate{State: state, isLeader: true}
283		g := NewGatewayLocator(
284			logger,
285			tsd,
286			"dc2",
287			"dc1",
288		)
289		g.SetUseReplicationSignal(true)
290
291		g.RefreshPrimaryGatewayFallbackAddresses([]string{
292			"7.7.7.7:7777",
293			"8.8.8.8:8888",
294		})
295
296		g.SetLastFederationStateReplicationError(nil, true)
297
298		t.Run("before first run", func(t *testing.T) {
299			assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
300			assert.Len(t, tsd.Calls, 0)
301			assert.Equal(t, []string(nil), g.listGateways(false))
302			assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
303			assert.False(t, tsd.datacenterSupportsFederationStates())
304		})
305
306		idx, err := g.runOnce(0)
307		require.NoError(t, err)
308		assert.Equal(t, uint64(2), idx)
309
310		t.Run("after first run", func(t *testing.T) {
311			assert.True(t, g.DialPrimaryThroughLocalGateway())
312			assert.Len(t, tsd.Calls, 1)
313			assert.Equal(t, []string{
314				"5.6.7.8:5555",
315				"8.7.6.5:9999",
316			}, g.listGateways(false))
317			assert.Equal(t, []string{
318				"5.6.7.8:5555",
319				"8.7.6.5:9999",
320			}, g.listGateways(true))
321			assert.True(t, tsd.datacenterSupportsFederationStates()) // have results, so we flip the bit
322		})
323	})
324
325	t.Run("secondary - with data and fallback - repl ok then failed 2 times", func(t *testing.T) {
326		// Only run for the leader.
327		logger := testutil.Logger(t)
328		tsd := &testServerDelegate{State: state, isLeader: true}
329		g := NewGatewayLocator(
330			logger,
331			tsd,
332			"dc2",
333			"dc1",
334		)
335		g.SetUseReplicationSignal(true)
336
337		g.RefreshPrimaryGatewayFallbackAddresses([]string{
338			"7.7.7.7:7777",
339			"8.8.8.8:8888",
340		})
341
342		g.SetLastFederationStateReplicationError(nil, true)
343		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
344		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
345
346		t.Run("before first run", func(t *testing.T) {
347			assert.True(t, g.DialPrimaryThroughLocalGateway()) // defaults to sure!
348			assert.Len(t, tsd.Calls, 0)
349			assert.Equal(t, []string(nil), g.listGateways(false))
350			assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
351			assert.False(t, tsd.datacenterSupportsFederationStates())
352		})
353
354		idx, err := g.runOnce(0)
355		require.NoError(t, err)
356		assert.Equal(t, uint64(2), idx)
357
358		t.Run("after first run", func(t *testing.T) {
359			assert.True(t, g.DialPrimaryThroughLocalGateway())
360			assert.Len(t, tsd.Calls, 1)
361			assert.Equal(t, []string{
362				"5.6.7.8:5555",
363				"8.7.6.5:9999",
364			}, g.listGateways(false))
365			assert.Equal(t, []string{
366				"5.6.7.8:5555",
367				"8.7.6.5:9999",
368			}, g.listGateways(true))
369			assert.True(t, tsd.datacenterSupportsFederationStates()) // have results, so we flip the bit
370		})
371	})
372
373	t.Run("secondary - with data and fallback - repl ok then failed 3 times", func(t *testing.T) {
374		// Only run for the leader.
375		logger := testutil.Logger(t)
376		tsd := &testServerDelegate{State: state, isLeader: true}
377		g := NewGatewayLocator(
378			logger,
379			tsd,
380			"dc2",
381			"dc1",
382		)
383		g.SetUseReplicationSignal(true)
384
385		g.RefreshPrimaryGatewayFallbackAddresses([]string{
386			"7.7.7.7:7777",
387			"8.8.8.8:8888",
388		})
389
390		g.SetLastFederationStateReplicationError(nil, true)
391		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
392		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
393		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
394
395		t.Run("before first run", func(t *testing.T) {
396			assert.False(t, g.DialPrimaryThroughLocalGateway()) // too many errors
397			assert.Len(t, tsd.Calls, 0)
398			assert.Equal(t, []string(nil), g.listGateways(false))
399			assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
400			assert.False(t, tsd.datacenterSupportsFederationStates())
401		})
402
403		idx, err := g.runOnce(0)
404		require.NoError(t, err)
405		assert.Equal(t, uint64(2), idx)
406
407		t.Run("after first run", func(t *testing.T) {
408			assert.False(t, g.DialPrimaryThroughLocalGateway())
409			assert.Len(t, tsd.Calls, 1)
410			assert.Equal(t, []string{
411				"5.6.7.8:5555",
412				"8.7.6.5:9999",
413			}, g.listGateways(false))
414			assert.Equal(t, []string{
415				"1.2.3.4:5555",
416				"4.3.2.1:9999",
417				"7.7.7.7:7777",
418				"8.8.8.8:8888",
419			}, g.listGateways(true))
420			assert.True(t, tsd.datacenterSupportsFederationStates()) // have results, so we flip the bit
421		})
422	})
423
424	t.Run("secondary - with data and fallback - repl ok then failed 3 times then ok again", func(t *testing.T) {
425		// Only run for the leader.
426		logger := testutil.Logger(t)
427		tsd := &testServerDelegate{State: state, isLeader: true}
428		g := NewGatewayLocator(
429			logger,
430			tsd,
431			"dc2",
432			"dc1",
433		)
434		g.SetUseReplicationSignal(true)
435
436		g.RefreshPrimaryGatewayFallbackAddresses([]string{
437			"7.7.7.7:7777",
438			"8.8.8.8:8888",
439		})
440
441		g.SetLastFederationStateReplicationError(nil, true)
442		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
443		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
444		g.SetLastFederationStateReplicationError(errors.New("fake"), true)
445		g.SetLastFederationStateReplicationError(nil, true)
446
447		t.Run("before first run", func(t *testing.T) {
448			assert.True(t, g.DialPrimaryThroughLocalGateway()) // all better again
449			assert.Len(t, tsd.Calls, 0)
450			assert.Equal(t, []string(nil), g.listGateways(false))
451			assert.Equal(t, []string(nil), g.listGateways(true)) // don't return any data until we initialize
452			assert.False(t, tsd.datacenterSupportsFederationStates())
453		})
454
455		idx, err := g.runOnce(0)
456		require.NoError(t, err)
457		assert.Equal(t, uint64(2), idx)
458
459		t.Run("after first run", func(t *testing.T) {
460			assert.True(t, g.DialPrimaryThroughLocalGateway()) // all better again
461			assert.Len(t, tsd.Calls, 1)
462			assert.Equal(t, []string{
463				"5.6.7.8:5555",
464				"8.7.6.5:9999",
465			}, g.listGateways(false))
466			assert.Equal(t, []string{
467				"5.6.7.8:5555",
468				"8.7.6.5:9999",
469			}, g.listGateways(true))
470			assert.True(t, tsd.datacenterSupportsFederationStates()) // have results, so we flip the bit
471		})
472	})
473}
474
475type testServerDelegate struct {
476	dcSupportsFederationStates int32 // atomically accessed, at start to prevent alignment issues
477
478	State *state.Store
479
480	Calls []uint64
481
482	isLeader    bool
483	lastContact time.Time
484}
485
486func (d *testServerDelegate) setDatacenterSupportsFederationStates() {
487	atomic.StoreInt32(&d.dcSupportsFederationStates, 1)
488}
489
490func (d *testServerDelegate) datacenterSupportsFederationStates() bool {
491	return atomic.LoadInt32(&d.dcSupportsFederationStates) != 0
492}
493
494// This is just enough to exercise the logic.
495func (d *testServerDelegate) blockingQuery(
496	queryOpts structs.QueryOptionsCompat,
497	queryMeta structs.QueryMetaCompat,
498	fn queryFn,
499) error {
500	minQueryIndex := queryOpts.GetMinQueryIndex()
501
502	d.Calls = append(d.Calls, minQueryIndex)
503
504	var ws memdb.WatchSet
505
506	err := fn(ws, d.State)
507	if err == nil && queryMeta.GetIndex() < 1 {
508		queryMeta.SetIndex(1)
509	}
510
511	return err
512}
513
514func (d *testServerDelegate) IsLeader() bool {
515	return d.isLeader
516}
517
518func (d *testServerDelegate) LeaderLastContact() time.Time {
519	return d.lastContact
520}
521