1package nethttp
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"net/url"
7	"testing"
8
9	opentracing "github.com/opentracing/opentracing-go"
10	"github.com/opentracing/opentracing-go/ext"
11	"github.com/opentracing/opentracing-go/mocktracer"
12)
13
14func makeRequest(t *testing.T, url string, options ...ClientOption) []*mocktracer.MockSpan {
15	tr := &mocktracer.MockTracer{}
16	span := tr.StartSpan("toplevel")
17	client := &http.Client{Transport: &Transport{}}
18	req, err := http.NewRequest("GET", url, nil)
19	if err != nil {
20		t.Fatal(err)
21	}
22	req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
23	req, ht := TraceRequest(tr, req, options...)
24	resp, err := client.Do(req)
25	if err != nil {
26		t.Fatal(err)
27	}
28	_ = resp.Body.Close()
29	ht.Finish()
30	span.Finish()
31
32	return tr.FinishedSpans()
33}
34
35func TestClientTrace(t *testing.T) {
36	mux := http.NewServeMux()
37	mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {})
38	mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
39		http.Redirect(w, r, "/ok", http.StatusTemporaryRedirect)
40	})
41	mux.HandleFunc("/fail", func(w http.ResponseWriter, r *http.Request) {
42		http.Error(w, "failure", http.StatusInternalServerError)
43	})
44	srv := httptest.NewServer(mux)
45	defer srv.Close()
46
47	helloWorldObserver := func(s opentracing.Span, r *http.Request) {
48		s.SetTag("hello", "world")
49	}
50
51	tests := []struct {
52		url          string
53		num          int
54		opts         []ClientOption
55		opName       string
56		expectedTags map[string]interface{}
57	}{
58		{url: "/ok", num: 3, opts: nil, opName: "HTTP Client"},
59		{url: "/redirect", num: 4, opts: []ClientOption{OperationName("client-span")}, opName: "client-span"},
60		{url: "/fail", num: 3, opts: nil, opName: "HTTP Client", expectedTags: makeTags(string(ext.Error), true)},
61		{url: "/ok", num: 3, opts: []ClientOption{ClientSpanObserver(helloWorldObserver)}, opName: "HTTP Client", expectedTags: makeTags("hello", "world")},
62	}
63
64	for _, tt := range tests {
65		t.Log(tt.opName)
66		spans := makeRequest(t, srv.URL+tt.url, tt.opts...)
67		if got, want := len(spans), tt.num; got != want {
68			t.Fatalf("got %d spans, expected %d", got, want)
69		}
70		var rootSpan *mocktracer.MockSpan
71		for _, span := range spans {
72			if span.ParentID == 0 {
73				rootSpan = span
74				break
75			}
76		}
77		if rootSpan == nil {
78			t.Fatal("cannot find root span with ParentID==0")
79		}
80
81		foundClientSpan := false
82		for _, span := range spans {
83			if span.ParentID == rootSpan.SpanContext.SpanID {
84				foundClientSpan = true
85				if got, want := span.OperationName, tt.opName; got != want {
86					t.Fatalf("got %s operation name, expected %s", got, want)
87				}
88			}
89			if span.OperationName == "HTTP GET" {
90				logs := span.Logs()
91				if len(logs) < 6 {
92					t.Fatalf("got %d, expected at least %d log events", len(logs), 6)
93				}
94
95				key := logs[0].Fields[0].Key
96				if key != "event" {
97					t.Fatalf("got %s, expected %s", key, "event")
98				}
99				v := logs[0].Fields[0].ValueString
100				if v != "GetConn" {
101					t.Fatalf("got %s, expected %s", v, "GetConn")
102				}
103
104				for k, expected := range tt.expectedTags {
105					result := span.Tag(k)
106					if expected != result {
107						t.Fatalf("got %v, expected %v, for key %s", result, expected, k)
108					}
109				}
110			}
111		}
112		if !foundClientSpan {
113			t.Fatal("cannot find client span")
114		}
115	}
116}
117
118func TestTracerFromRequest(t *testing.T) {
119	req, err := http.NewRequest("GET", "foobar", nil)
120	if err != nil {
121		t.Fatal(err)
122	}
123
124	ht := TracerFromRequest(req)
125	if ht != nil {
126		t.Fatal("request should not have a tracer yet")
127	}
128
129	tr := &mocktracer.MockTracer{}
130	req, expected := TraceRequest(tr, req)
131
132	ht = TracerFromRequest(req)
133	if ht != expected {
134		t.Fatalf("got %v, expected %v", ht, expected)
135	}
136}
137
138func TestInjectSpanContext(t *testing.T) {
139	tests := []struct {
140		name                     string
141		expectContextPropagation bool
142		opts                     []ClientOption
143	}{
144		{name: "Default", expectContextPropagation: true, opts: nil},
145		{name: "True", expectContextPropagation: true, opts: []ClientOption{InjectSpanContext(true)}},
146		{name: "False", expectContextPropagation: false, opts: []ClientOption{InjectSpanContext(false)}},
147	}
148
149	for _, tt := range tests {
150		t.Run(tt.name, func(t *testing.T) {
151			var handlerCalled bool
152			srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
153				handlerCalled = true
154				srvTr := mocktracer.New()
155				ctx, err := srvTr.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
156
157				if err != nil && tt.expectContextPropagation {
158					t.Fatal(err)
159				}
160
161				if tt.expectContextPropagation {
162					if err != nil || ctx == nil {
163						t.Fatal("expected propagation but unable to extract")
164					}
165				} else {
166					// Expect "opentracing: SpanContext not found in Extract carrier" when not injected
167					// Can't check ctx directly, because it gets set to emptyContext
168					if err == nil {
169						t.Fatal("unexpected propagation")
170					}
171				}
172			}))
173
174			tr := mocktracer.New()
175			span := tr.StartSpan("root")
176
177			req, err := http.NewRequest("GET", srv.URL, nil)
178			if err != nil {
179				t.Fatal(err)
180			}
181			req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
182
183			req, ht := TraceRequest(tr, req, tt.opts...)
184
185			client := &http.Client{Transport: &Transport{}}
186			resp, err := client.Do(req)
187			if err != nil {
188				t.Fatal(err)
189			}
190			_ = resp.Body.Close()
191
192			ht.Finish()
193			span.Finish()
194
195			srv.Close()
196
197			if !handlerCalled {
198				t.Fatal("server handler never called")
199			}
200		})
201	}
202}
203
204func makeTags(keyVals ...interface{}) map[string]interface{} {
205	result := make(map[string]interface{}, len(keyVals)/2)
206	for i := 0; i < len(keyVals)-1; i += 2 {
207		key := keyVals[i].(string)
208		result[key] = keyVals[i+1]
209	}
210	return result
211}
212
213func TestClientCustomURL(t *testing.T) {
214	mux := http.NewServeMux()
215	mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {})
216	srv := httptest.NewServer(mux)
217	defer srv.Close()
218
219	fn := func(u *url.URL) string {
220		// Simulate redacting token
221		return srv.URL + u.Path + "?token=*"
222	}
223
224	tests := []struct {
225		opts []ClientOption
226		url  string
227		tag  string
228	}{
229		// These first cases fail early
230		{[]ClientOption{}, "/ok?token=a", srv.Listener.Addr().String()},
231		{[]ClientOption{URLTagFunc(fn)}, "/ok?token=c", srv.Listener.Addr().String()},
232		// Disable ClientTrace to fire RoundTrip
233		{[]ClientOption{ClientTrace(false)}, "/ok?token=b", srv.URL + "/ok?token=b"},
234		{[]ClientOption{ClientTrace(false), URLTagFunc(fn)}, "/ok?token=c", srv.URL + "/ok?token=*"},
235	}
236
237	for _, tt := range tests {
238		var clientSpan *mocktracer.MockSpan
239
240		spans := makeRequest(t, srv.URL+tt.url, tt.opts...)
241		for _, span := range spans {
242			if span.OperationName == "HTTP GET" {
243				clientSpan = span
244				break
245			}
246		}
247		if clientSpan == nil {
248			t.Fatal("cannot find client span")
249		}
250		tag := clientSpan.Tags()["http.url"]
251		if got, want := tag, tt.tag; got != want {
252			t.Fatalf("got %s tag name, expected %s", got, want)
253		}
254	}
255}