1package httpmock
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"errors"
8	"io/ioutil"
9	"net"
10	"net/http"
11	"net/url"
12	"reflect"
13	"regexp"
14	"strings"
15	"testing"
16	"time"
17)
18
19var testURL = "http://www.example.com/"
20
21func assertBody(t *testing.T, resp *http.Response, expected string) {
22	defer resp.Body.Close()
23
24	data, err := ioutil.ReadAll(resp.Body)
25	if err != nil {
26		t.Fatal(err)
27	}
28
29	got := string(data)
30
31	if got != expected {
32		t.Errorf("Expected body: %#v, got %#v", expected, got)
33	}
34}
35
36func TestRouteKey(t *testing.T) {
37	got, expected := noResponder.String(), "NO_RESPONDER"
38	if got != expected {
39		t.Errorf("got: %v, expected: %v", got, expected)
40	}
41
42	got, expected = routeKey{Method: "GET", URL: "/foo"}.String(), "GET /foo"
43	if got != expected {
44		t.Errorf("got: %v, expected: %v", got, expected)
45	}
46}
47
48func TestMockTransport(t *testing.T) {
49	Activate()
50	defer Deactivate()
51
52	url := "https://github.com/"
53	body := `["hello world"]` + "\n"
54
55	RegisterResponder("GET", url, NewStringResponder(200, body))
56
57	// Read it as a simple string (ioutil.ReadAll will trigger io.EOF)
58	func() {
59		resp, err := http.Get(url)
60		if err != nil {
61			t.Fatal(err)
62		}
63		defer resp.Body.Close()
64
65		data, err := ioutil.ReadAll(resp.Body)
66		if err != nil {
67			t.Fatal(err)
68		}
69
70		if string(data) != body {
71			t.FailNow()
72		}
73
74		// the http client wraps our NoResponderFound error, so we just try and match on text
75		if _, err := http.Get(testURL); !strings.Contains(err.Error(), NoResponderFound.Error()) {
76			t.Fatal(err)
77		}
78	}()
79
80	// Do it again, but twice with json decoder (json Decode will not
81	// reach EOF, but Close is called as the JSON response is complete)
82	for i := 0; i < 2; i++ {
83		func() {
84			resp, err := http.Get(url)
85			if err != nil {
86				t.Fatal(err)
87			}
88			defer resp.Body.Close()
89
90			var res []string
91			err = json.NewDecoder(resp.Body).Decode(&res)
92			if err != nil {
93				t.Fatal(err)
94			}
95
96			if len(res) != 1 || res[0] != "hello world" {
97				t.Fatalf(`%v read instead of ["hello world"]`, res)
98			}
99		}()
100	}
101}
102
103// We should be able to find GET handlers when using an http.Request with a
104// default (zero-value) .Method.
105func TestMockTransportDefaultMethod(t *testing.T) {
106	Activate()
107	defer Deactivate()
108
109	const urlString = "https://github.com/"
110	url, err := url.Parse(urlString)
111	if err != nil {
112		t.Fatal(err)
113	}
114	body := "hello world"
115
116	RegisterResponder("GET", urlString, NewStringResponder(200, body))
117
118	req := &http.Request{
119		URL: url,
120		// Note: Method unspecified (zero-value)
121	}
122
123	client := &http.Client{}
124	resp, err := client.Do(req)
125	if err != nil {
126		t.Fatal(err)
127	}
128	defer resp.Body.Close()
129
130	data, err := ioutil.ReadAll(resp.Body)
131	if err != nil {
132		t.Fatal(err)
133	}
134
135	if string(data) != body {
136		t.FailNow()
137	}
138}
139
140func TestMockTransportReset(t *testing.T) {
141	DeactivateAndReset()
142
143	if len(DefaultTransport.responders) > 0 {
144		t.Fatal("expected no responders at this point")
145	}
146
147	RegisterResponder("GET", testURL, nil)
148
149	if len(DefaultTransport.responders) != 1 {
150		t.Fatal("expected one responder")
151	}
152
153	Reset()
154
155	if len(DefaultTransport.responders) > 0 {
156		t.Fatal("expected no responders as they were just reset")
157	}
158}
159
160func TestMockTransportNoResponder(t *testing.T) {
161	Activate()
162	defer DeactivateAndReset()
163
164	Reset()
165
166	if DefaultTransport.noResponder != nil {
167		t.Fatal("expected noResponder to be nil")
168	}
169
170	if _, err := http.Get(testURL); err == nil {
171		t.Fatal("expected to receive a connection error due to lack of responders")
172	}
173
174	RegisterNoResponder(NewStringResponder(200, "hello world"))
175
176	resp, err := http.Get(testURL)
177	if err != nil {
178		t.Fatal("expected request to succeed")
179	}
180
181	data, err := ioutil.ReadAll(resp.Body)
182	if err != nil {
183		t.Fatal(err)
184	}
185
186	if string(data) != "hello world" {
187		t.Fatal("expected body to be 'hello world'")
188	}
189}
190
191func TestMockTransportQuerystringFallback(t *testing.T) {
192	Activate()
193	defer DeactivateAndReset()
194
195	// register the testURL responder
196	RegisterResponder("GET", testURL, NewStringResponder(200, "hello world"))
197
198	for _, suffix := range []string{"?", "?hello=world", "?hello=world#foo", "?hello=world&hello=all", "#foo"} {
199		reqURL := testURL + suffix
200
201		// make a request for the testURL with a querystring
202		resp, err := http.Get(reqURL)
203		if err != nil {
204			t.Fatalf("expected request %s to succeed", reqURL)
205		}
206
207		data, err := ioutil.ReadAll(resp.Body)
208		if err != nil {
209			t.Fatalf("%s error: %s", reqURL, err)
210		}
211
212		if string(data) != "hello world" {
213			t.Fatalf("expected body of %s to be 'hello world'", reqURL)
214		}
215	}
216}
217
218func TestMockTransportPathOnlyFallback(t *testing.T) {
219	// Just in case a panic occurs
220	defer DeactivateAndReset()
221
222	for _, test := range []struct {
223		Responder string
224		Paths     []string
225	}{
226		{
227			// unsorted query string matches exactly
228			Responder: "/hello/world?query=string&abc=zz#fragment",
229			Paths: []string{
230				testURL + "hello/world?query=string&abc=zz#fragment",
231			},
232		},
233		{
234			// sorted query string matches all cases
235			Responder: "/hello/world?abc=zz&query=string#fragment",
236			Paths: []string{
237				testURL + "hello/world?query=string&abc=zz#fragment",
238				testURL + "hello/world?abc=zz&query=string#fragment",
239			},
240		},
241		{
242			// unsorted query string matches exactly
243			Responder: "/hello/world?query=string&abc=zz",
244			Paths: []string{
245				testURL + "hello/world?query=string&abc=zz",
246			},
247		},
248		{
249			// sorted query string matches all cases
250			Responder: "/hello/world?abc=zz&query=string",
251			Paths: []string{
252				testURL + "hello/world?query=string&abc=zz",
253				testURL + "hello/world?abc=zz&query=string",
254			},
255		},
256		{
257			// unsorted query string matches exactly
258			Responder: "/hello/world?query=string&query=string2&abc=zz",
259			Paths: []string{
260				testURL + "hello/world?query=string&query=string2&abc=zz",
261			},
262		},
263		// sorted query string matches all cases
264		{
265			Responder: "/hello/world?abc=zz&query=string&query=string2",
266			Paths: []string{
267				testURL + "hello/world?query=string&query=string2&abc=zz",
268				testURL + "hello/world?query=string2&query=string&abc=zz",
269				testURL + "hello/world?abc=zz&query=string2&query=string",
270			},
271		},
272		{
273			Responder: "/hello/world?query",
274			Paths: []string{
275				testURL + "hello/world?query",
276			},
277		},
278		{
279			Responder: "/hello/world?query&abc",
280			Paths: []string{
281				testURL + "hello/world?query&abc",
282				// testURL + "hello/world?abc&query" won' work as "=" is needed, see below
283			},
284		},
285		{
286			// In case the sorting does not matter for received params without
287			// values, we must register params with "="
288			Responder: "/hello/world?abc=&query=",
289			Paths: []string{
290				testURL + "hello/world?query&abc",
291				testURL + "hello/world?abc&query",
292			},
293		},
294		{
295			Responder: "/hello/world#fragment",
296			Paths: []string{
297				testURL + "hello/world#fragment",
298			},
299		},
300		{
301			Responder: "/hello/world",
302			Paths: []string{
303				testURL + "hello/world?query=string&abc=zz#fragment",
304				testURL + "hello/world?query=string&abc=zz",
305				testURL + "hello/world#fragment",
306				testURL + "hello/world",
307			},
308		},
309		// Regexp cases
310		{
311			Responder: `=~^http://.*/hello/.*ld\z`,
312			Paths: []string{
313				testURL + "hello/world?query=string&abc=zz#fragment",
314				testURL + "hello/world?query=string&abc=zz",
315				testURL + "hello/world#fragment",
316				testURL + "hello/world",
317			},
318		},
319		{
320			Responder: `=~^http://.*/hello/.*ld(\z|[?#])`,
321			Paths: []string{
322				testURL + "hello/world?query=string&abc=zz#fragment",
323				testURL + "hello/world?query=string&abc=zz",
324				testURL + "hello/world#fragment",
325				testURL + "hello/world",
326			},
327		},
328		{
329			Responder: `=~^/hello/.*ld\z`,
330			Paths: []string{
331				testURL + "hello/world?query=string&abc=zz#fragment",
332				testURL + "hello/world?query=string&abc=zz",
333				testURL + "hello/world#fragment",
334				testURL + "hello/world",
335			},
336		},
337		{
338			Responder: `=~^/hello/.*ld(\z|[?#])`,
339			Paths: []string{
340				testURL + "hello/world?query=string&abc=zz#fragment",
341				testURL + "hello/world?query=string&abc=zz",
342				testURL + "hello/world#fragment",
343				testURL + "hello/world",
344			},
345		},
346		{
347			Responder: `=~abc=zz`,
348			Paths: []string{
349				testURL + "hello/world?query=string&abc=zz#fragment",
350				testURL + "hello/world?query=string&abc=zz",
351			},
352		},
353	} {
354		Activate()
355
356		// register the responder
357		RegisterResponder("GET", test.Responder, NewStringResponder(200, "hello world"))
358
359		for _, reqURL := range test.Paths {
360			// make a request for the testURL with a querystring
361			resp, err := http.Get(reqURL)
362			if err != nil {
363				t.Fatalf("%s: expected request %s to succeed", test.Responder, reqURL)
364			}
365
366			data, err := ioutil.ReadAll(resp.Body)
367			if err != nil {
368				t.Fatalf("%s: %s error: %s", test.Responder, reqURL, err)
369			}
370
371			if string(data) != "hello world" {
372				t.Fatalf("%s: expected body of %s to be 'hello world'", test.Responder, reqURL)
373			}
374		}
375
376		DeactivateAndReset()
377	}
378}
379
380type dummyTripper struct{}
381
382func (d *dummyTripper) RoundTrip(*http.Request) (*http.Response, error) {
383	return nil, nil
384}
385
386func TestMockTransportInitialTransport(t *testing.T) {
387	DeactivateAndReset()
388
389	tripper := &dummyTripper{}
390	http.DefaultTransport = tripper
391
392	Activate()
393
394	if http.DefaultTransport == tripper {
395		t.Fatal("expected http.DefaultTransport to be a mock transport")
396	}
397
398	Deactivate()
399
400	if http.DefaultTransport != tripper {
401		t.Fatal("expected http.DefaultTransport to be dummy")
402	}
403}
404
405func TestMockTransportNonDefault(t *testing.T) {
406	// create a custom http client w/ custom Roundtripper
407	client := &http.Client{
408		Transport: &http.Transport{
409			Proxy: http.ProxyFromEnvironment,
410			Dial: (&net.Dialer{
411				Timeout:   60 * time.Second,
412				KeepAlive: 30 * time.Second,
413			}).Dial,
414			TLSHandshakeTimeout: 60 * time.Second,
415		},
416	}
417
418	// activate mocks for the client
419	ActivateNonDefault(client)
420	defer DeactivateAndReset()
421
422	body := "hello world!"
423
424	RegisterResponder("GET", testURL, NewStringResponder(200, body))
425
426	req, err := http.NewRequest("GET", testURL, nil)
427	if err != nil {
428		t.Fatal(err)
429	}
430
431	resp, err := client.Do(req)
432	if err != nil {
433		t.Fatal(err)
434	}
435
436	defer resp.Body.Close()
437
438	data, err := ioutil.ReadAll(resp.Body)
439	if err != nil {
440		t.Fatal(err)
441	}
442
443	if string(data) != body {
444		t.FailNow()
445	}
446}
447
448func TestMockTransportRespectsCancel(t *testing.T) {
449	Activate()
450	defer DeactivateAndReset()
451
452	const (
453		cancelNone = iota
454		cancelReq
455		cancelCtx
456	)
457
458	cases := []struct {
459		withCancel   int
460		cancelNow    bool
461		withPanic    bool
462		expectedBody string
463		expectedErr  error
464	}{
465		// No cancel specified at all. Falls back to normal behavior
466		{cancelNone, false, false, "hello world", nil},
467
468		// Cancel returns error
469		{cancelReq, true, false, "", errors.New("request canceled")},
470
471		// Cancel via context returns error
472		{cancelCtx, true, false, "", errors.New("context canceled")},
473
474		// Request can be cancelled but it is not cancelled.
475		{cancelReq, false, false, "hello world", nil},
476
477		// Request can be cancelled but it is not cancelled.
478		{cancelCtx, false, false, "hello world", nil},
479
480		// Panic in cancelled request is handled
481		{cancelReq, false, true, "", errors.New(`panic in responder: got "oh no"`)},
482
483		// Panic in cancelled request is handled
484		{cancelCtx, false, true, "", errors.New(`panic in responder: got "oh no"`)},
485	}
486
487	for _, c := range cases {
488		Reset()
489		if c.withPanic {
490			RegisterResponder("GET", testURL, func(r *http.Request) (*http.Response, error) {
491				time.Sleep(10 * time.Millisecond)
492				panic("oh no")
493			})
494		} else {
495			RegisterResponder("GET", testURL, func(r *http.Request) (*http.Response, error) {
496				time.Sleep(10 * time.Millisecond)
497				return NewStringResponse(http.StatusOK, "hello world"), nil
498			})
499		}
500
501		req, err := http.NewRequest("GET", testURL, nil)
502		if err != nil {
503			t.Fatal(err)
504		}
505
506		switch c.withCancel {
507		case cancelReq:
508			cancel := make(chan struct{}, 1)
509			req.Cancel = cancel // nolint: staticcheck
510			if c.cancelNow {
511				cancel <- struct{}{}
512			}
513		case cancelCtx:
514			ctx, cancel := context.WithCancel(req.Context())
515			req = req.WithContext(ctx)
516			if c.cancelNow {
517				cancel()
518			} else {
519				defer cancel() // avoid ctx leak
520			}
521		}
522
523		resp, err := http.DefaultClient.Do(req)
524
525		// If we expect an error but none was returned, it's fatal for this test...
526		if err == nil && c.expectedErr != nil {
527			t.Fatal("Error should not be nil")
528		}
529
530		if err != nil {
531			got := err.(*url.Error)
532			// Do not use reflect.DeepEqual as go 1.13 includes stack frames
533			// into errors issued by errors.New()
534			if c.expectedErr == nil || got.Err.Error() != c.expectedErr.Error() {
535				t.Errorf("Expected error: %v, got: %v", c.expectedErr, got.Err)
536			}
537		}
538
539		if c.expectedBody != "" {
540			assertBody(t, resp, c.expectedBody)
541		}
542	}
543}
544
545func TestMockTransportRespectsTimeout(t *testing.T) {
546	timeout := time.Millisecond
547	client := &http.Client{
548		Timeout: timeout,
549	}
550
551	ActivateNonDefault(client)
552	defer DeactivateAndReset()
553
554	RegisterResponder(
555		"GET", testURL,
556		func(r *http.Request) (*http.Response, error) {
557			time.Sleep(100 * timeout)
558			return NewStringResponse(http.StatusOK, ""), nil
559		},
560	)
561
562	_, err := client.Get(testURL)
563	if err == nil {
564		t.Fail()
565	}
566}
567
568func TestMockTransportCallCountReset(t *testing.T) {
569	Reset()
570	Activate()
571	defer Deactivate()
572
573	const (
574		url  = "https://github.com/path?b=1&a=2"
575		url2 = "https://gitlab.com/"
576	)
577
578	RegisterResponder("GET", url, NewStringResponder(200, "body"))
579	RegisterResponder("POST", "=~gitlab", NewStringResponder(200, "body"))
580
581	_, err := http.Get(url)
582	if err != nil {
583		t.Fatal(err)
584	}
585
586	buff := new(bytes.Buffer)
587	json.NewEncoder(buff).Encode("{}") // nolint: errcheck
588	_, err = http.Post(url2, "application/json", buff)
589	if err != nil {
590		t.Fatal(err)
591	}
592
593	_, err = http.Get(url)
594	if err != nil {
595		t.Fatal(err)
596	}
597
598	totalCallCount := GetTotalCallCount()
599	if totalCallCount != 3 {
600		t.Fatalf("did not track the total count of calls correctly. expected it to be 3, but it was %v", totalCallCount)
601	}
602
603	info := GetCallCountInfo()
604	expectedInfo := map[string]int{
605		"GET " + url: 2,
606		// Regexp match generates 2 entries:
607		"POST " + url2:  1, // the matched call
608		"POST =~gitlab": 1, // the regexp responder
609	}
610
611	if !reflect.DeepEqual(info, expectedInfo) {
612		t.Fatalf("did not correctly track the call count info. expected it to be \n %+v\n but it was \n %+v", expectedInfo, info)
613	}
614
615	Reset()
616
617	afterResetTotalCallCount := GetTotalCallCount()
618	if afterResetTotalCallCount != 0 {
619		t.Fatalf("did not reset the total count of calls correctly. expected it to be 0 after reset, but it was %v", afterResetTotalCallCount)
620	}
621
622	info = GetCallCountInfo()
623	if !reflect.DeepEqual(info, map[string]int{}) {
624		t.Fatalf("did not correctly reset the call count info. expected it to be \n {}\n but it was \n %+v", info)
625	}
626}
627
628func TestMockTransportCallCountZero(t *testing.T) {
629	Reset()
630	Activate()
631	defer Deactivate()
632
633	const (
634		url  = "https://github.com/path?b=1&a=2"
635		url2 = "https://gitlab.com/"
636	)
637
638	RegisterResponder("GET", url, NewStringResponder(200, "body"))
639	RegisterResponder("POST", "=~gitlab", NewStringResponder(200, "body"))
640
641	_, err := http.Get(url)
642	if err != nil {
643		t.Fatal(err)
644	}
645
646	buff := new(bytes.Buffer)
647	json.NewEncoder(buff).Encode("{}") // nolint: errcheck
648	_, err = http.Post(url2, "application/json", buff)
649	if err != nil {
650		t.Fatal(err)
651	}
652
653	_, err = http.Get(url)
654	if err != nil {
655		t.Fatal(err)
656	}
657
658	totalCallCount := GetTotalCallCount()
659	if totalCallCount != 3 {
660		t.Fatalf("did not track the total count of calls correctly. expected it to be 3, but it was %v", totalCallCount)
661	}
662
663	info := GetCallCountInfo()
664	expectedInfo := map[string]int{
665		"GET " + url: 2,
666		// Regexp match generates 2 entries:
667		"POST " + url2:  1, // the matched call
668		"POST =~gitlab": 1, // the regexp responder
669	}
670
671	if !reflect.DeepEqual(info, expectedInfo) {
672		t.Fatalf("did not correctly track the call count info. expected it to be \n %+v\n but it was \n %+v", expectedInfo, info)
673	}
674
675	ZeroCallCounters()
676
677	afterResetTotalCallCount := GetTotalCallCount()
678	if afterResetTotalCallCount != 0 {
679		t.Fatalf("did not reset the total count of calls correctly. expected it to be 0 after reset, but it was %v", afterResetTotalCallCount)
680	}
681
682	info = GetCallCountInfo()
683	expectedInfo = map[string]int{
684		"GET " + url: 0,
685		// Regexp match generates 2 entries:
686		"POST " + url2:  0, // the matched call
687		"POST =~gitlab": 0, // the regexp responder
688	}
689	if !reflect.DeepEqual(info, expectedInfo) {
690		t.Fatalf("did not correctly reset the call count info. expected it to be \n %+v\n but it was \n %+v", expectedInfo, info)
691	}
692}
693
694func TestRegisterResponderWithQuery(t *testing.T) {
695	// Just in case a panic occurs
696	defer DeactivateAndReset()
697
698	// create a custom http client w/ custom Roundtripper
699	client := &http.Client{
700		Transport: &http.Transport{
701			Proxy: http.ProxyFromEnvironment,
702			Dial: (&net.Dialer{
703				Timeout:   60 * time.Second,
704				KeepAlive: 30 * time.Second,
705			}).Dial,
706			TLSHandshakeTimeout: 60 * time.Second,
707		},
708	}
709
710	body := "hello world!"
711	testURLPath := "http://acme.test/api"
712
713	for _, test := range []struct {
714		URL     string
715		Queries []interface{}
716		URLs    []string
717	}{
718		{
719			Queries: []interface{}{
720				map[string]string{"a": "1", "b": "2"},
721				"a=1&b=2",
722				"b=2&a=1",
723				url.Values{"a": []string{"1"}, "b": []string{"2"}},
724			},
725			URLs: []string{
726				"http://acme.test/api?a=1&b=2",
727				"http://acme.test/api?b=2&a=1",
728			},
729		},
730		{
731			Queries: []interface{}{
732				url.Values{
733					"a": []string{"3", "2", "1"},
734					"b": []string{"4", "2"},
735					"c": []string{""}, // is the net/url way to record params without values
736					// Test:
737					//   u, _ := url.Parse("/hello/world?query")
738					//   fmt.Printf("%d<%s>\n", len(u.Query()["query"]), u.Query()["query"][0])
739					//   // prints "1<>"
740				},
741				"a=1&b=2&a=3&c&b=4&a=2",
742				"b=2&a=1&c=&b=4&a=2&a=3",
743				nil,
744			},
745			URLs: []string{
746				testURLPath + "?a=1&b=2&a=3&c&b=4&a=2",
747				testURLPath + "?a=1&b=2&a=3&c=&b=4&a=2",
748				testURLPath + "?b=2&a=1&c=&b=4&a=2&a=3",
749				testURLPath + "?b=2&a=1&c&b=4&a=2&a=3",
750			},
751		},
752	} {
753		for _, query := range test.Queries {
754			ActivateNonDefault(client)
755			RegisterResponderWithQuery("GET", testURLPath, query, NewStringResponder(200, body))
756
757			for _, url := range test.URLs {
758				req, err := http.NewRequest("GET", url, nil)
759				if err != nil {
760					t.Fatal(err)
761				}
762				resp, err := client.Do(req)
763				if err != nil {
764					t.Fatal(err)
765				}
766				data, err := ioutil.ReadAll(resp.Body)
767				resp.Body.Close()
768				if err != nil {
769					t.Fatal(err)
770				}
771				if string(data) != body {
772					t.Fatalf("query=%v URL=%s: %s ≠ %s", query, url, string(data), body)
773				}
774			}
775
776			DeactivateAndReset()
777		}
778	}
779}
780
781func TestRegisterResponderWithQueryPanic(t *testing.T) {
782	resp := NewStringResponder(200, "hello world!")
783
784	for _, test := range []struct {
785		Path        string
786		Query       interface{}
787		PanicPrefix string
788	}{
789		{
790			Path:        "foobar",
791			Query:       "%",
792			PanicPrefix: "RegisterResponderWithQuery bad query string: ",
793		},
794		{
795			Path:        "foobar",
796			Query:       1234,
797			PanicPrefix: "RegisterResponderWithQuery bad query type int. Only url.Values, map[string]string and string are allowed",
798		},
799		{
800			Path:        `=~regexp.*\z`,
801			Query:       "",
802			PanicPrefix: `path begins with "=~", RegisterResponder should be used instead of RegisterResponderWithQuery`,
803		},
804	} {
805		var (
806			didntPanic bool
807			panicVal   interface{}
808		)
809		func() {
810			defer func() {
811				panicVal = recover()
812			}()
813
814			RegisterResponderWithQuery("GET", test.Path, test.Query, resp)
815			didntPanic = true
816		}()
817
818		if didntPanic {
819			t.Fatalf("RegisterResponderWithQuery + query=%v did not panic", test.Query)
820		}
821
822		panicStr, ok := panicVal.(string)
823		if !ok || !strings.HasPrefix(panicStr, test.PanicPrefix) {
824			t.Fatalf(`RegisterResponderWithQuery + query=%v panic="%v" expected prefix="%v"`,
825				test.Query, panicVal, test.PanicPrefix)
826		}
827	}
828}
829
830func TestRegisterRegexpResponder(t *testing.T) {
831	Activate()
832	defer DeactivateAndReset()
833
834	rx := regexp.MustCompile("ex.mple")
835
836	RegisterRegexpResponder("GET", rx, NewStringResponder(200, "first"))
837	// Overwrite responder
838	RegisterRegexpResponder("GET", rx, NewStringResponder(200, "second"))
839
840	resp, err := http.Get(testURL)
841	if err != nil {
842		t.Fatalf("expected request %s to succeed", testURL)
843	}
844
845	data, err := ioutil.ReadAll(resp.Body)
846	if err != nil {
847		t.Fatalf("%s error: %s", testURL, err)
848	}
849
850	if string(data) != "second" {
851		t.Fatalf("expected body of %s to be 'hello world'", testURL)
852	}
853}
854
855func TestSubmatches(t *testing.T) {
856	req, err := http.NewRequest("GET", "/foo/bar", nil)
857	if err != nil {
858		t.Fatal(err)
859	}
860
861	var req2 *http.Request
862
863	t.Run("setSubmatches", func(t *testing.T) {
864		req2 = setSubmatches(req, nil)
865		if req2 != req {
866			t.Error("setSubmatches(req, nil) should return the same request")
867		}
868
869		req2 = setSubmatches(req, []string{})
870		if req2 != req {
871			t.Error("setSubmatches(req, []string{}) should return the same request")
872		}
873
874		req2 = setSubmatches(req, []string{"foo", "123", "-123", "12.3"})
875		if req2 == req {
876			t.Error("setSubmatches(req, []string{...}) should NOT return the same request")
877		}
878	})
879
880	t.Run("GetSubmatch", func(t *testing.T) {
881		_, err := GetSubmatch(req, 1)
882		if err != ErrSubmatchNotFound {
883			t.Errorf("Submatch should not be found in req: %v", err)
884		}
885
886		_, err = GetSubmatch(req2, 5)
887		if err != ErrSubmatchNotFound {
888			t.Errorf("Submatch #5 should not be found in req2: %v", err)
889		}
890
891		s, err := GetSubmatch(req2, 1)
892		if err != nil {
893			t.Errorf("GetSubmatch(req2, 1) failed: %v", err)
894		}
895		if s != "foo" {
896			t.Errorf("GetSubmatch(req2, 1) failed, got: %v, expected: foo", s)
897		}
898
899		s, err = GetSubmatch(req2, 4)
900		if err != nil {
901			t.Errorf("GetSubmatch(req2, 4) failed: %v", err)
902		}
903		if s != "12.3" {
904			t.Errorf("GetSubmatch(req2, 4) failed, got: %v, expected: 12.3", s)
905		}
906
907		s = MustGetSubmatch(req2, 4)
908		if s != "12.3" {
909			t.Errorf("GetSubmatch(req2, 4) failed, got: %v, expected: 12.3", s)
910		}
911	})
912
913	t.Run("GetSubmatchAsInt", func(t *testing.T) {
914		_, err := GetSubmatchAsInt(req, 1)
915		if err != ErrSubmatchNotFound {
916			t.Errorf("Submatch should not be found in req: %v", err)
917		}
918
919		_, err = GetSubmatchAsInt(req2, 4) // not an int
920		if err == nil || err == ErrSubmatchNotFound {
921			t.Errorf("Submatch should not be an int64: %v", err)
922		}
923
924		i, err := GetSubmatchAsInt(req2, 3)
925		if err != nil {
926			t.Errorf("GetSubmatchAsInt(req2, 3) failed: %v", err)
927		}
928		if i != -123 {
929			t.Errorf("GetSubmatchAsInt(req2, 3) failed, got: %d, expected: -123", i)
930		}
931
932		i = MustGetSubmatchAsInt(req2, 3)
933		if i != -123 {
934			t.Errorf("MustGetSubmatchAsInt(req2, 3) failed, got: %d, expected: -123", i)
935		}
936	})
937
938	t.Run("GetSubmatchAsUint", func(t *testing.T) {
939		_, err := GetSubmatchAsUint(req, 1)
940		if err != ErrSubmatchNotFound {
941			t.Errorf("Submatch should not be found in req: %v", err)
942		}
943
944		_, err = GetSubmatchAsUint(req2, 3) // not a uint
945		if err == nil || err == ErrSubmatchNotFound {
946			t.Errorf("Submatch should not be an uint64: %v", err)
947		}
948
949		u, err := GetSubmatchAsUint(req2, 2)
950		if err != nil {
951			t.Errorf("GetSubmatchAsUint(req2, 2) failed: %v", err)
952		}
953		if u != 123 {
954			t.Errorf("GetSubmatchAsUint(req2, 2) failed, got: %d, expected: 123", u)
955		}
956
957		u = MustGetSubmatchAsUint(req2, 2)
958		if u != 123 {
959			t.Errorf("MustGetSubmatchAsUint(req2, 2) failed, got: %d, expected: 123", u)
960		}
961	})
962
963	t.Run("GetSubmatchAsFloat", func(t *testing.T) {
964		_, err := GetSubmatchAsFloat(req, 1)
965		if err != ErrSubmatchNotFound {
966			t.Errorf("Submatch should not be found in req: %v", err)
967		}
968
969		_, err = GetSubmatchAsFloat(req2, 1) // not a float
970		if err == nil || err == ErrSubmatchNotFound {
971			t.Errorf("Submatch should not be an float64: %v", err)
972		}
973
974		f, err := GetSubmatchAsFloat(req2, 4)
975		if err != nil {
976			t.Errorf("GetSubmatchAsFloat(req2, 4) failed: %v", err)
977		}
978		if f != 12.3 {
979			t.Errorf("GetSubmatchAsFloat(req2, 4) failed, got: %f, expected: 12.3", f)
980		}
981
982		f = MustGetSubmatchAsFloat(req2, 4)
983		if f != 12.3 {
984			t.Errorf("MustGetSubmatchAsFloat(req2, 4) failed, got: %f, expected: 12.3", f)
985		}
986	})
987
988	t.Run("GetSubmatch* panics", func(t *testing.T) {
989		for _, test := range []struct {
990			Name        string
991			Fn          func()
992			PanicPrefix string
993		}{
994			{
995				Name:        "GetSubmatch & n < 1",
996				Fn:          func() { GetSubmatch(req, 0) }, // nolint: errcheck
997				PanicPrefix: "getting submatches starts at 1, not 0",
998			},
999			{
1000				Name:        "MustGetSubmatch",
1001				Fn:          func() { MustGetSubmatch(req, 1) },
1002				PanicPrefix: "GetSubmatch failed: " + ErrSubmatchNotFound.Error(),
1003			},
1004			{
1005				Name:        "MustGetSubmatchAsInt",
1006				Fn:          func() { MustGetSubmatchAsInt(req2, 4) }, // not an int
1007				PanicPrefix: "GetSubmatchAsInt failed: ",
1008			},
1009			{
1010				Name:        "MustGetSubmatchAsUint",
1011				Fn:          func() { MustGetSubmatchAsUint(req2, 3) }, // not a uint
1012				PanicPrefix: "GetSubmatchAsUint failed: ",
1013			},
1014			{
1015				Name:        "GetSubmatchAsFloat",
1016				Fn:          func() { MustGetSubmatchAsFloat(req2, 1) }, // not a float
1017				PanicPrefix: "GetSubmatchAsFloat failed: ",
1018			},
1019		} {
1020			var (
1021				didntPanic bool
1022				panicVal   interface{}
1023			)
1024			func() {
1025				defer func() { panicVal = recover() }()
1026				test.Fn()
1027				didntPanic = true
1028			}()
1029
1030			if didntPanic {
1031				t.Errorf("%s did not panic", test.Name)
1032			}
1033
1034			panicStr, ok := panicVal.(string)
1035			if !ok || !strings.HasPrefix(panicStr, test.PanicPrefix) {
1036				t.Errorf(`%s panic="%v" expected prefix="%v"`, test.Name, panicVal, test.PanicPrefix)
1037			}
1038		}
1039	})
1040
1041	t.Run("Full test", func(t *testing.T) {
1042		Activate()
1043		defer DeactivateAndReset()
1044
1045		var (
1046			id       uint64
1047			delta    float64
1048			deltaStr string
1049			inc      int64
1050		)
1051		RegisterResponder("GET", `=~^/id/(\d+)\?delta=(\d+(?:\.\d*)?)&inc=(-?\d+)\z`,
1052			func(req *http.Request) (*http.Response, error) {
1053				id = MustGetSubmatchAsUint(req, 1)
1054				delta = MustGetSubmatchAsFloat(req, 2)
1055				deltaStr = MustGetSubmatch(req, 2)
1056				inc = MustGetSubmatchAsInt(req, 3)
1057
1058				return NewStringResponse(http.StatusOK, "OK"), nil
1059			})
1060
1061		resp, err := http.Get("http://example.tld/id/123?delta=1.2&inc=-5")
1062		if err != nil {
1063			t.Fatal(err)
1064		}
1065		assertBody(t, resp, "OK")
1066
1067		// Check submatches
1068		if id != 123 {
1069			t.Errorf("seems MustGetSubmatchAsUint failed, got: %d, expected: 123", id)
1070		}
1071		if delta != 1.2 {
1072			t.Errorf("seems MustGetSubmatchAsFloat failed, got: %f, expected: 1.2", delta)
1073		}
1074		if deltaStr != "1.2" {
1075			t.Errorf("seems MustGetSubmatch failed, got: %v, expected: 1.2", deltaStr)
1076		}
1077		if inc != -5 {
1078			t.Errorf("seems MustGetSubmatchAsInt failed, got: %d, expected: 123", inc)
1079		}
1080	})
1081}
1082
1083func TestCheckStackTracer(t *testing.T) {
1084	req, err := http.NewRequest("GET", "http://foo.bar/", nil)
1085	if err != nil {
1086		t.Fatal(err)
1087	}
1088
1089	// no error
1090	gotErr := checkStackTracer(req, nil)
1091	if gotErr != nil {
1092		t.Errorf(`checkStackTracer(nil) should return nil, not %v`, gotErr)
1093	}
1094
1095	// Classic error
1096	err = errors.New("error")
1097	gotErr = checkStackTracer(req, err)
1098	if err != gotErr {
1099		t.Errorf(`checkStackTracer(err) should return %v, not %v`, err, gotErr)
1100	}
1101
1102	// stackTracer without customFn
1103	origErr := errors.New("foo")
1104	errTracer := stackTracer{
1105		err: origErr,
1106	}
1107	gotErr = checkStackTracer(req, errTracer)
1108	if gotErr != origErr {
1109		t.Errorf(`Returned error mismatch, expected: %v, got: %v`, origErr, gotErr)
1110	}
1111
1112	// stackTracer with nil error & without customFn
1113	errTracer = stackTracer{}
1114	gotErr = checkStackTracer(req, errTracer)
1115	if gotErr != nil {
1116		t.Errorf(`Returned error mismatch, expected: nil, got: %v`, gotErr)
1117	}
1118
1119	// stackTracer
1120	var mesg string
1121	errTracer = stackTracer{
1122		err: origErr,
1123		customFn: func(args ...interface{}) {
1124			mesg = args[0].(string)
1125		},
1126	}
1127	gotErr = checkStackTracer(req, errTracer)
1128	if !strings.HasPrefix(mesg, "foo\nCalled from ") || strings.HasSuffix(mesg, "\n") {
1129		t.Errorf(`mesg does not match "^foo\nCalled from .*[^\n]\z", it is "` + mesg + `"`)
1130	}
1131	if gotErr != origErr {
1132		t.Errorf(`Returned error mismatch, expected: %v, got: %v`, origErr, gotErr)
1133	}
1134
1135	// stackTracer with nil error but customFn
1136	mesg = ""
1137	errTracer = stackTracer{
1138		customFn: func(args ...interface{}) {
1139			mesg = args[0].(string)
1140		},
1141	}
1142	gotErr = checkStackTracer(req, errTracer)
1143	if !strings.HasPrefix(mesg, "GET http://foo.bar/\nCalled from ") || strings.HasSuffix(mesg, "\n") {
1144		t.Errorf(`mesg does not match "^foo\nCalled from .*[^\n]\z", it is "` + mesg + `"`)
1145	}
1146	if gotErr != nil {
1147		t.Errorf(`Returned error mismatch, expected: nil, got: %v`, gotErr)
1148	}
1149
1150	// Full test using Trace() Responder
1151	Activate()
1152	defer Deactivate()
1153
1154	const url = "https://foo.bar/"
1155	mesg = ""
1156	RegisterResponder("GET", url,
1157		NewStringResponder(200, "{}").
1158			Trace(func(args ...interface{}) { mesg = args[0].(string) }))
1159
1160	resp, err := http.Get(url)
1161	if err != nil {
1162		t.Fatal(err)
1163	}
1164	defer resp.Body.Close()
1165
1166	data, err := ioutil.ReadAll(resp.Body)
1167	if err != nil {
1168		t.Fatal(err)
1169	}
1170
1171	if string(data) != "{}" {
1172		t.FailNow()
1173	}
1174
1175	// Check that first frame is the net/http.Get() call
1176	if !strings.HasPrefix(mesg, "GET https://foo.bar/\nCalled from net/http.Get()\n    at ") ||
1177		strings.HasSuffix(mesg, "\n") {
1178		t.Errorf("Bad mesg: <%v>", mesg)
1179	}
1180}
1181