1package resolver
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net"
8	"strconv"
9	"strings"
10	"sync/atomic"
11)
12
13type resolver struct {
14	addrs  []string
15	dialer *net.Dialer
16	idx    uint64
17}
18
19// NewResolver - create a new instance of a dns resolver for plugging
20// into net.DefaultResolver.  Addresses should be a list of
21// ip addrs and optional port numbers, separated by colon.
22// For example: 1.2.3.4:53 and 1.2.3.4 are both valid.  In the absence
23// of a port number, 53 will be used instead.
24func NewResolver(addrs []string) (*net.Resolver, error) {
25	if len(addrs) == 0 {
26		return nil, errors.New("must specify at least resolver address")
27	}
28	cleanAddrs, err := normalizeAddrs(addrs)
29	if err != nil {
30		return nil, err
31	}
32	return &net.Resolver{
33		PreferGo: true,
34		Dial:     (&resolver{addrs: cleanAddrs, dialer: &net.Dialer{}}).dial,
35	}, nil
36}
37
38func normalizeAddrs(addrs []string) ([]string, error) {
39	normal := make([]string, len(addrs))
40	for i, addr := range addrs {
41
42		// if addr has no port, give it 53
43		if !strings.Contains(addr, ":") {
44			addr += ":53"
45		}
46
47		// validate addr is a valid host:port
48		host, portstr, err := net.SplitHostPort(addr)
49		if err != nil {
50			return nil, err
51		}
52
53		// validate valid port.
54		_, err = strconv.ParseUint(portstr, 10, 16)
55		if err != nil {
56			return nil, err
57		}
58
59		// make sure host is an ip.
60		ip := net.ParseIP(host)
61		if ip == nil {
62			return nil, fmt.Errorf("host %s is not an IP address", host)
63		}
64
65		normal[i] = addr
66	}
67	return normal, nil
68}
69
70// ignore the third parameter, as this represents the dns server address that
71// we are overriding.
72func (r *resolver) dial(ctx context.Context, network, _ string) (net.Conn, error) {
73	return r.dialer.DialContext(ctx, network, r.address())
74}
75
76func (r *resolver) address() string {
77	return r.addrs[atomic.AddUint64(&r.idx, 1)%uint64(len(r.addrs))]
78}
79