1package dns01 2 3import ( 4 "errors" 5 "fmt" 6 "net" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/miekg/dns" 12) 13 14const defaultResolvConf = "/etc/resolv.conf" 15 16// dnsTimeout is used to override the default DNS timeout of 10 seconds. 17var dnsTimeout = 10 * time.Second 18 19var ( 20 fqdnSoaCache = map[string]*soaCacheEntry{} 21 muFqdnSoaCache sync.Mutex 22) 23 24var defaultNameservers = []string{ 25 "google-public-dns-a.google.com:53", 26 "google-public-dns-b.google.com:53", 27} 28 29// recursiveNameservers are used to pre-check DNS propagation. 30var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) 31 32// soaCacheEntry holds a cached SOA record (only selected fields). 33type soaCacheEntry struct { 34 zone string // zone apex (a domain name) 35 primaryNs string // primary nameserver for the zone apex 36 expires time.Time // time when this cache entry should be evicted 37} 38 39func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry { 40 return &soaCacheEntry{ 41 zone: soa.Hdr.Name, 42 primaryNs: soa.Ns, 43 expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second), 44 } 45} 46 47// isExpired checks whether a cache entry should be considered expired. 48func (cache *soaCacheEntry) isExpired() bool { 49 return time.Now().After(cache.expires) 50} 51 52// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing. 53func ClearFqdnCache() { 54 muFqdnSoaCache.Lock() 55 fqdnSoaCache = map[string]*soaCacheEntry{} 56 muFqdnSoaCache.Unlock() 57} 58 59func AddDNSTimeout(timeout time.Duration) ChallengeOption { 60 return func(_ *Challenge) error { 61 dnsTimeout = timeout 62 return nil 63 } 64} 65 66func AddRecursiveNameservers(nameservers []string) ChallengeOption { 67 return func(_ *Challenge) error { 68 recursiveNameservers = ParseNameservers(nameservers) 69 return nil 70 } 71} 72 73// getNameservers attempts to get systems nameservers before falling back to the defaults. 74func getNameservers(path string, defaults []string) []string { 75 config, err := dns.ClientConfigFromFile(path) 76 if err != nil || len(config.Servers) == 0 { 77 return defaults 78 } 79 80 return ParseNameservers(config.Servers) 81} 82 83func ParseNameservers(servers []string) []string { 84 var resolvers []string 85 for _, resolver := range servers { 86 // ensure all servers have a port number 87 if _, _, err := net.SplitHostPort(resolver); err != nil { 88 resolvers = append(resolvers, net.JoinHostPort(resolver, "53")) 89 } else { 90 resolvers = append(resolvers, resolver) 91 } 92 } 93 return resolvers 94} 95 96// lookupNameservers returns the authoritative nameservers for the given fqdn. 97func lookupNameservers(fqdn string) ([]string, error) { 98 var authoritativeNss []string 99 100 zone, err := FindZoneByFqdn(fqdn) 101 if err != nil { 102 return nil, fmt.Errorf("could not determine the zone: %w", err) 103 } 104 105 r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true) 106 if err != nil { 107 return nil, err 108 } 109 110 for _, rr := range r.Answer { 111 if ns, ok := rr.(*dns.NS); ok { 112 authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns)) 113 } 114 } 115 116 if len(authoritativeNss) > 0 { 117 return authoritativeNss, nil 118 } 119 return nil, errors.New("could not determine authoritative nameservers") 120} 121 122// FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn 123// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. 124func FindPrimaryNsByFqdn(fqdn string) (string, error) { 125 return FindPrimaryNsByFqdnCustom(fqdn, recursiveNameservers) 126} 127 128// FindPrimaryNsByFqdnCustom determines the primary nameserver of the zone apex for the given fqdn 129// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. 130func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) { 131 soa, err := lookupSoaByFqdn(fqdn, nameservers) 132 if err != nil { 133 return "", err 134 } 135 return soa.primaryNs, nil 136} 137 138// FindZoneByFqdn determines the zone apex for the given fqdn 139// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. 140func FindZoneByFqdn(fqdn string) (string, error) { 141 return FindZoneByFqdnCustom(fqdn, recursiveNameservers) 142} 143 144// FindZoneByFqdnCustom determines the zone apex for the given fqdn 145// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. 146func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) { 147 soa, err := lookupSoaByFqdn(fqdn, nameservers) 148 if err != nil { 149 return "", err 150 } 151 return soa.zone, nil 152} 153 154func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { 155 muFqdnSoaCache.Lock() 156 defer muFqdnSoaCache.Unlock() 157 158 // Do we have it cached and is it still fresh? 159 if ent := fqdnSoaCache[fqdn]; ent != nil && !ent.isExpired() { 160 return ent, nil 161 } 162 163 ent, err := fetchSoaByFqdn(fqdn, nameservers) 164 if err != nil { 165 return nil, err 166 } 167 168 fqdnSoaCache[fqdn] = ent 169 return ent, nil 170} 171 172func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { 173 var err error 174 var in *dns.Msg 175 176 labelIndexes := dns.Split(fqdn) 177 for _, index := range labelIndexes { 178 domain := fqdn[index:] 179 180 in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true) 181 if err != nil { 182 continue 183 } 184 185 if in == nil { 186 continue 187 } 188 189 switch in.Rcode { 190 case dns.RcodeSuccess: 191 // Check if we got a SOA RR in the answer section 192 if len(in.Answer) == 0 { 193 continue 194 } 195 196 // CNAME records cannot/should not exist at the root of a zone. 197 // So we skip a domain when a CNAME is found. 198 if dnsMsgContainsCNAME(in) { 199 continue 200 } 201 202 for _, ans := range in.Answer { 203 if soa, ok := ans.(*dns.SOA); ok { 204 return newSoaCacheEntry(soa), nil 205 } 206 } 207 case dns.RcodeNameError: 208 // NXDOMAIN 209 default: 210 // Any response code other than NOERROR and NXDOMAIN is treated as error 211 return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain) 212 } 213 } 214 215 return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err)) 216} 217 218// dnsMsgContainsCNAME checks for a CNAME answer in msg. 219func dnsMsgContainsCNAME(msg *dns.Msg) bool { 220 for _, ans := range msg.Answer { 221 if _, ok := ans.(*dns.CNAME); ok { 222 return true 223 } 224 } 225 return false 226} 227 228func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) { 229 m := createDNSMsg(fqdn, rtype, recursive) 230 231 var in *dns.Msg 232 var err error 233 234 for _, ns := range nameservers { 235 in, err = sendDNSQuery(m, ns) 236 if err == nil && len(in.Answer) > 0 { 237 break 238 } 239 } 240 return in, err 241} 242 243func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { 244 m := new(dns.Msg) 245 m.SetQuestion(fqdn, rtype) 246 m.SetEdns0(4096, false) 247 248 if !recursive { 249 m.RecursionDesired = false 250 } 251 252 return m 253} 254 255func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { 256 udp := &dns.Client{Net: "udp", Timeout: dnsTimeout} 257 in, _, err := udp.Exchange(m, ns) 258 259 if in != nil && in.Truncated { 260 tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} 261 // If the TCP request succeeds, the err will reset to nil 262 in, _, err = tcp.Exchange(m, ns) 263 } 264 265 return in, err 266} 267 268func formatDNSError(msg *dns.Msg, err error) string { 269 var parts []string 270 271 if msg != nil { 272 parts = append(parts, dns.RcodeToString[msg.Rcode]) 273 } 274 275 if err != nil { 276 parts = append(parts, err.Error()) 277 } 278 279 if len(parts) > 0 { 280 return ": " + strings.Join(parts, " ") 281 } 282 283 return "" 284} 285