1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"context"
9	"internal/syscall/windows"
10	"os"
11	"runtime"
12	"syscall"
13	"unsafe"
14)
15
16const _WSAHOST_NOT_FOUND = syscall.Errno(11001)
17
18func winError(call string, err error) error {
19	switch err {
20	case _WSAHOST_NOT_FOUND:
21		return errNoSuchHost
22	}
23	return os.NewSyscallError(call, err)
24}
25
26func getprotobyname(name string) (proto int, err error) {
27	p, err := syscall.GetProtoByName(name)
28	if err != nil {
29		return 0, winError("getprotobyname", err)
30	}
31	return int(p.Proto), nil
32}
33
34// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
35func lookupProtocol(ctx context.Context, name string) (int, error) {
36	// GetProtoByName return value is stored in thread local storage.
37	// Start new os thread before the call to prevent races.
38	type result struct {
39		proto int
40		err   error
41	}
42	ch := make(chan result) // unbuffered
43	go func() {
44		acquireThread()
45		defer releaseThread()
46		runtime.LockOSThread()
47		defer runtime.UnlockOSThread()
48		proto, err := getprotobyname(name)
49		select {
50		case ch <- result{proto: proto, err: err}:
51		case <-ctx.Done():
52		}
53	}()
54	select {
55	case r := <-ch:
56		if r.err != nil {
57			if proto, err := lookupProtocolMap(name); err == nil {
58				return proto, nil
59			}
60
61			dnsError := &DNSError{Err: r.err.Error(), Name: name}
62			if r.err == errNoSuchHost {
63				dnsError.IsNotFound = true
64			}
65			r.err = dnsError
66		}
67		return r.proto, r.err
68	case <-ctx.Done():
69		return 0, mapErr(ctx.Err())
70	}
71}
72
73func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
74	ips, err := r.lookupIP(ctx, "ip", name)
75	if err != nil {
76		return nil, err
77	}
78	addrs := make([]string, 0, len(ips))
79	for _, ip := range ips {
80		addrs = append(addrs, ip.String())
81	}
82	return addrs, nil
83}
84
85func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
86	// TODO(bradfitz,brainman): use ctx more. See TODO below.
87
88	var family int32 = syscall.AF_UNSPEC
89	switch ipVersion(network) {
90	case '4':
91		family = syscall.AF_INET
92	case '6':
93		family = syscall.AF_INET6
94	}
95
96	getaddr := func() ([]IPAddr, error) {
97		acquireThread()
98		defer releaseThread()
99		hints := syscall.AddrinfoW{
100			Family:   family,
101			Socktype: syscall.SOCK_STREAM,
102			Protocol: syscall.IPPROTO_IP,
103		}
104		var result *syscall.AddrinfoW
105		name16p, err := syscall.UTF16PtrFromString(name)
106		if err != nil {
107			return nil, &DNSError{Name: name, Err: err.Error()}
108		}
109		e := syscall.GetAddrInfoW(name16p, nil, &hints, &result)
110		if e != nil {
111			err := winError("getaddrinfow", e)
112			dnsError := &DNSError{Err: err.Error(), Name: name}
113			if err == errNoSuchHost {
114				dnsError.IsNotFound = true
115			}
116			return nil, dnsError
117		}
118		defer syscall.FreeAddrInfoW(result)
119		addrs := make([]IPAddr, 0, 5)
120		for ; result != nil; result = result.Next {
121			addr := unsafe.Pointer(result.Addr)
122			switch result.Family {
123			case syscall.AF_INET:
124				a := (*syscall.RawSockaddrInet4)(addr).Addr
125				addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
126			case syscall.AF_INET6:
127				a := (*syscall.RawSockaddrInet6)(addr).Addr
128				zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
129				addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
130			default:
131				return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
132			}
133		}
134		return addrs, nil
135	}
136
137	type ret struct {
138		addrs []IPAddr
139		err   error
140	}
141
142	var ch chan ret
143	if ctx.Err() == nil {
144		ch = make(chan ret, 1)
145		go func() {
146			addr, err := getaddr()
147			ch <- ret{addrs: addr, err: err}
148		}()
149	}
150
151	select {
152	case r := <-ch:
153		return r.addrs, r.err
154	case <-ctx.Done():
155		// TODO(bradfitz,brainman): cancel the ongoing
156		// GetAddrInfoW? It would require conditionally using
157		// GetAddrInfoEx with lpOverlapped, which requires
158		// Windows 8 or newer. I guess we'll need oldLookupIP,
159		// newLookupIP, and newerLookUP.
160		//
161		// For now we just let it finish and write to the
162		// buffered channel.
163		return nil, &DNSError{
164			Name:      name,
165			Err:       ctx.Err().Error(),
166			IsTimeout: ctx.Err() == context.DeadlineExceeded,
167		}
168	}
169}
170
171func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
172	if r.preferGo() {
173		return lookupPortMap(network, service)
174	}
175
176	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
177	acquireThread()
178	defer releaseThread()
179	var stype int32
180	switch network {
181	case "tcp4", "tcp6":
182		stype = syscall.SOCK_STREAM
183	case "udp4", "udp6":
184		stype = syscall.SOCK_DGRAM
185	}
186	hints := syscall.AddrinfoW{
187		Family:   syscall.AF_UNSPEC,
188		Socktype: stype,
189		Protocol: syscall.IPPROTO_IP,
190	}
191	var result *syscall.AddrinfoW
192	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
193	if e != nil {
194		if port, err := lookupPortMap(network, service); err == nil {
195			return port, nil
196		}
197		err := winError("getaddrinfow", e)
198		dnsError := &DNSError{Err: err.Error(), Name: network + "/" + service}
199		if err == errNoSuchHost {
200			dnsError.IsNotFound = true
201		}
202		return 0, dnsError
203	}
204	defer syscall.FreeAddrInfoW(result)
205	if result == nil {
206		return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
207	}
208	addr := unsafe.Pointer(result.Addr)
209	switch result.Family {
210	case syscall.AF_INET:
211		a := (*syscall.RawSockaddrInet4)(addr)
212		return int(syscall.Ntohs(a.Port)), nil
213	case syscall.AF_INET6:
214		a := (*syscall.RawSockaddrInet6)(addr)
215		return int(syscall.Ntohs(a.Port)), nil
216	}
217	return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
218}
219
220func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
221	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
222	acquireThread()
223	defer releaseThread()
224	var r *syscall.DNSRecord
225	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
226	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
227	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
228		// if there are no aliases, the canonical name is the input name
229		return absDomainName(name), nil
230	}
231	if e != nil {
232		return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
233	}
234	defer syscall.DnsRecordListFree(r, 1)
235
236	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
237	cname := windows.UTF16PtrToString(resolved)
238	return absDomainName(cname), nil
239}
240
241func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
242	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
243	acquireThread()
244	defer releaseThread()
245	var target string
246	if service == "" && proto == "" {
247		target = name
248	} else {
249		target = "_" + service + "._" + proto + "." + name
250	}
251	var r *syscall.DNSRecord
252	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
253	if e != nil {
254		return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
255	}
256	defer syscall.DnsRecordListFree(r, 1)
257
258	srvs := make([]*SRV, 0, 10)
259	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
260		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
261		srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
262	}
263	byPriorityWeight(srvs).sort()
264	return absDomainName(target), srvs, nil
265}
266
267func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
268	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
269	acquireThread()
270	defer releaseThread()
271	var r *syscall.DNSRecord
272	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
273	if e != nil {
274		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
275	}
276	defer syscall.DnsRecordListFree(r, 1)
277
278	mxs := make([]*MX, 0, 10)
279	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
280		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
281		mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
282	}
283	byPref(mxs).sort()
284	return mxs, nil
285}
286
287func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
288	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
289	acquireThread()
290	defer releaseThread()
291	var r *syscall.DNSRecord
292	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
293	if e != nil {
294		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
295	}
296	defer syscall.DnsRecordListFree(r, 1)
297
298	nss := make([]*NS, 0, 10)
299	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
300		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
301		nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
302	}
303	return nss, nil
304}
305
306func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
307	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
308	acquireThread()
309	defer releaseThread()
310	var r *syscall.DNSRecord
311	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
312	if e != nil {
313		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
314	}
315	defer syscall.DnsRecordListFree(r, 1)
316
317	txts := make([]string, 0, 10)
318	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
319		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
320		s := ""
321		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] {
322			s += windows.UTF16PtrToString(v)
323		}
324		txts = append(txts, s)
325	}
326	return txts, nil
327}
328
329func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
330	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
331	acquireThread()
332	defer releaseThread()
333	arpa, err := reverseaddr(addr)
334	if err != nil {
335		return nil, err
336	}
337	var r *syscall.DNSRecord
338	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
339	if e != nil {
340		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
341	}
342	defer syscall.DnsRecordListFree(r, 1)
343
344	ptrs := make([]string, 0, 10)
345	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
346		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
347		ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
348	}
349	return ptrs, nil
350}
351
352const dnsSectionMask = 0x0003
353
354// returns only results applicable to name and resolves CNAME entries
355func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
356	cname := syscall.StringToUTF16Ptr(name)
357	if dnstype != syscall.DNS_TYPE_CNAME {
358		cname = resolveCNAME(cname, r)
359	}
360	rec := make([]*syscall.DNSRecord, 0, 10)
361	for p := r; p != nil; p = p.Next {
362		// in case of a local machine, DNS records are returned with DNSREC_QUESTION flag instead of DNS_ANSWER
363		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer && p.Dw&dnsSectionMask != syscall.DnsSectionQuestion {
364			continue
365		}
366		if p.Type != dnstype {
367			continue
368		}
369		if !syscall.DnsNameCompare(cname, p.Name) {
370			continue
371		}
372		rec = append(rec, p)
373	}
374	return rec
375}
376
377// returns the last CNAME in chain
378func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
379	// limit cname resolving to 10 in case of an infinite CNAME loop
380Cname:
381	for cnameloop := 0; cnameloop < 10; cnameloop++ {
382		for p := r; p != nil; p = p.Next {
383			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
384				continue
385			}
386			if p.Type != syscall.DNS_TYPE_CNAME {
387				continue
388			}
389			if !syscall.DnsNameCompare(name, p.Name) {
390				continue
391			}
392			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
393			continue Cname
394		}
395		break
396	}
397	return name
398}
399
400// concurrentThreadsLimit returns the number of threads we permit to
401// run concurrently doing DNS lookups.
402func concurrentThreadsLimit() int {
403	return 500
404}
405