1// +build go1.7
2
3package nethttp
4
5import (
6	"net/http"
7	"net/url"
8
9	opentracing "github.com/opentracing/opentracing-go"
10	"github.com/opentracing/opentracing-go/ext"
11)
12
13type mwOptions struct {
14	opNameFunc    func(r *http.Request) string
15	spanFilter    func(r *http.Request) bool
16	spanObserver  func(span opentracing.Span, r *http.Request)
17	urlTagFunc    func(u *url.URL) string
18	componentName string
19}
20
21// MWOption controls the behavior of the Middleware.
22type MWOption func(*mwOptions)
23
24// OperationNameFunc returns a MWOption that uses given function f
25// to generate operation name for each server-side span.
26func OperationNameFunc(f func(r *http.Request) string) MWOption {
27	return func(options *mwOptions) {
28		options.opNameFunc = f
29	}
30}
31
32// MWComponentName returns a MWOption that sets the component name
33// for the server-side span.
34func MWComponentName(componentName string) MWOption {
35	return func(options *mwOptions) {
36		options.componentName = componentName
37	}
38}
39
40// MWSpanFilter returns a MWOption that filters requests from creating a span
41// for the server-side span.
42// Span won't be created if it returns false.
43func MWSpanFilter(f func(r *http.Request) bool) MWOption {
44	return func(options *mwOptions) {
45		options.spanFilter = f
46	}
47}
48
49// MWSpanObserver returns a MWOption that observe the span
50// for the server-side span.
51func MWSpanObserver(f func(span opentracing.Span, r *http.Request)) MWOption {
52	return func(options *mwOptions) {
53		options.spanObserver = f
54	}
55}
56
57// MWURLTagFunc returns a MWOption that uses given function f
58// to set the span's http.url tag. Can be used to change the default
59// http.url tag, eg to redact sensitive information.
60func MWURLTagFunc(f func(u *url.URL) string) MWOption {
61	return func(options *mwOptions) {
62		options.urlTagFunc = f
63	}
64}
65
66// Middleware wraps an http.Handler and traces incoming requests.
67// Additionally, it adds the span to the request's context.
68//
69// By default, the operation name of the spans is set to "HTTP {method}".
70// This can be overriden with options.
71//
72// Example:
73// 	 http.ListenAndServe("localhost:80", nethttp.Middleware(tracer, http.DefaultServeMux))
74//
75// The options allow fine tuning the behavior of the middleware.
76//
77// Example:
78//   mw := nethttp.Middleware(
79//      tracer,
80//      http.DefaultServeMux,
81//      nethttp.OperationNameFunc(func(r *http.Request) string {
82//	        return "HTTP " + r.Method + ":/api/customers"
83//      }),
84//      nethttp.MWSpanObserver(func(sp opentracing.Span, r *http.Request) {
85//			sp.SetTag("http.uri", r.URL.EscapedPath())
86//		}),
87//   )
88func Middleware(tr opentracing.Tracer, h http.Handler, options ...MWOption) http.Handler {
89	return MiddlewareFunc(tr, h.ServeHTTP, options...)
90}
91
92// MiddlewareFunc wraps an http.HandlerFunc and traces incoming requests.
93// It behaves identically to the Middleware function above.
94//
95// Example:
96//   http.ListenAndServe("localhost:80", nethttp.MiddlewareFunc(tracer, MyHandler))
97func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOption) http.HandlerFunc {
98	opts := mwOptions{
99		opNameFunc: func(r *http.Request) string {
100			return "HTTP " + r.Method
101		},
102		spanFilter:   func(r *http.Request) bool { return true },
103		spanObserver: func(span opentracing.Span, r *http.Request) {},
104		urlTagFunc: func(u *url.URL) string {
105			return u.String()
106		},
107	}
108	for _, opt := range options {
109		opt(&opts)
110	}
111	// set component name, use "net/http" if caller does not specify
112	componentName := opts.componentName
113	if componentName == "" {
114		componentName = defaultComponentName
115	}
116
117	fn := func(w http.ResponseWriter, r *http.Request) {
118		if !opts.spanFilter(r) {
119			h(w, r)
120			return
121		}
122		ctx, _ := tr.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
123		sp := tr.StartSpan(opts.opNameFunc(r), ext.RPCServerOption(ctx))
124		ext.HTTPMethod.Set(sp, r.Method)
125		ext.HTTPUrl.Set(sp, opts.urlTagFunc(r.URL))
126		ext.Component.Set(sp, componentName)
127		opts.spanObserver(sp, r)
128
129		sct := &statusCodeTracker{ResponseWriter: w}
130		r = r.WithContext(opentracing.ContextWithSpan(r.Context(), sp))
131
132		defer func() {
133			panicErr := recover()
134			didPanic := panicErr != nil
135
136			if sct.status == 0 && !didPanic {
137				// Standard behavior of http.Server is to assume status code 200 if one was not written by a handler that returned successfully.
138				// https://github.com/golang/go/blob/fca286bed3ed0e12336532cc711875ae5b3cb02a/src/net/http/server.go#L120
139				sct.status = 200
140			}
141			if sct.status > 0 {
142				ext.HTTPStatusCode.Set(sp, uint16(sct.status))
143			}
144			if sct.status >= http.StatusInternalServerError || didPanic {
145				ext.Error.Set(sp, true)
146			}
147			sp.Finish()
148
149			if didPanic {
150				panic(panicErr)
151			}
152		}()
153
154		h(sct.wrappedResponseWriter(), r)
155	}
156	return http.HandlerFunc(fn)
157}
158