1// +build go1.7
2
3package nethttp
4
5import (
6	"context"
7	"io"
8	"net/http"
9	"net/http/httptrace"
10	"net/url"
11
12	"github.com/opentracing/opentracing-go"
13	"github.com/opentracing/opentracing-go/ext"
14	"github.com/opentracing/opentracing-go/log"
15)
16
17type contextKey int
18
19const (
20	keyTracer contextKey = iota
21)
22
23const defaultComponentName = "net/http"
24
25// Transport wraps a RoundTripper. If a request is being traced with
26// Tracer, Transport will inject the current span into the headers,
27// and set HTTP related tags on the span.
28type Transport struct {
29	// The actual RoundTripper to use for the request. A nil
30	// RoundTripper defaults to http.DefaultTransport.
31	http.RoundTripper
32}
33
34type clientOptions struct {
35	operationName            string
36	componentName            string
37	urlTagFunc         func(u *url.URL) string
38	disableClientTrace       bool
39	disableInjectSpanContext bool
40	spanObserver             func(span opentracing.Span, r *http.Request)
41}
42
43// ClientOption contols the behavior of TraceRequest.
44type ClientOption func(*clientOptions)
45
46// OperationName returns a ClientOption that sets the operation
47// name for the client-side span.
48func OperationName(operationName string) ClientOption {
49	return func(options *clientOptions) {
50		options.operationName = operationName
51	}
52}
53
54// URLTagFunc returns a ClientOption that uses given function f
55// to set the span's http.url tag. Can be used to change the default
56// http.url tag, eg to redact sensitive information.
57func URLTagFunc(f func(u *url.URL) string) ClientOption {
58	return func(options *clientOptions) {
59		options.urlTagFunc = f
60	}
61}
62
63// ComponentName returns a ClientOption that sets the component
64// name for the client-side span.
65func ComponentName(componentName string) ClientOption {
66	return func(options *clientOptions) {
67		options.componentName = componentName
68	}
69}
70
71// ClientTrace returns a ClientOption that turns on or off
72// extra instrumentation via httptrace.WithClientTrace.
73func ClientTrace(enabled bool) ClientOption {
74	return func(options *clientOptions) {
75		options.disableClientTrace = !enabled
76	}
77}
78
79// InjectSpanContext returns a ClientOption that turns on or off
80// injection of the Span context in the request HTTP headers.
81// If this option is not used, the default behaviour is to
82// inject the span context.
83func InjectSpanContext(enabled bool) ClientOption {
84	return func(options *clientOptions) {
85		options.disableInjectSpanContext = !enabled
86	}
87}
88
89// ClientSpanObserver returns a ClientOption that observes the span
90// for the client-side span.
91func ClientSpanObserver(f func(span opentracing.Span, r *http.Request)) ClientOption {
92	return func(options *clientOptions) {
93		options.spanObserver = f
94	}
95}
96
97// TraceRequest adds a ClientTracer to req, tracing the request and
98// all requests caused due to redirects. When tracing requests this
99// way you must also use Transport.
100//
101// Example:
102//
103// 	func AskGoogle(ctx context.Context) error {
104// 		client := &http.Client{Transport: &nethttp.Transport{}}
105// 		req, err := http.NewRequest("GET", "http://google.com", nil)
106// 		if err != nil {
107// 			return err
108// 		}
109// 		req = req.WithContext(ctx) // extend existing trace, if any
110//
111// 		req, ht := nethttp.TraceRequest(tracer, req)
112// 		defer ht.Finish()
113//
114// 		res, err := client.Do(req)
115// 		if err != nil {
116// 			return err
117// 		}
118// 		res.Body.Close()
119// 		return nil
120// 	}
121func TraceRequest(tr opentracing.Tracer, req *http.Request, options ...ClientOption) (*http.Request, *Tracer) {
122	opts := &clientOptions{
123		urlTagFunc: func(u *url.URL) string {
124			return u.String()
125		},
126		spanObserver: func(_ opentracing.Span, _ *http.Request) {},
127	}
128	for _, opt := range options {
129		opt(opts)
130	}
131	ht := &Tracer{tr: tr, opts: opts}
132	ctx := req.Context()
133	if !opts.disableClientTrace {
134		ctx = httptrace.WithClientTrace(ctx, ht.clientTrace())
135	}
136	req = req.WithContext(context.WithValue(ctx, keyTracer, ht))
137	return req, ht
138}
139
140type closeTracker struct {
141	io.ReadCloser
142	sp opentracing.Span
143}
144
145func (c closeTracker) Close() error {
146	err := c.ReadCloser.Close()
147	c.sp.LogFields(log.String("event", "ClosedBody"))
148	c.sp.Finish()
149	return err
150}
151
152// TracerFromRequest retrieves the Tracer from the request. If the request does
153// not have a Tracer it will return nil.
154func TracerFromRequest(req *http.Request) *Tracer {
155	tr, ok := req.Context().Value(keyTracer).(*Tracer)
156	if !ok {
157		return nil
158	}
159	return tr
160}
161
162// RoundTrip implements the RoundTripper interface.
163func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
164	rt := t.RoundTripper
165	if rt == nil {
166		rt = http.DefaultTransport
167	}
168	tracer := TracerFromRequest(req)
169	if tracer == nil {
170		return rt.RoundTrip(req)
171	}
172
173	tracer.start(req)
174
175	ext.HTTPMethod.Set(tracer.sp, req.Method)
176	ext.HTTPUrl.Set(tracer.sp, tracer.opts.urlTagFunc(req.URL))
177	tracer.opts.spanObserver(tracer.sp, req)
178
179	if !tracer.opts.disableInjectSpanContext {
180		carrier := opentracing.HTTPHeadersCarrier(req.Header)
181		tracer.sp.Tracer().Inject(tracer.sp.Context(), opentracing.HTTPHeaders, carrier)
182	}
183
184	resp, err := rt.RoundTrip(req)
185
186	if err != nil {
187		tracer.sp.Finish()
188		return resp, err
189	}
190	ext.HTTPStatusCode.Set(tracer.sp, uint16(resp.StatusCode))
191	if resp.StatusCode >= http.StatusInternalServerError {
192		ext.Error.Set(tracer.sp, true)
193	}
194	if req.Method == "HEAD" {
195		tracer.sp.Finish()
196	} else {
197		resp.Body = closeTracker{resp.Body, tracer.sp}
198	}
199	return resp, nil
200}
201
202// Tracer holds tracing details for one HTTP request.
203type Tracer struct {
204	tr   opentracing.Tracer
205	root opentracing.Span
206	sp   opentracing.Span
207	opts *clientOptions
208}
209
210func (h *Tracer) start(req *http.Request) opentracing.Span {
211	if h.root == nil {
212		parent := opentracing.SpanFromContext(req.Context())
213		var spanctx opentracing.SpanContext
214		if parent != nil {
215			spanctx = parent.Context()
216		}
217		operationName := h.opts.operationName
218		if operationName == "" {
219			operationName = "HTTP Client"
220		}
221		root := h.tr.StartSpan(operationName, opentracing.ChildOf(spanctx))
222		h.root = root
223	}
224
225	ctx := h.root.Context()
226	h.sp = h.tr.StartSpan("HTTP "+req.Method, opentracing.ChildOf(ctx))
227	ext.SpanKindRPCClient.Set(h.sp)
228
229	componentName := h.opts.componentName
230	if componentName == "" {
231		componentName = defaultComponentName
232	}
233	ext.Component.Set(h.sp, componentName)
234
235	return h.sp
236}
237
238// Finish finishes the span of the traced request.
239func (h *Tracer) Finish() {
240	if h.root != nil {
241		h.root.Finish()
242	}
243}
244
245// Span returns the root span of the traced request. This function
246// should only be called after the request has been executed.
247func (h *Tracer) Span() opentracing.Span {
248	return h.root
249}
250
251func (h *Tracer) clientTrace() *httptrace.ClientTrace {
252	return &httptrace.ClientTrace{
253		GetConn:              h.getConn,
254		GotConn:              h.gotConn,
255		PutIdleConn:          h.putIdleConn,
256		GotFirstResponseByte: h.gotFirstResponseByte,
257		Got100Continue:       h.got100Continue,
258		DNSStart:             h.dnsStart,
259		DNSDone:              h.dnsDone,
260		ConnectStart:         h.connectStart,
261		ConnectDone:          h.connectDone,
262		WroteHeaders:         h.wroteHeaders,
263		Wait100Continue:      h.wait100Continue,
264		WroteRequest:         h.wroteRequest,
265	}
266}
267
268func (h *Tracer) getConn(hostPort string) {
269	ext.HTTPUrl.Set(h.sp, hostPort)
270	h.sp.LogFields(log.String("event", "GetConn"))
271}
272
273func (h *Tracer) gotConn(info httptrace.GotConnInfo) {
274	h.sp.SetTag("net/http.reused", info.Reused)
275	h.sp.SetTag("net/http.was_idle", info.WasIdle)
276	h.sp.LogFields(log.String("event", "GotConn"))
277}
278
279func (h *Tracer) putIdleConn(error) {
280	h.sp.LogFields(log.String("event", "PutIdleConn"))
281}
282
283func (h *Tracer) gotFirstResponseByte() {
284	h.sp.LogFields(log.String("event", "GotFirstResponseByte"))
285}
286
287func (h *Tracer) got100Continue() {
288	h.sp.LogFields(log.String("event", "Got100Continue"))
289}
290
291func (h *Tracer) dnsStart(info httptrace.DNSStartInfo) {
292	h.sp.LogFields(
293		log.String("event", "DNSStart"),
294		log.String("host", info.Host),
295	)
296}
297
298func (h *Tracer) dnsDone(info httptrace.DNSDoneInfo) {
299	fields := []log.Field{log.String("event", "DNSDone")}
300	for _, addr := range info.Addrs {
301		fields = append(fields, log.String("addr", addr.String()))
302	}
303	if info.Err != nil {
304		fields = append(fields, log.Error(info.Err))
305	}
306	h.sp.LogFields(fields...)
307}
308
309func (h *Tracer) connectStart(network, addr string) {
310	h.sp.LogFields(
311		log.String("event", "ConnectStart"),
312		log.String("network", network),
313		log.String("addr", addr),
314	)
315}
316
317func (h *Tracer) connectDone(network, addr string, err error) {
318	if err != nil {
319		h.sp.LogFields(
320			log.String("message", "ConnectDone"),
321			log.String("network", network),
322			log.String("addr", addr),
323			log.String("event", "error"),
324			log.Error(err),
325		)
326	} else {
327		h.sp.LogFields(
328			log.String("event", "ConnectDone"),
329			log.String("network", network),
330			log.String("addr", addr),
331		)
332	}
333}
334
335func (h *Tracer) wroteHeaders() {
336	h.sp.LogFields(log.String("event", "WroteHeaders"))
337}
338
339func (h *Tracer) wait100Continue() {
340	h.sp.LogFields(log.String("event", "Wait100Continue"))
341}
342
343func (h *Tracer) wroteRequest(info httptrace.WroteRequestInfo) {
344	if info.Err != nil {
345		h.sp.LogFields(
346			log.String("message", "WroteRequest"),
347			log.String("event", "error"),
348			log.Error(info.Err),
349		)
350		ext.Error.Set(h.sp, true)
351	} else {
352		h.sp.LogFields(log.String("event", "WroteRequest"))
353	}
354}
355