1// Licensed to Elasticsearch B.V. under one or more contributor
2// license agreements. See the NOTICE file distributed with
3// this work for additional information regarding copyright
4// ownership. Elasticsearch B.V. licenses this file to you under
5// the Apache License, Version 2.0 (the "License"); you may
6// not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18package apmrestful
19
20import (
21	"net/http"
22
23	restful "github.com/emicklei/go-restful"
24
25	"go.elastic.co/apm"
26	"go.elastic.co/apm/module/apmhttp"
27)
28
29// Filter returns a new restful.Filter for tracing requests
30// and recovering and reporting panics to Elastic APM.
31//
32// By default, the filter will use apm.DefaultTracer.
33// Use WithTracer to specify an alternative tracer.
34func Filter(o ...Option) restful.FilterFunction {
35	opts := options{
36		tracer:         apm.DefaultTracer,
37		requestIgnorer: apmhttp.DefaultServerRequestIgnorer(),
38	}
39	for _, o := range o {
40		o(&opts)
41	}
42	return (&filter{
43		tracer:         opts.tracer,
44		requestIgnorer: opts.requestIgnorer,
45	}).filter
46}
47
48type filter struct {
49	tracer         *apm.Tracer
50	requestIgnorer apmhttp.RequestIgnorerFunc
51}
52
53func (f *filter) filter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
54	if !f.tracer.Active() || f.requestIgnorer(req.Request) {
55		chain.ProcessFilter(req, resp)
56		return
57	}
58
59	var name string
60	if routePath := massageRoutePath(req.SelectedRoutePath()); routePath != "" {
61		name = req.Request.Method + " " + massageRoutePath(req.SelectedRoutePath())
62	} else {
63		name = apmhttp.UnknownRouteRequestName(req.Request)
64	}
65	tx, httpRequest := apmhttp.StartTransaction(f.tracer, name, req.Request)
66	defer tx.End()
67	req.Request = httpRequest
68	body := f.tracer.CaptureHTTPRequestBody(httpRequest)
69
70	const frameworkName = "go-restful"
71	const frameworkVersion = ""
72	if tx.Sampled() {
73		tx.Context.SetFramework(frameworkName, frameworkVersion)
74	}
75
76	origResponseWriter := resp.ResponseWriter
77	w, httpResp := apmhttp.WrapResponseWriter(origResponseWriter)
78	resp.ResponseWriter = w
79	defer func() {
80		resp.ResponseWriter = origResponseWriter
81		if v := recover(); v != nil {
82			if httpResp.StatusCode == 0 {
83				w.WriteHeader(http.StatusInternalServerError)
84			}
85			e := f.tracer.Recovered(v)
86			e.SetTransaction(tx)
87			apmhttp.SetContext(&e.Context, req.Request, httpResp, body)
88			e.Context.SetFramework(frameworkName, frameworkVersion)
89			e.Send()
90		}
91		apmhttp.SetTransactionContext(tx, req.Request, httpResp, body)
92		body.Discard()
93	}()
94	chain.ProcessFilter(req, resp)
95	if httpResp.StatusCode == 0 {
96		httpResp.StatusCode = http.StatusOK
97	}
98}
99
100type options struct {
101	tracer         *apm.Tracer
102	requestIgnorer apmhttp.RequestIgnorerFunc
103}
104
105// Option sets options for tracing.
106type Option func(*options)
107
108// WithTracer returns an Option which sets t as the tracer
109// to use for tracing server requests.
110func WithTracer(t *apm.Tracer) Option {
111	if t == nil {
112		panic("t == nil")
113	}
114	return func(o *options) {
115		o.tracer = t
116	}
117}
118