1package acme
2
3import (
4	"crypto/sha256"
5	"encoding/base64"
6	"errors"
7	"fmt"
8	"log"
9	"net"
10	"strings"
11	"time"
12
13	"github.com/miekg/dns"
14	"golang.org/x/net/publicsuffix"
15)
16
17type preCheckDNSFunc func(fqdn, value string) (bool, error)
18
19var (
20	preCheckDNS preCheckDNSFunc = checkDNSPropagation
21	fqdnToZone                  = map[string]string{}
22)
23
24var RecursiveNameservers = []string{
25	"google-public-dns-a.google.com:53",
26	"google-public-dns-b.google.com:53",
27}
28
29// DNSTimeout is used to override the default DNS timeout of 10 seconds.
30var DNSTimeout = 10 * time.Second
31
32// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
33func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
34	keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
35	// base64URL encoding without padding
36	keyAuthSha := base64.URLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
37	value = strings.TrimRight(keyAuthSha, "=")
38	ttl = 120
39	fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
40	return
41}
42
43// dnsChallenge implements the dns-01 challenge according to ACME 7.5
44type dnsChallenge struct {
45	jws      *jws
46	validate validateFunc
47	provider ChallengeProvider
48}
49
50func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
51	logf("[INFO][%s] acme: Trying to solve DNS-01", domain)
52
53	if s.provider == nil {
54		return errors.New("No DNS Provider configured")
55	}
56
57	// Generate the Key Authorization for the challenge
58	keyAuth, err := getKeyAuthorization(chlng.Token, s.jws.privKey)
59	if err != nil {
60		return err
61	}
62
63	err = s.provider.Present(domain, chlng.Token, keyAuth)
64	if err != nil {
65		return fmt.Errorf("Error presenting token: %s", err)
66	}
67	defer func() {
68		err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
69		if err != nil {
70			log.Printf("Error cleaning up %s: %v ", domain, err)
71		}
72	}()
73
74	fqdn, value, _ := DNS01Record(domain, keyAuth)
75
76	logf("[INFO][%s] Checking DNS record propagation...", domain)
77
78	var timeout, interval time.Duration
79	switch provider := s.provider.(type) {
80	case ChallengeProviderTimeout:
81		timeout, interval = provider.Timeout()
82	default:
83		timeout, interval = 60*time.Second, 2*time.Second
84	}
85
86	err = WaitFor(timeout, interval, func() (bool, error) {
87		return preCheckDNS(fqdn, value)
88	})
89	if err != nil {
90		return err
91	}
92
93	return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
94}
95
96// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
97func checkDNSPropagation(fqdn, value string) (bool, error) {
98	// Initial attempt to resolve at the recursive NS
99	r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameservers, true)
100	if err != nil {
101		return false, err
102	}
103	if r.Rcode == dns.RcodeSuccess {
104		// If we see a CNAME here then use the alias
105		for _, rr := range r.Answer {
106			if cn, ok := rr.(*dns.CNAME); ok {
107				if cn.Hdr.Name == fqdn {
108					fqdn = cn.Target
109					break
110				}
111			}
112		}
113	}
114
115	authoritativeNss, err := lookupNameservers(fqdn)
116	if err != nil {
117		return false, err
118	}
119
120	return checkAuthoritativeNss(fqdn, value, authoritativeNss)
121}
122
123// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
124func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
125	for _, ns := range nameservers {
126		r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
127		if err != nil {
128			return false, err
129		}
130
131		if r.Rcode != dns.RcodeSuccess {
132			return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
133		}
134
135		var found bool
136		for _, rr := range r.Answer {
137			if txt, ok := rr.(*dns.TXT); ok {
138				if strings.Join(txt.Txt, "") == value {
139					found = true
140					break
141				}
142			}
143		}
144
145		if !found {
146			return false, fmt.Errorf("NS %s did not return the expected TXT record", ns)
147		}
148	}
149
150	return true, nil
151}
152
153// dnsQuery will query a nameserver, iterating through the supplied servers as it retries
154// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
155func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
156	m := new(dns.Msg)
157	m.SetQuestion(fqdn, rtype)
158	m.SetEdns0(4096, false)
159
160	if !recursive {
161		m.RecursionDesired = false
162	}
163
164	// Will retry the request based on the number of servers (n+1)
165	for i := 1; i <= len(nameservers)+1; i++ {
166		ns := nameservers[i%len(nameservers)]
167		udp := &dns.Client{Net: "udp", Timeout: DNSTimeout}
168		in, _, err = udp.Exchange(m, ns)
169
170		if err == dns.ErrTruncated {
171			tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
172			// If the TCP request suceeds, the err will reset to nil
173			in, _, err = tcp.Exchange(m, ns)
174		}
175
176		if err == nil {
177			break
178		}
179	}
180	return
181}
182
183// lookupNameservers returns the authoritative nameservers for the given fqdn.
184func lookupNameservers(fqdn string) ([]string, error) {
185	var authoritativeNss []string
186
187	zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
188	if err != nil {
189		return nil, err
190	}
191
192	r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
193	if err != nil {
194		return nil, err
195	}
196
197	for _, rr := range r.Answer {
198		if ns, ok := rr.(*dns.NS); ok {
199			authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
200		}
201	}
202
203	if len(authoritativeNss) > 0 {
204		return authoritativeNss, nil
205	}
206	return nil, fmt.Errorf("Could not determine authoritative nameservers")
207}
208
209// FindZoneByFqdn determines the zone of the given fqdn
210func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
211	// Do we have it cached?
212	if zone, ok := fqdnToZone[fqdn]; ok {
213		return zone, nil
214	}
215
216	// Query the authoritative nameserver for a hopefully non-existing SOA record,
217	// in the authority section of the reply it will have the SOA of the
218	// containing zone. rfc2308 has this to say on the subject:
219	//   Name servers authoritative for a zone MUST include the SOA record of
220	//   the zone in the authority section of the response when reporting an
221	//   NXDOMAIN or indicating that no data (NODATA) of the requested type exists
222	in, err := dnsQuery(fqdn, dns.TypeSOA, nameservers, true)
223	if err != nil {
224		return "", err
225	}
226	if in.Rcode != dns.RcodeNameError {
227		if in.Rcode != dns.RcodeSuccess {
228			return "", fmt.Errorf("The NS returned %s for %s", dns.RcodeToString[in.Rcode], fqdn)
229		}
230		// We have a success, so one of the answers has to be a SOA RR
231		for _, ans := range in.Answer {
232			if soa, ok := ans.(*dns.SOA); ok {
233				return checkIfTLD(fqdn, soa)
234			}
235		}
236		// Or it is NODATA, fall through to NXDOMAIN
237	}
238	// Search the authority section for our precious SOA RR
239	for _, ns := range in.Ns {
240		if soa, ok := ns.(*dns.SOA); ok {
241			return checkIfTLD(fqdn, soa)
242		}
243	}
244	return "", fmt.Errorf("The NS did not return the expected SOA record in the authority section")
245}
246
247func checkIfTLD(fqdn string, soa *dns.SOA) (string, error) {
248	zone := soa.Hdr.Name
249	// If we ended up on one of the TLDs, it means the domain did not exist.
250	publicsuffix, _ := publicsuffix.PublicSuffix(UnFqdn(zone))
251	if publicsuffix == UnFqdn(zone) {
252		return "", fmt.Errorf("Could not determine zone authoritatively")
253	}
254	fqdnToZone[fqdn] = zone
255	return zone, nil
256}
257
258// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
259func ClearFqdnCache() {
260	fqdnToZone = map[string]string{}
261}
262
263// ToFqdn converts the name into a fqdn appending a trailing dot.
264func ToFqdn(name string) string {
265	n := len(name)
266	if n == 0 || name[n-1] == '.' {
267		return name
268	}
269	return name + "."
270}
271
272// UnFqdn converts the fqdn into a name removing the trailing dot.
273func UnFqdn(name string) string {
274	n := len(name)
275	if n != 0 && name[n-1] == '.' {
276		return name[:n-1]
277	}
278	return name
279}
280