1// Copyright 2016 The Prometheus Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14package prober
15
16import (
17	"context"
18	"errors"
19	"fmt"
20	"io"
21	"io/ioutil"
22	"net"
23	"net/http"
24	"net/http/cookiejar"
25	"net/http/httptrace"
26	"net/textproto"
27	"net/url"
28	"regexp"
29	"strconv"
30	"strings"
31	"sync"
32	"time"
33
34	"github.com/go-kit/kit/log"
35	"github.com/go-kit/kit/log/level"
36	"github.com/prometheus/client_golang/prometheus"
37	pconfig "github.com/prometheus/common/config"
38	"golang.org/x/net/publicsuffix"
39
40	"github.com/prometheus/blackbox_exporter/config"
41)
42
43func matchRegularExpressions(reader io.Reader, httpConfig config.HTTPProbe, logger log.Logger) bool {
44	body, err := ioutil.ReadAll(reader)
45	if err != nil {
46		level.Error(logger).Log("msg", "Error reading HTTP body", "err", err)
47		return false
48	}
49	for _, expression := range httpConfig.FailIfBodyMatchesRegexp {
50		re, err := regexp.Compile(expression)
51		if err != nil {
52			level.Error(logger).Log("msg", "Could not compile regular expression", "regexp", expression, "err", err)
53			return false
54		}
55		if re.Match(body) {
56			level.Error(logger).Log("msg", "Body matched regular expression", "regexp", expression)
57			return false
58		}
59	}
60	for _, expression := range httpConfig.FailIfBodyNotMatchesRegexp {
61		re, err := regexp.Compile(expression)
62		if err != nil {
63			level.Error(logger).Log("msg", "Could not compile regular expression", "regexp", expression, "err", err)
64			return false
65		}
66		if !re.Match(body) {
67			level.Error(logger).Log("msg", "Body did not match regular expression", "regexp", expression)
68			return false
69		}
70	}
71	return true
72}
73
74func matchRegularExpressionsOnHeaders(header http.Header, httpConfig config.HTTPProbe, logger log.Logger) bool {
75	for _, headerMatchSpec := range httpConfig.FailIfHeaderMatchesRegexp {
76		values := header[textproto.CanonicalMIMEHeaderKey(headerMatchSpec.Header)]
77		if len(values) == 0 {
78			if !headerMatchSpec.AllowMissing {
79				level.Error(logger).Log("msg", "Missing required header", "header", headerMatchSpec.Header)
80				return false
81			} else {
82				continue // No need to match any regex on missing headers.
83			}
84		}
85
86		re, err := regexp.Compile(headerMatchSpec.Regexp)
87		if err != nil {
88			level.Error(logger).Log("msg", "Could not compile regular expression", "regexp", headerMatchSpec.Regexp, "err", err)
89			return false
90		}
91
92		for _, val := range values {
93			if re.MatchString(val) {
94				level.Error(logger).Log("msg", "Header matched regular expression", "header", headerMatchSpec.Header,
95					"regexp", headerMatchSpec.Regexp, "value_count", len(values))
96				return false
97			}
98		}
99	}
100	for _, headerMatchSpec := range httpConfig.FailIfHeaderNotMatchesRegexp {
101		values := header[textproto.CanonicalMIMEHeaderKey(headerMatchSpec.Header)]
102		if len(values) == 0 {
103			if !headerMatchSpec.AllowMissing {
104				level.Error(logger).Log("msg", "Missing required header", "header", headerMatchSpec.Header)
105				return false
106			} else {
107				continue // No need to match any regex on missing headers.
108			}
109		}
110
111		re, err := regexp.Compile(headerMatchSpec.Regexp)
112		if err != nil {
113			level.Error(logger).Log("msg", "Could not compile regular expression", "regexp", headerMatchSpec.Regexp, "err", err)
114			return false
115		}
116
117		anyHeaderValueMatched := false
118
119		for _, val := range values {
120			if re.MatchString(val) {
121				anyHeaderValueMatched = true
122				break
123			}
124		}
125
126		if !anyHeaderValueMatched {
127			level.Error(logger).Log("msg", "Header did not match regular expression", "header", headerMatchSpec.Header,
128				"regexp", headerMatchSpec.Regexp, "value_count", len(values))
129			return false
130		}
131	}
132
133	return true
134}
135
136// roundTripTrace holds timings for a single HTTP roundtrip.
137type roundTripTrace struct {
138	tls           bool
139	start         time.Time
140	dnsDone       time.Time
141	connectDone   time.Time
142	gotConn       time.Time
143	responseStart time.Time
144	end           time.Time
145}
146
147// transport is a custom transport keeping traces for each HTTP roundtrip.
148type transport struct {
149	Transport             http.RoundTripper
150	NoServerNameTransport http.RoundTripper
151	firstHost             string
152	logger                log.Logger
153
154	mu      sync.Mutex
155	traces  []*roundTripTrace
156	current *roundTripTrace
157}
158
159func newTransport(rt, noServerName http.RoundTripper, logger log.Logger) *transport {
160	return &transport{
161		Transport:             rt,
162		NoServerNameTransport: noServerName,
163		logger:                logger,
164		traces:                []*roundTripTrace{},
165	}
166}
167
168// RoundTrip switches to a new trace, then runs embedded RoundTripper.
169func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
170	level.Info(t.logger).Log("msg", "Making HTTP request", "url", req.URL.String(), "host", req.Host)
171
172	trace := &roundTripTrace{}
173	if req.URL.Scheme == "https" {
174		trace.tls = true
175	}
176	t.current = trace
177	t.traces = append(t.traces, trace)
178
179	if t.firstHost == "" {
180		t.firstHost = req.URL.Host
181	}
182
183	if t.firstHost != req.URL.Host {
184		// This is a redirect to something other than the initial host,
185		// so TLS ServerName should not be set.
186		level.Info(t.logger).Log("msg", "Address does not match first address, not sending TLS ServerName", "first", t.firstHost, "address", req.URL.Host)
187		return t.NoServerNameTransport.RoundTrip(req)
188	}
189
190	return t.Transport.RoundTrip(req)
191}
192
193func (t *transport) DNSStart(_ httptrace.DNSStartInfo) {
194	t.mu.Lock()
195	defer t.mu.Unlock()
196	t.current.start = time.Now()
197}
198func (t *transport) DNSDone(_ httptrace.DNSDoneInfo) {
199	t.mu.Lock()
200	defer t.mu.Unlock()
201	t.current.dnsDone = time.Now()
202}
203func (ts *transport) ConnectStart(_, _ string) {
204	ts.mu.Lock()
205	defer ts.mu.Unlock()
206	t := ts.current
207	// No DNS resolution because we connected to IP directly.
208	if t.dnsDone.IsZero() {
209		t.start = time.Now()
210		t.dnsDone = t.start
211	}
212}
213func (t *transport) ConnectDone(net, addr string, err error) {
214	t.mu.Lock()
215	defer t.mu.Unlock()
216	t.current.connectDone = time.Now()
217}
218func (t *transport) GotConn(_ httptrace.GotConnInfo) {
219	t.mu.Lock()
220	defer t.mu.Unlock()
221	t.current.gotConn = time.Now()
222}
223func (t *transport) GotFirstResponseByte() {
224	t.mu.Lock()
225	defer t.mu.Unlock()
226	t.current.responseStart = time.Now()
227}
228
229func ProbeHTTP(ctx context.Context, target string, module config.Module, registry *prometheus.Registry, logger log.Logger) (success bool) {
230	var redirects int
231	var (
232		durationGaugeVec = prometheus.NewGaugeVec(prometheus.GaugeOpts{
233			Name: "probe_http_duration_seconds",
234			Help: "Duration of http request by phase, summed over all redirects",
235		}, []string{"phase"})
236		contentLengthGauge = prometheus.NewGauge(prometheus.GaugeOpts{
237			Name: "probe_http_content_length",
238			Help: "Length of http content response",
239		})
240		bodyUncompressedLengthGauge = prometheus.NewGauge(prometheus.GaugeOpts{
241			Name: "probe_http_uncompressed_body_length",
242			Help: "Length of uncompressed response body",
243		})
244		redirectsGauge = prometheus.NewGauge(prometheus.GaugeOpts{
245			Name: "probe_http_redirects",
246			Help: "The number of redirects",
247		})
248
249		isSSLGauge = prometheus.NewGauge(prometheus.GaugeOpts{
250			Name: "probe_http_ssl",
251			Help: "Indicates if SSL was used for the final redirect",
252		})
253
254		statusCodeGauge = prometheus.NewGauge(prometheus.GaugeOpts{
255			Name: "probe_http_status_code",
256			Help: "Response HTTP status code",
257		})
258
259		probeSSLEarliestCertExpiryGauge = prometheus.NewGauge(prometheus.GaugeOpts{
260			Name: "probe_ssl_earliest_cert_expiry",
261			Help: "Returns earliest SSL cert expiry in unixtime",
262		})
263
264		probeSSLLastChainExpiryTimestampSeconds = prometheus.NewGauge(prometheus.GaugeOpts{
265			Name: "probe_ssl_last_chain_expiry_timestamp_seconds",
266			Help: "Returns last SSL chain expiry in timestamp seconds",
267		})
268
269		probeTLSVersion = prometheus.NewGaugeVec(
270			prometheus.GaugeOpts{
271				Name: "probe_tls_version_info",
272				Help: "Contains the TLS version used",
273			},
274			[]string{"version"},
275		)
276
277		probeHTTPVersionGauge = prometheus.NewGauge(prometheus.GaugeOpts{
278			Name: "probe_http_version",
279			Help: "Returns the version of HTTP of the probe response",
280		})
281
282		probeFailedDueToRegex = prometheus.NewGauge(prometheus.GaugeOpts{
283			Name: "probe_failed_due_to_regex",
284			Help: "Indicates if probe failed due to regex",
285		})
286
287		probeHTTPLastModified = prometheus.NewGauge(prometheus.GaugeOpts{
288			Name: "probe_http_last_modified_timestamp_seconds",
289			Help: "Returns the Last-Modified HTTP response header in unixtime",
290		})
291	)
292
293	for _, lv := range []string{"resolve", "connect", "tls", "processing", "transfer"} {
294		durationGaugeVec.WithLabelValues(lv)
295	}
296
297	registry.MustRegister(durationGaugeVec)
298	registry.MustRegister(contentLengthGauge)
299	registry.MustRegister(bodyUncompressedLengthGauge)
300	registry.MustRegister(redirectsGauge)
301	registry.MustRegister(isSSLGauge)
302	registry.MustRegister(statusCodeGauge)
303	registry.MustRegister(probeHTTPVersionGauge)
304	registry.MustRegister(probeFailedDueToRegex)
305
306	httpConfig := module.HTTP
307
308	if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") {
309		target = "http://" + target
310	}
311
312	targetURL, err := url.Parse(target)
313	if err != nil {
314		level.Error(logger).Log("msg", "Could not parse target URL", "err", err)
315		return false
316	}
317	targetHost, targetPort, err := net.SplitHostPort(targetURL.Host)
318	// If split fails, assuming it's a hostname without port part.
319	if err != nil {
320		targetHost = targetURL.Host
321	}
322
323	ip, lookupTime, err := chooseProtocol(ctx, module.HTTP.IPProtocol, module.HTTP.IPProtocolFallback, targetHost, registry, logger)
324	if err != nil {
325		level.Error(logger).Log("msg", "Error resolving address", "err", err)
326		return false
327	}
328	durationGaugeVec.WithLabelValues("resolve").Add(lookupTime)
329
330	httpClientConfig := module.HTTP.HTTPClientConfig
331	if len(httpClientConfig.TLSConfig.ServerName) == 0 {
332		// If there is no `server_name` in tls_config, use
333		// the hostname of the target.
334		httpClientConfig.TLSConfig.ServerName = targetHost
335	}
336	client, err := pconfig.NewClientFromConfig(httpClientConfig, "http_probe", true)
337	if err != nil {
338		level.Error(logger).Log("msg", "Error generating HTTP client", "err", err)
339		return false
340	}
341
342	httpClientConfig.TLSConfig.ServerName = ""
343	noServerName, err := pconfig.NewRoundTripperFromConfig(httpClientConfig, "http_probe", true)
344	if err != nil {
345		level.Error(logger).Log("msg", "Error generating HTTP client without ServerName", "err", err)
346		return false
347	}
348
349	jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
350	if err != nil {
351		level.Error(logger).Log("msg", "Error generating cookiejar", "err", err)
352		return false
353	}
354	client.Jar = jar
355
356	// Inject transport that tracks traces for each redirect,
357	// and does not set TLS ServerNames on redirect if needed.
358	tt := newTransport(client.Transport, noServerName, logger)
359	client.Transport = tt
360
361	client.CheckRedirect = func(r *http.Request, via []*http.Request) error {
362		level.Info(logger).Log("msg", "Received redirect", "location", r.Response.Header.Get("Location"))
363		redirects = len(via)
364		if redirects > 10 || httpConfig.NoFollowRedirects {
365			level.Info(logger).Log("msg", "Not following redirect")
366			return errors.New("don't follow redirects")
367		}
368		return nil
369	}
370
371	if httpConfig.Method == "" {
372		httpConfig.Method = "GET"
373	}
374
375	// Replace the host field in the URL with the IP we resolved.
376	origHost := targetURL.Host
377	if targetPort == "" {
378		if strings.Contains(ip.String(), ":") {
379			targetURL.Host = "[" + ip.String() + "]"
380		} else {
381			targetURL.Host = ip.String()
382		}
383	} else {
384		targetURL.Host = net.JoinHostPort(ip.String(), targetPort)
385	}
386
387	var body io.Reader
388	var respBodyBytes int64
389
390	// If a body is configured, add it to the request.
391	if httpConfig.Body != "" {
392		body = strings.NewReader(httpConfig.Body)
393	}
394
395	request, err := http.NewRequest(httpConfig.Method, targetURL.String(), body)
396	request.Host = origHost
397	request = request.WithContext(ctx)
398	if err != nil {
399		level.Error(logger).Log("msg", "Error creating request", "err", err)
400		return
401	}
402
403	for key, value := range httpConfig.Headers {
404		if strings.Title(key) == "Host" {
405			request.Host = value
406			continue
407		}
408		request.Header.Set(key, value)
409	}
410
411	trace := &httptrace.ClientTrace{
412		DNSStart:             tt.DNSStart,
413		DNSDone:              tt.DNSDone,
414		ConnectStart:         tt.ConnectStart,
415		ConnectDone:          tt.ConnectDone,
416		GotConn:              tt.GotConn,
417		GotFirstResponseByte: tt.GotFirstResponseByte,
418	}
419	request = request.WithContext(httptrace.WithClientTrace(request.Context(), trace))
420
421	resp, err := client.Do(request)
422	// Err won't be nil if redirects were turned off. See https://github.com/golang/go/issues/3795
423	if err != nil && resp == nil {
424		level.Error(logger).Log("msg", "Error for HTTP request", "err", err)
425	} else {
426		requestErrored := (err != nil)
427
428		level.Info(logger).Log("msg", "Received HTTP response", "status_code", resp.StatusCode)
429		if len(httpConfig.ValidStatusCodes) != 0 {
430			for _, code := range httpConfig.ValidStatusCodes {
431				if resp.StatusCode == code {
432					success = true
433					break
434				}
435			}
436			if !success {
437				level.Info(logger).Log("msg", "Invalid HTTP response status code", "status_code", resp.StatusCode,
438					"valid_status_codes", fmt.Sprintf("%v", httpConfig.ValidStatusCodes))
439			}
440		} else if 200 <= resp.StatusCode && resp.StatusCode < 300 {
441			success = true
442		} else {
443			level.Info(logger).Log("msg", "Invalid HTTP response status code, wanted 2xx", "status_code", resp.StatusCode)
444		}
445
446		if success && (len(httpConfig.FailIfHeaderMatchesRegexp) > 0 || len(httpConfig.FailIfHeaderNotMatchesRegexp) > 0) {
447			success = matchRegularExpressionsOnHeaders(resp.Header, httpConfig, logger)
448			if success {
449				probeFailedDueToRegex.Set(0)
450			} else {
451				probeFailedDueToRegex.Set(1)
452			}
453		}
454
455		if success && (len(httpConfig.FailIfBodyMatchesRegexp) > 0 || len(httpConfig.FailIfBodyNotMatchesRegexp) > 0) {
456			success = matchRegularExpressions(resp.Body, httpConfig, logger)
457			if success {
458				probeFailedDueToRegex.Set(0)
459			} else {
460				probeFailedDueToRegex.Set(1)
461			}
462		}
463
464		if resp != nil && !requestErrored {
465			respBodyBytes, err = io.Copy(ioutil.Discard, resp.Body)
466			if err != nil {
467				level.Info(logger).Log("msg", "Failed to read HTTP response body", "err", err)
468				success = false
469			}
470
471			resp.Body.Close()
472		}
473
474		// At this point body is fully read and we can write end time.
475		tt.current.end = time.Now()
476
477		// Check if there is a Last-Modified HTTP response header.
478		if t, err := http.ParseTime(resp.Header.Get("Last-Modified")); err == nil {
479			registry.MustRegister(probeHTTPLastModified)
480			probeHTTPLastModified.Set(float64(t.Unix()))
481		}
482
483		var httpVersionNumber float64
484		httpVersionNumber, err = strconv.ParseFloat(strings.TrimPrefix(resp.Proto, "HTTP/"), 64)
485		if err != nil {
486			level.Error(logger).Log("msg", "Error parsing version number from HTTP version", "err", err)
487		}
488		probeHTTPVersionGauge.Set(httpVersionNumber)
489
490		if len(httpConfig.ValidHTTPVersions) != 0 {
491			found := false
492			for _, version := range httpConfig.ValidHTTPVersions {
493				if version == resp.Proto {
494					found = true
495					break
496				}
497			}
498			if !found {
499				level.Error(logger).Log("msg", "Invalid HTTP version number", "version", httpVersionNumber)
500				success = false
501			}
502		}
503
504	}
505
506	if resp == nil {
507		resp = &http.Response{}
508	}
509	tt.mu.Lock()
510	defer tt.mu.Unlock()
511	for i, trace := range tt.traces {
512		level.Info(logger).Log(
513			"msg", "Response timings for roundtrip",
514			"roundtrip", i,
515			"start", trace.start,
516			"dnsDone", trace.dnsDone,
517			"connectDone", trace.connectDone,
518			"gotConn", trace.gotConn,
519			"responseStart", trace.responseStart,
520			"end", trace.end,
521		)
522		// We get the duration for the first request from chooseProtocol.
523		if i != 0 {
524			durationGaugeVec.WithLabelValues("resolve").Add(trace.dnsDone.Sub(trace.start).Seconds())
525		}
526		// Continue here if we never got a connection because a request failed.
527		if trace.gotConn.IsZero() {
528			continue
529		}
530		if trace.tls {
531			// dnsDone must be set if gotConn was set.
532			durationGaugeVec.WithLabelValues("connect").Add(trace.connectDone.Sub(trace.dnsDone).Seconds())
533			durationGaugeVec.WithLabelValues("tls").Add(trace.gotConn.Sub(trace.dnsDone).Seconds())
534		} else {
535			durationGaugeVec.WithLabelValues("connect").Add(trace.gotConn.Sub(trace.dnsDone).Seconds())
536		}
537
538		// Continue here if we never got a response from the server.
539		if trace.responseStart.IsZero() {
540			continue
541		}
542		durationGaugeVec.WithLabelValues("processing").Add(trace.responseStart.Sub(trace.gotConn).Seconds())
543
544		// Continue here if we never read the full response from the server.
545		// Usually this means that request either failed or was redirected.
546		if trace.end.IsZero() {
547			continue
548		}
549		durationGaugeVec.WithLabelValues("transfer").Add(trace.end.Sub(trace.responseStart).Seconds())
550	}
551
552	if resp.TLS != nil {
553		isSSLGauge.Set(float64(1))
554		registry.MustRegister(probeSSLEarliestCertExpiryGauge, probeTLSVersion, probeSSLLastChainExpiryTimestampSeconds)
555		probeSSLEarliestCertExpiryGauge.Set(float64(getEarliestCertExpiry(resp.TLS).Unix()))
556		probeTLSVersion.WithLabelValues(getTLSVersion(resp.TLS)).Set(1)
557		probeSSLLastChainExpiryTimestampSeconds.Set(float64(getLastChainExpiry(resp.TLS).Unix()))
558		if httpConfig.FailIfSSL {
559			level.Error(logger).Log("msg", "Final request was over SSL")
560			success = false
561		}
562	} else if httpConfig.FailIfNotSSL {
563		level.Error(logger).Log("msg", "Final request was not over SSL")
564		success = false
565	}
566
567	statusCodeGauge.Set(float64(resp.StatusCode))
568	contentLengthGauge.Set(float64(resp.ContentLength))
569	bodyUncompressedLengthGauge.Set(float64(respBodyBytes))
570	redirectsGauge.Set(float64(redirects))
571	return
572}
573