1// Copyright 2017 The Prometheus Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14package promhttp
15
16import (
17	"context"
18	"fmt"
19	"log"
20	"net/http"
21	"net/http/httptest"
22	"testing"
23	"time"
24
25	"github.com/prometheus/client_golang/prometheus"
26)
27
28func makeInstrumentedClient() (*http.Client, *prometheus.Registry) {
29	client := http.DefaultClient
30	client.Timeout = 1 * time.Second
31
32	reg := prometheus.NewRegistry()
33
34	inFlightGauge := prometheus.NewGauge(prometheus.GaugeOpts{
35		Name: "client_in_flight_requests",
36		Help: "A gauge of in-flight requests for the wrapped client.",
37	})
38
39	counter := prometheus.NewCounterVec(
40		prometheus.CounterOpts{
41			Name: "client_api_requests_total",
42			Help: "A counter for requests from the wrapped client.",
43		},
44		[]string{"code", "method"},
45	)
46
47	dnsLatencyVec := prometheus.NewHistogramVec(
48		prometheus.HistogramOpts{
49			Name:    "dns_duration_seconds",
50			Help:    "Trace dns latency histogram.",
51			Buckets: []float64{.005, .01, .025, .05},
52		},
53		[]string{"event"},
54	)
55
56	tlsLatencyVec := prometheus.NewHistogramVec(
57		prometheus.HistogramOpts{
58			Name:    "tls_duration_seconds",
59			Help:    "Trace tls latency histogram.",
60			Buckets: []float64{.05, .1, .25, .5},
61		},
62		[]string{"event"},
63	)
64
65	histVec := prometheus.NewHistogramVec(
66		prometheus.HistogramOpts{
67			Name:    "request_duration_seconds",
68			Help:    "A histogram of request latencies.",
69			Buckets: prometheus.DefBuckets,
70		},
71		[]string{"method"},
72	)
73
74	reg.MustRegister(counter, tlsLatencyVec, dnsLatencyVec, histVec, inFlightGauge)
75
76	trace := &InstrumentTrace{
77		DNSStart: func(t float64) {
78			dnsLatencyVec.WithLabelValues("dns_start").Observe(t)
79		},
80		DNSDone: func(t float64) {
81			dnsLatencyVec.WithLabelValues("dns_done").Observe(t)
82		},
83		TLSHandshakeStart: func(t float64) {
84			tlsLatencyVec.WithLabelValues("tls_handshake_start").Observe(t)
85		},
86		TLSHandshakeDone: func(t float64) {
87			tlsLatencyVec.WithLabelValues("tls_handshake_done").Observe(t)
88		},
89	}
90
91	client.Transport = InstrumentRoundTripperInFlight(inFlightGauge,
92		InstrumentRoundTripperCounter(counter,
93			InstrumentRoundTripperTrace(trace,
94				InstrumentRoundTripperDuration(histVec, http.DefaultTransport),
95			),
96		),
97	)
98	return client, reg
99}
100
101func TestClientMiddlewareAPI(t *testing.T) {
102	client, reg := makeInstrumentedClient()
103	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104		w.WriteHeader(http.StatusOK)
105	}))
106	defer backend.Close()
107
108	resp, err := client.Get(backend.URL)
109	if err != nil {
110		t.Fatal(err)
111	}
112	defer resp.Body.Close()
113
114	mfs, err := reg.Gather()
115	if err != nil {
116		t.Fatal(err)
117	}
118	if want, got := 3, len(mfs); want != got {
119		t.Fatalf("unexpected number of metric families gathered, want %d, got %d", want, got)
120	}
121	for _, mf := range mfs {
122		if len(mf.Metric) == 0 {
123			t.Errorf("metric family %s must not be empty", mf.GetName())
124		}
125	}
126}
127
128func TestClientMiddlewareAPIWithRequestContext(t *testing.T) {
129	client, reg := makeInstrumentedClient()
130	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131		w.WriteHeader(http.StatusOK)
132	}))
133	defer backend.Close()
134
135	req, err := http.NewRequest("GET", backend.URL, nil)
136	if err != nil {
137		t.Fatalf("%v", err)
138	}
139
140	// Set a context with a long timeout.
141	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
142	defer cancel()
143	req = req.WithContext(ctx)
144
145	resp, err := client.Do(req)
146	if err != nil {
147		t.Fatal(err)
148	}
149	defer resp.Body.Close()
150
151	mfs, err := reg.Gather()
152	if err != nil {
153		t.Fatal(err)
154	}
155	if want, got := 3, len(mfs); want != got {
156		t.Fatalf("unexpected number of metric families gathered, want %d, got %d", want, got)
157	}
158	for _, mf := range mfs {
159		if len(mf.Metric) == 0 {
160			t.Errorf("metric family %s must not be empty", mf.GetName())
161		}
162	}
163}
164
165func TestClientMiddlewareAPIWithRequestContextTimeout(t *testing.T) {
166	client, _ := makeInstrumentedClient()
167
168	// Slow testserver responding in 100ms.
169	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
170		time.Sleep(100 * time.Millisecond)
171		w.WriteHeader(http.StatusOK)
172	}))
173	defer backend.Close()
174
175	req, err := http.NewRequest("GET", backend.URL, nil)
176	if err != nil {
177		t.Fatalf("%v", err)
178	}
179
180	// Set a context with a short timeout.
181	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
182	defer cancel()
183	req = req.WithContext(ctx)
184
185	_, err = client.Do(req)
186	if err == nil {
187		t.Fatal("did not get timeout error")
188	}
189	if want, got := fmt.Sprintf("Get %s: context deadline exceeded", backend.URL), err.Error(); want != got {
190		t.Fatalf("want error %q, got %q", want, got)
191	}
192}
193
194func ExampleInstrumentRoundTripperDuration() {
195	client := http.DefaultClient
196	client.Timeout = 1 * time.Second
197
198	inFlightGauge := prometheus.NewGauge(prometheus.GaugeOpts{
199		Name: "client_in_flight_requests",
200		Help: "A gauge of in-flight requests for the wrapped client.",
201	})
202
203	counter := prometheus.NewCounterVec(
204		prometheus.CounterOpts{
205			Name: "client_api_requests_total",
206			Help: "A counter for requests from the wrapped client.",
207		},
208		[]string{"code", "method"},
209	)
210
211	// dnsLatencyVec uses custom buckets based on expected dns durations.
212	// It has an instance label "event", which is set in the
213	// DNSStart and DNSDonehook functions defined in the
214	// InstrumentTrace struct below.
215	dnsLatencyVec := prometheus.NewHistogramVec(
216		prometheus.HistogramOpts{
217			Name:    "dns_duration_seconds",
218			Help:    "Trace dns latency histogram.",
219			Buckets: []float64{.005, .01, .025, .05},
220		},
221		[]string{"event"},
222	)
223
224	// tlsLatencyVec uses custom buckets based on expected tls durations.
225	// It has an instance label "event", which is set in the
226	// TLSHandshakeStart and TLSHandshakeDone hook functions defined in the
227	// InstrumentTrace struct below.
228	tlsLatencyVec := prometheus.NewHistogramVec(
229		prometheus.HistogramOpts{
230			Name:    "tls_duration_seconds",
231			Help:    "Trace tls latency histogram.",
232			Buckets: []float64{.05, .1, .25, .5},
233		},
234		[]string{"event"},
235	)
236
237	// histVec has no labels, making it a zero-dimensional ObserverVec.
238	histVec := prometheus.NewHistogramVec(
239		prometheus.HistogramOpts{
240			Name:    "request_duration_seconds",
241			Help:    "A histogram of request latencies.",
242			Buckets: prometheus.DefBuckets,
243		},
244		[]string{},
245	)
246
247	// Register all of the metrics in the standard registry.
248	prometheus.MustRegister(counter, tlsLatencyVec, dnsLatencyVec, histVec, inFlightGauge)
249
250	// Define functions for the available httptrace.ClientTrace hook
251	// functions that we want to instrument.
252	trace := &InstrumentTrace{
253		DNSStart: func(t float64) {
254			dnsLatencyVec.WithLabelValues("dns_start").Observe(t)
255		},
256		DNSDone: func(t float64) {
257			dnsLatencyVec.WithLabelValues("dns_done").Observe(t)
258		},
259		TLSHandshakeStart: func(t float64) {
260			tlsLatencyVec.WithLabelValues("tls_handshake_start").Observe(t)
261		},
262		TLSHandshakeDone: func(t float64) {
263			tlsLatencyVec.WithLabelValues("tls_handshake_done").Observe(t)
264		},
265	}
266
267	// Wrap the default RoundTripper with middleware.
268	roundTripper := InstrumentRoundTripperInFlight(inFlightGauge,
269		InstrumentRoundTripperCounter(counter,
270			InstrumentRoundTripperTrace(trace,
271				InstrumentRoundTripperDuration(histVec, http.DefaultTransport),
272			),
273		),
274	)
275
276	// Set the RoundTripper on our client.
277	client.Transport = roundTripper
278
279	resp, err := client.Get("http://google.com")
280	if err != nil {
281		log.Printf("error: %v", err)
282	}
283	defer resp.Body.Close()
284}
285