1package checkdomain 2 3import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "net/http" 10 "strconv" 11 "strings" 12) 13 14const ( 15 ns1 = "ns.checkdomain.de" 16 ns2 = "ns2.checkdomain.de" 17) 18 19const domainNotFound = -1 20 21// max page limit that the checkdomain api allows. 22const maxLimit = 100 23 24// max integer value. 25const maxInt = int((^uint(0)) >> 1) 26 27type ( 28 // Some fields have been omitted from the structs 29 // because they are not required for this application. 30 31 DomainListingResponse struct { 32 Page int `json:"page"` 33 Limit int `json:"limit"` 34 Pages int `json:"pages"` 35 Total int `json:"total"` 36 Embedded EmbeddedDomainList `json:"_embedded"` 37 } 38 39 EmbeddedDomainList struct { 40 Domains []*Domain `json:"domains"` 41 } 42 43 Domain struct { 44 ID int `json:"id"` 45 Name string `json:"name"` 46 } 47 48 DomainResponse struct { 49 ID int `json:"id"` 50 Name string `json:"name"` 51 Created string `json:"created"` 52 PaidUp string `json:"payed_up"` 53 Active bool `json:"active"` 54 } 55 56 NameserverResponse struct { 57 General NameserverGeneral `json:"general"` 58 Nameservers []*Nameserver `json:"nameservers"` 59 SOA NameserverSOA `json:"soa"` 60 } 61 62 NameserverGeneral struct { 63 IPv4 string `json:"ip_v4"` 64 IPv6 string `json:"ip_v6"` 65 IncludeWWW bool `json:"include_www"` 66 } 67 68 NameserverSOA struct { 69 Mail string `json:"mail"` 70 Refresh int `json:"refresh"` 71 Retry int `json:"retry"` 72 Expiry int `json:"expiry"` 73 TTL int `json:"ttl"` 74 } 75 76 Nameserver struct { 77 Name string `json:"name"` 78 } 79 80 RecordListingResponse struct { 81 Page int `json:"page"` 82 Limit int `json:"limit"` 83 Pages int `json:"pages"` 84 Total int `json:"total"` 85 Embedded EmbeddedRecordList `json:"_embedded"` 86 } 87 88 EmbeddedRecordList struct { 89 Records []*Record `json:"records"` 90 } 91 92 Record struct { 93 Name string `json:"name"` 94 Value string `json:"value"` 95 TTL int `json:"ttl"` 96 Priority int `json:"priority"` 97 Type string `json:"type"` 98 } 99) 100 101func (d *DNSProvider) getDomainIDByName(name string) (int, error) { 102 // Load from cache if exists 103 d.domainIDMu.Lock() 104 id, ok := d.domainIDMapping[name] 105 d.domainIDMu.Unlock() 106 if ok { 107 return id, nil 108 } 109 110 // Find out by querying API 111 domains, err := d.listDomains() 112 if err != nil { 113 return domainNotFound, err 114 } 115 116 // Linear search over all registered domains 117 for _, domain := range domains { 118 if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) { 119 d.domainIDMu.Lock() 120 d.domainIDMapping[name] = domain.ID 121 d.domainIDMu.Unlock() 122 123 return domain.ID, nil 124 } 125 } 126 127 return domainNotFound, errors.New("domain not found") 128} 129 130func (d *DNSProvider) listDomains() ([]*Domain, error) { 131 req, err := d.makeRequest(http.MethodGet, "/v1/domains", http.NoBody) 132 if err != nil { 133 return nil, fmt.Errorf("failed to make request: %w", err) 134 } 135 136 // Checkdomain also provides a query param 'query' which allows filtering domains for a string. 137 // But that functionality is kinda broken, 138 // so we scan through the whole list of registered domains to later find the one that is of interest to us. 139 q := req.URL.Query() 140 q.Set("limit", strconv.Itoa(maxLimit)) 141 142 currentPage := 1 143 totalPages := maxInt 144 145 var domainList []*Domain 146 for currentPage <= totalPages { 147 q.Set("page", strconv.Itoa(currentPage)) 148 req.URL.RawQuery = q.Encode() 149 150 var res DomainListingResponse 151 if err := d.sendRequest(req, &res); err != nil { 152 return nil, fmt.Errorf("failed to send domain listing request: %w", err) 153 } 154 155 // This is the first response, 156 // so we update totalPages and allocate the slice memory. 157 if totalPages == maxInt { 158 totalPages = res.Pages 159 domainList = make([]*Domain, 0, res.Total) 160 } 161 162 domainList = append(domainList, res.Embedded.Domains...) 163 currentPage++ 164 } 165 166 return domainList, nil 167} 168 169func (d *DNSProvider) getNameserverInfo(domainID int) (*NameserverResponse, error) { 170 req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers", domainID), http.NoBody) 171 if err != nil { 172 return nil, err 173 } 174 175 res := &NameserverResponse{} 176 if err := d.sendRequest(req, res); err != nil { 177 return nil, err 178 } 179 180 return res, nil 181} 182 183func (d *DNSProvider) checkNameservers(domainID int) error { 184 info, err := d.getNameserverInfo(domainID) 185 if err != nil { 186 return err 187 } 188 189 var found1, found2 bool 190 for _, item := range info.Nameservers { 191 switch item.Name { 192 case ns1: 193 found1 = true 194 case ns2: 195 found2 = true 196 } 197 } 198 199 if !found1 || !found2 { 200 return errors.New("not using checkdomain nameservers, can not update records") 201 } 202 203 return nil 204} 205 206func (d *DNSProvider) createRecord(domainID int, record *Record) error { 207 bs, err := json.Marshal(record) 208 if err != nil { 209 return fmt.Errorf("encoding record failed: %w", err) 210 } 211 212 req, err := d.makeRequest(http.MethodPost, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs)) 213 if err != nil { 214 return err 215 } 216 217 return d.sendRequest(req, nil) 218} 219 220// Checkdomain doesn't seem provide a way to delete records but one can replace all records at once. 221// The current solution is to fetch all records and then use that list minus the record deleted as the new record list. 222// TODO: Simplify this function once Checkdomain do provide the functionality. 223func (d *DNSProvider) deleteTXTRecord(domainID int, recordName, recordValue string) error { 224 domainInfo, err := d.getDomainInfo(domainID) 225 if err != nil { 226 return err 227 } 228 229 nsInfo, err := d.getNameserverInfo(domainID) 230 if err != nil { 231 return err 232 } 233 234 allRecords, err := d.listRecords(domainID, "") 235 if err != nil { 236 return err 237 } 238 239 recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".") 240 241 var recordsToKeep []*Record 242 243 // Find and delete matching records 244 for _, record := range allRecords { 245 if skipRecord(recordName, recordValue, record, nsInfo) { 246 continue 247 } 248 249 // Checkdomain API can return records without any TTL set (indicated by the value of 0). 250 // The API Call to replace the records would fail if we wouldn't specify a value. 251 // Thus, we use the default TTL queried beforehand 252 if record.TTL == 0 { 253 record.TTL = nsInfo.SOA.TTL 254 } 255 256 recordsToKeep = append(recordsToKeep, record) 257 } 258 259 return d.replaceRecords(domainID, recordsToKeep) 260} 261 262func (d *DNSProvider) getDomainInfo(domainID int) (*DomainResponse, error) { 263 req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d", domainID), http.NoBody) 264 if err != nil { 265 return nil, err 266 } 267 268 var res DomainResponse 269 err = d.sendRequest(req, &res) 270 if err != nil { 271 return nil, err 272 } 273 274 return &res, nil 275} 276 277func (d *DNSProvider) listRecords(domainID int, recordType string) ([]*Record, error) { 278 req, err := d.makeRequest(http.MethodGet, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), http.NoBody) 279 if err != nil { 280 return nil, fmt.Errorf("failed to make request: %w", err) 281 } 282 283 q := req.URL.Query() 284 q.Set("limit", strconv.Itoa(maxLimit)) 285 if recordType != "" { 286 q.Set("type", recordType) 287 } 288 289 currentPage := 1 290 totalPages := maxInt 291 292 var recordList []*Record 293 for currentPage <= totalPages { 294 q.Set("page", strconv.Itoa(currentPage)) 295 req.URL.RawQuery = q.Encode() 296 297 var res RecordListingResponse 298 if err := d.sendRequest(req, &res); err != nil { 299 return nil, fmt.Errorf("failed to send record listing request: %w", err) 300 } 301 302 // This is the first response, so we update totalPages and allocate the slice memory. 303 if totalPages == maxInt { 304 totalPages = res.Pages 305 recordList = make([]*Record, 0, res.Total) 306 } 307 308 recordList = append(recordList, res.Embedded.Records...) 309 currentPage++ 310 } 311 312 return recordList, nil 313} 314 315func (d *DNSProvider) replaceRecords(domainID int, records []*Record) error { 316 bs, err := json.Marshal(records) 317 if err != nil { 318 return fmt.Errorf("encoding record failed: %w", err) 319 } 320 321 req, err := d.makeRequest(http.MethodPut, fmt.Sprintf("/v1/domains/%d/nameservers/records", domainID), bytes.NewReader(bs)) 322 if err != nil { 323 return err 324 } 325 326 return d.sendRequest(req, nil) 327} 328 329func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool { 330 // Skip empty records 331 if record.Value == "" { 332 return true 333 } 334 335 // Skip some special records, otherwise we would get a "Nameserver update failed" 336 if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") { 337 return true 338 } 339 340 nameMatch := recordName == "" || record.Name == recordName 341 valueMatch := recordValue == "" || record.Value == recordValue 342 343 // Skip our matching record 344 if record.Type == "TXT" && nameMatch && valueMatch { 345 return true 346 } 347 348 return false 349} 350 351func (d *DNSProvider) makeRequest(method, resource string, body io.Reader) (*http.Request, error) { 352 uri, err := d.config.Endpoint.Parse(resource) 353 if err != nil { 354 return nil, err 355 } 356 357 req, err := http.NewRequest(method, uri.String(), body) 358 if err != nil { 359 return nil, err 360 } 361 362 req.Header.Set("Accept", "application/json") 363 req.Header.Set("Authorization", "Bearer "+d.config.Token) 364 if method != http.MethodGet { 365 req.Header.Set("Content-Type", "application/json") 366 } 367 368 return req, nil 369} 370 371func (d *DNSProvider) sendRequest(req *http.Request, result interface{}) error { 372 resp, err := d.config.HTTPClient.Do(req) 373 if err != nil { 374 return err 375 } 376 377 if err = checkResponse(resp); err != nil { 378 return err 379 } 380 381 defer func() { _ = resp.Body.Close() }() 382 383 if result == nil { 384 return nil 385 } 386 387 raw, err := io.ReadAll(resp.Body) 388 if err != nil { 389 return err 390 } 391 392 err = json.Unmarshal(raw, result) 393 if err != nil { 394 return fmt.Errorf("unmarshaling %T error [status code=%d]: %w: %s", result, resp.StatusCode, err, string(raw)) 395 } 396 return nil 397} 398 399func checkResponse(resp *http.Response) error { 400 if resp.StatusCode < http.StatusBadRequest { 401 return nil 402 } 403 404 if resp.Body == nil { 405 return fmt.Errorf("response body is nil, status code=%d", resp.StatusCode) 406 } 407 408 defer func() { _ = resp.Body.Close() }() 409 410 raw, err := io.ReadAll(resp.Body) 411 if err != nil { 412 return fmt.Errorf("unable to read body: status code=%d, error=%w", resp.StatusCode, err) 413 } 414 415 return fmt.Errorf("status code=%d: %s", resp.StatusCode, string(raw)) 416} 417