1package nethttp
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"net/url"
7	"reflect"
8	"strings"
9	"testing"
10
11	"github.com/opentracing/opentracing-go"
12	"github.com/opentracing/opentracing-go/ext"
13	"github.com/opentracing/opentracing-go/mocktracer"
14)
15
16func TestOperationNameOption(t *testing.T) {
17	mux := http.NewServeMux()
18	mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {})
19
20	fn := func(r *http.Request) string {
21		return "HTTP " + r.Method + ": /root"
22	}
23
24	tests := []struct {
25		options []MWOption
26		opName  string
27	}{
28		{nil, "HTTP GET"},
29		{[]MWOption{OperationNameFunc(fn)}, "HTTP GET: /root"},
30	}
31
32	for _, tt := range tests {
33		testCase := tt
34		t.Run(testCase.opName, func(t *testing.T) {
35			tr := &mocktracer.MockTracer{}
36			mw := Middleware(tr, mux, testCase.options...)
37			srv := httptest.NewServer(mw)
38			defer srv.Close()
39
40			_, err := http.Get(srv.URL)
41			if err != nil {
42				t.Fatalf("server returned error: %v", err)
43			}
44
45			spans := tr.FinishedSpans()
46			if got, want := len(spans), 1; got != want {
47				t.Fatalf("got %d spans, expected %d", got, want)
48			}
49
50			if got, want := spans[0].OperationName, testCase.opName; got != want {
51				t.Fatalf("got %s operation name, expected %s", got, want)
52			}
53		})
54	}
55}
56
57func TestSpanObserverOption(t *testing.T) {
58	mux := http.NewServeMux()
59	mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {})
60
61	opNamefn := func(r *http.Request) string {
62		return "HTTP " + r.Method + ": /root"
63	}
64	spanObserverfn := func(sp opentracing.Span, r *http.Request) {
65		sp.SetTag("http.uri", r.URL.EscapedPath())
66	}
67	wantTags := map[string]interface{}{"http.uri": "/"}
68
69	tests := []struct {
70		options []MWOption
71		opName  string
72		Tags    map[string]interface{}
73	}{
74		{nil, "HTTP GET", nil},
75		{[]MWOption{OperationNameFunc(opNamefn)}, "HTTP GET: /root", nil},
76		{[]MWOption{MWSpanObserver(spanObserverfn)}, "HTTP GET", wantTags},
77		{[]MWOption{OperationNameFunc(opNamefn), MWSpanObserver(spanObserverfn)}, "HTTP GET: /root", wantTags},
78	}
79
80	for _, tt := range tests {
81		testCase := tt
82		t.Run(testCase.opName, func(t *testing.T) {
83			tr := &mocktracer.MockTracer{}
84			mw := Middleware(tr, mux, testCase.options...)
85			srv := httptest.NewServer(mw)
86			defer srv.Close()
87
88			_, err := http.Get(srv.URL)
89			if err != nil {
90				t.Fatalf("server returned error: %v", err)
91			}
92
93			spans := tr.FinishedSpans()
94			if got, want := len(spans), 1; got != want {
95				t.Fatalf("got %d spans, expected %d", got, want)
96			}
97
98			if got, want := spans[0].OperationName, testCase.opName; got != want {
99				t.Fatalf("got %s operation name, expected %s", got, want)
100			}
101
102			defaultLength := 5
103			if len(spans[0].Tags()) != len(testCase.Tags)+defaultLength {
104				t.Fatalf("got tag length %d, expected %d", len(spans[0].Tags()), len(testCase.Tags))
105			}
106			for k, v := range testCase.Tags {
107				if tag := spans[0].Tag(k); v != tag.(string) {
108					t.Fatalf("got %v tag, expected %v", tag, v)
109				}
110			}
111		})
112	}
113}
114
115func TestSpanFilterOption(t *testing.T) {
116	mux := http.NewServeMux()
117	mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {})
118
119	spanFilterfn := func(r *http.Request) bool {
120		return !strings.HasPrefix(r.Header.Get("User-Agent"), "kube-probe")
121	}
122	noAgentReq, _ := http.NewRequest("GET", "/root", nil)
123	noAgentReq.Header.Del("User-Agent")
124	probeReq1, _ := http.NewRequest("GET", "/root", nil)
125	probeReq1.Header.Add("User-Agent", "kube-probe/1.12")
126	probeReq2, _ := http.NewRequest("GET", "/root", nil)
127	probeReq2.Header.Add("User-Agent", "kube-probe/9.99")
128	postmanReq, _ := http.NewRequest("GET", "/root", nil)
129	postmanReq.Header.Add("User-Agent", "PostmanRuntime/7.3.0")
130	tests := []struct {
131		options            []MWOption
132		request            *http.Request
133		opName             string
134		ExpectToCreateSpan bool
135	}{
136		{nil, noAgentReq, "No filter", true},
137		{[]MWOption{MWSpanFilter(spanFilterfn)}, noAgentReq, "No User-Agent", true},
138		{[]MWOption{MWSpanFilter(spanFilterfn)}, probeReq1, "User-Agent: kube-probe/1.12", false},
139		{[]MWOption{MWSpanFilter(spanFilterfn)}, probeReq2, "User-Agent: kube-probe/9.99", false},
140		{[]MWOption{MWSpanFilter(spanFilterfn)}, postmanReq, "User-Agent: PostmanRuntime/7.3.0", true},
141	}
142
143	for _, tt := range tests {
144		testCase := tt
145		t.Run(testCase.opName, func(t *testing.T) {
146			tr := &mocktracer.MockTracer{}
147			mw := Middleware(tr, mux, testCase.options...)
148			srv := httptest.NewServer(mw)
149			defer srv.Close()
150
151			client := &http.Client{}
152			testCase.request.URL, _ = url.Parse(srv.URL)
153			_, err := client.Do(testCase.request)
154			if err != nil {
155				t.Fatalf("server returned error: %v", err)
156			}
157
158			spans := tr.FinishedSpans()
159			if spanCreated := len(spans) == 1; spanCreated != testCase.ExpectToCreateSpan {
160				t.Fatalf("spanCreated %t, ExpectToCreateSpan %t", spanCreated, testCase.ExpectToCreateSpan)
161			}
162		})
163	}
164}
165
166func TestURLTagOption(t *testing.T) {
167	mux := http.NewServeMux()
168	mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {})
169
170	fn := func(u *url.URL) string {
171		// Log path only (no query parameters etc)
172		return u.Path
173	}
174
175	tests := []struct {
176		options []MWOption
177		url     string
178		tag     string
179	}{
180		{[]MWOption{}, "/root?token=123", "/root?token=123"},
181		{[]MWOption{MWURLTagFunc(fn)}, "/root?token=123", "/root"},
182	}
183
184	for _, tt := range tests {
185		testCase := tt
186		t.Run(testCase.tag, func(t *testing.T) {
187			tr := &mocktracer.MockTracer{}
188			mw := Middleware(tr, mux, testCase.options...)
189			srv := httptest.NewServer(mw)
190			defer srv.Close()
191
192			_, err := http.Get(srv.URL + testCase.url)
193			if err != nil {
194				t.Fatalf("server returned error: %v", err)
195			}
196
197			spans := tr.FinishedSpans()
198			if got, want := len(spans), 1; got != want {
199				t.Fatalf("got %d spans, expected %d", got, want)
200			}
201
202			tag := spans[0].Tags()["http.url"]
203			if got, want := tag, testCase.tag; got != want {
204				t.Fatalf("got %s tag name, expected %s", got, want)
205			}
206		})
207	}
208}
209
210func TestSpanErrorAndStatusCode(t *testing.T) {
211	mux := http.NewServeMux()
212	mux.HandleFunc("/header-and-body", func(w http.ResponseWriter, r *http.Request) {
213		w.WriteHeader(200)
214		w.Write([]byte("OK"))
215	})
216	mux.HandleFunc("/body-only", func(w http.ResponseWriter, r *http.Request) {
217		w.Write([]byte("OK"))
218	})
219	mux.HandleFunc("/header-only", func(w http.ResponseWriter, r *http.Request) {
220		w.WriteHeader(200)
221	})
222	mux.HandleFunc("/empty", func(w http.ResponseWriter, r *http.Request) {
223		// no status header
224	})
225	mux.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) {
226		w.WriteHeader(500)
227	})
228
229	expStatusOK := map[string]interface{}{"http.status_code": uint16(200)}
230
231	tests := []struct {
232		url  string
233		tags map[string]interface{}
234	}{
235		{url: "/header-and-body", tags: expStatusOK},
236		{url: "/body-only", tags: expStatusOK},
237		{url: "/header-only", tags: expStatusOK},
238		{url: "/empty", tags: expStatusOK},
239		{url: "/error", tags: map[string]interface{}{"http.status_code": uint16(500), string(ext.Error): true}},
240	}
241
242	for _, tt := range tests {
243		testCase := tt
244		t.Run(testCase.url, func(t *testing.T) {
245			tr := &mocktracer.MockTracer{}
246			mw := Middleware(tr, mux)
247			srv := httptest.NewServer(mw)
248			defer srv.Close()
249
250			_, err := http.Get(srv.URL + testCase.url)
251			if err != nil {
252				t.Fatalf("server returned error: %v", err)
253			}
254
255			spans := tr.FinishedSpans()
256			if got, want := len(spans), 1; got != want {
257				t.Fatalf("got %d spans, expected %d", got, want)
258			}
259
260			for k, v := range testCase.tags {
261				if tag := spans[0].Tag(k); !reflect.DeepEqual(tag, v) {
262					t.Fatalf("tag %s: got %v, expected %v", k, tag, v)
263				}
264			}
265		})
266	}
267}
268
269func BenchmarkStatusCodeTrackingOverhead(b *testing.B) {
270	mux := http.NewServeMux()
271	mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {})
272	tr := &mocktracer.MockTracer{}
273	mw := Middleware(tr, mux)
274	srv := httptest.NewServer(mw)
275	defer srv.Close()
276
277	b.RunParallel(func(pb *testing.PB) {
278		for pb.Next() {
279			resp, err := http.Get(srv.URL)
280			if err != nil {
281				b.Fatalf("server returned error: %v", err)
282			}
283			err = resp.Body.Close()
284			if err != nil {
285				b.Fatalf("failed to close response: %v", err)
286			}
287		}
288	})
289}
290
291func TestMiddlewareHandlerPanic(t *testing.T) {
292	tests := []struct {
293		handler func(w http.ResponseWriter, r *http.Request)
294		status  uint16
295		isError bool
296		name    string
297	}{
298		{
299			name: "OK",
300			handler: func(w http.ResponseWriter, r *http.Request) {
301				w.Write([]byte("OK"))
302			},
303			status:  200,
304			isError: false,
305		},
306		{
307			name: "Panic",
308			handler: func(w http.ResponseWriter, r *http.Request) {
309				panic("panic test")
310			},
311			status:  0,
312			isError: true,
313		},
314		{
315			name: "InternalServerError",
316			handler: func(w http.ResponseWriter, r *http.Request) {
317				w.WriteHeader(http.StatusInternalServerError)
318				w.Write([]byte("InternalServerError"))
319			},
320			status:  500,
321			isError: true,
322		},
323	}
324
325	for _, testCase := range tests {
326		t.Run(testCase.name, func(t *testing.T) {
327			mux := http.NewServeMux()
328			mux.HandleFunc("/root", testCase.handler)
329			tr := &mocktracer.MockTracer{}
330			srv := httptest.NewServer(MiddlewareFunc(tr, mux.ServeHTTP))
331			defer srv.Close()
332
333			_, err := http.Get(srv.URL + "/root")
334			if err != nil {
335				t.Logf("server returned error: %v", err)
336			}
337
338			spans := tr.FinishedSpans()
339			if got, want := len(spans), 1; got != want {
340				t.Fatalf("got %d spans, expected %d", got, want)
341			}
342			actualStatus := spans[0].Tag(string(ext.HTTPStatusCode))
343			if testCase.status > 0 && !reflect.DeepEqual(testCase.status, actualStatus) {
344				t.Fatalf("got status code %v, expected %d", actualStatus, testCase.status)
345			}
346			actualErr, ok := spans[0].Tag(string(ext.Error)).(bool)
347			if !ok {
348				actualErr = false
349			}
350			if testCase.isError != actualErr {
351				t.Fatalf("got span error %v, expected %v", actualErr, testCase.isError)
352			}
353		})
354	}
355}
356