1package etchosts
2
3import (
4	"bufio"
5	"bytes"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"os"
10	"regexp"
11	"strings"
12	"sync"
13)
14
15// Record Structure for a single host record
16type Record struct {
17	Hosts string
18	IP    string
19}
20
21// WriteTo writes record to file and returns bytes written or error
22func (r Record) WriteTo(w io.Writer) (int64, error) {
23	n, err := fmt.Fprintf(w, "%s\t%s\n", r.IP, r.Hosts)
24	return int64(n), err
25}
26
27var (
28	// Default hosts config records slice
29	defaultContent = []Record{
30		{Hosts: "localhost", IP: "127.0.0.1"},
31		{Hosts: "localhost ip6-localhost ip6-loopback", IP: "::1"},
32		{Hosts: "ip6-localnet", IP: "fe00::0"},
33		{Hosts: "ip6-mcastprefix", IP: "ff00::0"},
34		{Hosts: "ip6-allnodes", IP: "ff02::1"},
35		{Hosts: "ip6-allrouters", IP: "ff02::2"},
36	}
37
38	// A cache of path level locks for synchronizing /etc/hosts
39	// updates on a file level
40	pathMap = make(map[string]*sync.Mutex)
41
42	// A package level mutex to synchronize the cache itself
43	pathMutex sync.Mutex
44)
45
46func pathLock(path string) func() {
47	pathMutex.Lock()
48	defer pathMutex.Unlock()
49
50	pl, ok := pathMap[path]
51	if !ok {
52		pl = &sync.Mutex{}
53		pathMap[path] = pl
54	}
55
56	pl.Lock()
57	return func() {
58		pl.Unlock()
59	}
60}
61
62// Drop drops the path string from the path cache
63func Drop(path string) {
64	pathMutex.Lock()
65	defer pathMutex.Unlock()
66
67	delete(pathMap, path)
68}
69
70// Build function
71// path is path to host file string required
72// IP, hostname, and domainname set main record leave empty for no master record
73// extraContent is an array of extra host records.
74func Build(path, IP, hostname, domainname string, extraContent []Record) error {
75	defer pathLock(path)()
76
77	content := bytes.NewBuffer(nil)
78	if IP != "" {
79		//set main record
80		var mainRec Record
81		mainRec.IP = IP
82		// User might have provided a FQDN in hostname or split it across hostname
83		// and domainname.  We want the FQDN and the bare hostname.
84		fqdn := hostname
85		if domainname != "" {
86			fqdn = fmt.Sprintf("%s.%s", fqdn, domainname)
87		}
88		parts := strings.SplitN(fqdn, ".", 2)
89		if len(parts) == 2 {
90			mainRec.Hosts = fmt.Sprintf("%s %s", fqdn, parts[0])
91		} else {
92			mainRec.Hosts = fqdn
93		}
94		if _, err := mainRec.WriteTo(content); err != nil {
95			return err
96		}
97	}
98	// Write defaultContent slice to buffer
99	for _, r := range defaultContent {
100		if _, err := r.WriteTo(content); err != nil {
101			return err
102		}
103	}
104	// Write extra content from function arguments
105	for _, r := range extraContent {
106		if _, err := r.WriteTo(content); err != nil {
107			return err
108		}
109	}
110
111	return ioutil.WriteFile(path, content.Bytes(), 0644)
112}
113
114// Add adds an arbitrary number of Records to an already existing /etc/hosts file
115func Add(path string, recs []Record) error {
116	defer pathLock(path)()
117
118	if len(recs) == 0 {
119		return nil
120	}
121
122	b, err := mergeRecords(path, recs)
123	if err != nil {
124		return err
125	}
126
127	return ioutil.WriteFile(path, b, 0644)
128}
129
130func mergeRecords(path string, recs []Record) ([]byte, error) {
131	f, err := os.Open(path)
132	if err != nil {
133		return nil, err
134	}
135	defer f.Close()
136
137	content := bytes.NewBuffer(nil)
138
139	if _, err := content.ReadFrom(f); err != nil {
140		return nil, err
141	}
142
143	for _, r := range recs {
144		if _, err := r.WriteTo(content); err != nil {
145			return nil, err
146		}
147	}
148
149	return content.Bytes(), nil
150}
151
152// Delete deletes an arbitrary number of Records already existing in /etc/hosts file
153func Delete(path string, recs []Record) error {
154	defer pathLock(path)()
155
156	if len(recs) == 0 {
157		return nil
158	}
159	old, err := os.Open(path)
160	if err != nil {
161		return err
162	}
163
164	var buf bytes.Buffer
165
166	s := bufio.NewScanner(old)
167	eol := []byte{'\n'}
168loop:
169	for s.Scan() {
170		b := s.Bytes()
171		if len(b) == 0 {
172			continue
173		}
174
175		if b[0] == '#' {
176			buf.Write(b)
177			buf.Write(eol)
178			continue
179		}
180		for _, r := range recs {
181			if bytes.HasSuffix(b, []byte("\t"+r.Hosts)) {
182				continue loop
183			}
184		}
185		buf.Write(b)
186		buf.Write(eol)
187	}
188	old.Close()
189	if err := s.Err(); err != nil {
190		return err
191	}
192	return ioutil.WriteFile(path, buf.Bytes(), 0644)
193}
194
195// Update all IP addresses where hostname matches.
196// path is path to host file
197// IP is new IP address
198// hostname is hostname to search for to replace IP
199func Update(path, IP, hostname string) error {
200	defer pathLock(path)()
201
202	old, err := ioutil.ReadFile(path)
203	if err != nil {
204		return err
205	}
206	var re = regexp.MustCompile(fmt.Sprintf("(\\S*)(\\t%s)(\\s|\\.)", regexp.QuoteMeta(hostname)))
207	return ioutil.WriteFile(path, re.ReplaceAll(old, []byte(IP+"$2"+"$3")), 0644)
208}
209