1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"context"
9	"errors"
10	"internal/bytealg"
11	"io"
12	"os"
13)
14
15func query(ctx context.Context, filename, query string, bufSize int) (addrs []string, err error) {
16	queryAddrs := func() (addrs []string, err error) {
17		file, err := os.OpenFile(filename, os.O_RDWR, 0)
18		if err != nil {
19			return nil, err
20		}
21		defer file.Close()
22
23		_, err = file.Seek(0, io.SeekStart)
24		if err != nil {
25			return nil, err
26		}
27		_, err = file.WriteString(query)
28		if err != nil {
29			return nil, err
30		}
31		_, err = file.Seek(0, io.SeekStart)
32		if err != nil {
33			return nil, err
34		}
35		buf := make([]byte, bufSize)
36		for {
37			n, _ := file.Read(buf)
38			if n <= 0 {
39				break
40			}
41			addrs = append(addrs, string(buf[:n]))
42		}
43		return addrs, nil
44	}
45
46	type ret struct {
47		addrs []string
48		err   error
49	}
50
51	ch := make(chan ret, 1)
52	go func() {
53		addrs, err := queryAddrs()
54		ch <- ret{addrs: addrs, err: err}
55	}()
56
57	select {
58	case r := <-ch:
59		return r.addrs, r.err
60	case <-ctx.Done():
61		return nil, &DNSError{
62			Name:      query,
63			Err:       ctx.Err().Error(),
64			IsTimeout: ctx.Err() == context.DeadlineExceeded,
65		}
66	}
67}
68
69func queryCS(ctx context.Context, net, host, service string) (res []string, err error) {
70	switch net {
71	case "tcp4", "tcp6":
72		net = "tcp"
73	case "udp4", "udp6":
74		net = "udp"
75	}
76	if host == "" {
77		host = "*"
78	}
79	return query(ctx, netdir+"/cs", net+"!"+host+"!"+service, 128)
80}
81
82func queryCS1(ctx context.Context, net string, ip IP, port int) (clone, dest string, err error) {
83	ips := "*"
84	if len(ip) != 0 && !ip.IsUnspecified() {
85		ips = ip.String()
86	}
87	lines, err := queryCS(ctx, net, ips, itoa(port))
88	if err != nil {
89		return
90	}
91	f := getFields(lines[0])
92	if len(f) < 2 {
93		return "", "", errors.New("bad response from ndb/cs")
94	}
95	clone, dest = f[0], f[1]
96	return
97}
98
99func queryDNS(ctx context.Context, addr string, typ string) (res []string, err error) {
100	return query(ctx, netdir+"/dns", addr+" "+typ, 1024)
101}
102
103// toLower returns a lower-case version of in. Restricting us to
104// ASCII is sufficient to handle the IP protocol names and allow
105// us to not depend on the strings and unicode packages.
106func toLower(in string) string {
107	for _, c := range in {
108		if 'A' <= c && c <= 'Z' {
109			// Has upper case; need to fix.
110			out := []byte(in)
111			for i := 0; i < len(in); i++ {
112				c := in[i]
113				if 'A' <= c && c <= 'Z' {
114					c += 'a' - 'A'
115				}
116				out[i] = c
117			}
118			return string(out)
119		}
120	}
121	return in
122}
123
124// lookupProtocol looks up IP protocol name and returns
125// the corresponding protocol number.
126func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
127	lines, err := query(ctx, netdir+"/cs", "!protocol="+toLower(name), 128)
128	if err != nil {
129		return 0, err
130	}
131	if len(lines) == 0 {
132		return 0, UnknownNetworkError(name)
133	}
134	f := getFields(lines[0])
135	if len(f) < 2 {
136		return 0, UnknownNetworkError(name)
137	}
138	s := f[1]
139	if n, _, ok := dtoi(s[bytealg.IndexByteString(s, '=')+1:]); ok {
140		return n, nil
141	}
142	return 0, UnknownNetworkError(name)
143}
144
145func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
146	// Use netdir/cs instead of netdir/dns because cs knows about
147	// host names in local network (e.g. from /lib/ndb/local)
148	lines, err := queryCS(ctx, "net", host, "1")
149	if err != nil {
150		dnsError := &DNSError{Err: err.Error(), Name: host}
151		if stringsHasSuffix(err.Error(), "dns failure") {
152			dnsError.Err = errNoSuchHost.Error()
153			dnsError.IsNotFound = true
154		}
155		return nil, dnsError
156	}
157loop:
158	for _, line := range lines {
159		f := getFields(line)
160		if len(f) < 2 {
161			continue
162		}
163		addr := f[1]
164		if i := bytealg.IndexByteString(addr, '!'); i >= 0 {
165			addr = addr[:i] // remove port
166		}
167		if ParseIP(addr) == nil {
168			continue
169		}
170		// only return unique addresses
171		for _, a := range addrs {
172			if a == addr {
173				continue loop
174			}
175		}
176		addrs = append(addrs, addr)
177	}
178	return
179}
180
181func (r *Resolver) lookupIP(ctx context.Context, _, host string) (addrs []IPAddr, err error) {
182	lits, err := r.lookupHost(ctx, host)
183	if err != nil {
184		return
185	}
186	for _, lit := range lits {
187		host, zone := splitHostZone(lit)
188		if ip := ParseIP(host); ip != nil {
189			addr := IPAddr{IP: ip, Zone: zone}
190			addrs = append(addrs, addr)
191		}
192	}
193	return
194}
195
196func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
197	switch network {
198	case "tcp4", "tcp6":
199		network = "tcp"
200	case "udp4", "udp6":
201		network = "udp"
202	}
203	lines, err := queryCS(ctx, network, "127.0.0.1", toLower(service))
204	if err != nil {
205		return
206	}
207	unknownPortError := &AddrError{Err: "unknown port", Addr: network + "/" + service}
208	if len(lines) == 0 {
209		return 0, unknownPortError
210	}
211	f := getFields(lines[0])
212	if len(f) < 2 {
213		return 0, unknownPortError
214	}
215	s := f[1]
216	if i := bytealg.IndexByteString(s, '!'); i >= 0 {
217		s = s[i+1:] // remove address
218	}
219	if n, _, ok := dtoi(s); ok {
220		return n, nil
221	}
222	return 0, unknownPortError
223}
224
225func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
226	lines, err := queryDNS(ctx, name, "cname")
227	if err != nil {
228		if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") {
229			cname = name + "."
230			err = nil
231		}
232		return
233	}
234	if len(lines) > 0 {
235		if f := getFields(lines[0]); len(f) >= 3 {
236			return f[2] + ".", nil
237		}
238	}
239	return "", errors.New("bad response from ndb/dns")
240}
241
242func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
243	var target string
244	if service == "" && proto == "" {
245		target = name
246	} else {
247		target = "_" + service + "._" + proto + "." + name
248	}
249	lines, err := queryDNS(ctx, target, "srv")
250	if err != nil {
251		return
252	}
253	for _, line := range lines {
254		f := getFields(line)
255		if len(f) < 6 {
256			continue
257		}
258		port, _, portOk := dtoi(f[4])
259		priority, _, priorityOk := dtoi(f[3])
260		weight, _, weightOk := dtoi(f[2])
261		if !(portOk && priorityOk && weightOk) {
262			continue
263		}
264		addrs = append(addrs, &SRV{absDomainName([]byte(f[5])), uint16(port), uint16(priority), uint16(weight)})
265		cname = absDomainName([]byte(f[0]))
266	}
267	byPriorityWeight(addrs).sort()
268	return
269}
270
271func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
272	lines, err := queryDNS(ctx, name, "mx")
273	if err != nil {
274		return
275	}
276	for _, line := range lines {
277		f := getFields(line)
278		if len(f) < 4 {
279			continue
280		}
281		if pref, _, ok := dtoi(f[2]); ok {
282			mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)})
283		}
284	}
285	byPref(mx).sort()
286	return
287}
288
289func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
290	lines, err := queryDNS(ctx, name, "ns")
291	if err != nil {
292		return
293	}
294	for _, line := range lines {
295		f := getFields(line)
296		if len(f) < 3 {
297			continue
298		}
299		ns = append(ns, &NS{absDomainName([]byte(f[2]))})
300	}
301	return
302}
303
304func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
305	lines, err := queryDNS(ctx, name, "txt")
306	if err != nil {
307		return
308	}
309	for _, line := range lines {
310		if i := bytealg.IndexByteString(line, '\t'); i >= 0 {
311			txt = append(txt, absDomainName([]byte(line[i+1:])))
312		}
313	}
314	return
315}
316
317func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
318	arpa, err := reverseaddr(addr)
319	if err != nil {
320		return
321	}
322	lines, err := queryDNS(ctx, arpa, "ptr")
323	if err != nil {
324		return
325	}
326	for _, line := range lines {
327		f := getFields(line)
328		if len(f) < 3 {
329			continue
330		}
331		name = append(name, absDomainName([]byte(f[2])))
332	}
333	return
334}
335
336// concurrentThreadsLimit returns the number of threads we permit to
337// run concurrently doing DNS lookups.
338func concurrentThreadsLimit() int {
339	return 500
340}
341