1package structs
2
3import (
4	"encoding/json"
5	"fmt"
6	"testing"
7
8	"github.com/hashicorp/consul/api"
9	"github.com/stretchr/testify/require"
10)
11
12func TestConnectProxyConfig_ToAPI(t *testing.T) {
13	tests := []struct {
14		name string
15		in   ConnectProxyConfig
16		want *api.AgentServiceConnectProxyConfig
17	}{
18		{
19			name: "service",
20			in: ConnectProxyConfig{
21				DestinationServiceName: "web",
22				DestinationServiceID:   "web1",
23				LocalServiceAddress:    "127.0.0.2",
24				LocalServicePort:       5555,
25				Config: map[string]interface{}{
26					"foo": "bar",
27				},
28				MeshGateway: MeshGatewayConfig{
29					Mode: MeshGatewayModeLocal,
30				},
31				Upstreams: Upstreams{
32					{
33						DestinationType: UpstreamDestTypeService,
34						DestinationName: "foo",
35						Datacenter:      "dc1",
36						LocalBindPort:   1234,
37						MeshGateway: MeshGatewayConfig{
38							Mode: MeshGatewayModeLocal,
39						},
40					},
41					{
42						DestinationType:  UpstreamDestTypePreparedQuery,
43						DestinationName:  "foo",
44						Datacenter:       "dc1",
45						LocalBindPort:    2345,
46						LocalBindAddress: "127.10.10.10",
47					},
48				},
49			},
50			want: &api.AgentServiceConnectProxyConfig{
51				DestinationServiceName: "web",
52				DestinationServiceID:   "web1",
53				LocalServiceAddress:    "127.0.0.2",
54				LocalServicePort:       5555,
55				Config: map[string]interface{}{
56					"foo": "bar",
57				},
58				MeshGateway: api.MeshGatewayConfig{
59					Mode: api.MeshGatewayModeLocal,
60				},
61				Upstreams: []api.Upstream{
62					{
63						DestinationType: UpstreamDestTypeService,
64						DestinationName: "foo",
65						Datacenter:      "dc1",
66						LocalBindPort:   1234,
67						MeshGateway: api.MeshGatewayConfig{
68							Mode: api.MeshGatewayModeLocal,
69						},
70					},
71					{
72						DestinationType:  UpstreamDestTypePreparedQuery,
73						DestinationName:  "foo",
74						Datacenter:       "dc1",
75						LocalBindPort:    2345,
76						LocalBindAddress: "127.10.10.10",
77					},
78				},
79			},
80		},
81	}
82	for _, tt := range tests {
83		t.Run(tt.name, func(t *testing.T) {
84			require.Equal(t, tt.want, tt.in.ToAPI())
85		})
86	}
87}
88
89func TestUpstream_MarshalJSON(t *testing.T) {
90	tests := []struct {
91		name    string
92		in      Upstream
93		want    string
94		wantErr bool
95	}{
96		{
97			name: "service",
98			in: Upstream{
99				DestinationType: UpstreamDestTypeService,
100				DestinationName: "foo",
101				Datacenter:      "dc1",
102				LocalBindPort:   1234,
103			},
104			want: `{
105				"DestinationType": "service",
106				"DestinationName": "foo",
107				"Datacenter": "dc1",
108				"LocalBindPort": 1234,
109				"MeshGateway": {},
110				"Config": null
111			}`,
112			wantErr: false,
113		},
114		{
115			name: "pq",
116			in: Upstream{
117				DestinationType: UpstreamDestTypePreparedQuery,
118				DestinationName: "foo",
119				Datacenter:      "dc1",
120				LocalBindPort:   1234,
121			},
122			want: `{
123				"DestinationType": "prepared_query",
124				"DestinationName": "foo",
125				"Datacenter": "dc1",
126				"LocalBindPort": 1234,
127				"MeshGateway": {},
128				"Config": null
129			}`,
130			wantErr: false,
131		},
132	}
133	for _, tt := range tests {
134		t.Run(tt.name, func(t *testing.T) {
135			require := require.New(t)
136			got, err := json.Marshal(tt.in)
137			if tt.wantErr {
138				require.Error(err)
139				return
140			}
141			require.NoError(err)
142			require.JSONEq(tt.want, string(got))
143		})
144	}
145}
146
147func TestUpstream_UnmarshalJSON(t *testing.T) {
148	tests := []struct {
149		name    string
150		json    string
151		want    Upstream
152		wantErr bool
153	}{
154		{
155			name: "service",
156			json: `{
157				"DestinationType": "service",
158				"DestinationName": "foo",
159				"Datacenter": "dc1"
160			}`,
161			want: Upstream{
162				DestinationType: UpstreamDestTypeService,
163				DestinationName: "foo",
164				Datacenter:      "dc1",
165			},
166			wantErr: false,
167		},
168		{
169			name: "pq",
170			json: `{
171				"DestinationType": "prepared_query",
172				"DestinationName": "foo",
173				"Datacenter": "dc1"
174			}`,
175			want: Upstream{
176				DestinationType: UpstreamDestTypePreparedQuery,
177				DestinationName: "foo",
178				Datacenter:      "dc1",
179			},
180			wantErr: false,
181		},
182	}
183	for _, tt := range tests {
184		t.Run(tt.name, func(t *testing.T) {
185			require := require.New(t)
186			var got Upstream
187			err := json.Unmarshal([]byte(tt.json), &got)
188			if tt.wantErr {
189				require.Error(err)
190				return
191			}
192			require.NoError(err)
193			require.Equal(tt.want, got)
194		})
195	}
196}
197
198func TestMeshGatewayConfig_OverlayWith(t *testing.T) {
199	var (
200		D = MeshGatewayConfig{Mode: MeshGatewayModeDefault}
201		N = MeshGatewayConfig{Mode: MeshGatewayModeNone}
202		R = MeshGatewayConfig{Mode: MeshGatewayModeRemote}
203		L = MeshGatewayConfig{Mode: MeshGatewayModeLocal}
204	)
205
206	type testCase struct {
207		base, overlay, expect MeshGatewayConfig
208	}
209	cases := []testCase{
210		{D, D, D},
211		{D, N, N},
212		{D, R, R},
213		{D, L, L},
214		{N, D, N},
215		{N, N, N},
216		{N, R, R},
217		{N, L, L},
218		{R, D, R},
219		{R, N, N},
220		{R, R, R},
221		{R, L, L},
222		{L, D, L},
223		{L, N, N},
224		{L, R, R},
225		{L, L, L},
226	}
227
228	for _, tc := range cases {
229		tc := tc
230
231		t.Run(fmt.Sprintf("%s overlaid with %s", tc.base.Mode, tc.overlay.Mode),
232			func(t *testing.T) {
233				got := tc.base.OverlayWith(tc.overlay)
234				require.Equal(t, tc.expect, got)
235			})
236	}
237}
238
239func TestValidateMeshGatewayMode(t *testing.T) {
240	for _, tc := range []struct {
241		modeConstant string
242		modeExplicit string
243		expect       MeshGatewayMode
244		ok           bool
245	}{
246		{string(MeshGatewayModeNone), "none", MeshGatewayModeNone, true},
247		{string(MeshGatewayModeDefault), "", MeshGatewayModeDefault, true},
248		{string(MeshGatewayModeLocal), "local", MeshGatewayModeLocal, true},
249		{string(MeshGatewayModeRemote), "remote", MeshGatewayModeRemote, true},
250	} {
251		tc := tc
252
253		t.Run(tc.modeConstant+" (constant)", func(t *testing.T) {
254			got, err := ValidateMeshGatewayMode(tc.modeConstant)
255			if tc.ok {
256				require.NoError(t, err)
257				require.Equal(t, tc.expect, got)
258			} else {
259				require.Error(t, err)
260			}
261		})
262		t.Run(tc.modeExplicit+" (explicit)", func(t *testing.T) {
263			got, err := ValidateMeshGatewayMode(tc.modeExplicit)
264			if tc.ok {
265				require.NoError(t, err)
266				require.Equal(t, tc.expect, got)
267			} else {
268				require.Error(t, err)
269			}
270		})
271	}
272}
273