1package peerdiscovery
2
3import (
4	"fmt"
5	"net"
6	"strconv"
7	"sync"
8	"time"
9
10	"golang.org/x/net/ipv4"
11	"golang.org/x/net/ipv6"
12)
13
14// IPVersion specifies the version of the Internet Protocol to be used.
15type IPVersion uint
16
17const (
18	IPv4 IPVersion = 4
19	IPv6 IPVersion = 6
20)
21
22// Discovered is the structure of the discovered peers,
23// which holds their local address (port removed) and
24// a payload if there is one.
25type Discovered struct {
26	// Address is the local address of a discovered peer.
27	Address string
28	// Payload is the associated payload from discovered peer.
29	Payload []byte
30}
31
32func (d Discovered) String() string {
33	return fmt.Sprintf("address: %s, payload: %s", d.Address, d.Payload)
34}
35
36// Settings are the settings that can be specified for
37// doing peer discovery.
38type Settings struct {
39	// Limit is the number of peers to discover, use < 1 for unlimited.
40	Limit int
41	// Port is the port to broadcast on (the peers must also broadcast using the same port).
42	// The default port is 9999.
43	Port string
44	// MulticastAddress specifies the multicast address.
45	// You should be able to use any of 224.0.0.0/4 or ff00::/8.
46	// By default it uses the Simple Service Discovery Protocol
47	// address (239.255.255.250 for IPv4 or ff02::c for IPv6).
48	MulticastAddress string
49	// Payload is the bytes that are sent out with each broadcast. Must be short.
50	Payload []byte
51	// PayloadFunc is the function that will be called to dynamically generate payload
52	// before every broadcast. If this pointer is nil `Payload` field will be broadcasted instead.
53	PayloadFunc func() []byte
54	// Delay is the amount of time between broadcasts. The default delay is 1 second.
55	Delay time.Duration
56	// TimeLimit is the amount of time to spend discovering, if the limit is not reached.
57	// A negative limit indiciates scanning until the limit was reached or, if an
58	// unlimited scanning was requested, no timeout.
59	// The default time limit is 10 seconds.
60	TimeLimit time.Duration
61	// StopChan is a channel to stop the peer discvoery immediatley after reception.
62	StopChan chan struct{}
63	// AllowSelf will allow discovery the local machine (default false)
64	AllowSelf bool
65	// DisableBroadcast will not allow sending out a broadcast
66	DisableBroadcast bool
67	// IPVersion specifies the version of the Internet Protocol (default IPv4)
68	IPVersion IPVersion
69	// Notify will be called each time a new peer was discovered.
70	// The default is nil, which means no notification whatsoever.
71	Notify func(Discovered)
72
73	portNum                 int
74	multicastAddressNumbers net.IP
75}
76
77// peerDiscovery is the object that can do the discovery for finding LAN peers.
78type peerDiscovery struct {
79	settings Settings
80
81	received map[string][]byte
82	sync.RWMutex
83}
84
85// initialize returns a new peerDiscovery object which can be used to discover peers.
86// The settings are optional. If any setting is not supplied, then defaults are used.
87// See the Settings for more information.
88func initialize(settings Settings) (p *peerDiscovery, err error) {
89	p = new(peerDiscovery)
90	p.Lock()
91	defer p.Unlock()
92
93	// initialize settings
94	p.settings = settings
95
96	// defaults
97	if p.settings.Port == "" {
98		p.settings.Port = "9999"
99	}
100	if p.settings.IPVersion == 0 {
101		p.settings.IPVersion = IPv4
102	}
103	if p.settings.MulticastAddress == "" {
104		if p.settings.IPVersion == IPv4 {
105			p.settings.MulticastAddress = "239.255.255.250"
106		} else {
107			p.settings.MulticastAddress = "ff02::c"
108		}
109	}
110	if len(p.settings.Payload) == 0 {
111		p.settings.Payload = []byte("hi")
112	}
113	if p.settings.Delay == 0 {
114		p.settings.Delay = 1 * time.Second
115	}
116	if p.settings.TimeLimit == 0 {
117		p.settings.TimeLimit = 10 * time.Second
118	}
119	if p.settings.StopChan == nil {
120		p.settings.StopChan = make(chan struct{})
121	}
122	p.received = make(map[string][]byte)
123	p.settings.multicastAddressNumbers = net.ParseIP(p.settings.MulticastAddress)
124	if p.settings.multicastAddressNumbers == nil {
125		err = fmt.Errorf("Multicast Address %s could not be converted to an IP",
126			p.settings.MulticastAddress)
127		return
128	}
129	p.settings.portNum, err = strconv.Atoi(p.settings.Port)
130	if err != nil {
131		return
132	}
133	return
134}
135
136type NetPacketConn interface {
137	JoinGroup(ifi *net.Interface, group net.Addr) error
138	SetMulticastInterface(ini *net.Interface) error
139	SetMulticastTTL(int) error
140	ReadFrom(buf []byte) (int, net.Addr, error)
141	WriteTo(buf []byte, dst net.Addr) (int, error)
142}
143
144// filterInterfaces returns a list of valid network interfaces
145func filterInterfaces(ipv4 bool) (ifaces []net.Interface, err error) {
146	allIfaces, err := net.Interfaces()
147	if err != nil {
148		return
149	}
150	ifaces = make([]net.Interface, 0, len(allIfaces))
151	for i := range allIfaces {
152		iface := allIfaces[i]
153		if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagBroadcast == 0 {
154			// interface is down or does not support broadcasting
155			continue
156		}
157		addrs, _ := iface.Addrs()
158		supported := false
159		for j := range addrs {
160			addr := addrs[j].(*net.IPNet)
161			if addr == nil || addr.IP == nil {
162				continue
163			}
164			isv4 := addr.IP.To4() != nil
165			if isv4 == ipv4 {
166				// IP family matches, go on and use interface
167				supported = true
168				break
169			}
170		}
171		if supported {
172			ifaces = append(ifaces, iface)
173		}
174	}
175	return
176}
177
178// Discover will use the created settings to scan for LAN peers. It will return
179// an array of the discovered peers and their associate payloads. It will not
180// return broadcasts sent to itself.
181func Discover(settings ...Settings) (discoveries []Discovered, err error) {
182	s := Settings{}
183	if len(settings) > 0 {
184		s = settings[0]
185	}
186	p, err := initialize(s)
187	if err != nil {
188		return
189	}
190
191	p.RLock()
192	address := net.JoinHostPort(p.settings.MulticastAddress, p.settings.Port)
193	portNum := p.settings.portNum
194
195	tickerDuration := p.settings.Delay
196	timeLimit := p.settings.TimeLimit
197	p.RUnlock()
198
199	ifaces, err := filterInterfaces(p.settings.IPVersion == IPv4)
200	if err != nil {
201		return
202	}
203	if len(ifaces) == 0 {
204		err = fmt.Errorf("no multicast interface found")
205		return
206	}
207
208	// Open up a connection
209	c, err := net.ListenPacket(fmt.Sprintf("udp%d", p.settings.IPVersion), address)
210	if err != nil {
211		return
212	}
213	defer c.Close()
214
215	group := p.settings.multicastAddressNumbers
216
217	// ipv{4,6} have an own PacketConn, which does not implement net.PacketConn
218	var p2 NetPacketConn
219	if p.settings.IPVersion == IPv4 {
220		p2 = PacketConn4{ipv4.NewPacketConn(c)}
221	} else {
222		p2 = PacketConn6{ipv6.NewPacketConn(c)}
223	}
224
225	for i := range ifaces {
226		p2.JoinGroup(&ifaces[i], &net.UDPAddr{IP: group, Port: portNum})
227	}
228
229	go p.listen()
230	ticker := time.NewTicker(tickerDuration)
231	defer ticker.Stop()
232	start := time.Now()
233
234	for {
235		exit := false
236
237		p.RLock()
238		if len(p.received) >= p.settings.Limit && p.settings.Limit > 0 {
239			exit = true
240		}
241		p.RUnlock()
242
243		if !s.DisableBroadcast {
244			payload := p.settings.Payload
245			if p.settings.PayloadFunc != nil {
246				payload = p.settings.PayloadFunc()
247			}
248			// write to multicast
249			broadcast(p2, payload, ifaces, &net.UDPAddr{IP: group, Port: portNum})
250		}
251
252		select {
253		case <-p.settings.StopChan:
254			exit = true
255		case <-ticker.C:
256		}
257
258		if exit || timeLimit > 0 && time.Since(start) > timeLimit {
259			break
260		}
261	}
262
263	if !s.DisableBroadcast {
264		payload := p.settings.Payload
265		if p.settings.PayloadFunc != nil {
266			payload = p.settings.PayloadFunc()
267		}
268		// send out broadcast that is finished
269		broadcast(p2, payload, ifaces, &net.UDPAddr{IP: group, Port: portNum})
270	}
271
272	p.RLock()
273	discoveries = make([]Discovered, len(p.received))
274	i := 0
275	for ip, payload := range p.received {
276		discoveries[i] = Discovered{
277			Address: ip,
278			Payload: payload,
279		}
280		i++
281	}
282	p.RUnlock()
283	return
284}
285
286func broadcast(p2 NetPacketConn, payload []byte, ifaces []net.Interface, dst net.Addr) {
287	for i := range ifaces {
288		if errMulticast := p2.SetMulticastInterface(&ifaces[i]); errMulticast != nil {
289			continue
290		}
291		p2.SetMulticastTTL(2)
292		if _, errMulticast := p2.WriteTo([]byte(payload), dst); errMulticast != nil {
293			continue
294		}
295	}
296}
297
298const (
299	// https://en.wikipedia.org/wiki/User_Datagram_Protocol#Packet_structure
300	maxDatagramSize = 66507
301)
302
303// Listen binds to the UDP address and port given and writes packets received
304// from that address to a buffer which is passed to a hander
305func (p *peerDiscovery) listen() (recievedBytes []byte, err error) {
306	p.RLock()
307	address := net.JoinHostPort(p.settings.MulticastAddress, p.settings.Port)
308	portNum := p.settings.portNum
309	allowSelf := p.settings.AllowSelf
310	timeLimit := p.settings.TimeLimit
311	notify := p.settings.Notify
312	p.RUnlock()
313	localIPs := getLocalIPs()
314
315	// get interfaces
316	ifaces, err := net.Interfaces()
317	if err != nil {
318		return
319	}
320	// log.Println(ifaces)
321
322	// Open up a connection
323	c, err := net.ListenPacket(fmt.Sprintf("udp%d", p.settings.IPVersion), address)
324	if err != nil {
325		return
326	}
327	defer c.Close()
328
329	group := p.settings.multicastAddressNumbers
330	var p2 NetPacketConn
331	if p.settings.IPVersion == IPv4 {
332		p2 = PacketConn4{ipv4.NewPacketConn(c)}
333	} else {
334		p2 = PacketConn6{ipv6.NewPacketConn(c)}
335	}
336
337	for i := range ifaces {
338		p2.JoinGroup(&ifaces[i], &net.UDPAddr{IP: group, Port: portNum})
339	}
340
341	start := time.Now()
342	// Loop forever reading from the socket
343	for {
344		buffer := make([]byte, maxDatagramSize)
345		var (
346			n       int
347			src     net.Addr
348			errRead error
349		)
350		n, src, errRead = p2.ReadFrom(buffer)
351		if errRead != nil {
352			err = errRead
353			return
354		}
355
356		srcHost, _, _ := net.SplitHostPort(src.String())
357
358		if _, ok := localIPs[srcHost]; ok && !allowSelf {
359			continue
360		}
361
362		// log.Println(src, hex.Dump(buffer[:n]))
363
364		p.Lock()
365		if _, ok := p.received[srcHost]; !ok {
366			p.received[srcHost] = buffer[:n]
367		}
368		p.Unlock()
369
370		if notify != nil {
371			notify(Discovered{
372				Address: srcHost,
373				Payload: buffer[:n],
374			})
375		}
376
377		p.RLock()
378		if len(p.received) >= p.settings.Limit && p.settings.Limit > 0 {
379			p.RUnlock()
380			break
381		}
382		if timeLimit > 0 && time.Since(start) > timeLimit {
383			p.RUnlock()
384			break
385		}
386		p.RUnlock()
387	}
388
389	return
390}
391
392// getLocalIPs returns the local ip address
393func getLocalIPs() (ips map[string]struct{}) {
394	ips = make(map[string]struct{})
395	ips["localhost"] = struct{}{}
396	ips["127.0.0.1"] = struct{}{}
397	ips["::1"] = struct{}{}
398
399	ifaces, err := net.Interfaces()
400	if err != nil {
401		return
402	}
403
404	for _, iface := range ifaces {
405		addrs, err := iface.Addrs()
406		if err != nil {
407			continue
408		}
409
410		for _, address := range addrs {
411			ip, _, err := net.ParseCIDR(address.String())
412			if err != nil {
413				// log.Printf("Failed to parse %s: %v", address.String(), err)
414				continue
415			}
416
417			ips[ip.String()+"%"+iface.Name] = struct{}{}
418			ips[ip.String()] = struct{}{}
419		}
420	}
421	return
422}
423
424type PacketConn4 struct {
425	*ipv4.PacketConn
426}
427
428// ReadFrom wraps the ipv4 ReadFrom without a control message
429func (pc4 PacketConn4) ReadFrom(buf []byte) (int, net.Addr, error) {
430	n, _, addr, err := pc4.PacketConn.ReadFrom(buf)
431	return n, addr, err
432}
433
434// WriteTo wraps the ipv4 WriteTo without a control message
435func (pc4 PacketConn4) WriteTo(buf []byte, dst net.Addr) (int, error) {
436	return pc4.PacketConn.WriteTo(buf, nil, dst)
437}
438
439type PacketConn6 struct {
440	*ipv6.PacketConn
441}
442
443// ReadFrom wraps the ipv6 ReadFrom without a control message
444func (pc6 PacketConn6) ReadFrom(buf []byte) (int, net.Addr, error) {
445	n, _, addr, err := pc6.PacketConn.ReadFrom(buf)
446	return n, addr, err
447}
448
449// WriteTo wraps the ipv6 WriteTo without a control message
450func (pc6 PacketConn6) WriteTo(buf []byte, dst net.Addr) (int, error) {
451	return pc6.PacketConn.WriteTo(buf, nil, dst)
452}
453
454// SetMulticastTTL wraps the hop limit of ipv6
455func (pc6 PacketConn6) SetMulticastTTL(i int) error {
456	return pc6.SetMulticastHopLimit(i)
457}
458