1package disco
2
3import (
4	"crypto/tls"
5	"net/http"
6	"net/http/httptest"
7	"net/url"
8	"os"
9	"strconv"
10	"testing"
11
12	"github.com/hashicorp/terraform-svchost"
13	"github.com/hashicorp/terraform-svchost/auth"
14)
15
16func TestMain(m *testing.M) {
17	// During all tests we override the HTTP transport we use for discovery
18	// so it'll tolerate the locally-generated TLS certificates we use
19	// for test URLs.
20	httpTransport = &http.Transport{
21		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
22	}
23
24	os.Exit(m.Run())
25}
26
27func TestDiscover(t *testing.T) {
28	t.Run("happy path", func(t *testing.T) {
29		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
30			resp := []byte(`
31{
32"thingy.v1": "http://example.com/foo",
33"wotsit.v2": "http://example.net/bar"
34}
35`)
36			w.Header().Add("Content-Type", "application/json")
37			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
38			w.Write(resp)
39		})
40		defer close()
41
42		givenHost := "localhost" + portStr
43		host, err := svchost.ForComparison(givenHost)
44		if err != nil {
45			t.Fatalf("test server hostname is invalid: %s", err)
46		}
47
48		d := New()
49		discovered, err := d.Discover(host)
50		if err != nil {
51			t.Fatalf("unexpected discovery error: %s", err)
52		}
53
54		gotURL, err := discovered.ServiceURL("thingy.v1")
55		if err != nil {
56			t.Fatalf("unexpected service URL error: %s", err)
57		}
58		if gotURL == nil {
59			t.Fatalf("found no URL for thingy.v1")
60		}
61		if got, want := gotURL.String(), "http://example.com/foo"; got != want {
62			t.Fatalf("wrong result %q; want %q", got, want)
63		}
64	})
65	t.Run("chunked encoding", func(t *testing.T) {
66		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
67			resp := []byte(`
68{
69"thingy.v1": "http://example.com/foo",
70"wotsit.v2": "http://example.net/bar"
71}
72`)
73			w.Header().Add("Content-Type", "application/json")
74			// We're going to force chunked encoding here -- and thus prevent
75			// the server from predicting the length -- so we can make sure
76			// our client is tolerant of servers using this encoding.
77			w.Write(resp[:5])
78			w.(http.Flusher).Flush()
79			w.Write(resp[5:])
80			w.(http.Flusher).Flush()
81		})
82		defer close()
83
84		givenHost := "localhost" + portStr
85		host, err := svchost.ForComparison(givenHost)
86		if err != nil {
87			t.Fatalf("test server hostname is invalid: %s", err)
88		}
89
90		d := New()
91		discovered, err := d.Discover(host)
92		if err != nil {
93			t.Fatalf("unexpected discovery error: %s", err)
94		}
95
96		gotURL, err := discovered.ServiceURL("wotsit.v2")
97		if err != nil {
98			t.Fatalf("unexpected service URL error: %s", err)
99		}
100		if gotURL == nil {
101			t.Fatalf("found no URL for wotsit.v2")
102		}
103		if got, want := gotURL.String(), "http://example.net/bar"; got != want {
104			t.Fatalf("wrong result %q; want %q", got, want)
105		}
106	})
107	t.Run("with credentials", func(t *testing.T) {
108		var authHeaderText string
109		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
110			resp := []byte(`{}`)
111			authHeaderText = r.Header.Get("Authorization")
112			w.Header().Add("Content-Type", "application/json")
113			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
114			w.Write(resp)
115		})
116		defer close()
117
118		givenHost := "localhost" + portStr
119		host, err := svchost.ForComparison(givenHost)
120		if err != nil {
121			t.Fatalf("test server hostname is invalid: %s", err)
122		}
123
124		d := New()
125		d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
126			host: map[string]interface{}{
127				"token": "abc123",
128			},
129		}))
130		d.Discover(host)
131		if got, want := authHeaderText, "Bearer abc123"; got != want {
132			t.Fatalf("wrong Authorization header\ngot:  %s\nwant: %s", got, want)
133		}
134	})
135	t.Run("forced services override", func(t *testing.T) {
136		forced := map[string]interface{}{
137			"thingy.v1": "http://example.net/foo",
138			"wotsit.v2": "/foo",
139		}
140
141		d := New()
142		d.ForceHostServices(svchost.Hostname("example.com"), forced)
143
144		givenHost := "example.com"
145		host, err := svchost.ForComparison(givenHost)
146		if err != nil {
147			t.Fatalf("test server hostname is invalid: %s", err)
148		}
149
150		discovered, err := d.Discover(host)
151		if err != nil {
152			t.Fatalf("unexpected discovery error: %s", err)
153		}
154		{
155			gotURL, err := discovered.ServiceURL("thingy.v1")
156			if err != nil {
157				t.Fatalf("unexpected service URL error: %s", err)
158			}
159			if gotURL == nil {
160				t.Fatalf("found no URL for thingy.v1")
161			}
162			if got, want := gotURL.String(), "http://example.net/foo"; got != want {
163				t.Fatalf("wrong result %q; want %q", got, want)
164			}
165		}
166		{
167			gotURL, err := discovered.ServiceURL("wotsit.v2")
168			if err != nil {
169				t.Fatalf("unexpected service URL error: %s", err)
170			}
171			if gotURL == nil {
172				t.Fatalf("found no URL for wotsit.v2")
173			}
174			if got, want := gotURL.String(), "https://example.com/foo"; got != want {
175				t.Fatalf("wrong result %q; want %q", got, want)
176			}
177		}
178	})
179	t.Run("not JSON", func(t *testing.T) {
180		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
181			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
182			w.Header().Add("Content-Type", "application/octet-stream")
183			w.Write(resp)
184		})
185		defer close()
186
187		givenHost := "localhost" + portStr
188		host, err := svchost.ForComparison(givenHost)
189		if err != nil {
190			t.Fatalf("test server hostname is invalid: %s", err)
191		}
192
193		d := New()
194		discovered, err := d.Discover(host)
195		if err == nil {
196			t.Fatalf("expected a discovery error")
197		}
198
199		// Returned discovered should be nil.
200		if discovered != nil {
201			t.Errorf("discovered not nil; should be")
202		}
203	})
204	t.Run("malformed JSON", func(t *testing.T) {
205		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
206			resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
207			w.Header().Add("Content-Type", "application/json")
208			w.Write(resp)
209		})
210		defer close()
211
212		givenHost := "localhost" + portStr
213		host, err := svchost.ForComparison(givenHost)
214		if err != nil {
215			t.Fatalf("test server hostname is invalid: %s", err)
216		}
217
218		d := New()
219		discovered, err := d.Discover(host)
220		if err == nil {
221			t.Fatalf("expected a discovery error")
222		}
223
224		// Returned discovered should be nil.
225		if discovered != nil {
226			t.Errorf("discovered not nil; should be")
227		}
228	})
229	t.Run("JSON with redundant charset", func(t *testing.T) {
230		// The JSON RFC defines no parameters for the application/json
231		// MIME type, but some servers have a weird tendency to just add
232		// "charset" to everything, so we'll make sure we ignore it successfully.
233		// (JSON uses content sniffing for encoding detection, not media type params.)
234		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
235			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
236			w.Header().Add("Content-Type", "application/json; charset=latin-1")
237			w.Write(resp)
238		})
239		defer close()
240
241		givenHost := "localhost" + portStr
242		host, err := svchost.ForComparison(givenHost)
243		if err != nil {
244			t.Fatalf("test server hostname is invalid: %s", err)
245		}
246
247		d := New()
248		discovered, err := d.Discover(host)
249		if err != nil {
250			t.Fatalf("unexpected discovery error: %s", err)
251		}
252
253		if discovered.services == nil {
254			t.Errorf("response is empty; shouldn't be")
255		}
256	})
257	t.Run("no discovery doc", func(t *testing.T) {
258		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
259			w.WriteHeader(404)
260		})
261		defer close()
262
263		givenHost := "localhost" + portStr
264		host, err := svchost.ForComparison(givenHost)
265		if err != nil {
266			t.Fatalf("test server hostname is invalid: %s", err)
267		}
268
269		d := New()
270		discovered, err := d.Discover(host)
271		if err != nil {
272			t.Fatalf("unexpected discovery error: %s", err)
273		}
274
275		// Returned discovered.services should be nil (empty).
276		if discovered.services != nil {
277			t.Errorf("discovered.services not nil (empty); should be")
278		}
279	})
280	t.Run("redirect", func(t *testing.T) {
281		// For this test, we have two servers and one redirects to the other
282		portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
283			// This server is the one that returns a real response.
284			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
285			w.Header().Add("Content-Type", "application/json")
286			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
287			w.Write(resp)
288		})
289		portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
290			// This server is the one that redirects.
291			http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
292		})
293		defer close1()
294		defer close2()
295
296		givenHost := "localhost" + portStr2
297		host, err := svchost.ForComparison(givenHost)
298		if err != nil {
299			t.Fatalf("test server hostname is invalid: %s", err)
300		}
301
302		d := New()
303		discovered, err := d.Discover(host)
304		if err != nil {
305			t.Fatalf("unexpected discovery error: %s", err)
306		}
307
308		gotURL, err := discovered.ServiceURL("thingy.v1")
309		if err != nil {
310			t.Fatalf("unexpected service URL error: %s", err)
311		}
312		if gotURL == nil {
313			t.Fatalf("found no URL for thingy.v1")
314		}
315		if got, want := gotURL.String(), "http://example.com/foo"; got != want {
316			t.Fatalf("wrong result %q; want %q", got, want)
317		}
318
319		// The base URL for the host object should be the URL we redirected to,
320		// rather than the we redirected _from_.
321		gotBaseURL := discovered.discoURL.String()
322		wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
323		if gotBaseURL != wantBaseURL {
324			t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
325		}
326
327	})
328}
329
330func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
331	server := httptest.NewTLSServer(http.HandlerFunc(
332		func(w http.ResponseWriter, r *http.Request) {
333			// Test server always returns 404 if the URL isn't what we expect
334			if r.URL.Path != "/.well-known/terraform.json" {
335				w.WriteHeader(404)
336				w.Write([]byte("not found"))
337				return
338			}
339
340			// If the URL is correct then the given hander decides the response
341			h(w, r)
342		},
343	))
344
345	serverURL, _ := url.Parse(server.URL)
346
347	portStr = serverURL.Port()
348	if portStr != "" {
349		portStr = ":" + portStr
350	}
351
352	close = func() {
353		server.Close()
354	}
355
356	return portStr, close
357}
358