1package checkdomain
2
3import (
4	"bytes"
5	"encoding/json"
6	"errors"
7	"fmt"
8	"io"
9	"net/http"
10	"strconv"
11	"strings"
12)
13
14const (
15	ns1 = "ns.checkdomain.de"
16	ns2 = "ns2.checkdomain.de"
17)
18
19const domainNotFound = -1
20
21// max page limit that the checkdomain api allows.
22const maxLimit = 100
23
24// max integer value.
25const maxInt = int((^uint(0)) >> 1)
26
27type (
28	// Some fields have been omitted from the structs
29	// because they are not required for this application.
30
31	DomainListingResponse struct {
32		Page     int                `json:"page"`
33		Limit    int                `json:"limit"`
34		Pages    int                `json:"pages"`
35		Total    int                `json:"total"`
36		Embedded EmbeddedDomainList `json:"_embedded"`
37	}
38
39	EmbeddedDomainList struct {
40		Domains []*Domain `json:"domains"`
41	}
42
43	Domain struct {
44		ID   int    `json:"id"`
45		Name string `json:"name"`
46	}
47
48	DomainResponse struct {
49		ID      int    `json:"id"`
50		Name    string `json:"name"`
51		Created string `json:"created"`
52		PaidUp  string `json:"payed_up"`
53		Active  bool   `json:"active"`
54	}
55
56	NameserverResponse struct {
57		General     NameserverGeneral `json:"general"`
58		Nameservers []*Nameserver     `json:"nameservers"`
59		SOA         NameserverSOA     `json:"soa"`
60	}
61
62	NameserverGeneral struct {
63		IPv4       string `json:"ip_v4"`
64		IPv6       string `json:"ip_v6"`
65		IncludeWWW bool   `json:"include_www"`
66	}
67
68	NameserverSOA struct {
69		Mail    string `json:"mail"`
70		Refresh int    `json:"refresh"`
71		Retry   int    `json:"retry"`
72		Expiry  int    `json:"expiry"`
73		TTL     int    `json:"ttl"`
74	}
75
76	Nameserver struct {
77		Name string `json:"name"`
78	}
79
80	RecordListingResponse struct {
81		Page     int                `json:"page"`
82		Limit    int                `json:"limit"`
83		Pages    int                `json:"pages"`
84		Total    int                `json:"total"`
85		Embedded EmbeddedRecordList `json:"_embedded"`
86	}
87
88	EmbeddedRecordList struct {
89		Records []*Record `json:"records"`
90	}
91
92	Record struct {
93		Name     string `json:"name"`
94		Value    string `json:"value"`
95		TTL      int    `json:"ttl"`
96		Priority int    `json:"priority"`
97		Type     string `json:"type"`
98	}
99)
100
101func (d *DNSProvider) getDomainIDByName(name string) (int, error) {
102	// Load from cache if exists
103	d.domainIDMu.Lock()
104	id, ok := d.domainIDMapping[name]
105	d.domainIDMu.Unlock()
106	if ok {
107		return id, nil
108	}
109
110	// Find out by querying API
111	domains, err := d.listDomains()
112	if err != nil {
113		return domainNotFound, err
114	}
115
116	// Linear search over all registered domains
117	for _, domain := range domains {
118		if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
119			d.domainIDMu.Lock()
120			d.domainIDMapping[name] = domain.ID
121			d.domainIDMu.Unlock()
122
123			return domain.ID, nil
124		}
125	}
126
127	return domainNotFound, errors.New("domain not found")
128}
129
130func (d *DNSProvider) listDomains() ([]*Domain, error) {
131	req, err := d.makeRequest(http.MethodGet, "/v1/domains", http.NoBody)
132	if err != nil {
133		return nil, fmt.Errorf("failed to make request: %w", err)
134	}
135
136	// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
137	// But that functionality is kinda broken,
138	// so we scan through the whole list of registered domains to later find the one that is of interest to us.
139	q := req.URL.Query()
140	q.Set("limit", strconv.Itoa(maxLimit))
141
142	currentPage := 1
143	totalPages := maxInt
144
145	var domainList []*Domain
146	for currentPage <= totalPages {
147		q.Set("page", strconv.Itoa(currentPage))
148		req.URL.RawQuery = q.Encode()
149
150		var res DomainListingResponse
151		if err := d.sendRequest(req, &res); err != nil {
152			return nil, fmt.Errorf("failed to send domain listing request: %w", err)
153		}
154
155		// This is the first response,
156		// so we update totalPages and allocate the slice memory.
157		if totalPages == maxInt {
158			totalPages = res.Pages
159			domainList = make([]*Domain, 0, res.Total)
160		}
161
162		domainList = append(domainList, res.Embedded.Domains...)
163		currentPage++
164	}
165
166	return domainList, nil
167}
168
169func (d *DNSProvider) getNameserverInfo(domainID int) (*NameserverResponse, error) {
170	req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers", domainID), http.NoBody)
171	if err != nil {
172		return nil, err
173	}
174
175	res := &NameserverResponse{}
176	if err := d.sendRequest(req, res); err != nil {
177		return nil, err
178	}
179
180	return res, nil
181}
182
183func (d *DNSProvider) checkNameservers(domainID int) error {
184	info, err := d.getNameserverInfo(domainID)
185	if err != nil {
186		return err
187	}
188
189	var found1, found2 bool
190	for _, item := range info.Nameservers {
191		switch item.Name {
192		case ns1:
193			found1 = true
194		case ns2:
195			found2 = true
196		}
197	}
198
199	if !found1 || !found2 {
200		return errors.New("not using checkdomain nameservers, can not update records")
201	}
202
203	return nil
204}
205
206func (d *DNSProvider) createRecord(domainID int, record *Record) error {
207	bs, err := json.Marshal(record)
208	if err != nil {
209		return fmt.Errorf("encoding record failed: %w", err)
210	}
211
212	req, err := d.makeRequest(http.MethodPost, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
213	if err != nil {
214		return err
215	}
216
217	return d.sendRequest(req, nil)
218}
219
220// Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
221// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
222// TODO: Simplify this function once Checkdomain do provide the functionality.
223func (d *DNSProvider) deleteTXTRecord(domainID int, recordName, recordValue string) error {
224	domainInfo, err := d.getDomainInfo(domainID)
225	if err != nil {
226		return err
227	}
228
229	nsInfo, err := d.getNameserverInfo(domainID)
230	if err != nil {
231		return err
232	}
233
234	allRecords, err := d.listRecords(domainID, "")
235	if err != nil {
236		return err
237	}
238
239	recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
240
241	var recordsToKeep []*Record
242
243	// Find and delete matching records
244	for _, record := range allRecords {
245		if skipRecord(recordName, recordValue, record, nsInfo) {
246			continue
247		}
248
249		// Checkdomain API can return records without any TTL set (indicated by the value of 0).
250		// The API Call to replace the records would fail if we wouldn't specify a value.
251		// Thus, we use the default TTL queried beforehand
252		if record.TTL == 0 {
253			record.TTL = nsInfo.SOA.TTL
254		}
255
256		recordsToKeep = append(recordsToKeep, record)
257	}
258
259	return d.replaceRecords(domainID, recordsToKeep)
260}
261
262func (d *DNSProvider) getDomainInfo(domainID int) (*DomainResponse, error) {
263	req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d", domainID), http.NoBody)
264	if err != nil {
265		return nil, err
266	}
267
268	var res DomainResponse
269	err = d.sendRequest(req, &res)
270	if err != nil {
271		return nil, err
272	}
273
274	return &res, nil
275}
276
277func (d *DNSProvider) listRecords(domainID int, recordType string) ([]*Record, error) {
278	req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), http.NoBody)
279	if err != nil {
280		return nil, fmt.Errorf("failed to make request: %w", err)
281	}
282
283	q := req.URL.Query()
284	q.Set("limit", strconv.Itoa(maxLimit))
285	if recordType != "" {
286		q.Set("type", recordType)
287	}
288
289	currentPage := 1
290	totalPages := maxInt
291
292	var recordList []*Record
293	for currentPage <= totalPages {
294		q.Set("page", strconv.Itoa(currentPage))
295		req.URL.RawQuery = q.Encode()
296
297		var res RecordListingResponse
298		if err := d.sendRequest(req, &res); err != nil {
299			return nil, fmt.Errorf("failed to send record listing request: %w", err)
300		}
301
302		// This is the first response, so we update totalPages and allocate the slice memory.
303		if totalPages == maxInt {
304			totalPages = res.Pages
305			recordList = make([]*Record, 0, res.Total)
306		}
307
308		recordList = append(recordList, res.Embedded.Records...)
309		currentPage++
310	}
311
312	return recordList, nil
313}
314
315func (d *DNSProvider) replaceRecords(domainID int, records []*Record) error {
316	bs, err := json.Marshal(records)
317	if err != nil {
318		return fmt.Errorf("encoding record failed: %w", err)
319	}
320
321	req, err := d.makeRequest(http.MethodPut, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs))
322	if err != nil {
323		return err
324	}
325
326	return d.sendRequest(req, nil)
327}
328
329func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
330	// Skip empty records
331	if record.Value == "" {
332		return true
333	}
334
335	// Skip some special records, otherwise we would get a "Nameserver update failed"
336	if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
337		return true
338	}
339
340	nameMatch := recordName == "" || record.Name == recordName
341	valueMatch := recordValue == "" || record.Value == recordValue
342
343	// Skip our matching record
344	if record.Type == "TXT" && nameMatch && valueMatch {
345		return true
346	}
347
348	return false
349}
350
351func (d *DNSProvider) makeRequest(method, resource string, body io.Reader) (*http.Request, error) {
352	uri, err := d.config.Endpoint.Parse(resource)
353	if err != nil {
354		return nil, err
355	}
356
357	req, err := http.NewRequest(method, uri.String(), body)
358	if err != nil {
359		return nil, err
360	}
361
362	req.Header.Set("Accept", "application/json")
363	req.Header.Set("Authorization", "Bearer "+d.config.Token)
364	if method != http.MethodGet {
365		req.Header.Set("Content-Type", "application/json")
366	}
367
368	return req, nil
369}
370
371func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error {
372	resp, err := d.config.HTTPClient.Do(req)
373	if err != nil {
374		return err
375	}
376
377	if err = checkResponse(resp); err != nil {
378		return err
379	}
380
381	defer func() { _ = resp.Body.Close() }()
382
383	if result == nil {
384		return nil
385	}
386
387	raw, err := io.ReadAll(resp.Body)
388	if err != nil {
389		return err
390	}
391
392	err = json.Unmarshal(raw, result)
393	if err != nil {
394		return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw))
395	}
396	return nil
397}
398
399func checkResponse(resp *http.Response) error {
400	if resp.StatusCode < http.StatusBadRequest {
401		return nil
402	}
403
404	if resp.Body == nil {
405		return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode)
406	}
407
408	defer func() { _ = resp.Body.Close() }()
409
410	raw, err := io.ReadAll(resp.Body)
411	if err != nil {
412		return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err)
413	}
414
415	return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw))
416}
417