1package dns01
2
3import (
4	"errors"
5	"fmt"
6	"net"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/miekg/dns"
12)
13
14const defaultResolvConf = "/etc/resolv.conf"
15
16// dnsTimeout is used to override the default DNS timeout of 10 seconds.
17var dnsTimeout = 10 * time.Second
18
19var (
20	fqdnSoaCache   = map[string]*soaCacheEntry{}
21	muFqdnSoaCache sync.Mutex
22)
23
24var defaultNameservers = []string{
25	"google-public-dns-a.google.com:53",
26	"google-public-dns-b.google.com:53",
27}
28
29// recursiveNameservers are used to pre-check DNS propagation.
30var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
31
32// soaCacheEntry holds a cached SOA record (only selected fields).
33type soaCacheEntry struct {
34	zone      string    // zone apex (a domain name)
35	primaryNs string    // primary nameserver for the zone apex
36	expires   time.Time // time when this cache entry should be evicted
37}
38
39func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry {
40	return &soaCacheEntry{
41		zone:      soa.Hdr.Name,
42		primaryNs: soa.Ns,
43		expires:   time.Now().Add(time.Duration(soa.Refresh) * time.Second),
44	}
45}
46
47// isExpired checks whether a cache entry should be considered expired.
48func (cache *soaCacheEntry) isExpired() bool {
49	return time.Now().After(cache.expires)
50}
51
52// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
53func ClearFqdnCache() {
54	muFqdnSoaCache.Lock()
55	fqdnSoaCache = map[string]*soaCacheEntry{}
56	muFqdnSoaCache.Unlock()
57}
58
59func AddDNSTimeout(timeout time.Duration) ChallengeOption {
60	return func(_ *Challenge) error {
61		dnsTimeout = timeout
62		return nil
63	}
64}
65
66func AddRecursiveNameservers(nameservers []string) ChallengeOption {
67	return func(_ *Challenge) error {
68		recursiveNameservers = ParseNameservers(nameservers)
69		return nil
70	}
71}
72
73// getNameservers attempts to get systems nameservers before falling back to the defaults.
74func getNameservers(path string, defaults []string) []string {
75	config, err := dns.ClientConfigFromFile(path)
76	if err != nil || len(config.Servers) == 0 {
77		return defaults
78	}
79
80	return ParseNameservers(config.Servers)
81}
82
83func ParseNameservers(servers []string) []string {
84	var resolvers []string
85	for _, resolver := range servers {
86		// ensure all servers have a port number
87		if _, _, err := net.SplitHostPort(resolver); err != nil {
88			resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
89		} else {
90			resolvers = append(resolvers, resolver)
91		}
92	}
93	return resolvers
94}
95
96// lookupNameservers returns the authoritative nameservers for the given fqdn.
97func lookupNameservers(fqdn string) ([]string, error) {
98	var authoritativeNss []string
99
100	zone, err := FindZoneByFqdn(fqdn)
101	if err != nil {
102		return nil, fmt.Errorf("could not determine the zone: %w", err)
103	}
104
105	r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
106	if err != nil {
107		return nil, err
108	}
109
110	for _, rr := range r.Answer {
111		if ns, ok := rr.(*dns.NS); ok {
112			authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
113		}
114	}
115
116	if len(authoritativeNss) > 0 {
117		return authoritativeNss, nil
118	}
119	return nil, errors.New("could not determine authoritative nameservers")
120}
121
122// FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn
123// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
124func FindPrimaryNsByFqdn(fqdn string) (string, error) {
125	return FindPrimaryNsByFqdnCustom(fqdn, recursiveNameservers)
126}
127
128// FindPrimaryNsByFqdnCustom determines the primary nameserver of the zone apex for the given fqdn
129// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
130func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) {
131	soa, err := lookupSoaByFqdn(fqdn, nameservers)
132	if err != nil {
133		return "", err
134	}
135	return soa.primaryNs, nil
136}
137
138// FindZoneByFqdn determines the zone apex for the given fqdn
139// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
140func FindZoneByFqdn(fqdn string) (string, error) {
141	return FindZoneByFqdnCustom(fqdn, recursiveNameservers)
142}
143
144// FindZoneByFqdnCustom determines the zone apex for the given fqdn
145// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
146func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
147	soa, err := lookupSoaByFqdn(fqdn, nameservers)
148	if err != nil {
149		return "", err
150	}
151	return soa.zone, nil
152}
153
154func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
155	muFqdnSoaCache.Lock()
156	defer muFqdnSoaCache.Unlock()
157
158	// Do we have it cached and is it still fresh?
159	if ent := fqdnSoaCache[fqdn]; ent != nil && !ent.isExpired() {
160		return ent, nil
161	}
162
163	ent, err := fetchSoaByFqdn(fqdn, nameservers)
164	if err != nil {
165		return nil, err
166	}
167
168	fqdnSoaCache[fqdn] = ent
169	return ent, nil
170}
171
172func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
173	var err error
174	var in *dns.Msg
175
176	labelIndexes := dns.Split(fqdn)
177	for _, index := range labelIndexes {
178		domain := fqdn[index:]
179
180		in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
181		if err != nil {
182			continue
183		}
184
185		if in == nil {
186			continue
187		}
188
189		switch in.Rcode {
190		case dns.RcodeSuccess:
191			// Check if we got a SOA RR in the answer section
192			if len(in.Answer) == 0 {
193				continue
194			}
195
196			// CNAME records cannot/should not exist at the root of a zone.
197			// So we skip a domain when a CNAME is found.
198			if dnsMsgContainsCNAME(in) {
199				continue
200			}
201
202			for _, ans := range in.Answer {
203				if soa, ok := ans.(*dns.SOA); ok {
204					return newSoaCacheEntry(soa), nil
205				}
206			}
207		case dns.RcodeNameError:
208			// NXDOMAIN
209		default:
210			// Any response code other than NOERROR and NXDOMAIN is treated as error
211			return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain)
212		}
213	}
214
215	return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err))
216}
217
218// dnsMsgContainsCNAME checks for a CNAME answer in msg.
219func dnsMsgContainsCNAME(msg *dns.Msg) bool {
220	for _, ans := range msg.Answer {
221		if _, ok := ans.(*dns.CNAME); ok {
222			return true
223		}
224	}
225	return false
226}
227
228func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
229	m := createDNSMsg(fqdn, rtype, recursive)
230
231	var in *dns.Msg
232	var err error
233
234	for _, ns := range nameservers {
235		in, err = sendDNSQuery(m, ns)
236		if err == nil && len(in.Answer) > 0 {
237			break
238		}
239	}
240	return in, err
241}
242
243func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
244	m := new(dns.Msg)
245	m.SetQuestion(fqdn, rtype)
246	m.SetEdns0(4096, false)
247
248	if !recursive {
249		m.RecursionDesired = false
250	}
251
252	return m
253}
254
255func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
256	udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
257	in, _, err := udp.Exchange(m, ns)
258
259	if in != nil && in.Truncated {
260		tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
261		// If the TCP request succeeds, the err will reset to nil
262		in, _, err = tcp.Exchange(m, ns)
263	}
264
265	return in, err
266}
267
268func formatDNSError(msg *dns.Msg, err error) string {
269	var parts []string
270
271	if msg != nil {
272		parts = append(parts, dns.RcodeToString[msg.Rcode])
273	}
274
275	if err != nil {
276		parts = append(parts, err.Error())
277	}
278
279	if len(parts) > 0 {
280		return ": " + strings.Join(parts, " ")
281	}
282
283	return ""
284}
285