1package main
2
3import (
4	"bytes"
5	"fmt"
6	"io/ioutil"
7	"math/rand"
8	"net/url"
9	"os"
10	"path/filepath"
11	"strings"
12	"time"
13
14	"github.com/dchest/safefile"
15
16	"github.com/jedisct1/dlog"
17	"github.com/jedisct1/go-dnsstamps"
18	"github.com/jedisct1/go-minisign"
19)
20
21type SourceFormat int
22
23const (
24	SourceFormatV2 = iota
25)
26
27const (
28	DefaultPrefetchDelay    time.Duration = 24 * time.Hour
29	MinimumPrefetchInterval time.Duration = 10 * time.Minute
30)
31
32type Source struct {
33	name                    string
34	urls                    []*url.URL
35	format                  SourceFormat
36	in                      []byte
37	minisignKey             *minisign.PublicKey
38	cacheFile               string
39	cacheTTL, prefetchDelay time.Duration
40	refresh                 time.Time
41	prefix                  string
42}
43
44func (source *Source) checkSignature(bin, sig []byte) (err error) {
45	var signature minisign.Signature
46	if signature, err = minisign.DecodeSignature(string(sig)); err == nil {
47		_, err = source.minisignKey.Verify(bin, signature)
48	}
49	return err
50}
51
52// timeNow can be replaced by tests to provide a static value
53var timeNow = time.Now
54
55func (source *Source) fetchFromCache(now time.Time) (delay time.Duration, err error) {
56	var bin, sig []byte
57	if bin, err = ioutil.ReadFile(source.cacheFile); err != nil {
58		return
59	}
60	if sig, err = ioutil.ReadFile(source.cacheFile + ".minisig"); err != nil {
61		return
62	}
63	if err = source.checkSignature(bin, sig); err != nil {
64		return
65	}
66	source.in = bin
67	var fi os.FileInfo
68	if fi, err = os.Stat(source.cacheFile); err != nil {
69		return
70	}
71	if elapsed := now.Sub(fi.ModTime()); elapsed < source.cacheTTL {
72		delay = source.prefetchDelay - elapsed
73		dlog.Debugf("Source [%s] cache file [%s] is still fresh, next update: %v", source.name, source.cacheFile, delay)
74	} else {
75		dlog.Debugf("Source [%s] cache file [%s] needs to be refreshed", source.name, source.cacheFile)
76	}
77	return
78}
79
80func writeSource(f string, bin, sig []byte) (err error) {
81	var fSrc, fSig *safefile.File
82	if fSrc, err = safefile.Create(f, 0644); err != nil {
83		return
84	}
85	defer fSrc.Close()
86	if fSig, err = safefile.Create(f+".minisig", 0644); err != nil {
87		return
88	}
89	defer fSig.Close()
90	if _, err = fSrc.Write(bin); err != nil {
91		return
92	}
93	if _, err = fSig.Write(sig); err != nil {
94		return
95	}
96	if err = fSrc.Commit(); err != nil {
97		return
98	}
99	return fSig.Commit()
100}
101
102func (source *Source) writeToCache(bin, sig []byte, now time.Time) {
103	f := source.cacheFile
104	var writeErr error // an error writing cache isn't fatal
105	defer func() {
106		source.in = bin
107		if writeErr == nil {
108			return
109		}
110		if absPath, absErr := filepath.Abs(f); absErr == nil {
111			f = absPath
112		}
113		dlog.Warnf("%s: %s", f, writeErr)
114	}()
115	if !bytes.Equal(source.in, bin) {
116		if writeErr = writeSource(f, bin, sig); writeErr != nil {
117			return
118		}
119	}
120	writeErr = os.Chtimes(f, now, now)
121}
122
123func (source *Source) parseURLs(urls []string) {
124	for _, urlStr := range urls {
125		if srcURL, err := url.Parse(urlStr); err != nil {
126			dlog.Warnf("Source [%s] failed to parse URL [%s]", source.name, urlStr)
127		} else {
128			source.urls = append(source.urls, srcURL)
129		}
130	}
131}
132
133func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) {
134	bin, _, _, _, err = xTransport.Get(u, "", DefaultTimeout)
135	return bin, err
136}
137
138func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) {
139	if delay, err = source.fetchFromCache(now); err != nil {
140		if len(source.urls) == 0 {
141			dlog.Errorf("Source [%s] cache file [%s] not present and no valid URL", source.name, source.cacheFile)
142			return
143		}
144		dlog.Debugf("Source [%s] cache file [%s] not present", source.name, source.cacheFile)
145	}
146	if len(source.urls) > 0 {
147		defer func() {
148			source.refresh = now.Add(delay)
149		}()
150	}
151	if len(source.urls) == 0 || delay > 0 {
152		return
153	}
154	delay = MinimumPrefetchInterval
155	var bin, sig []byte
156	for _, srcURL := range source.urls {
157		dlog.Infof("Source [%s] loading from URL [%s]", source.name, srcURL)
158		sigURL := &url.URL{}
159		*sigURL = *srcURL // deep copy to avoid parsing twice
160		sigURL.Path += ".minisig"
161		if bin, err = fetchFromURL(xTransport, srcURL); err != nil {
162			dlog.Debugf("Source [%s] failed to download from URL [%s]", source.name, srcURL)
163			continue
164		}
165		if sig, err = fetchFromURL(xTransport, sigURL); err != nil {
166			dlog.Debugf("Source [%s] failed to download signature from URL [%s]", source.name, sigURL)
167			continue
168		}
169		if err = source.checkSignature(bin, sig); err == nil {
170			break // valid signature
171		} // above err check inverted to make use of implicit continue
172		dlog.Debugf("Source [%s] failed signature check using URL [%s]", source.name, srcURL)
173	}
174	if err != nil {
175		return
176	}
177	source.writeToCache(bin, sig, now)
178	delay = source.prefetchDelay
179	return
180}
181
182// NewSource loads a new source using the given cacheFile and urls, ensuring it has a valid signature
183func NewSource(name string, xTransport *XTransport, urls []string, minisignKeyStr string, cacheFile string, formatStr string, refreshDelay time.Duration, prefix string) (source *Source, err error) {
184	if refreshDelay < DefaultPrefetchDelay {
185		refreshDelay = DefaultPrefetchDelay
186	}
187	source = &Source{name: name, urls: []*url.URL{}, cacheFile: cacheFile, cacheTTL: refreshDelay, prefetchDelay: DefaultPrefetchDelay, prefix: prefix}
188	if formatStr == "v2" {
189		source.format = SourceFormatV2
190	} else {
191		return source, fmt.Errorf("Unsupported source format: [%s]", formatStr)
192	}
193	if minisignKey, err := minisign.NewPublicKey(minisignKeyStr); err == nil {
194		source.minisignKey = &minisignKey
195	} else {
196		return source, err
197	}
198	source.parseURLs(urls)
199	if _, err = source.fetchWithCache(xTransport, timeNow()); err == nil {
200		dlog.Noticef("Source [%s] loaded", name)
201	}
202	return
203}
204
205// PrefetchSources downloads latest versions of given sources, ensuring they have a valid signature before caching
206func PrefetchSources(xTransport *XTransport, sources []*Source) time.Duration {
207	now := timeNow()
208	interval := MinimumPrefetchInterval
209	for _, source := range sources {
210		if source.refresh.IsZero() || source.refresh.After(now) {
211			continue
212		}
213		dlog.Debugf("Prefetching [%s]", source.name)
214		if delay, err := source.fetchWithCache(xTransport, now); err != nil {
215			dlog.Infof("Prefetching [%s] failed: %v, will retry in %v", source.name, err, interval)
216		} else {
217			dlog.Debugf("Prefetching [%s] succeeded, next update: %v", source.name, delay)
218			if delay >= MinimumPrefetchInterval && (interval == MinimumPrefetchInterval || interval > delay) {
219				interval = delay
220			}
221		}
222	}
223	return interval
224}
225
226func (source *Source) Parse() ([]RegisteredServer, error) {
227	if source.format == SourceFormatV2 {
228		return source.parseV2()
229	}
230	dlog.Fatal("Unexpected source format")
231	return []RegisteredServer{}, nil
232}
233
234func (source *Source) parseV2() ([]RegisteredServer, error) {
235	var registeredServers []RegisteredServer
236	var stampErrs []string
237	appendStampErr := func(format string, a ...interface{}) {
238		stampErr := fmt.Sprintf(format, a...)
239		stampErrs = append(stampErrs, stampErr)
240		dlog.Warn(stampErr)
241	}
242	in := string(source.in)
243	parts := strings.Split(in, "## ")
244	if len(parts) < 2 {
245		return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls)
246	}
247	parts = parts[1:]
248	for _, part := range parts {
249		part = strings.TrimSpace(part)
250		subparts := strings.Split(part, "\n")
251		if len(subparts) < 2 {
252			return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls)
253		}
254		name := strings.TrimSpace(subparts[0])
255		if len(name) == 0 {
256			return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls)
257		}
258		subparts = subparts[1:]
259		name = source.prefix + name
260		var stampStr, description string
261		stampStrs := make([]string, 0)
262		for _, subpart := range subparts {
263			subpart = strings.TrimSpace(subpart)
264			if strings.HasPrefix(subpart, "sdns:") && len(subpart) >= 6 {
265				stampStrs = append(stampStrs, subpart)
266				continue
267			} else if len(subpart) == 0 || strings.HasPrefix(subpart, "//") {
268				continue
269			}
270			if len(description) > 0 {
271				description += "\n"
272			}
273			description += subpart
274		}
275		stampStrsLen := len(stampStrs)
276		if stampStrsLen <= 0 {
277			appendStampErr("Missing stamp for server [%s]", name)
278			continue
279		} else if stampStrsLen > 1 {
280			rand.Shuffle(stampStrsLen, func(i, j int) { stampStrs[i], stampStrs[j] = stampStrs[j], stampStrs[i] })
281		}
282		var stamp dnsstamps.ServerStamp
283		var err error
284		for _, stampStr = range stampStrs {
285			stamp, err = dnsstamps.NewServerStampFromString(stampStr)
286			if err == nil {
287				break
288			}
289			appendStampErr("Invalid or unsupported stamp [%v]: %s", stampStr, err.Error())
290		}
291		if err != nil {
292			continue
293		}
294		registeredServer := RegisteredServer{
295			name: name, stamp: stamp, description: description,
296		}
297		dlog.Debugf("Registered [%s] with stamp [%s]", name, stamp.String())
298		registeredServers = append(registeredServers, registeredServer)
299	}
300	if len(stampErrs) > 0 {
301		return registeredServers, fmt.Errorf("%s", strings.Join(stampErrs, ", "))
302	}
303	return registeredServers, nil
304}
305