1// Copyright 2017 The Prometheus Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14package promhttp
15
16import (
17	"errors"
18	"net/http"
19	"strconv"
20	"strings"
21	"time"
22
23	dto "github.com/prometheus/client_model/go"
24
25	"github.com/prometheus/client_golang/prometheus"
26)
27
28// magicString is used for the hacky label test in checkLabels. Remove once fixed.
29const magicString = "zZgWfBxLqvG8kc8IMv3POi2Bb0tZI3vAnBx+gBaFi9FyPzB/CzKUer1yufDa"
30
31// InstrumentHandlerInFlight is a middleware that wraps the provided
32// http.Handler. It sets the provided prometheus.Gauge to the number of
33// requests currently handled by the wrapped http.Handler.
34//
35// See the example for InstrumentHandlerDuration for example usage.
36func InstrumentHandlerInFlight(g prometheus.Gauge, next http.Handler) http.Handler {
37	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38		g.Inc()
39		defer g.Dec()
40		next.ServeHTTP(w, r)
41	})
42}
43
44// InstrumentHandlerDuration is a middleware that wraps the provided
45// http.Handler to observe the request duration with the provided ObserverVec.
46// The ObserverVec must have zero, one, or two non-const non-curried labels. For
47// those, the only allowed label names are "code" and "method". The function
48// panics otherwise. The Observe method of the Observer in the ObserverVec is
49// called with the request duration in seconds. Partitioning happens by HTTP
50// status code and/or HTTP method if the respective instance label names are
51// present in the ObserverVec. For unpartitioned observations, use an
52// ObserverVec with zero labels. Note that partitioning of Histograms is
53// expensive and should be used judiciously.
54//
55// If the wrapped Handler does not set a status code, a status code of 200 is assumed.
56//
57// If the wrapped Handler panics, no values are reported.
58//
59// Note that this method is only guaranteed to never observe negative durations
60// if used with Go1.9+.
61func InstrumentHandlerDuration(obs prometheus.ObserverVec, next http.Handler) http.HandlerFunc {
62	code, method := checkLabels(obs)
63
64	if code {
65		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
66			now := time.Now()
67			d := newDelegator(w, nil)
68			next.ServeHTTP(d, r)
69
70			obs.With(labels(code, method, r.Method, d.Status())).Observe(time.Since(now).Seconds())
71		})
72	}
73
74	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
75		now := time.Now()
76		next.ServeHTTP(w, r)
77		obs.With(labels(code, method, r.Method, 0)).Observe(time.Since(now).Seconds())
78	})
79}
80
81// InstrumentHandlerCounter is a middleware that wraps the provided http.Handler
82// to observe the request result with the provided CounterVec.  The CounterVec
83// must have zero, one, or two non-const non-curried labels. For those, the only
84// allowed label names are "code" and "method". The function panics
85// otherwise. Partitioning of the CounterVec happens by HTTP status code and/or
86// HTTP method if the respective instance label names are present in the
87// CounterVec. For unpartitioned counting, use a CounterVec with zero labels.
88//
89// If the wrapped Handler does not set a status code, a status code of 200 is assumed.
90//
91// If the wrapped Handler panics, the Counter is not incremented.
92//
93// See the example for InstrumentHandlerDuration for example usage.
94func InstrumentHandlerCounter(counter *prometheus.CounterVec, next http.Handler) http.HandlerFunc {
95	code, method := checkLabels(counter)
96
97	if code {
98		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
99			d := newDelegator(w, nil)
100			next.ServeHTTP(d, r)
101			counter.With(labels(code, method, r.Method, d.Status())).Inc()
102		})
103	}
104
105	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
106		next.ServeHTTP(w, r)
107		counter.With(labels(code, method, r.Method, 0)).Inc()
108	})
109}
110
111// InstrumentHandlerTimeToWriteHeader is a middleware that wraps the provided
112// http.Handler to observe with the provided ObserverVec the request duration
113// until the response headers are written. The ObserverVec must have zero, one,
114// or two non-const non-curried labels. For those, the only allowed label names
115// are "code" and "method". The function panics otherwise. The Observe method of
116// the Observer in the ObserverVec is called with the request duration in
117// seconds. Partitioning happens by HTTP status code and/or HTTP method if the
118// respective instance label names are present in the ObserverVec. For
119// unpartitioned observations, use an ObserverVec with zero labels. Note that
120// partitioning of Histograms is expensive and should be used judiciously.
121//
122// If the wrapped Handler panics before calling WriteHeader, no value is
123// reported.
124//
125// Note that this method is only guaranteed to never observe negative durations
126// if used with Go1.9+.
127//
128// See the example for InstrumentHandlerDuration for example usage.
129func InstrumentHandlerTimeToWriteHeader(obs prometheus.ObserverVec, next http.Handler) http.HandlerFunc {
130	code, method := checkLabels(obs)
131
132	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133		now := time.Now()
134		d := newDelegator(w, func(status int) {
135			obs.With(labels(code, method, r.Method, status)).Observe(time.Since(now).Seconds())
136		})
137		next.ServeHTTP(d, r)
138	})
139}
140
141// InstrumentHandlerRequestSize is a middleware that wraps the provided
142// http.Handler to observe the request size with the provided ObserverVec.  The
143// ObserverVec must have zero, one, or two non-const non-curried labels. For
144// those, the only allowed label names are "code" and "method". The function
145// panics otherwise. The Observe method of the Observer in the ObserverVec is
146// called with the request size in bytes. Partitioning happens by HTTP status
147// code and/or HTTP method if the respective instance label names are present in
148// the ObserverVec. For unpartitioned observations, use an ObserverVec with zero
149// labels. Note that partitioning of Histograms is expensive and should be used
150// judiciously.
151//
152// If the wrapped Handler does not set a status code, a status code of 200 is assumed.
153//
154// If the wrapped Handler panics, no values are reported.
155//
156// See the example for InstrumentHandlerDuration for example usage.
157func InstrumentHandlerRequestSize(obs prometheus.ObserverVec, next http.Handler) http.HandlerFunc {
158	code, method := checkLabels(obs)
159
160	if code {
161		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
162			d := newDelegator(w, nil)
163			next.ServeHTTP(d, r)
164			size := computeApproximateRequestSize(r)
165			obs.With(labels(code, method, r.Method, d.Status())).Observe(float64(size))
166		})
167	}
168
169	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
170		next.ServeHTTP(w, r)
171		size := computeApproximateRequestSize(r)
172		obs.With(labels(code, method, r.Method, 0)).Observe(float64(size))
173	})
174}
175
176// InstrumentHandlerResponseSize is a middleware that wraps the provided
177// http.Handler to observe the response size with the provided ObserverVec.  The
178// ObserverVec must have zero, one, or two non-const non-curried labels. For
179// those, the only allowed label names are "code" and "method". The function
180// panics otherwise. The Observe method of the Observer in the ObserverVec is
181// called with the response size in bytes. Partitioning happens by HTTP status
182// code and/or HTTP method if the respective instance label names are present in
183// the ObserverVec. For unpartitioned observations, use an ObserverVec with zero
184// labels. Note that partitioning of Histograms is expensive and should be used
185// judiciously.
186//
187// If the wrapped Handler does not set a status code, a status code of 200 is assumed.
188//
189// If the wrapped Handler panics, no values are reported.
190//
191// See the example for InstrumentHandlerDuration for example usage.
192func InstrumentHandlerResponseSize(obs prometheus.ObserverVec, next http.Handler) http.Handler {
193	code, method := checkLabels(obs)
194	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
195		d := newDelegator(w, nil)
196		next.ServeHTTP(d, r)
197		obs.With(labels(code, method, r.Method, d.Status())).Observe(float64(d.Written()))
198	})
199}
200
201func checkLabels(c prometheus.Collector) (code bool, method bool) {
202	// TODO(beorn7): Remove this hacky way to check for instance labels
203	// once Descriptors can have their dimensionality queried.
204	var (
205		desc *prometheus.Desc
206		m    prometheus.Metric
207		pm   dto.Metric
208		lvs  []string
209	)
210
211	// Get the Desc from the Collector.
212	descc := make(chan *prometheus.Desc, 1)
213	c.Describe(descc)
214
215	select {
216	case desc = <-descc:
217	default:
218		panic("no description provided by collector")
219	}
220	select {
221	case <-descc:
222		panic("more than one description provided by collector")
223	default:
224	}
225
226	close(descc)
227
228	// Create a ConstMetric with the Desc. Since we don't know how many
229	// variable labels there are, try for as long as it needs.
230	for err := errors.New("dummy"); err != nil; lvs = append(lvs, magicString) {
231		m, err = prometheus.NewConstMetric(desc, prometheus.UntypedValue, 0, lvs...)
232	}
233
234	// Write out the metric into a proto message and look at the labels.
235	// If the value is not the magicString, it is a constLabel, which doesn't interest us.
236	// If the label is curried, it doesn't interest us.
237	// In all other cases, only "code" or "method" is allowed.
238	if err := m.Write(&pm); err != nil {
239		panic("error checking metric for labels")
240	}
241	for _, label := range pm.Label {
242		name, value := label.GetName(), label.GetValue()
243		if value != magicString || isLabelCurried(c, name) {
244			continue
245		}
246		switch name {
247		case "code":
248			code = true
249		case "method":
250			method = true
251		default:
252			panic("metric partitioned with non-supported labels")
253		}
254	}
255	return
256}
257
258func isLabelCurried(c prometheus.Collector, label string) bool {
259	// This is even hackier than the label test above.
260	// We essentially try to curry again and see if it works.
261	// But for that, we need to type-convert to the two
262	// types we use here, ObserverVec or *CounterVec.
263	switch v := c.(type) {
264	case *prometheus.CounterVec:
265		if _, err := v.CurryWith(prometheus.Labels{label: "dummy"}); err == nil {
266			return false
267		}
268	case prometheus.ObserverVec:
269		if _, err := v.CurryWith(prometheus.Labels{label: "dummy"}); err == nil {
270			return false
271		}
272	default:
273		panic("unsupported metric vec type")
274	}
275	return true
276}
277
278// emptyLabels is a one-time allocation for non-partitioned metrics to avoid
279// unnecessary allocations on each request.
280var emptyLabels = prometheus.Labels{}
281
282func labels(code, method bool, reqMethod string, status int) prometheus.Labels {
283	if !(code || method) {
284		return emptyLabels
285	}
286	labels := prometheus.Labels{}
287
288	if code {
289		labels["code"] = sanitizeCode(status)
290	}
291	if method {
292		labels["method"] = sanitizeMethod(reqMethod)
293	}
294
295	return labels
296}
297
298func computeApproximateRequestSize(r *http.Request) int {
299	s := 0
300	if r.URL != nil {
301		s += len(r.URL.String())
302	}
303
304	s += len(r.Method)
305	s += len(r.Proto)
306	for name, values := range r.Header {
307		s += len(name)
308		for _, value := range values {
309			s += len(value)
310		}
311	}
312	s += len(r.Host)
313
314	// N.B. r.Form and r.MultipartForm are assumed to be included in r.URL.
315
316	if r.ContentLength != -1 {
317		s += int(r.ContentLength)
318	}
319	return s
320}
321
322func sanitizeMethod(m string) string {
323	switch m {
324	case "GET", "get":
325		return "get"
326	case "PUT", "put":
327		return "put"
328	case "HEAD", "head":
329		return "head"
330	case "POST", "post":
331		return "post"
332	case "DELETE", "delete":
333		return "delete"
334	case "CONNECT", "connect":
335		return "connect"
336	case "OPTIONS", "options":
337		return "options"
338	case "NOTIFY", "notify":
339		return "notify"
340	default:
341		return strings.ToLower(m)
342	}
343}
344
345// If the wrapped http.Handler has not set a status code, i.e. the value is
346// currently 0, santizeCode will return 200, for consistency with behavior in
347// the stdlib.
348func sanitizeCode(s int) string {
349	switch s {
350	case 100:
351		return "100"
352	case 101:
353		return "101"
354
355	case 200, 0:
356		return "200"
357	case 201:
358		return "201"
359	case 202:
360		return "202"
361	case 203:
362		return "203"
363	case 204:
364		return "204"
365	case 205:
366		return "205"
367	case 206:
368		return "206"
369
370	case 300:
371		return "300"
372	case 301:
373		return "301"
374	case 302:
375		return "302"
376	case 304:
377		return "304"
378	case 305:
379		return "305"
380	case 307:
381		return "307"
382
383	case 400:
384		return "400"
385	case 401:
386		return "401"
387	case 402:
388		return "402"
389	case 403:
390		return "403"
391	case 404:
392		return "404"
393	case 405:
394		return "405"
395	case 406:
396		return "406"
397	case 407:
398		return "407"
399	case 408:
400		return "408"
401	case 409:
402		return "409"
403	case 410:
404		return "410"
405	case 411:
406		return "411"
407	case 412:
408		return "412"
409	case 413:
410		return "413"
411	case 414:
412		return "414"
413	case 415:
414		return "415"
415	case 416:
416		return "416"
417	case 417:
418		return "417"
419	case 418:
420		return "418"
421
422	case 500:
423		return "500"
424	case 501:
425		return "501"
426	case 502:
427		return "502"
428	case 503:
429		return "503"
430	case 504:
431		return "504"
432	case 505:
433		return "505"
434
435	case 428:
436		return "428"
437	case 429:
438		return "429"
439	case 431:
440		return "431"
441	case 511:
442		return "511"
443
444	default:
445		return strconv.Itoa(s)
446	}
447}
448