1// Copyright 2016 The Prometheus Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14package prober
15
16import (
17	"context"
18	"net"
19	"regexp"
20	"time"
21
22	"github.com/go-kit/kit/log"
23	"github.com/go-kit/kit/log/level"
24	"github.com/miekg/dns"
25	"github.com/prometheus/client_golang/prometheus"
26	pconfig "github.com/prometheus/common/config"
27
28	"github.com/prometheus/blackbox_exporter/config"
29)
30
31// validRRs checks a slice of RRs received from the server against a DNSRRValidator.
32func validRRs(rrs *[]dns.RR, v *config.DNSRRValidator, logger log.Logger) bool {
33	var anyMatch bool = false
34	var allMatch bool = true
35	// Fail the probe if there are no RRs of a given type, but a regexp match is required
36	// (i.e. FailIfNotMatchesRegexp or FailIfNoneMatchesRegexp is set).
37	if len(*rrs) == 0 && len(v.FailIfNotMatchesRegexp) > 0 {
38		level.Error(logger).Log("msg", "fail_if_not_matches_regexp specified but no RRs returned")
39		return false
40	}
41	if len(*rrs) == 0 && len(v.FailIfNoneMatchesRegexp) > 0 {
42		level.Error(logger).Log("msg", "fail_if_none_matches_regexp specified but no RRs returned")
43		return false
44	}
45	for _, rr := range *rrs {
46		level.Info(logger).Log("msg", "Validating RR", "rr", rr)
47		for _, re := range v.FailIfMatchesRegexp {
48			match, err := regexp.MatchString(re, rr.String())
49			if err != nil {
50				level.Error(logger).Log("msg", "Error matching regexp", "regexp", re, "err", err)
51				return false
52			}
53			if match {
54				level.Error(logger).Log("msg", "At least one RR matched regexp", "regexp", re, "rr", rr)
55				return false
56			}
57		}
58		for _, re := range v.FailIfAllMatchRegexp {
59			match, err := regexp.MatchString(re, rr.String())
60			if err != nil {
61				level.Error(logger).Log("msg", "Error matching regexp", "regexp", re, "err", err)
62				return false
63			}
64			if !match {
65				allMatch = false
66			}
67		}
68		for _, re := range v.FailIfNotMatchesRegexp {
69			match, err := regexp.MatchString(re, rr.String())
70			if err != nil {
71				level.Error(logger).Log("msg", "Error matching regexp", "regexp", re, "err", err)
72				return false
73			}
74			if !match {
75				level.Error(logger).Log("msg", "At least one RR did not match regexp", "regexp", re, "rr", rr)
76				return false
77			}
78		}
79		for _, re := range v.FailIfNoneMatchesRegexp {
80			match, err := regexp.MatchString(re, rr.String())
81			if err != nil {
82				level.Error(logger).Log("msg", "Error matching regexp", "regexp", re, "err", err)
83				return false
84			}
85			if match {
86				anyMatch = true
87			}
88		}
89	}
90	if len(v.FailIfAllMatchRegexp) > 0 && !allMatch {
91		level.Error(logger).Log("msg", "Not all RRs matched regexp")
92		return false
93	}
94	if len(v.FailIfNoneMatchesRegexp) > 0 && !anyMatch {
95		level.Error(logger).Log("msg", "None of the RRs did matched any regexp")
96		return false
97	}
98	return true
99}
100
101// validRcode checks rcode in the response against a list of valid rcodes.
102func validRcode(rcode int, valid []string, logger log.Logger) bool {
103	var validRcodes []int
104	// If no list of valid rcodes is specified, only NOERROR is considered valid.
105	if valid == nil {
106		validRcodes = append(validRcodes, dns.StringToRcode["NOERROR"])
107	} else {
108		for _, rcode := range valid {
109			rc, ok := dns.StringToRcode[rcode]
110			if !ok {
111				level.Error(logger).Log("msg", "Invalid rcode", "rcode", rcode, "known_rcode", dns.RcodeToString)
112				return false
113			}
114			validRcodes = append(validRcodes, rc)
115		}
116	}
117	for _, rc := range validRcodes {
118		if rcode == rc {
119			level.Info(logger).Log("msg", "Rcode is valid", "rcode", rcode, "string_rcode", dns.RcodeToString[rcode])
120			return true
121		}
122	}
123	level.Error(logger).Log("msg", "Rcode is not one of the valid rcodes", "rcode", rcode, "string_rcode", dns.RcodeToString[rcode], "valid_rcodes", validRcodes)
124	return false
125}
126
127func ProbeDNS(ctx context.Context, target string, module config.Module, registry *prometheus.Registry, logger log.Logger) bool {
128	var dialProtocol string
129	probeDNSAnswerRRSGauge := prometheus.NewGauge(prometheus.GaugeOpts{
130		Name: "probe_dns_answer_rrs",
131		Help: "Returns number of entries in the answer resource record list",
132	})
133	probeDNSAuthorityRRSGauge := prometheus.NewGauge(prometheus.GaugeOpts{
134		Name: "probe_dns_authority_rrs",
135		Help: "Returns number of entries in the authority resource record list",
136	})
137	probeDNSAdditionalRRSGauge := prometheus.NewGauge(prometheus.GaugeOpts{
138		Name: "probe_dns_additional_rrs",
139		Help: "Returns number of entries in the additional resource record list",
140	})
141	registry.MustRegister(probeDNSAnswerRRSGauge)
142	registry.MustRegister(probeDNSAuthorityRRSGauge)
143	registry.MustRegister(probeDNSAdditionalRRSGauge)
144
145	qc := uint16(dns.ClassINET)
146	if module.DNS.QueryClass != "" {
147		var ok bool
148		qc, ok = dns.StringToClass[module.DNS.QueryClass]
149		if !ok {
150			level.Error(logger).Log("msg", "Invalid query class", "Class seen", module.DNS.QueryClass, "Existing classes", dns.ClassToString)
151			return false
152		}
153	}
154
155	qt := dns.TypeANY
156	if module.DNS.QueryType != "" {
157		var ok bool
158		qt, ok = dns.StringToType[module.DNS.QueryType]
159		if !ok {
160			level.Error(logger).Log("msg", "Invalid query type", "Type seen", module.DNS.QueryType, "Existing types", dns.TypeToString)
161			return false
162		}
163	}
164	var probeDNSSOAGauge prometheus.Gauge
165
166	var ip *net.IPAddr
167	if module.DNS.TransportProtocol == "" {
168		module.DNS.TransportProtocol = "udp"
169	}
170	if !(module.DNS.TransportProtocol == "udp" || module.DNS.TransportProtocol == "tcp") {
171		level.Error(logger).Log("msg", "Configuration error: Expected transport protocol udp or tcp", "protocol", module.DNS.TransportProtocol)
172		return false
173	}
174
175	targetAddr, port, err := net.SplitHostPort(target)
176	if err != nil {
177		// Target only contains host so fallback to default port and set targetAddr as target.
178		port = "53"
179		targetAddr = target
180	}
181	ip, _, err = chooseProtocol(ctx, module.DNS.IPProtocol, module.DNS.IPProtocolFallback, targetAddr, registry, logger)
182	if err != nil {
183		level.Error(logger).Log("msg", "Error resolving address", "err", err)
184		return false
185	}
186	targetIP := net.JoinHostPort(ip.String(), port)
187
188	if ip.IP.To4() == nil {
189		dialProtocol = module.DNS.TransportProtocol + "6"
190	} else {
191		dialProtocol = module.DNS.TransportProtocol + "4"
192	}
193
194	if module.DNS.DNSOverTLS {
195		if module.DNS.TransportProtocol == "tcp" {
196			dialProtocol += "-tls"
197		} else {
198			level.Error(logger).Log("msg", "Configuration error: Expected transport protocol tcp for DoT", "protocol", module.DNS.TransportProtocol)
199			return false
200		}
201	}
202
203	client := new(dns.Client)
204	client.Net = dialProtocol
205
206	if module.DNS.DNSOverTLS {
207		tlsConfig, err := pconfig.NewTLSConfig(&module.DNS.TLSConfig)
208		if err != nil {
209			level.Error(logger).Log("msg", "Failed to create TLS configuration", "err", err)
210			return false
211		}
212		if tlsConfig.ServerName == "" {
213			// Use target-hostname as default for TLS-servername.
214			tlsConfig.ServerName = targetAddr
215		}
216
217		client.TLSConfig = tlsConfig
218	}
219
220	// Use configured SourceIPAddress.
221	if len(module.DNS.SourceIPAddress) > 0 {
222		srcIP := net.ParseIP(module.DNS.SourceIPAddress)
223		if srcIP == nil {
224			level.Error(logger).Log("msg", "Error parsing source ip address", "srcIP", module.DNS.SourceIPAddress)
225			return false
226		}
227		level.Info(logger).Log("msg", "Using local address", "srcIP", srcIP)
228		client.Dialer = &net.Dialer{}
229		if module.DNS.TransportProtocol == "tcp" {
230			client.Dialer.LocalAddr = &net.TCPAddr{IP: srcIP}
231		} else {
232			client.Dialer.LocalAddr = &net.UDPAddr{IP: srcIP}
233		}
234	}
235
236	msg := new(dns.Msg)
237	msg.Id = dns.Id()
238	msg.RecursionDesired = true
239	msg.Question = make([]dns.Question, 1)
240	msg.Question[0] = dns.Question{dns.Fqdn(module.DNS.QueryName), qt, qc}
241
242	level.Info(logger).Log("msg", "Making DNS query", "target", targetIP, "dial_protocol", dialProtocol, "query", module.DNS.QueryName, "type", qt, "class", qc)
243	timeoutDeadline, _ := ctx.Deadline()
244	client.Timeout = time.Until(timeoutDeadline)
245	response, _, err := client.Exchange(msg, targetIP)
246	if err != nil {
247		level.Error(logger).Log("msg", "Error while sending a DNS query", "err", err)
248		return false
249	}
250	level.Info(logger).Log("msg", "Got response", "response", response)
251
252	probeDNSAnswerRRSGauge.Set(float64(len(response.Answer)))
253	probeDNSAuthorityRRSGauge.Set(float64(len(response.Ns)))
254	probeDNSAdditionalRRSGauge.Set(float64(len(response.Extra)))
255
256	if qt == dns.TypeSOA {
257		probeDNSSOAGauge = prometheus.NewGauge(prometheus.GaugeOpts{
258			Name: "probe_dns_serial",
259			Help: "Returns the serial number of the zone",
260		})
261		registry.MustRegister(probeDNSSOAGauge)
262
263		for _, a := range response.Answer {
264			if soa, ok := a.(*dns.SOA); ok {
265				probeDNSSOAGauge.Set(float64(soa.Serial))
266			}
267		}
268	}
269
270	if !validRcode(response.Rcode, module.DNS.ValidRcodes, logger) {
271		return false
272	}
273	level.Info(logger).Log("msg", "Validating Answer RRs")
274	if !validRRs(&response.Answer, &module.DNS.ValidateAnswer, logger) {
275		level.Error(logger).Log("msg", "Answer RRs validation failed")
276		return false
277	}
278	level.Info(logger).Log("msg", "Validating Authority RRs")
279	if !validRRs(&response.Ns, &module.DNS.ValidateAuthority, logger) {
280		level.Error(logger).Log("msg", "Authority RRs validation failed")
281		return false
282	}
283	level.Info(logger).Log("msg", "Validating Additional RRs")
284	if !validRRs(&response.Extra, &module.DNS.ValidateAdditional, logger) {
285		level.Error(logger).Log("msg", "Additional RRs validation failed")
286		return false
287	}
288	return true
289}
290