1// Copyright 2017 The Prometheus Authors 2// Licensed under the Apache License, Version 2.0 (the "License"); 3// you may not use this file except in compliance with the License. 4// You may obtain a copy of the License at 5// 6// http://www.apache.org/licenses/LICENSE-2.0 7// 8// Unless required by applicable law or agreed to in writing, software 9// distributed under the License is distributed on an "AS IS" BASIS, 10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11// See the License for the specific language governing permissions and 12// limitations under the License. 13 14package promhttp 15 16import ( 17 "context" 18 "fmt" 19 "log" 20 "net/http" 21 "net/http/httptest" 22 "testing" 23 "time" 24 25 "github.com/prometheus/client_golang/prometheus" 26) 27 28func makeInstrumentedClient() (*http.Client, *prometheus.Registry) { 29 client := http.DefaultClient 30 client.Timeout = 1 * time.Second 31 32 reg := prometheus.NewRegistry() 33 34 inFlightGauge := prometheus.NewGauge(prometheus.GaugeOpts{ 35 Name: "client_in_flight_requests", 36 Help: "A gauge of in-flight requests for the wrapped client.", 37 }) 38 39 counter := prometheus.NewCounterVec( 40 prometheus.CounterOpts{ 41 Name: "client_api_requests_total", 42 Help: "A counter for requests from the wrapped client.", 43 }, 44 []string{"code", "method"}, 45 ) 46 47 dnsLatencyVec := prometheus.NewHistogramVec( 48 prometheus.HistogramOpts{ 49 Name: "dns_duration_seconds", 50 Help: "Trace dns latency histogram.", 51 Buckets: []float64{.005, .01, .025, .05}, 52 }, 53 []string{"event"}, 54 ) 55 56 tlsLatencyVec := prometheus.NewHistogramVec( 57 prometheus.HistogramOpts{ 58 Name: "tls_duration_seconds", 59 Help: "Trace tls latency histogram.", 60 Buckets: []float64{.05, .1, .25, .5}, 61 }, 62 []string{"event"}, 63 ) 64 65 histVec := prometheus.NewHistogramVec( 66 prometheus.HistogramOpts{ 67 Name: "request_duration_seconds", 68 Help: "A histogram of request latencies.", 69 Buckets: prometheus.DefBuckets, 70 }, 71 []string{"method"}, 72 ) 73 74 reg.MustRegister(counter, tlsLatencyVec, dnsLatencyVec, histVec, inFlightGauge) 75 76 trace := &InstrumentTrace{ 77 DNSStart: func(t float64) { 78 dnsLatencyVec.WithLabelValues("dns_start").Observe(t) 79 }, 80 DNSDone: func(t float64) { 81 dnsLatencyVec.WithLabelValues("dns_done").Observe(t) 82 }, 83 TLSHandshakeStart: func(t float64) { 84 tlsLatencyVec.WithLabelValues("tls_handshake_start").Observe(t) 85 }, 86 TLSHandshakeDone: func(t float64) { 87 tlsLatencyVec.WithLabelValues("tls_handshake_done").Observe(t) 88 }, 89 } 90 91 client.Transport = InstrumentRoundTripperInFlight(inFlightGauge, 92 InstrumentRoundTripperCounter(counter, 93 InstrumentRoundTripperTrace(trace, 94 InstrumentRoundTripperDuration(histVec, http.DefaultTransport), 95 ), 96 ), 97 ) 98 return client, reg 99} 100 101func TestClientMiddlewareAPI(t *testing.T) { 102 client, reg := makeInstrumentedClient() 103 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 104 w.WriteHeader(http.StatusOK) 105 })) 106 defer backend.Close() 107 108 resp, err := client.Get(backend.URL) 109 if err != nil { 110 t.Fatal(err) 111 } 112 defer resp.Body.Close() 113 114 mfs, err := reg.Gather() 115 if err != nil { 116 t.Fatal(err) 117 } 118 if want, got := 3, len(mfs); want != got { 119 t.Fatalf("unexpected number of metric families gathered, want %d, got %d", want, got) 120 } 121 for _, mf := range mfs { 122 if len(mf.Metric) == 0 { 123 t.Errorf("metric family %s must not be empty", mf.GetName()) 124 } 125 } 126} 127 128func TestClientMiddlewareAPIWithRequestContext(t *testing.T) { 129 client, reg := makeInstrumentedClient() 130 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 131 w.WriteHeader(http.StatusOK) 132 })) 133 defer backend.Close() 134 135 req, err := http.NewRequest("GET", backend.URL, nil) 136 if err != nil { 137 t.Fatalf("%v", err) 138 } 139 140 // Set a context with a long timeout. 141 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 142 defer cancel() 143 req = req.WithContext(ctx) 144 145 resp, err := client.Do(req) 146 if err != nil { 147 t.Fatal(err) 148 } 149 defer resp.Body.Close() 150 151 mfs, err := reg.Gather() 152 if err != nil { 153 t.Fatal(err) 154 } 155 if want, got := 3, len(mfs); want != got { 156 t.Fatalf("unexpected number of metric families gathered, want %d, got %d", want, got) 157 } 158 for _, mf := range mfs { 159 if len(mf.Metric) == 0 { 160 t.Errorf("metric family %s must not be empty", mf.GetName()) 161 } 162 } 163} 164 165func TestClientMiddlewareAPIWithRequestContextTimeout(t *testing.T) { 166 client, _ := makeInstrumentedClient() 167 168 // Slow testserver responding in 100ms. 169 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 170 time.Sleep(100 * time.Millisecond) 171 w.WriteHeader(http.StatusOK) 172 })) 173 defer backend.Close() 174 175 req, err := http.NewRequest("GET", backend.URL, nil) 176 if err != nil { 177 t.Fatalf("%v", err) 178 } 179 180 // Set a context with a short timeout. 181 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 182 defer cancel() 183 req = req.WithContext(ctx) 184 185 _, err = client.Do(req) 186 if err == nil { 187 t.Fatal("did not get timeout error") 188 } 189 if want, got := fmt.Sprintf("Get %s: context deadline exceeded", backend.URL), err.Error(); want != got { 190 t.Fatalf("want error %q, got %q", want, got) 191 } 192} 193 194func ExampleInstrumentRoundTripperDuration() { 195 client := http.DefaultClient 196 client.Timeout = 1 * time.Second 197 198 inFlightGauge := prometheus.NewGauge(prometheus.GaugeOpts{ 199 Name: "client_in_flight_requests", 200 Help: "A gauge of in-flight requests for the wrapped client.", 201 }) 202 203 counter := prometheus.NewCounterVec( 204 prometheus.CounterOpts{ 205 Name: "client_api_requests_total", 206 Help: "A counter for requests from the wrapped client.", 207 }, 208 []string{"code", "method"}, 209 ) 210 211 // dnsLatencyVec uses custom buckets based on expected dns durations. 212 // It has an instance label "event", which is set in the 213 // DNSStart and DNSDonehook functions defined in the 214 // InstrumentTrace struct below. 215 dnsLatencyVec := prometheus.NewHistogramVec( 216 prometheus.HistogramOpts{ 217 Name: "dns_duration_seconds", 218 Help: "Trace dns latency histogram.", 219 Buckets: []float64{.005, .01, .025, .05}, 220 }, 221 []string{"event"}, 222 ) 223 224 // tlsLatencyVec uses custom buckets based on expected tls durations. 225 // It has an instance label "event", which is set in the 226 // TLSHandshakeStart and TLSHandshakeDone hook functions defined in the 227 // InstrumentTrace struct below. 228 tlsLatencyVec := prometheus.NewHistogramVec( 229 prometheus.HistogramOpts{ 230 Name: "tls_duration_seconds", 231 Help: "Trace tls latency histogram.", 232 Buckets: []float64{.05, .1, .25, .5}, 233 }, 234 []string{"event"}, 235 ) 236 237 // histVec has no labels, making it a zero-dimensional ObserverVec. 238 histVec := prometheus.NewHistogramVec( 239 prometheus.HistogramOpts{ 240 Name: "request_duration_seconds", 241 Help: "A histogram of request latencies.", 242 Buckets: prometheus.DefBuckets, 243 }, 244 []string{}, 245 ) 246 247 // Register all of the metrics in the standard registry. 248 prometheus.MustRegister(counter, tlsLatencyVec, dnsLatencyVec, histVec, inFlightGauge) 249 250 // Define functions for the available httptrace.ClientTrace hook 251 // functions that we want to instrument. 252 trace := &InstrumentTrace{ 253 DNSStart: func(t float64) { 254 dnsLatencyVec.WithLabelValues("dns_start").Observe(t) 255 }, 256 DNSDone: func(t float64) { 257 dnsLatencyVec.WithLabelValues("dns_done").Observe(t) 258 }, 259 TLSHandshakeStart: func(t float64) { 260 tlsLatencyVec.WithLabelValues("tls_handshake_start").Observe(t) 261 }, 262 TLSHandshakeDone: func(t float64) { 263 tlsLatencyVec.WithLabelValues("tls_handshake_done").Observe(t) 264 }, 265 } 266 267 // Wrap the default RoundTripper with middleware. 268 roundTripper := InstrumentRoundTripperInFlight(inFlightGauge, 269 InstrumentRoundTripperCounter(counter, 270 InstrumentRoundTripperTrace(trace, 271 InstrumentRoundTripperDuration(histVec, http.DefaultTransport), 272 ), 273 ), 274 ) 275 276 // Set the RoundTripper on our client. 277 client.Transport = roundTripper 278 279 resp, err := client.Get("http://google.com") 280 if err != nil { 281 log.Printf("error: %v", err) 282 } 283 defer resp.Body.Close() 284} 285