1// Copyright 2017 The Prometheus Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14package promhttp
15
16import (
17	"crypto/tls"
18	"net/http"
19	"net/http/httptrace"
20	"time"
21
22	"github.com/prometheus/client_golang/prometheus"
23)
24
25// The RoundTripperFunc type is an adapter to allow the use of ordinary
26// functions as RoundTrippers. If f is a function with the appropriate
27// signature, RountTripperFunc(f) is a RoundTripper that calls f.
28type RoundTripperFunc func(req *http.Request) (*http.Response, error)
29
30// RoundTrip implements the RoundTripper interface.
31func (rt RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
32	return rt(r)
33}
34
35// InstrumentRoundTripperInFlight is a middleware that wraps the provided
36// http.RoundTripper. It sets the provided prometheus.Gauge to the number of
37// requests currently handled by the wrapped http.RoundTripper.
38//
39// See the example for ExampleInstrumentRoundTripperDuration for example usage.
40func InstrumentRoundTripperInFlight(gauge prometheus.Gauge, next http.RoundTripper) RoundTripperFunc {
41	return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
42		gauge.Inc()
43		defer gauge.Dec()
44		return next.RoundTrip(r)
45	})
46}
47
48// InstrumentRoundTripperCounter is a middleware that wraps the provided
49// http.RoundTripper to observe the request result with the provided CounterVec.
50// The CounterVec must have zero, one, or two non-const non-curried labels. For
51// those, the only allowed label names are "code" and "method". The function
52// panics otherwise. Partitioning of the CounterVec happens by HTTP status code
53// and/or HTTP method if the respective instance label names are present in the
54// CounterVec. For unpartitioned counting, use a CounterVec with zero labels.
55//
56// If the wrapped RoundTripper panics or returns a non-nil error, the Counter
57// is not incremented.
58//
59// See the example for ExampleInstrumentRoundTripperDuration for example usage.
60func InstrumentRoundTripperCounter(counter *prometheus.CounterVec, next http.RoundTripper) RoundTripperFunc {
61	code, method := checkLabels(counter)
62
63	return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
64		resp, err := next.RoundTrip(r)
65		if err == nil {
66			counter.With(labels(code, method, r.Method, resp.StatusCode)).Inc()
67		}
68		return resp, err
69	})
70}
71
72// InstrumentRoundTripperDuration is a middleware that wraps the provided
73// http.RoundTripper to observe the request duration with the provided
74// ObserverVec.  The ObserverVec must have zero, one, or two non-const
75// non-curried labels. For those, the only allowed label names are "code" and
76// "method". The function panics otherwise. The Observe method of the Observer
77// in the ObserverVec is called with the request duration in
78// seconds. Partitioning happens by HTTP status code and/or HTTP method if the
79// respective instance label names are present in the ObserverVec. For
80// unpartitioned observations, use an ObserverVec with zero labels. Note that
81// partitioning of Histograms is expensive and should be used judiciously.
82//
83// If the wrapped RoundTripper panics or returns a non-nil error, no values are
84// reported.
85//
86// Note that this method is only guaranteed to never observe negative durations
87// if used with Go1.9+.
88func InstrumentRoundTripperDuration(obs prometheus.ObserverVec, next http.RoundTripper) RoundTripperFunc {
89	code, method := checkLabels(obs)
90
91	return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
92		start := time.Now()
93		resp, err := next.RoundTrip(r)
94		if err == nil {
95			obs.With(labels(code, method, r.Method, resp.StatusCode)).Observe(time.Since(start).Seconds())
96		}
97		return resp, err
98	})
99}
100
101// InstrumentTrace is used to offer flexibility in instrumenting the available
102// httptrace.ClientTrace hook functions. Each function is passed a float64
103// representing the time in seconds since the start of the http request. A user
104// may choose to use separately buckets Histograms, or implement custom
105// instance labels on a per function basis.
106type InstrumentTrace struct {
107	GotConn              func(float64)
108	PutIdleConn          func(float64)
109	GotFirstResponseByte func(float64)
110	Got100Continue       func(float64)
111	DNSStart             func(float64)
112	DNSDone              func(float64)
113	ConnectStart         func(float64)
114	ConnectDone          func(float64)
115	TLSHandshakeStart    func(float64)
116	TLSHandshakeDone     func(float64)
117	WroteHeaders         func(float64)
118	Wait100Continue      func(float64)
119	WroteRequest         func(float64)
120}
121
122// InstrumentRoundTripperTrace is a middleware that wraps the provided
123// RoundTripper and reports times to hook functions provided in the
124// InstrumentTrace struct. Hook functions that are not present in the provided
125// InstrumentTrace struct are ignored. Times reported to the hook functions are
126// time since the start of the request. Only with Go1.9+, those times are
127// guaranteed to never be negative. (Earlier Go versions are not using a
128// monotonic clock.) Note that partitioning of Histograms is expensive and
129// should be used judiciously.
130//
131// For hook functions that receive an error as an argument, no observations are
132// made in the event of a non-nil error value.
133//
134// See the example for ExampleInstrumentRoundTripperDuration for example usage.
135func InstrumentRoundTripperTrace(it *InstrumentTrace, next http.RoundTripper) RoundTripperFunc {
136	return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
137		start := time.Now()
138
139		trace := &httptrace.ClientTrace{
140			GotConn: func(_ httptrace.GotConnInfo) {
141				if it.GotConn != nil {
142					it.GotConn(time.Since(start).Seconds())
143				}
144			},
145			PutIdleConn: func(err error) {
146				if err != nil {
147					return
148				}
149				if it.PutIdleConn != nil {
150					it.PutIdleConn(time.Since(start).Seconds())
151				}
152			},
153			DNSStart: func(_ httptrace.DNSStartInfo) {
154				if it.DNSStart != nil {
155					it.DNSStart(time.Since(start).Seconds())
156				}
157			},
158			DNSDone: func(_ httptrace.DNSDoneInfo) {
159				if it.DNSDone != nil {
160					it.DNSDone(time.Since(start).Seconds())
161				}
162			},
163			ConnectStart: func(_, _ string) {
164				if it.ConnectStart != nil {
165					it.ConnectStart(time.Since(start).Seconds())
166				}
167			},
168			ConnectDone: func(_, _ string, err error) {
169				if err != nil {
170					return
171				}
172				if it.ConnectDone != nil {
173					it.ConnectDone(time.Since(start).Seconds())
174				}
175			},
176			GotFirstResponseByte: func() {
177				if it.GotFirstResponseByte != nil {
178					it.GotFirstResponseByte(time.Since(start).Seconds())
179				}
180			},
181			Got100Continue: func() {
182				if it.Got100Continue != nil {
183					it.Got100Continue(time.Since(start).Seconds())
184				}
185			},
186			TLSHandshakeStart: func() {
187				if it.TLSHandshakeStart != nil {
188					it.TLSHandshakeStart(time.Since(start).Seconds())
189				}
190			},
191			TLSHandshakeDone: func(_ tls.ConnectionState, err error) {
192				if err != nil {
193					return
194				}
195				if it.TLSHandshakeDone != nil {
196					it.TLSHandshakeDone(time.Since(start).Seconds())
197				}
198			},
199			WroteHeaders: func() {
200				if it.WroteHeaders != nil {
201					it.WroteHeaders(time.Since(start).Seconds())
202				}
203			},
204			Wait100Continue: func() {
205				if it.Wait100Continue != nil {
206					it.Wait100Continue(time.Since(start).Seconds())
207				}
208			},
209			WroteRequest: func(_ httptrace.WroteRequestInfo) {
210				if it.WroteRequest != nil {
211					it.WroteRequest(time.Since(start).Seconds())
212				}
213			},
214		}
215		r = r.WithContext(httptrace.WithClientTrace(r.Context(), trace))
216
217		return next.RoundTrip(r)
218	})
219}
220