1package mdns
2
3import (
4	"fmt"
5	"log"
6	"net"
7	"strings"
8	"sync/atomic"
9	"time"
10
11	"github.com/miekg/dns"
12	"golang.org/x/net/ipv4"
13	"golang.org/x/net/ipv6"
14)
15
16// ServiceEntry is returned after we query for a service
17type ServiceEntry struct {
18	Name       string
19	Host       string
20	AddrV4     net.IP
21	AddrV6     net.IP
22	Port       int
23	Info       string
24	InfoFields []string
25
26	Addr net.IP // @Deprecated
27
28	hasTXT bool
29	sent   bool
30}
31
32// complete is used to check if we have all the info we need
33func (s *ServiceEntry) complete() bool {
34	return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT
35}
36
37// QueryParam is used to customize how a Lookup is performed
38type QueryParam struct {
39	Service             string               // Service to lookup
40	Domain              string               // Lookup domain, default "local"
41	Timeout             time.Duration        // Lookup timeout, default 1 second
42	Interface           *net.Interface       // Multicast interface to use
43	Entries             chan<- *ServiceEntry // Entries Channel
44	WantUnicastResponse bool                 // Unicast response desired, as per 5.4 in RFC
45}
46
47// DefaultParams is used to return a default set of QueryParam's
48func DefaultParams(service string) *QueryParam {
49	return &QueryParam{
50		Service:             service,
51		Domain:              "local",
52		Timeout:             time.Second,
53		Entries:             make(chan *ServiceEntry),
54		WantUnicastResponse: false, // TODO(reddaly): Change this default.
55	}
56}
57
58// Query looks up a given service, in a domain, waiting at most
59// for a timeout before finishing the query. The results are streamed
60// to a channel. Sends will not block, so clients should make sure to
61// either read or buffer.
62func Query(params *QueryParam) error {
63	// Create a new client
64	client, err := newClient()
65	if err != nil {
66		return err
67	}
68	defer client.Close()
69
70	// Set the multicast interface
71	if params.Interface != nil {
72		if err := client.setInterface(params.Interface); err != nil {
73			return err
74		}
75	}
76
77	// Ensure defaults are set
78	if params.Domain == "" {
79		params.Domain = "local"
80	}
81	if params.Timeout == 0 {
82		params.Timeout = time.Second
83	}
84
85	// Run the query
86	return client.query(params)
87}
88
89// Lookup is the same as Query, however it uses all the default parameters
90func Lookup(service string, entries chan<- *ServiceEntry) error {
91	params := DefaultParams(service)
92	params.Entries = entries
93	return Query(params)
94}
95
96// Client provides a query interface that can be used to
97// search for service providers using mDNS
98type client struct {
99	ipv4UnicastConn *net.UDPConn
100	ipv6UnicastConn *net.UDPConn
101
102	ipv4MulticastConn *net.UDPConn
103	ipv6MulticastConn *net.UDPConn
104
105	closed   int32
106	closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
107}
108
109// NewClient creates a new mdns Client that can be used to query
110// for records
111func newClient() (*client, error) {
112	// TODO(reddaly): At least attempt to bind to the port required in the spec.
113	// Create a IPv4 listener
114	uconn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
115	if err != nil {
116		log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
117	}
118	uconn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
119	if err != nil {
120		log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
121	}
122
123	if uconn4 == nil && uconn6 == nil {
124		return nil, fmt.Errorf("failed to bind to any unicast udp port")
125	}
126
127	mconn4, err := net.ListenMulticastUDP("udp4", nil, ipv4Addr)
128	if err != nil {
129		log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err)
130	}
131	mconn6, err := net.ListenMulticastUDP("udp6", nil, ipv6Addr)
132	if err != nil {
133		log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err)
134	}
135
136	if mconn4 == nil && mconn6 == nil {
137		return nil, fmt.Errorf("failed to bind to any multicast udp port")
138	}
139
140	c := &client{
141		ipv4MulticastConn: mconn4,
142		ipv6MulticastConn: mconn6,
143		ipv4UnicastConn:   uconn4,
144		ipv6UnicastConn:   uconn6,
145		closedCh:          make(chan struct{}),
146	}
147	return c, nil
148}
149
150// Close is used to cleanup the client
151func (c *client) Close() error {
152	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
153		// something else already closed it
154		return nil
155	}
156
157	log.Printf("[INFO] mdns: Closing client %v", *c)
158	close(c.closedCh)
159
160	if c.ipv4UnicastConn != nil {
161		c.ipv4UnicastConn.Close()
162	}
163	if c.ipv6UnicastConn != nil {
164		c.ipv6UnicastConn.Close()
165	}
166	if c.ipv4MulticastConn != nil {
167		c.ipv4MulticastConn.Close()
168	}
169	if c.ipv6MulticastConn != nil {
170		c.ipv6MulticastConn.Close()
171	}
172
173	return nil
174}
175
176// setInterface is used to set the query interface, uses system
177// default if not provided
178func (c *client) setInterface(iface *net.Interface) error {
179	p := ipv4.NewPacketConn(c.ipv4UnicastConn)
180	if err := p.SetMulticastInterface(iface); err != nil {
181		return err
182	}
183	p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
184	if err := p2.SetMulticastInterface(iface); err != nil {
185		return err
186	}
187	p = ipv4.NewPacketConn(c.ipv4MulticastConn)
188	if err := p.SetMulticastInterface(iface); err != nil {
189		return err
190	}
191	p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
192	if err := p2.SetMulticastInterface(iface); err != nil {
193		return err
194	}
195	return nil
196}
197
198// query is used to perform a lookup and stream results
199func (c *client) query(params *QueryParam) error {
200	// Create the service name
201	serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
202
203	// Start listening for response packets
204	msgCh := make(chan *dns.Msg, 32)
205	go c.recv(c.ipv4UnicastConn, msgCh)
206	go c.recv(c.ipv6UnicastConn, msgCh)
207	go c.recv(c.ipv4MulticastConn, msgCh)
208	go c.recv(c.ipv6MulticastConn, msgCh)
209
210	// Send the query
211	m := new(dns.Msg)
212	m.SetQuestion(serviceAddr, dns.TypePTR)
213	// RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question
214	// Section
215	//
216	// In the Question Section of a Multicast DNS query, the top bit of the qclass
217	// field is used to indicate that unicast responses are preferred for this
218	// particular question.  (See Section 5.4.)
219	if params.WantUnicastResponse {
220		m.Question[0].Qclass |= 1 << 15
221	}
222	m.RecursionDesired = false
223	if err := c.sendQuery(m); err != nil {
224		return err
225	}
226
227	// Map the in-progress responses
228	inprogress := make(map[string]*ServiceEntry)
229
230	// Listen until we reach the timeout
231	finish := time.After(params.Timeout)
232	for {
233		select {
234		case resp := <-msgCh:
235			var inp *ServiceEntry
236			for _, answer := range append(resp.Answer, resp.Extra...) {
237				// TODO(reddaly): Check that response corresponds to serviceAddr?
238				switch rr := answer.(type) {
239				case *dns.PTR:
240					// Create new entry for this
241					inp = ensureName(inprogress, rr.Ptr)
242
243				case *dns.SRV:
244					// Check for a target mismatch
245					if rr.Target != rr.Hdr.Name {
246						alias(inprogress, rr.Hdr.Name, rr.Target)
247					}
248
249					// Get the port
250					inp = ensureName(inprogress, rr.Hdr.Name)
251					inp.Host = rr.Target
252					inp.Port = int(rr.Port)
253
254				case *dns.TXT:
255					// Pull out the txt
256					inp = ensureName(inprogress, rr.Hdr.Name)
257					inp.Info = strings.Join(rr.Txt, "|")
258					inp.InfoFields = rr.Txt
259					inp.hasTXT = true
260
261				case *dns.A:
262					// Pull out the IP
263					inp = ensureName(inprogress, rr.Hdr.Name)
264					inp.Addr = rr.A // @Deprecated
265					inp.AddrV4 = rr.A
266
267				case *dns.AAAA:
268					// Pull out the IP
269					inp = ensureName(inprogress, rr.Hdr.Name)
270					inp.Addr = rr.AAAA // @Deprecated
271					inp.AddrV6 = rr.AAAA
272				}
273			}
274
275			if inp == nil {
276				continue
277			}
278
279			// Check if this entry is complete
280			if inp.complete() {
281				if inp.sent {
282					continue
283				}
284				inp.sent = true
285				select {
286				case params.Entries <- inp:
287				default:
288				}
289			} else {
290				// Fire off a node specific query
291				m := new(dns.Msg)
292				m.SetQuestion(inp.Name, dns.TypePTR)
293				m.RecursionDesired = false
294				if err := c.sendQuery(m); err != nil {
295					log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
296				}
297			}
298		case <-finish:
299			return nil
300		}
301	}
302}
303
304// sendQuery is used to multicast a query out
305func (c *client) sendQuery(q *dns.Msg) error {
306	buf, err := q.Pack()
307	if err != nil {
308		return err
309	}
310	if c.ipv4UnicastConn != nil {
311		_, err = c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
312		if err != nil {
313			return err
314		}
315	}
316	if c.ipv6UnicastConn != nil {
317		_, err = c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
318		if err != nil {
319			return err
320		}
321	}
322	return nil
323}
324
325// recv is used to receive until we get a shutdown
326func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
327	if l == nil {
328		return
329	}
330	buf := make([]byte, 65536)
331	for atomic.LoadInt32(&c.closed) == 0 {
332		n, err := l.Read(buf)
333
334		if atomic.LoadInt32(&c.closed) == 1 {
335			return
336		}
337
338		if err != nil {
339			log.Printf("[ERR] mdns: Failed to read packet: %v", err)
340			continue
341		}
342		msg := new(dns.Msg)
343		if err := msg.Unpack(buf[:n]); err != nil {
344			log.Printf("[ERR] mdns: Failed to unpack packet: %v", err)
345			continue
346		}
347		select {
348		case msgCh <- msg:
349		case <-c.closedCh:
350			return
351		}
352	}
353}
354
355// ensureName is used to ensure the named node is in progress
356func ensureName(inprogress map[string]*ServiceEntry, name string) *ServiceEntry {
357	if inp, ok := inprogress[name]; ok {
358		return inp
359	}
360	inp := &ServiceEntry{
361		Name: name,
362	}
363	inprogress[name] = inp
364	return inp
365}
366
367// alias is used to setup an alias between two entries
368func alias(inprogress map[string]*ServiceEntry, src, dst string) {
369	srcEntry := ensureName(inprogress, src)
370	inprogress[dst] = srcEntry
371}
372