1// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
2
3package main
4
5import (
6	"compress/gzip"
7	"context"
8	"crypto/tls"
9	"crypto/x509"
10	"encoding/json"
11	"flag"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"log"
16	"net"
17	"net/http"
18	"net/url"
19	"os"
20	"path/filepath"
21	"strconv"
22	"strings"
23	"time"
24
25	"github.com/syncthing/syncthing/lib/protocol"
26
27	"github.com/golang/groupcache/lru"
28	"github.com/oschwald/geoip2-golang"
29	"github.com/prometheus/client_golang/prometheus"
30	"github.com/prometheus/client_golang/prometheus/promhttp"
31	"github.com/syncthing/syncthing/cmd/strelaypoolsrv/auto"
32	"github.com/syncthing/syncthing/lib/assets"
33	"github.com/syncthing/syncthing/lib/rand"
34	"github.com/syncthing/syncthing/lib/relay/client"
35	"github.com/syncthing/syncthing/lib/sync"
36	"github.com/syncthing/syncthing/lib/tlsutil"
37	"golang.org/x/time/rate"
38)
39
40type location struct {
41	Latitude  float64 `json:"latitude"`
42	Longitude float64 `json:"longitude"`
43	City      string  `json:"city"`
44	Country   string  `json:"country"`
45	Continent string  `json:"continent"`
46}
47
48type relay struct {
49	URL            string   `json:"url"`
50	Location       location `json:"location"`
51	uri            *url.URL
52	Stats          *stats    `json:"stats"`
53	StatsRetrieved time.Time `json:"statsRetrieved"`
54}
55
56type stats struct {
57	StartTime          time.Time `json:"startTime"`
58	UptimeSeconds      int       `json:"uptimeSeconds"`
59	PendingSessionKeys int       `json:"numPendingSessionKeys"`
60	ActiveSessions     int       `json:"numActiveSessions"`
61	Connections        int       `json:"numConnections"`
62	Proxies            int       `json:"numProxies"`
63	BytesProxied       int       `json:"bytesProxied"`
64	GoVersion          string    `json:"goVersion"`
65	GoOS               string    `json:"goOS"`
66	GoArch             string    `json:"goArch"`
67	GoMaxProcs         int       `json:"goMaxProcs"`
68	GoRoutines         int       `json:"goNumRoutine"`
69	Rates              []int64   `json:"kbps10s1m5m15m30m60m"`
70	Options            struct {
71		NetworkTimeout int      `json:"network-timeout"`
72		PintInterval   int      `json:"ping-interval"`
73		MessageTimeout int      `json:"message-timeout"`
74		SessionRate    int      `json:"per-session-rate"`
75		GlobalRate     int      `json:"global-rate"`
76		Pools          []string `json:"pools"`
77		ProvidedBy     string   `json:"provided-by"`
78	} `json:"options"`
79}
80
81func (r relay) String() string {
82	return r.URL
83}
84
85type request struct {
86	relay      *relay
87	result     chan result
88	queueTimer *prometheus.Timer
89}
90
91type result struct {
92	err      error
93	eviction time.Duration
94}
95
96var (
97	testCert          tls.Certificate
98	knownRelaysFile   = filepath.Join(os.TempDir(), "strelaypoolsrv_known_relays")
99	listen            = ":80"
100	dir               string
101	evictionTime      = time.Hour
102	debug             bool
103	getLRUSize        = 10 << 10
104	getLimitBurst     = 10
105	getLimitAvg       = 2
106	postLRUSize       = 1 << 10
107	postLimitBurst    = 2
108	postLimitAvg      = 2
109	getLimit          time.Duration
110	postLimit         time.Duration
111	permRelaysFile    string
112	ipHeader          string
113	geoipPath         string
114	proto             string
115	statsRefresh      = time.Minute / 2
116	requestQueueLen   = 10
117	requestProcessors = 1
118
119	getMut      = sync.NewMutex()
120	getLRUCache *lru.Cache
121
122	postMut      = sync.NewMutex()
123	postLRUCache *lru.Cache
124
125	requests chan request
126
127	mut             = sync.NewRWMutex()
128	knownRelays     = make([]*relay, 0)
129	permanentRelays = make([]*relay, 0)
130	evictionTimers  = make(map[string]*time.Timer)
131)
132
133const (
134	httpStatusEnhanceYourCalm = 429
135)
136
137func main() {
138	log.SetOutput(os.Stdout)
139	log.SetFlags(log.Lshortfile)
140
141	flag.StringVar(&listen, "listen", listen, "Listen address")
142	flag.StringVar(&dir, "keys", dir, "Directory where http-cert.pem and http-key.pem is stored for TLS listening")
143	flag.BoolVar(&debug, "debug", debug, "Enable debug output")
144	flag.DurationVar(&evictionTime, "eviction", evictionTime, "After how long the relay is evicted")
145	flag.IntVar(&getLRUSize, "get-limit-cache", getLRUSize, "Get request limiter cache size")
146	flag.IntVar(&getLimitAvg, "get-limit-avg", getLimitAvg, "Allowed average get request rate, per 10 s")
147	flag.IntVar(&getLimitBurst, "get-limit-burst", getLimitBurst, "Allowed burst get requests")
148	flag.IntVar(&postLRUSize, "post-limit-cache", postLRUSize, "Post request limiter cache size")
149	flag.IntVar(&postLimitAvg, "post-limit-avg", postLimitAvg, "Allowed average post request rate, per minute")
150	flag.IntVar(&postLimitBurst, "post-limit-burst", postLimitBurst, "Allowed burst post requests")
151	flag.StringVar(&permRelaysFile, "perm-relays", "", "Path to list of permanent relays")
152	flag.StringVar(&ipHeader, "ip-header", "", "Name of header which holds clients ip:port. Only meaningful when running behind a reverse proxy.")
153	flag.StringVar(&geoipPath, "geoip", "GeoLite2-City.mmdb", "Path to GeoLite2-City database")
154	flag.StringVar(&proto, "protocol", "tcp", "Protocol used for listening. 'tcp' for IPv4 and IPv6, 'tcp4' for IPv4, 'tcp6' for IPv6")
155	flag.DurationVar(&statsRefresh, "stats-refresh", statsRefresh, "Interval at which to refresh relay stats")
156	flag.IntVar(&requestQueueLen, "request-queue", requestQueueLen, "Queue length for incoming test requests")
157	flag.IntVar(&requestProcessors, "request-processors", requestProcessors, "Number of request processor routines")
158
159	flag.Parse()
160
161	requests = make(chan request, requestQueueLen)
162
163	getLimit = 10 * time.Second / time.Duration(getLimitAvg)
164	postLimit = time.Minute / time.Duration(postLimitAvg)
165
166	getLRUCache = lru.New(getLRUSize)
167	postLRUCache = lru.New(postLRUSize)
168
169	var listener net.Listener
170	var err error
171
172	if permRelaysFile != "" {
173		permanentRelays = loadRelays(permRelaysFile)
174	}
175
176	testCert = createTestCertificate()
177
178	for i := 0; i < requestProcessors; i++ {
179		go requestProcessor()
180	}
181
182	// Load relays from cache in the background.
183	// Load them in a serial fashion to make sure any genuine requests
184	// are not dropped.
185	go func() {
186		for _, relay := range loadRelays(knownRelaysFile) {
187			resultChan := make(chan result)
188			requests <- request{relay, resultChan, nil}
189			result := <-resultChan
190			if result.err != nil {
191				relayTestsTotal.WithLabelValues("failed").Inc()
192			} else {
193				relayTestsTotal.WithLabelValues("success").Inc()
194			}
195		}
196		// Run the the stats refresher once the relays are loaded.
197		statsRefresher(statsRefresh)
198	}()
199
200	if dir != "" {
201		if debug {
202			log.Println("Starting TLS listener on", listen)
203		}
204		certFile, keyFile := filepath.Join(dir, "http-cert.pem"), filepath.Join(dir, "http-key.pem")
205		var cert tls.Certificate
206		cert, err = tls.LoadX509KeyPair(certFile, keyFile)
207		if err != nil {
208			log.Fatalln("Failed to load HTTP X509 key pair:", err)
209		}
210
211		tlsCfg := &tls.Config{
212			Certificates: []tls.Certificate{cert},
213			MinVersion:   tls.VersionTLS10, // No SSLv3
214			ClientAuth:   tls.RequestClientCert,
215			CipherSuites: []uint16{
216				// No RC4
217				tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
218				tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
219				tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
220				tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
221				tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
222				tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
223				tls.TLS_RSA_WITH_AES_128_CBC_SHA,
224				tls.TLS_RSA_WITH_AES_256_CBC_SHA,
225				tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
226				tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
227			},
228		}
229
230		listener, err = tls.Listen(proto, listen, tlsCfg)
231	} else {
232		if debug {
233			log.Println("Starting plain listener on", listen)
234		}
235		listener, err = net.Listen(proto, listen)
236	}
237
238	if err != nil {
239		log.Fatalln("listen:", err)
240	}
241
242	handler := http.NewServeMux()
243	handler.HandleFunc("/", handleAssets)
244	handler.HandleFunc("/endpoint", handleRequest)
245	handler.HandleFunc("/metrics", handleMetrics)
246
247	srv := http.Server{
248		Handler:     handler,
249		ReadTimeout: 10 * time.Second,
250	}
251
252	err = srv.Serve(listener)
253	if err != nil {
254		log.Fatalln("serve:", err)
255	}
256}
257
258func handleMetrics(w http.ResponseWriter, r *http.Request) {
259	timer := prometheus.NewTimer(metricsRequestsSeconds)
260	// Acquire the mutex just to make sure we're not caught mid-way stats collection
261	mut.RLock()
262	promhttp.Handler().ServeHTTP(w, r)
263	mut.RUnlock()
264	timer.ObserveDuration()
265}
266
267func handleAssets(w http.ResponseWriter, r *http.Request) {
268	w.Header().Set("Cache-Control", "no-cache, must-revalidate")
269
270	path := r.URL.Path[1:]
271	if path == "" {
272		path = "index.html"
273	}
274
275	as, ok := auto.Assets()[path]
276	if !ok {
277		w.WriteHeader(http.StatusNotFound)
278		return
279	}
280
281	assets.Serve(w, r, as)
282}
283
284func handleRequest(w http.ResponseWriter, r *http.Request) {
285	timer := prometheus.NewTimer(apiRequestsSeconds.WithLabelValues(r.Method))
286
287	w = NewLoggingResponseWriter(w)
288	defer func() {
289		timer.ObserveDuration()
290		lw := w.(*loggingResponseWriter)
291		apiRequestsTotal.WithLabelValues(r.Method, strconv.Itoa(lw.statusCode)).Inc()
292	}()
293
294	if ipHeader != "" {
295		r.RemoteAddr = r.Header.Get(ipHeader)
296	}
297	w.Header().Set("Access-Control-Allow-Origin", "*")
298	switch r.Method {
299	case "GET":
300		if limit(r.RemoteAddr, getLRUCache, getMut, getLimit, getLimitBurst) {
301			w.WriteHeader(httpStatusEnhanceYourCalm)
302			return
303		}
304		handleGetRequest(w, r)
305	case "POST":
306		if limit(r.RemoteAddr, postLRUCache, postMut, postLimit, postLimitBurst) {
307			w.WriteHeader(httpStatusEnhanceYourCalm)
308			return
309		}
310		handlePostRequest(w, r)
311	default:
312		if debug {
313			log.Println("Unhandled HTTP method", r.Method)
314		}
315		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
316	}
317}
318
319func handleGetRequest(rw http.ResponseWriter, r *http.Request) {
320	rw.Header().Set("Content-Type", "application/json; charset=utf-8")
321
322	mut.RLock()
323	relays := make([]*relay, len(permanentRelays)+len(knownRelays))
324	n := copy(relays, permanentRelays)
325	copy(relays[n:], knownRelays)
326	mut.RUnlock()
327
328	// Shuffle
329	rand.Shuffle(relays)
330
331	w := io.Writer(rw)
332	if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
333		rw.Header().Set("Content-Encoding", "gzip")
334		gw := gzip.NewWriter(rw)
335		defer gw.Close()
336		w = gw
337	}
338
339	_ = json.NewEncoder(w).Encode(map[string][]*relay{
340		"relays": relays,
341	})
342}
343
344func handlePostRequest(w http.ResponseWriter, r *http.Request) {
345	var relayCert *x509.Certificate
346	if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
347		relayCert = r.TLS.PeerCertificates[0]
348		log.Printf("Got TLS cert from relay server")
349	}
350
351	var newRelay relay
352	err := json.NewDecoder(r.Body).Decode(&newRelay)
353	r.Body.Close()
354
355	if err != nil {
356		if debug {
357			log.Println("Failed to parse payload")
358		}
359		http.Error(w, err.Error(), http.StatusBadRequest)
360		return
361	}
362
363	uri, err := url.Parse(newRelay.URL)
364	if err != nil {
365		if debug {
366			log.Println("Failed to parse URI", newRelay.URL)
367		}
368		http.Error(w, err.Error(), http.StatusBadRequest)
369		return
370	}
371
372	if relayCert != nil {
373		advertisedId := uri.Query().Get("id")
374		idFromCert := protocol.NewDeviceID(relayCert.Raw).String()
375		if advertisedId != idFromCert {
376			log.Println("Warning: Relay server requested to join with an ID different from the join request, rejecting")
377			http.Error(w, "mismatched advertised id and join request cert", http.StatusBadRequest)
378			return
379		}
380	}
381
382	host, port, err := net.SplitHostPort(uri.Host)
383	if err != nil {
384		if debug {
385			log.Println("Failed to split URI", newRelay.URL)
386		}
387		http.Error(w, err.Error(), http.StatusBadRequest)
388		return
389	}
390
391	// Get the IP address of the client
392	rhost := r.RemoteAddr
393	if host, _, err := net.SplitHostPort(rhost); err == nil {
394		rhost = host
395	}
396
397	ip := net.ParseIP(host)
398	// The client did not provide an IP address, use the IP address of the client.
399	if ip == nil || ip.IsUnspecified() {
400		uri.Host = net.JoinHostPort(rhost, port)
401		newRelay.URL = uri.String()
402	} else if host != rhost && relayCert == nil {
403		if debug {
404			log.Println("IP address advertised does not match client IP address", r.RemoteAddr, uri)
405		}
406		http.Error(w, fmt.Sprintf("IP advertised %s does not match client IP %s", host, rhost), http.StatusUnauthorized)
407		return
408	}
409
410	newRelay.uri = uri
411
412	for _, current := range permanentRelays {
413		if current.uri.Host == newRelay.uri.Host {
414			if debug {
415				log.Println("Asked to add a relay", newRelay, "which exists in permanent list")
416			}
417			http.Error(w, "Invalid request", http.StatusBadRequest)
418			return
419		}
420	}
421
422	reschan := make(chan result)
423
424	select {
425	case requests <- request{&newRelay, reschan, prometheus.NewTimer(relayTestActionsSeconds.WithLabelValues("queue"))}:
426		result := <-reschan
427		if result.err != nil {
428			relayTestsTotal.WithLabelValues("failed").Inc()
429			http.Error(w, result.err.Error(), http.StatusBadRequest)
430			return
431		}
432		relayTestsTotal.WithLabelValues("success").Inc()
433		w.Header().Set("Content-Type", "application/json; charset=utf-8")
434		json.NewEncoder(w).Encode(map[string]time.Duration{
435			"evictionIn": result.eviction,
436		})
437
438	default:
439		relayTestsTotal.WithLabelValues("dropped").Inc()
440		if debug {
441			log.Println("Dropping request")
442		}
443		w.WriteHeader(httpStatusEnhanceYourCalm)
444	}
445}
446
447func requestProcessor() {
448	for request := range requests {
449		if request.queueTimer != nil {
450			request.queueTimer.ObserveDuration()
451		}
452
453		timer := prometheus.NewTimer(relayTestActionsSeconds.WithLabelValues("test"))
454		handleRelayTest(request)
455		timer.ObserveDuration()
456	}
457}
458
459func handleRelayTest(request request) {
460	if debug {
461		log.Println("Request for", request.relay)
462	}
463	if err := client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3); err != nil {
464		if debug {
465			log.Println("Test for relay", request.relay, "failed:", err)
466		}
467		request.result <- result{err, 0}
468		return
469	}
470
471	stats := fetchStats(request.relay)
472	location := getLocation(request.relay.uri.Host)
473
474	mut.Lock()
475	if stats != nil {
476		updateMetrics(request.relay.uri.Host, *stats, location)
477	}
478	request.relay.Stats = stats
479	request.relay.StatsRetrieved = time.Now().Truncate(time.Second)
480	request.relay.Location = location
481
482	timer, ok := evictionTimers[request.relay.uri.Host]
483	if ok {
484		if debug {
485			log.Println("Stopping existing timer for", request.relay)
486		}
487		timer.Stop()
488	}
489
490	for i, current := range knownRelays {
491		if current.uri.Host == request.relay.uri.Host {
492			if debug {
493				log.Println("Relay", request.relay, "already exists")
494			}
495
496			// Evict the old entry anyway, as configuration might have changed.
497			last := len(knownRelays) - 1
498			knownRelays[i] = knownRelays[last]
499			knownRelays = knownRelays[:last]
500
501			goto found
502		}
503	}
504
505	if debug {
506		log.Println("Adding new relay", request.relay)
507	}
508
509found:
510
511	knownRelays = append(knownRelays, request.relay)
512	evictionTimers[request.relay.uri.Host] = time.AfterFunc(evictionTime, evict(request.relay))
513
514	mut.Unlock()
515
516	if err := saveRelays(knownRelaysFile, knownRelays); err != nil {
517		log.Println("Failed to write known relays: " + err.Error())
518	}
519
520	request.result <- result{nil, evictionTime}
521}
522
523func evict(relay *relay) func() {
524	return func() {
525		mut.Lock()
526		defer mut.Unlock()
527		if debug {
528			log.Println("Evicting", relay)
529		}
530		for i, current := range knownRelays {
531			if current.uri.Host == relay.uri.Host {
532				if debug {
533					log.Println("Evicted", relay)
534				}
535				last := len(knownRelays) - 1
536				knownRelays[i] = knownRelays[last]
537				knownRelays = knownRelays[:last]
538				deleteMetrics(current.uri.Host)
539			}
540		}
541		delete(evictionTimers, relay.uri.Host)
542	}
543}
544
545func limit(addr string, cache *lru.Cache, lock sync.Mutex, intv time.Duration, burst int) bool {
546	if host, _, err := net.SplitHostPort(addr); err == nil {
547		addr = host
548	}
549
550	lock.Lock()
551	v, _ := cache.Get(addr)
552	bkt, ok := v.(*rate.Limiter)
553	if !ok {
554		bkt = rate.NewLimiter(rate.Every(intv), burst)
555		cache.Add(addr, bkt)
556	}
557	lock.Unlock()
558
559	return !bkt.Allow()
560}
561
562func loadRelays(file string) []*relay {
563	content, err := ioutil.ReadFile(file)
564	if err != nil {
565		log.Println("Failed to load relays: " + err.Error())
566		return nil
567	}
568
569	var relays []*relay
570	for _, line := range strings.Split(string(content), "\n") {
571		if len(line) == 0 {
572			continue
573		}
574
575		uri, err := url.Parse(line)
576		if err != nil {
577			if debug {
578				log.Println("Skipping relay", line, "due to parse error", err)
579			}
580			continue
581
582		}
583
584		relays = append(relays, &relay{
585			URL:      line,
586			Location: getLocation(uri.Host),
587			uri:      uri,
588		})
589		if debug {
590			log.Println("Adding relay", line)
591		}
592	}
593	return relays
594}
595
596func saveRelays(file string, relays []*relay) error {
597	var content string
598	for _, relay := range relays {
599		content += relay.uri.String() + "\n"
600	}
601	return ioutil.WriteFile(file, []byte(content), 0777)
602}
603
604func createTestCertificate() tls.Certificate {
605	tmpDir, err := ioutil.TempDir("", "relaypoolsrv")
606	if err != nil {
607		log.Fatal(err)
608	}
609
610	certFile, keyFile := filepath.Join(tmpDir, "cert.pem"), filepath.Join(tmpDir, "key.pem")
611	cert, err := tlsutil.NewCertificate(certFile, keyFile, "relaypoolsrv", 20*365)
612	if err != nil {
613		log.Fatalln("Failed to create test X509 key pair:", err)
614	}
615
616	return cert
617}
618
619func getLocation(host string) location {
620	timer := prometheus.NewTimer(locationLookupSeconds)
621	defer timer.ObserveDuration()
622	db, err := geoip2.Open(geoipPath)
623	if err != nil {
624		return location{}
625	}
626	defer db.Close()
627
628	addr, err := net.ResolveTCPAddr("tcp", host)
629	if err != nil {
630		return location{}
631	}
632
633	city, err := db.City(addr.IP)
634	if err != nil {
635		return location{}
636	}
637
638	return location{
639		Longitude: city.Location.Longitude,
640		Latitude:  city.Location.Latitude,
641		City:      city.City.Names["en"],
642		Country:   city.Country.IsoCode,
643		Continent: city.Continent.Code,
644	}
645}
646
647type loggingResponseWriter struct {
648	http.ResponseWriter
649	statusCode int
650}
651
652func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
653	return &loggingResponseWriter{w, http.StatusOK}
654}
655
656func (lrw *loggingResponseWriter) WriteHeader(code int) {
657	lrw.statusCode = code
658	lrw.ResponseWriter.WriteHeader(code)
659}
660