1// Copyright (c) The Thanos Authors.
2// Licensed under the Apache License 2.0.
3
4package http
5
6import (
7	"fmt"
8	"net/http"
9	"strings"
10	"time"
11
12	"github.com/opentracing/opentracing-go"
13	"github.com/prometheus/client_golang/prometheus"
14	"github.com/prometheus/client_golang/prometheus/promauto"
15	"github.com/prometheus/client_golang/prometheus/promhttp"
16	"github.com/uber/jaeger-client-go"
17)
18
19// InstrumentationMiddleware holds necessary metrics to instrument an http.Server
20// and provides necessary behaviors.
21type InstrumentationMiddleware interface {
22	// NewHandler wraps the given HTTP handler for instrumentation.
23	NewHandler(handlerName string, handler http.Handler) http.HandlerFunc
24}
25
26type nopInstrumentationMiddleware struct{}
27
28func (ins nopInstrumentationMiddleware) NewHandler(handlerName string, handler http.Handler) http.HandlerFunc {
29	return func(w http.ResponseWriter, r *http.Request) {
30		handler.ServeHTTP(w, r)
31	}
32}
33
34// NewNopInstrumentationMiddleware provides a InstrumentationMiddleware which does nothing.
35func NewNopInstrumentationMiddleware() InstrumentationMiddleware {
36	return nopInstrumentationMiddleware{}
37}
38
39type defaultInstrumentationMiddleware struct {
40	requestDuration *prometheus.HistogramVec
41	requestSize     *prometheus.SummaryVec
42	requestsTotal   *prometheus.CounterVec
43	responseSize    *prometheus.SummaryVec
44}
45
46// NewInstrumentationMiddleware provides default InstrumentationMiddleware.
47// Passing nil as buckets uses the default buckets.
48func NewInstrumentationMiddleware(reg prometheus.Registerer, buckets []float64) InstrumentationMiddleware {
49	if buckets == nil {
50		buckets = []float64{0.001, 0.01, 0.1, 0.3, 0.6, 1, 3, 6, 9, 20, 30, 60, 90, 120, 240, 360, 720}
51	}
52
53	ins := defaultInstrumentationMiddleware{
54		requestDuration: promauto.With(reg).NewHistogramVec(
55			prometheus.HistogramOpts{
56				Name:    "http_request_duration_seconds",
57				Help:    "Tracks the latencies for HTTP requests.",
58				Buckets: buckets,
59			},
60			[]string{"code", "handler", "method"},
61		),
62
63		requestSize: promauto.With(reg).NewSummaryVec(
64			prometheus.SummaryOpts{
65				Name: "http_request_size_bytes",
66				Help: "Tracks the size of HTTP requests.",
67			},
68			[]string{"code", "handler", "method"},
69		),
70
71		requestsTotal: promauto.With(reg).NewCounterVec(
72			prometheus.CounterOpts{
73				Name: "http_requests_total",
74				Help: "Tracks the number of HTTP requests.",
75			}, []string{"code", "handler", "method"},
76		),
77
78		responseSize: promauto.With(reg).NewSummaryVec(
79			prometheus.SummaryOpts{
80				Name: "http_response_size_bytes",
81				Help: "Tracks the size of HTTP responses.",
82			},
83			[]string{"code", "handler", "method"},
84		),
85	}
86	return &ins
87}
88
89// NewHandler wraps the given HTTP handler for instrumentation. It
90// registers four metric collectors (if not already done) and reports HTTP
91// metrics to the (newly or already) registered collectors: http_requests_total
92// (CounterVec), http_request_duration_seconds (Histogram),
93// http_request_size_bytes (Summary), http_response_size_bytes (Summary). Each
94// has a constant label named "handler" with the provided handlerName as
95// value. http_requests_total is a metric vector partitioned by HTTP method
96// (label name "method") and HTTP status code (label name "code").
97func (ins *defaultInstrumentationMiddleware) NewHandler(handlerName string, handler http.Handler) http.HandlerFunc {
98	return promhttp.InstrumentHandlerRequestSize(
99		ins.requestSize.MustCurryWith(prometheus.Labels{"handler": handlerName}),
100		promhttp.InstrumentHandlerCounter(
101			ins.requestsTotal.MustCurryWith(prometheus.Labels{"handler": handlerName}),
102			promhttp.InstrumentHandlerResponseSize(
103				ins.responseSize.MustCurryWith(prometheus.Labels{"handler": handlerName}),
104				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105					now := time.Now()
106
107					wd := &responseWriterDelegator{w: w}
108					handler.ServeHTTP(wd, r)
109
110					observer := ins.requestDuration.WithLabelValues(
111						wd.Status(),
112						handlerName,
113						strings.ToLower(r.Method),
114					)
115					observer.Observe(time.Since(now).Seconds())
116
117					// If we find a tracingID we'll expose it as Exemplar.
118					span := opentracing.SpanFromContext(r.Context())
119					if span != nil {
120						spanCtx, ok := span.Context().(jaeger.SpanContext)
121						if ok && spanCtx.IsSampled() {
122							observer.(prometheus.ExemplarObserver).ObserveWithExemplar(
123								time.Since(now).Seconds(),
124								prometheus.Labels{
125									"traceID": spanCtx.TraceID().String(),
126								},
127							)
128						}
129					}
130				}),
131			),
132		),
133	)
134}
135
136// responseWriterDelegator implements http.ResponseWriter and extracts the statusCode.
137type responseWriterDelegator struct {
138	w          http.ResponseWriter
139	written    bool
140	statusCode int
141}
142
143func (wd *responseWriterDelegator) Header() http.Header {
144	return wd.w.Header()
145}
146
147func (wd *responseWriterDelegator) Write(bytes []byte) (int, error) {
148	return wd.w.Write(bytes)
149}
150
151func (wd *responseWriterDelegator) WriteHeader(statusCode int) {
152	wd.written = true
153	wd.statusCode = statusCode
154	wd.w.WriteHeader(statusCode)
155}
156
157func (wd *responseWriterDelegator) StatusCode() int {
158	if !wd.written {
159		return http.StatusOK
160	}
161	return wd.statusCode
162}
163
164func (wd *responseWriterDelegator) Status() string {
165	return fmt.Sprintf("%d", wd.StatusCode())
166}
167