1package main
2
3import (
4	"flag"
5	"fmt"
6	"log"
7	"net/http"
8	"os"
9	"strings"
10	"sync"
11	"time"
12
13	"github.com/miekg/dns"
14	"github.com/naoina/toml"
15	"github.com/prometheus/client_golang/prometheus"
16	"github.com/prometheus/client_golang/prometheus/promhttp"
17)
18
19var addr = flag.String("listen-address", ":9204", "Prometheus metrics port")
20var conf = flag.String("config", "/usr/local/etc/prometheus-dnssec-checks", "Configuration file")
21var resolvers = flag.String("resolvers", "8.8.8.8:53,1.1.1.1:53", "Resolvers to use (comma separated)")
22var timeout = flag.Duration("timeout", 10*time.Second, "Timeout for network operations")
23
24type Records struct {
25	Zone   string
26	Record string
27	Type   string
28}
29
30type Logger interface {
31	Print(v ...interface{})
32	Printf(format string, v ...interface{})
33}
34
35type Exporter struct {
36	Records []Records
37
38	records  *prometheus.GaugeVec
39	resolves *prometheus.GaugeVec
40
41	resolvers []string
42	dnsClient *dns.Client
43
44	logger Logger
45}
46
47func NewDNSSECExporter(timeout time.Duration, resolvers []string, logger Logger) *Exporter {
48	return &Exporter{
49		records: prometheus.NewGaugeVec(
50			prometheus.GaugeOpts{
51				Namespace: "dnssec",
52				Subsystem: "zone",
53				Name:      "record_days_left",
54				Help:      "Number of days the signature will be valid",
55			},
56			[]string{
57				"zone",
58				"record",
59				"type",
60			},
61		),
62		resolves: prometheus.NewGaugeVec(
63			prometheus.GaugeOpts{
64				Namespace: "dnssec",
65				Subsystem: "zone",
66				Name:      "record_resolves",
67				Help:      "Does the record resolve using the specified DNSSEC enabled resolvers",
68			},
69			[]string{
70				"resolver",
71				"zone",
72				"record",
73				"type",
74			},
75		),
76		dnsClient: &dns.Client{
77			Net:     "tcp",
78			Timeout: timeout,
79		},
80		resolvers: resolvers,
81		logger:    logger,
82	}
83}
84
85func (e *Exporter) Describe(ch chan<- *prometheus.Desc) {
86	e.records.Describe(ch)
87	e.resolves.Describe(ch)
88}
89
90func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
91
92	var wg sync.WaitGroup
93
94	wg.Add(len(e.Records) * (len(e.resolvers) + 1))
95
96	for _, rec := range e.Records {
97
98		rec := rec
99
100		// Check the expiration
101
102		go func() {
103
104			exp := e.expiration(rec.Zone, rec.Record, rec.Type)
105
106			e.records.WithLabelValues(
107				rec.Zone, rec.Record, rec.Type,
108			).Set(float64(time.Until(exp)/time.Hour) / 24)
109
110			wg.Done()
111
112		}()
113
114		// Check the configured resolvers
115
116		for _, resolver := range e.resolvers {
117
118			resolver := resolver
119
120			go func() {
121
122				resolves := e.resolve(rec.Zone, rec.Record, rec.Type, resolver)
123
124				e.resolves.WithLabelValues(
125					resolver, rec.Zone, rec.Record, rec.Type,
126				).Set(map[bool]float64{true: 1}[resolves])
127
128				wg.Done()
129
130			}()
131
132		}
133
134	}
135
136	wg.Wait()
137
138	e.records.Collect(ch)
139	e.resolves.Collect(ch)
140
141}
142
143func (e *Exporter) expiration(zone, record, recordType string) (exp time.Time) {
144
145	msg := &dns.Msg{}
146	msg.SetQuestion(hostname(zone, record), dns.TypeRRSIG)
147
148	response, _, err := e.dnsClient.Exchange(msg, e.resolvers[0])
149	if err != nil {
150		e.logger.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err)
151		return
152	}
153
154	var sig *dns.RRSIG
155
156	for _, rr := range response.Answer {
157
158		if rrsig, ok := rr.(*dns.RRSIG); ok &&
159			rrsig.TypeCovered == dns.StringToType[recordType] {
160
161			sig = rrsig
162			break
163
164		}
165	}
166
167	if sig == nil {
168		e.logger.Printf("didn't find RRSIG for %v covering type %v matching a tag of a DNSKEY", hostname(zone, record), recordType)
169		return
170	}
171
172	exp = time.Unix(int64(sig.Expiration), 0)
173	if exp.IsZero() {
174		e.logger.Printf("zero exp for RRSIG for %v covering type %v", hostname(zone, record), recordType)
175		return
176	}
177
178	return
179
180}
181
182func (e *Exporter) resolve(zone, record, recordType, resolver string) (resolves bool) {
183
184	msg := &dns.Msg{}
185	msg.SetQuestion(hostname(zone, record), dns.StringToType[recordType])
186	msg.SetEdns0(4096, true)
187
188	response, _, err := e.dnsClient.Exchange(msg, resolver)
189	if err != nil {
190		e.logger.Printf("while resolving for %v: %v", hostname(zone, record), err)
191		return
192	}
193
194	return response.AuthenticatedData &&
195		!response.CheckingDisabled &&
196		response.Rcode == dns.RcodeSuccess
197
198}
199
200func hostname(zone, record string) string {
201
202	if record == "@" {
203		return fmt.Sprintf("%s.", zone)
204	}
205
206	return fmt.Sprintf("%s.%s.", record, zone)
207
208}
209
210func main() {
211
212	flag.Parse()
213
214	f, err := os.Open(*conf)
215	if err != nil {
216		log.Fatalf("couldn't open configuration file: %v", err)
217	}
218
219	logger := log.New(os.Stderr, "", log.LstdFlags)
220
221	r := strings.Split(*resolvers, ",")
222
223	exporter := NewDNSSECExporter(*timeout, r, logger)
224
225	if err := toml.NewDecoder(f).Decode(exporter); err != nil {
226		log.Fatalf("couldn't parse configuration file: %v", err)
227	}
228
229	prometheus.MustRegister(exporter)
230
231	http.Handle("/metrics", promhttp.Handler())
232
233	log.Fatal(http.ListenAndServe(*addr, nil))
234
235}
236