1package dnsutil
2
3import (
4	"fmt"
5	"net"
6	"os"
7
8	"github.com/miekg/dns"
9)
10
11// ParseHostPortOrFile parses the strings in s, each string can either be a address,
12// address:port or a filename. The address part is checked and the filename case a
13// resolv.conf like file is parsed and the nameserver found are returned.
14func ParseHostPortOrFile(s ...string) ([]string, error) {
15	var servers []string
16	for _, host := range s {
17		addr, _, err := net.SplitHostPort(host)
18		if err != nil {
19			// Parse didn't work, it is not a addr:port combo
20			if net.ParseIP(host) == nil {
21				// Not an IP address.
22				ss, err := tryFile(host)
23				if err == nil {
24					servers = append(servers, ss...)
25					continue
26				}
27				return servers, fmt.Errorf("not an IP address or file: %q", host)
28			}
29			ss := net.JoinHostPort(host, "53")
30			servers = append(servers, ss)
31			continue
32		}
33
34		if net.ParseIP(addr) == nil {
35			// No an IP address.
36			ss, err := tryFile(host)
37			if err == nil {
38				servers = append(servers, ss...)
39				continue
40			}
41			return servers, fmt.Errorf("not an IP address or file: %q", host)
42		}
43		servers = append(servers, host)
44	}
45	return servers, nil
46}
47
48// Try to open this is a file first.
49func tryFile(s string) ([]string, error) {
50	c, err := dns.ClientConfigFromFile(s)
51	if err == os.ErrNotExist {
52		return nil, fmt.Errorf("failed to open file %q: %q", s, err)
53	} else if err != nil {
54		return nil, err
55	}
56
57	servers := []string{}
58	for _, s := range c.Servers {
59		servers = append(servers, net.JoinHostPort(s, c.Port))
60	}
61	return servers, nil
62}
63
64// ParseHostPort will check if the host part is a valid IP address, if the
65// IP address is valid, but no port is found, defaultPort is added.
66func ParseHostPort(s, defaultPort string) (string, error) {
67	addr, port, err := net.SplitHostPort(s)
68	if port == "" {
69		port = defaultPort
70	}
71	if err != nil {
72		if net.ParseIP(s) == nil {
73			return "", fmt.Errorf("must specify an IP address: `%s'", s)
74		}
75		return net.JoinHostPort(s, port), nil
76	}
77
78	if net.ParseIP(addr) == nil {
79		return "", fmt.Errorf("must specify an IP address: `%s'", addr)
80	}
81	return net.JoinHostPort(addr, port), nil
82}
83