1// Licensed to Elasticsearch B.V. under one or more contributor
2// license agreements. See the NOTICE file distributed with
3// this work for additional information regarding copyright
4// ownership. Elasticsearch B.V. licenses this file to you under
5// the Apache License, Version 2.0 (the "License"); you may
6// not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18package apmhttp
19
20import (
21	"context"
22	"net/http"
23
24	"go.elastic.co/apm"
25)
26
27// Wrap returns an http.Handler wrapping h, reporting each request as
28// a transaction to Elastic APM.
29//
30// By default, the returned Handler will use apm.DefaultTracer.
31// Use WithTracer to specify an alternative tracer.
32//
33// By default, the returned Handler will recover panics, reporting
34// them to the configured tracer. To override this behaviour, use
35// WithRecovery.
36func Wrap(h http.Handler, o ...ServerOption) http.Handler {
37	if h == nil {
38		panic("h == nil")
39	}
40	handler := &handler{
41		handler:        h,
42		tracer:         apm.DefaultTracer,
43		requestName:    ServerRequestName,
44		requestIgnorer: DefaultServerRequestIgnorer(),
45	}
46	for _, o := range o {
47		o(handler)
48	}
49	if handler.recovery == nil {
50		handler.recovery = NewTraceRecovery(handler.tracer)
51	}
52	return handler
53}
54
55// handler wraps an http.Handler, reporting a new transaction for each request.
56//
57// The http.Request's context will be updated with the transaction.
58type handler struct {
59	handler        http.Handler
60	tracer         *apm.Tracer
61	recovery       RecoveryFunc
62	requestName    RequestNameFunc
63	requestIgnorer RequestIgnorerFunc
64}
65
66// ServeHTTP delegates to h.Handler, tracing the transaction with
67// h.Tracer, or apm.DefaultTracer if h.Tracer is nil.
68func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
69	if !h.tracer.Active() || h.requestIgnorer(req) {
70		h.handler.ServeHTTP(w, req)
71		return
72	}
73	tx, req := StartTransaction(h.tracer, h.requestName(req), req)
74	defer tx.End()
75
76	body := h.tracer.CaptureHTTPRequestBody(req)
77	w, resp := WrapResponseWriter(w)
78	defer func() {
79		if v := recover(); v != nil {
80			if resp.StatusCode == 0 {
81				w.WriteHeader(http.StatusInternalServerError)
82			}
83			h.recovery(w, req, resp, body, tx, v)
84		}
85		SetTransactionContext(tx, req, resp, body)
86		body.Discard()
87	}()
88	h.handler.ServeHTTP(w, req)
89	if resp.StatusCode == 0 {
90		resp.StatusCode = http.StatusOK
91	}
92}
93
94// StartTransaction returns a new Transaction with name,
95// created with tracer, and taking trace context from req.
96//
97// If the transaction is not ignored, the request will be
98// returned with the transaction added to its context.
99func StartTransaction(tracer *apm.Tracer, name string, req *http.Request) (*apm.Transaction, *http.Request) {
100	var opts apm.TransactionOptions
101	if values := req.Header[TraceparentHeader]; len(values) == 1 && values[0] != "" {
102		if c, err := ParseTraceparentHeader(values[0]); err == nil {
103			opts.TraceContext = c
104		}
105	}
106	tx := tracer.StartTransactionOptions(name, "request", opts)
107	ctx := apm.ContextWithTransaction(req.Context(), tx)
108	req = RequestWithContext(ctx, req)
109	return tx, req
110}
111
112// SetTransactionContext sets tx.Result and, if the transaction is being
113// sampled, sets tx.Context with information from req, resp, and body.
114func SetTransactionContext(tx *apm.Transaction, req *http.Request, resp *Response, body *apm.BodyCapturer) {
115	tx.Result = StatusCodeResult(resp.StatusCode)
116	if !tx.Sampled() {
117		return
118	}
119	SetContext(&tx.Context, req, resp, body)
120}
121
122// SetContext sets the context for a transaction or error using information
123// from req, resp, and body.
124func SetContext(ctx *apm.Context, req *http.Request, resp *Response, body *apm.BodyCapturer) {
125	ctx.SetHTTPRequest(req)
126	ctx.SetHTTPRequestBody(body)
127	ctx.SetHTTPStatusCode(resp.StatusCode)
128	ctx.SetHTTPResponseHeaders(resp.Headers)
129}
130
131// WrapResponseWriter wraps an http.ResponseWriter and returns the wrapped
132// value along with a *Response which will be filled in when the handler
133// is called. The *Response value must not be inspected until after the
134// request has been handled, to avoid data races. If neither of the
135// ResponseWriter's Write or WriteHeader methods are called, then the
136// response's StatusCode field will be zero.
137//
138// The returned http.ResponseWriter implements http.Pusher and http.Hijacker
139// if and only if the provided http.ResponseWriter does.
140func WrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *Response) {
141	rw := responseWriter{
142		ResponseWriter: w,
143		resp: Response{
144			Headers: w.Header(),
145		},
146	}
147	h, _ := w.(http.Hijacker)
148	p, _ := w.(http.Pusher)
149	switch {
150	case h != nil && p != nil:
151		rwhp := &responseWriterHijackerPusher{
152			responseWriter: rw,
153			Hijacker:       h,
154			Pusher:         p,
155		}
156		return rwhp, &rwhp.resp
157	case h != nil:
158		rwh := &responseWriterHijacker{
159			responseWriter: rw,
160			Hijacker:       h,
161		}
162		return rwh, &rwh.resp
163	case p != nil:
164		rwp := &responseWriterPusher{
165			responseWriter: rw,
166			Pusher:         p,
167		}
168		return rwp, &rwp.resp
169	}
170	return &rw, &rw.resp
171}
172
173// Response records details of the HTTP response.
174type Response struct {
175	// StatusCode records the HTTP status code set via WriteHeader.
176	StatusCode int
177
178	// Headers holds the headers set in the ResponseWriter.
179	Headers http.Header
180}
181
182type responseWriter struct {
183	http.ResponseWriter
184	resp Response
185}
186
187// WriteHeader sets w.resp.StatusCode and calls through to the embedded
188// ResponseWriter.
189func (w *responseWriter) WriteHeader(statusCode int) {
190	w.ResponseWriter.WriteHeader(statusCode)
191	w.resp.StatusCode = statusCode
192}
193
194// Write calls through to the embedded ResponseWriter, setting
195// w.resp.StatusCode to http.StatusOK if WriteHeader has not already
196// been called.
197func (w *responseWriter) Write(data []byte) (int, error) {
198	n, err := w.ResponseWriter.Write(data)
199	if w.resp.StatusCode == 0 {
200		w.resp.StatusCode = http.StatusOK
201	}
202	return n, err
203}
204
205// CloseNotify returns w.closeNotify() if w.closeNotify is non-nil,
206// otherwise it returns nil.
207func (w *responseWriter) CloseNotify() <-chan bool {
208	if closeNotifier, ok := w.ResponseWriter.(http.CloseNotifier); ok {
209		return closeNotifier.CloseNotify()
210	}
211	return nil
212}
213
214// Flush calls w.flush() if w.flush is non-nil, otherwise
215// it does nothing.
216func (w *responseWriter) Flush() {
217	if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
218		flusher.Flush()
219	}
220}
221
222type responseWriterHijacker struct {
223	responseWriter
224	http.Hijacker
225}
226
227type responseWriterPusher struct {
228	responseWriter
229	http.Pusher
230}
231
232type responseWriterHijackerPusher struct {
233	responseWriter
234	http.Hijacker
235	http.Pusher
236}
237
238// ServerOption sets options for tracing server requests.
239type ServerOption func(*handler)
240
241// WithTracer returns a ServerOption which sets t as the tracer
242// to use for tracing server requests.
243func WithTracer(t *apm.Tracer) ServerOption {
244	if t == nil {
245		panic("t == nil")
246	}
247	return func(h *handler) {
248		h.tracer = t
249	}
250}
251
252// WithRecovery returns a ServerOption which sets r as the recovery
253// function to use for tracing server requests.
254func WithRecovery(r RecoveryFunc) ServerOption {
255	if r == nil {
256		panic("r == nil")
257	}
258	return func(h *handler) {
259		h.recovery = r
260	}
261}
262
263// RequestNameFunc is the type of a function for use in
264// WithServerRequestName.
265type RequestNameFunc func(*http.Request) string
266
267// WithServerRequestName returns a ServerOption which sets r as the function
268// to use to obtain the transaction name for the given server request.
269func WithServerRequestName(r RequestNameFunc) ServerOption {
270	if r == nil {
271		panic("r == nil")
272	}
273	return func(h *handler) {
274		h.requestName = r
275	}
276}
277
278// RequestIgnorerFunc is the type of a function for use in
279// WithServerRequestIgnorer.
280type RequestIgnorerFunc func(*http.Request) bool
281
282// WithServerRequestIgnorer returns a ServerOption which sets r as the
283// function to use to determine whether or not a server request should
284// be ignored. If r is nil, all requests will be reported.
285func WithServerRequestIgnorer(r RequestIgnorerFunc) ServerOption {
286	if r == nil {
287		r = IgnoreNone
288	}
289	return func(h *handler) {
290		h.requestIgnorer = r
291	}
292}
293
294// RequestWithContext is equivalent to req.WithContext, except that the URL
295// pointer is copied, rather than the contents.
296func RequestWithContext(ctx context.Context, req *http.Request) *http.Request {
297	url := req.URL
298	req.URL = nil
299	reqCopy := req.WithContext(ctx)
300	reqCopy.URL = url
301	req.URL = url
302	return reqCopy
303}
304