1package main 2 3import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "math/rand" 8 "net/url" 9 "os" 10 "path/filepath" 11 "strings" 12 "time" 13 14 "github.com/dchest/safefile" 15 16 "github.com/jedisct1/dlog" 17 "github.com/jedisct1/go-dnsstamps" 18 "github.com/jedisct1/go-minisign" 19) 20 21type SourceFormat int 22 23const ( 24 SourceFormatV2 = iota 25) 26 27const ( 28 DefaultPrefetchDelay time.Duration = 24 * time.Hour 29 MinimumPrefetchInterval time.Duration = 10 * time.Minute 30) 31 32type Source struct { 33 name string 34 urls []*url.URL 35 format SourceFormat 36 in []byte 37 minisignKey *minisign.PublicKey 38 cacheFile string 39 cacheTTL, prefetchDelay time.Duration 40 refresh time.Time 41 prefix string 42} 43 44func (source *Source) checkSignature(bin, sig []byte) (err error) { 45 var signature minisign.Signature 46 if signature, err = minisign.DecodeSignature(string(sig)); err == nil { 47 _, err = source.minisignKey.Verify(bin, signature) 48 } 49 return err 50} 51 52// timeNow can be replaced by tests to provide a static value 53var timeNow = time.Now 54 55func (source *Source) fetchFromCache(now time.Time) (delay time.Duration, err error) { 56 var bin, sig []byte 57 if bin, err = ioutil.ReadFile(source.cacheFile); err != nil { 58 return 59 } 60 if sig, err = ioutil.ReadFile(source.cacheFile + ".minisig"); err != nil { 61 return 62 } 63 if err = source.checkSignature(bin, sig); err != nil { 64 return 65 } 66 source.in = bin 67 var fi os.FileInfo 68 if fi, err = os.Stat(source.cacheFile); err != nil { 69 return 70 } 71 if elapsed := now.Sub(fi.ModTime()); elapsed < source.cacheTTL { 72 delay = source.prefetchDelay - elapsed 73 dlog.Debugf("Source [%s] cache file [%s] is still fresh, next update: %v", source.name, source.cacheFile, delay) 74 } else { 75 dlog.Debugf("Source [%s] cache file [%s] needs to be refreshed", source.name, source.cacheFile) 76 } 77 return 78} 79 80func writeSource(f string, bin, sig []byte) (err error) { 81 var fSrc, fSig *safefile.File 82 if fSrc, err = safefile.Create(f, 0644); err != nil { 83 return 84 } 85 defer fSrc.Close() 86 if fSig, err = safefile.Create(f+".minisig", 0644); err != nil { 87 return 88 } 89 defer fSig.Close() 90 if _, err = fSrc.Write(bin); err != nil { 91 return 92 } 93 if _, err = fSig.Write(sig); err != nil { 94 return 95 } 96 if err = fSrc.Commit(); err != nil { 97 return 98 } 99 return fSig.Commit() 100} 101 102func (source *Source) writeToCache(bin, sig []byte, now time.Time) { 103 f := source.cacheFile 104 var writeErr error // an error writing cache isn't fatal 105 defer func() { 106 source.in = bin 107 if writeErr == nil { 108 return 109 } 110 if absPath, absErr := filepath.Abs(f); absErr == nil { 111 f = absPath 112 } 113 dlog.Warnf("%s: %s", f, writeErr) 114 }() 115 if !bytes.Equal(source.in, bin) { 116 if writeErr = writeSource(f, bin, sig); writeErr != nil { 117 return 118 } 119 } 120 writeErr = os.Chtimes(f, now, now) 121} 122 123func (source *Source) parseURLs(urls []string) { 124 for _, urlStr := range urls { 125 if srcURL, err := url.Parse(urlStr); err != nil { 126 dlog.Warnf("Source [%s] failed to parse URL [%s]", source.name, urlStr) 127 } else { 128 source.urls = append(source.urls, srcURL) 129 } 130 } 131} 132 133func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { 134 bin, _, _, _, err = xTransport.Get(u, "", DefaultTimeout) 135 return bin, err 136} 137 138func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) { 139 if delay, err = source.fetchFromCache(now); err != nil { 140 if len(source.urls) == 0 { 141 dlog.Errorf("Source [%s] cache file [%s] not present and no valid URL", source.name, source.cacheFile) 142 return 143 } 144 dlog.Debugf("Source [%s] cache file [%s] not present", source.name, source.cacheFile) 145 } 146 if len(source.urls) > 0 { 147 defer func() { 148 source.refresh = now.Add(delay) 149 }() 150 } 151 if len(source.urls) == 0 || delay > 0 { 152 return 153 } 154 delay = MinimumPrefetchInterval 155 var bin, sig []byte 156 for _, srcURL := range source.urls { 157 dlog.Infof("Source [%s] loading from URL [%s]", source.name, srcURL) 158 sigURL := &url.URL{} 159 *sigURL = *srcURL // deep copy to avoid parsing twice 160 sigURL.Path += ".minisig" 161 if bin, err = fetchFromURL(xTransport, srcURL); err != nil { 162 dlog.Debugf("Source [%s] failed to download from URL [%s]", source.name, srcURL) 163 continue 164 } 165 if sig, err = fetchFromURL(xTransport, sigURL); err != nil { 166 dlog.Debugf("Source [%s] failed to download signature from URL [%s]", source.name, sigURL) 167 continue 168 } 169 if err = source.checkSignature(bin, sig); err == nil { 170 break // valid signature 171 } // above err check inverted to make use of implicit continue 172 dlog.Debugf("Source [%s] failed signature check using URL [%s]", source.name, srcURL) 173 } 174 if err != nil { 175 return 176 } 177 source.writeToCache(bin, sig, now) 178 delay = source.prefetchDelay 179 return 180} 181 182// NewSource loads a new source using the given cacheFile and urls, ensuring it has a valid signature 183func NewSource(name string, xTransport *XTransport, urls []string, minisignKeyStr string, cacheFile string, formatStr string, refreshDelay time.Duration, prefix string) (source *Source, err error) { 184 if refreshDelay < DefaultPrefetchDelay { 185 refreshDelay = DefaultPrefetchDelay 186 } 187 source = &Source{name: name, urls: []*url.URL{}, cacheFile: cacheFile, cacheTTL: refreshDelay, prefetchDelay: DefaultPrefetchDelay, prefix: prefix} 188 if formatStr == "v2" { 189 source.format = SourceFormatV2 190 } else { 191 return source, fmt.Errorf("Unsupported source format: [%s]", formatStr) 192 } 193 if minisignKey, err := minisign.NewPublicKey(minisignKeyStr); err == nil { 194 source.minisignKey = &minisignKey 195 } else { 196 return source, err 197 } 198 source.parseURLs(urls) 199 if _, err = source.fetchWithCache(xTransport, timeNow()); err == nil { 200 dlog.Noticef("Source [%s] loaded", name) 201 } 202 return 203} 204 205// PrefetchSources downloads latest versions of given sources, ensuring they have a valid signature before caching 206func PrefetchSources(xTransport *XTransport, sources []*Source) time.Duration { 207 now := timeNow() 208 interval := MinimumPrefetchInterval 209 for _, source := range sources { 210 if source.refresh.IsZero() || source.refresh.After(now) { 211 continue 212 } 213 dlog.Debugf("Prefetching [%s]", source.name) 214 if delay, err := source.fetchWithCache(xTransport, now); err != nil { 215 dlog.Infof("Prefetching [%s] failed: %v, will retry in %v", source.name, err, interval) 216 } else { 217 dlog.Debugf("Prefetching [%s] succeeded, next update: %v", source.name, delay) 218 if delay >= MinimumPrefetchInterval && (interval == MinimumPrefetchInterval || interval > delay) { 219 interval = delay 220 } 221 } 222 } 223 return interval 224} 225 226func (source *Source) Parse() ([]RegisteredServer, error) { 227 if source.format == SourceFormatV2 { 228 return source.parseV2() 229 } 230 dlog.Fatal("Unexpected source format") 231 return []RegisteredServer{}, nil 232} 233 234func (source *Source) parseV2() ([]RegisteredServer, error) { 235 var registeredServers []RegisteredServer 236 var stampErrs []string 237 appendStampErr := func(format string, a ...interface{}) { 238 stampErr := fmt.Sprintf(format, a...) 239 stampErrs = append(stampErrs, stampErr) 240 dlog.Warn(stampErr) 241 } 242 in := string(source.in) 243 parts := strings.Split(in, "## ") 244 if len(parts) < 2 { 245 return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls) 246 } 247 parts = parts[1:] 248 for _, part := range parts { 249 part = strings.TrimSpace(part) 250 subparts := strings.Split(part, "\n") 251 if len(subparts) < 2 { 252 return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls) 253 } 254 name := strings.TrimSpace(subparts[0]) 255 if len(name) == 0 { 256 return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls) 257 } 258 subparts = subparts[1:] 259 name = source.prefix + name 260 var stampStr, description string 261 stampStrs := make([]string, 0) 262 for _, subpart := range subparts { 263 subpart = strings.TrimSpace(subpart) 264 if strings.HasPrefix(subpart, "sdns:") && len(subpart) >= 6 { 265 stampStrs = append(stampStrs, subpart) 266 continue 267 } else if len(subpart) == 0 || strings.HasPrefix(subpart, "//") { 268 continue 269 } 270 if len(description) > 0 { 271 description += "\n" 272 } 273 description += subpart 274 } 275 stampStrsLen := len(stampStrs) 276 if stampStrsLen <= 0 { 277 appendStampErr("Missing stamp for server [%s]", name) 278 continue 279 } else if stampStrsLen > 1 { 280 rand.Shuffle(stampStrsLen, func(i, j int) { stampStrs[i], stampStrs[j] = stampStrs[j], stampStrs[i] }) 281 } 282 var stamp dnsstamps.ServerStamp 283 var err error 284 for _, stampStr = range stampStrs { 285 stamp, err = dnsstamps.NewServerStampFromString(stampStr) 286 if err == nil { 287 break 288 } 289 appendStampErr("Invalid or unsupported stamp [%v]: %s", stampStr, err.Error()) 290 } 291 if err != nil { 292 continue 293 } 294 registeredServer := RegisteredServer{ 295 name: name, stamp: stamp, description: description, 296 } 297 dlog.Debugf("Registered [%s] with stamp [%s]", name, stamp.String()) 298 registeredServers = append(registeredServers, registeredServer) 299 } 300 if len(stampErrs) > 0 { 301 return registeredServers, fmt.Errorf("%s", strings.Join(stampErrs, ", ")) 302 } 303 return registeredServers, nil 304} 305