1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"context"
9	"os"
10	"runtime"
11	"syscall"
12	"unsafe"
13)
14
15const _WSAHOST_NOT_FOUND = syscall.Errno(11001)
16
17func winError(call string, err error) error {
18	switch err {
19	case _WSAHOST_NOT_FOUND:
20		return errNoSuchHost
21	}
22	return os.NewSyscallError(call, err)
23}
24
25func getprotobyname(name string) (proto int, err error) {
26	p, err := syscall.GetProtoByName(name)
27	if err != nil {
28		return 0, winError("getprotobyname", err)
29	}
30	return int(p.Proto), nil
31}
32
33// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
34func lookupProtocol(ctx context.Context, name string) (int, error) {
35	// GetProtoByName return value is stored in thread local storage.
36	// Start new os thread before the call to prevent races.
37	type result struct {
38		proto int
39		err   error
40	}
41	ch := make(chan result) // unbuffered
42	go func() {
43		acquireThread()
44		defer releaseThread()
45		runtime.LockOSThread()
46		defer runtime.UnlockOSThread()
47		proto, err := getprotobyname(name)
48		select {
49		case ch <- result{proto: proto, err: err}:
50		case <-ctx.Done():
51		}
52	}()
53	select {
54	case r := <-ch:
55		if r.err != nil {
56			if proto, err := lookupProtocolMap(name); err == nil {
57				return proto, nil
58			}
59			r.err = &DNSError{Err: r.err.Error(), Name: name}
60		}
61		return r.proto, r.err
62	case <-ctx.Done():
63		return 0, mapErr(ctx.Err())
64	}
65}
66
67func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
68	ips, err := r.lookupIP(ctx, name)
69	if err != nil {
70		return nil, err
71	}
72	addrs := make([]string, 0, len(ips))
73	for _, ip := range ips {
74		addrs = append(addrs, ip.String())
75	}
76	return addrs, nil
77}
78
79func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
80	// TODO(bradfitz,brainman): use ctx more. See TODO below.
81
82	type ret struct {
83		addrs []IPAddr
84		err   error
85	}
86	ch := make(chan ret, 1)
87	go func() {
88		acquireThread()
89		defer releaseThread()
90		hints := syscall.AddrinfoW{
91			Family:   syscall.AF_UNSPEC,
92			Socktype: syscall.SOCK_STREAM,
93			Protocol: syscall.IPPROTO_IP,
94		}
95		var result *syscall.AddrinfoW
96		e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
97		if e != nil {
98			ch <- ret{err: &DNSError{Err: winError("getaddrinfow", e).Error(), Name: name}}
99		}
100		defer syscall.FreeAddrInfoW(result)
101		addrs := make([]IPAddr, 0, 5)
102		for ; result != nil; result = result.Next {
103			addr := unsafe.Pointer(result.Addr)
104			switch result.Family {
105			case syscall.AF_INET:
106				a := (*syscall.RawSockaddrInet4)(addr).Addr
107				addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
108			case syscall.AF_INET6:
109				a := (*syscall.RawSockaddrInet6)(addr).Addr
110				zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
111				addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
112			default:
113				ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}}
114			}
115		}
116		ch <- ret{addrs: addrs}
117	}()
118	select {
119	case r := <-ch:
120		return r.addrs, r.err
121	case <-ctx.Done():
122		// TODO(bradfitz,brainman): cancel the ongoing
123		// GetAddrInfoW? It would require conditionally using
124		// GetAddrInfoEx with lpOverlapped, which requires
125		// Windows 8 or newer. I guess we'll need oldLookupIP,
126		// newLookupIP, and newerLookUP.
127		//
128		// For now we just let it finish and write to the
129		// buffered channel.
130		return nil, &DNSError{
131			Name:      name,
132			Err:       ctx.Err().Error(),
133			IsTimeout: ctx.Err() == context.DeadlineExceeded,
134		}
135	}
136}
137
138func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
139	if r.PreferGo {
140		return lookupPortMap(network, service)
141	}
142
143	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
144	acquireThread()
145	defer releaseThread()
146	var stype int32
147	switch network {
148	case "tcp4", "tcp6":
149		stype = syscall.SOCK_STREAM
150	case "udp4", "udp6":
151		stype = syscall.SOCK_DGRAM
152	}
153	hints := syscall.AddrinfoW{
154		Family:   syscall.AF_UNSPEC,
155		Socktype: stype,
156		Protocol: syscall.IPPROTO_IP,
157	}
158	var result *syscall.AddrinfoW
159	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
160	if e != nil {
161		if port, err := lookupPortMap(network, service); err == nil {
162			return port, nil
163		}
164		return 0, &DNSError{Err: winError("getaddrinfow", e).Error(), Name: network + "/" + service}
165	}
166	defer syscall.FreeAddrInfoW(result)
167	if result == nil {
168		return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
169	}
170	addr := unsafe.Pointer(result.Addr)
171	switch result.Family {
172	case syscall.AF_INET:
173		a := (*syscall.RawSockaddrInet4)(addr)
174		return int(syscall.Ntohs(a.Port)), nil
175	case syscall.AF_INET6:
176		a := (*syscall.RawSockaddrInet6)(addr)
177		return int(syscall.Ntohs(a.Port)), nil
178	}
179	return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
180}
181
182func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
183	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
184	acquireThread()
185	defer releaseThread()
186	var r *syscall.DNSRecord
187	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
188	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
189	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
190		// if there are no aliases, the canonical name is the input name
191		return absDomainName([]byte(name)), nil
192	}
193	if e != nil {
194		return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
195	}
196	defer syscall.DnsRecordListFree(r, 1)
197
198	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
199	cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:])
200	return absDomainName([]byte(cname)), nil
201}
202
203func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
204	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
205	acquireThread()
206	defer releaseThread()
207	var target string
208	if service == "" && proto == "" {
209		target = name
210	} else {
211		target = "_" + service + "._" + proto + "." + name
212	}
213	var r *syscall.DNSRecord
214	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
215	if e != nil {
216		return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
217	}
218	defer syscall.DnsRecordListFree(r, 1)
219
220	srvs := make([]*SRV, 0, 10)
221	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
222		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
223		srvs = append(srvs, &SRV{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]))), v.Port, v.Priority, v.Weight})
224	}
225	byPriorityWeight(srvs).sort()
226	return absDomainName([]byte(target)), srvs, nil
227}
228
229func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
230	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
231	acquireThread()
232	defer releaseThread()
233	var r *syscall.DNSRecord
234	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
235	if e != nil {
236		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
237	}
238	defer syscall.DnsRecordListFree(r, 1)
239
240	mxs := make([]*MX, 0, 10)
241	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
242		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
243		mxs = append(mxs, &MX{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]))), v.Preference})
244	}
245	byPref(mxs).sort()
246	return mxs, nil
247}
248
249func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
250	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
251	acquireThread()
252	defer releaseThread()
253	var r *syscall.DNSRecord
254	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
255	if e != nil {
256		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
257	}
258	defer syscall.DnsRecordListFree(r, 1)
259
260	nss := make([]*NS, 0, 10)
261	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
262		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
263		nss = append(nss, &NS{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))})
264	}
265	return nss, nil
266}
267
268func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
269	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
270	acquireThread()
271	defer releaseThread()
272	var r *syscall.DNSRecord
273	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
274	if e != nil {
275		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
276	}
277	defer syscall.DnsRecordListFree(r, 1)
278
279	txts := make([]string, 0, 10)
280	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
281		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
282		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
283			s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
284			txts = append(txts, s)
285		}
286	}
287	return txts, nil
288}
289
290func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
291	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
292	acquireThread()
293	defer releaseThread()
294	arpa, err := reverseaddr(addr)
295	if err != nil {
296		return nil, err
297	}
298	var r *syscall.DNSRecord
299	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
300	if e != nil {
301		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
302	}
303	defer syscall.DnsRecordListFree(r, 1)
304
305	ptrs := make([]string, 0, 10)
306	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
307		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
308		ptrs = append(ptrs, absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))))
309	}
310	return ptrs, nil
311}
312
313const dnsSectionMask = 0x0003
314
315// returns only results applicable to name and resolves CNAME entries
316func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
317	cname := syscall.StringToUTF16Ptr(name)
318	if dnstype != syscall.DNS_TYPE_CNAME {
319		cname = resolveCNAME(cname, r)
320	}
321	rec := make([]*syscall.DNSRecord, 0, 10)
322	for p := r; p != nil; p = p.Next {
323		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
324			continue
325		}
326		if p.Type != dnstype {
327			continue
328		}
329		if !syscall.DnsNameCompare(cname, p.Name) {
330			continue
331		}
332		rec = append(rec, p)
333	}
334	return rec
335}
336
337// returns the last CNAME in chain
338func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
339	// limit cname resolving to 10 in case of a infinite CNAME loop
340Cname:
341	for cnameloop := 0; cnameloop < 10; cnameloop++ {
342		for p := r; p != nil; p = p.Next {
343			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
344				continue
345			}
346			if p.Type != syscall.DNS_TYPE_CNAME {
347				continue
348			}
349			if !syscall.DnsNameCompare(name, p.Name) {
350				continue
351			}
352			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
353			continue Cname
354		}
355		break
356	}
357	return name
358}
359