1package agent
2
3import (
4	"fmt"
5	"io"
6	"log"
7	"net"
8	"time"
9
10	"github.com/hashicorp/mdns"
11)
12
13const (
14	mdnsPollInterval  = 60 * time.Second
15	mdnsQuietInterval = 100 * time.Millisecond
16)
17
18// AgentMDNS is used to advertise ourself using mDNS and to
19// attempt to join peers periodically using mDNS queries.
20type AgentMDNS struct {
21	agent    *Agent
22	discover string
23	logger   *log.Logger
24	seen     map[string]struct{}
25	server   *mdns.Server
26	replay   bool
27	iface    *net.Interface
28}
29
30// NewAgentMDNS is used to create a new AgentMDNS
31func NewAgentMDNS(agent *Agent, logOutput io.Writer, replay bool,
32	node, discover string, iface *net.Interface, bind net.IP, port int) (*AgentMDNS, error) {
33	// Create the service
34	service, err := mdns.NewMDNSService(
35		node,
36		mdnsName(discover),
37		"",
38		"",
39		port,
40		[]net.IP{bind},
41		[]string{fmt.Sprintf("Serf '%s' cluster", discover)})
42	if err != nil {
43		return nil, err
44	}
45
46	// Configure mdns server
47	conf := &mdns.Config{
48		Zone:  service,
49		Iface: iface,
50	}
51
52	// Create the server
53	server, err := mdns.NewServer(conf)
54	if err != nil {
55		return nil, err
56	}
57
58	// Initialize the AgentMDNS
59	m := &AgentMDNS{
60		agent:    agent,
61		discover: discover,
62		logger:   log.New(logOutput, "", log.LstdFlags),
63		seen:     make(map[string]struct{}),
64		server:   server,
65		replay:   replay,
66		iface:    iface,
67	}
68
69	// Start the background workers
70	go m.run()
71	return m, nil
72}
73
74// run is a long running goroutine that scans for new hosts periodically
75func (m *AgentMDNS) run() {
76	hosts := make(chan *mdns.ServiceEntry, 32)
77	poll := time.After(0)
78	var quiet <-chan time.Time
79	var join []string
80
81	for {
82		select {
83		case h := <-hosts:
84			// Format the host address
85			addr := net.TCPAddr{IP: h.Addr, Port: h.Port}
86			addrS := addr.String()
87
88			// Skip if we've handled this host already
89			if _, ok := m.seen[addrS]; ok {
90				continue
91			}
92
93			// Queue for handling
94			join = append(join, addrS)
95			quiet = time.After(mdnsQuietInterval)
96
97		case <-quiet:
98			// Attempt the join
99			n, err := m.agent.Join(join, m.replay)
100			if err != nil {
101				m.logger.Printf("[ERR] agent.mdns: Failed to join: %v", err)
102			}
103			if n > 0 {
104				m.logger.Printf("[INFO] agent.mdns: Joined %d hosts", n)
105			}
106
107			// Mark all as seen
108			for _, n := range join {
109				m.seen[n] = struct{}{}
110			}
111			join = nil
112
113		case <-poll:
114			poll = time.After(mdnsPollInterval)
115			go m.poll(hosts)
116		}
117	}
118}
119
120// poll is invoked periodically to check for new hosts
121func (m *AgentMDNS) poll(hosts chan *mdns.ServiceEntry) {
122	params := mdns.QueryParam{
123		Service:   mdnsName(m.discover),
124		Interface: m.iface,
125		Entries:   hosts,
126	}
127	if err := mdns.Query(&params); err != nil {
128		m.logger.Printf("[ERR] agent.mdns: Failed to poll for new hosts: %v", err)
129	}
130}
131
132// mdnsName returns the service name to register and to lookup
133func mdnsName(discover string) string {
134	return fmt.Sprintf("_serf_%s._tcp", discover)
135}
136