1// Snowflake-specific websocket server plugin. It reports the transport name as
2// "snowflake".
3package main
4
5import (
6	"flag"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"log"
11	"net"
12	"net/http"
13	"os"
14	"os/signal"
15	"path/filepath"
16	"strings"
17	"sync"
18	"syscall"
19
20	"git.torproject.org/pluggable-transports/snowflake.git/v2/common/safelog"
21	"golang.org/x/crypto/acme/autocert"
22
23	pt "git.torproject.org/pluggable-transports/goptlib.git"
24	sf "git.torproject.org/pluggable-transports/snowflake.git/v2/server/lib"
25)
26
27const ptMethodName = "snowflake"
28
29var ptInfo pt.ServerInfo
30
31func usage() {
32	fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]
33
34WebSocket server pluggable transport for Snowflake. Works only as a managed
35proxy. Uses TLS with ACME (Let's Encrypt) by default. Set the certificate
36hostnames with the --acme-hostnames option. Use ServerTransportListenAddr in
37torrc to choose the listening port. When using TLS, this program will open an
38additional HTTP listener on port 80 to work with ACME.
39
40`, os.Args[0])
41	flag.PrintDefaults()
42}
43
44//proxy copies data bidirectionally from one connection to another.
45func proxy(local *net.TCPConn, conn net.Conn) {
46	var wg sync.WaitGroup
47	wg.Add(2)
48
49	go func() {
50		if _, err := io.Copy(conn, local); err != nil && err != io.ErrClosedPipe {
51			log.Printf("error copying ORPort to WebSocket %v", err)
52		}
53		local.CloseRead()
54		conn.Close()
55		wg.Done()
56	}()
57	go func() {
58		if _, err := io.Copy(local, conn); err != nil && err != io.ErrClosedPipe {
59			log.Printf("error copying WebSocket to ORPort %v", err)
60		}
61		local.CloseWrite()
62		conn.Close()
63		wg.Done()
64	}()
65
66	wg.Wait()
67}
68
69//handleConn bidirectionally connects a client snowflake connection with an ORPort.
70func handleConn(conn net.Conn) error {
71	addr := conn.RemoteAddr().String()
72	statsChannel <- addr != ""
73	or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
74	if err != nil {
75		return fmt.Errorf("failed to connect to ORPort: %s", err)
76	}
77	defer or.Close()
78	proxy(or, conn)
79	return nil
80}
81
82//acceptLoop accepts incoming client snowflake connection and passes them to a handler function.
83func acceptLoop(ln net.Listener) {
84	for {
85		conn, err := ln.Accept()
86		if err != nil {
87			if err, ok := err.(net.Error); ok && err.Temporary() {
88				continue
89			}
90			log.Printf("Snowflake accept error: %s", err)
91			break
92		}
93		go func() {
94			defer conn.Close()
95			err := handleConn(conn)
96			if err != nil {
97				log.Printf("handleConn: %v", err)
98			}
99		}()
100	}
101}
102
103func getCertificateCacheDir() (string, error) {
104	stateDir, err := pt.MakeStateDir()
105	if err != nil {
106		return "", err
107	}
108	return filepath.Join(stateDir, "snowflake-certificate-cache"), nil
109}
110
111func main() {
112	var acmeEmail string
113	var acmeHostnamesCommas string
114	var disableTLS bool
115	var logFilename string
116	var unsafeLogging bool
117
118	flag.Usage = usage
119	flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
120	flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
121	flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
122	flag.StringVar(&logFilename, "log", "", "log file to write to")
123	flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
124	flag.Parse()
125
126	log.SetFlags(log.LstdFlags | log.LUTC)
127
128	var logOutput io.Writer = os.Stderr
129	if logFilename != "" {
130		f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
131		if err != nil {
132			log.Fatalf("can't open log file: %s", err)
133		}
134		defer f.Close()
135		logOutput = f
136	}
137	if unsafeLogging {
138		log.SetOutput(logOutput)
139	} else {
140		// We want to send the log output through our scrubber first
141		log.SetOutput(&safelog.LogScrubber{Output: logOutput})
142	}
143
144	if !disableTLS && acmeHostnamesCommas == "" {
145		log.Fatal("the --acme-hostnames option is required")
146	}
147	acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
148
149	log.Printf("starting")
150	var err error
151	ptInfo, err = pt.ServerSetup(nil)
152	if err != nil {
153		log.Fatalf("error in setup: %s", err)
154	}
155
156	go statsThread()
157
158	var certManager *autocert.Manager
159	if !disableTLS {
160		log.Printf("ACME hostnames: %q", acmeHostnames)
161
162		var cache autocert.Cache
163		var cacheDir string
164		cacheDir, err = getCertificateCacheDir()
165		if err == nil {
166			log.Printf("caching ACME certificates in directory %q", cacheDir)
167			cache = autocert.DirCache(cacheDir)
168		} else {
169			log.Printf("disabling ACME certificate cache: %s", err)
170		}
171
172		certManager = &autocert.Manager{
173			Prompt:     autocert.AcceptTOS,
174			HostPolicy: autocert.HostWhitelist(acmeHostnames...),
175			Email:      acmeEmail,
176			Cache:      cache,
177		}
178	}
179
180	// The ACME HTTP-01 responder only works when it is running on port 80.
181	// We actually open the port in the loop below, so that any errors can
182	// be reported in the SMETHOD-ERROR of some bindaddr.
183	// https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
184	needHTTP01Listener := !disableTLS
185
186	listeners := make([]net.Listener, 0)
187	for _, bindaddr := range ptInfo.Bindaddrs {
188		if bindaddr.MethodName != ptMethodName {
189			pt.SmethodError(bindaddr.MethodName, "no such method")
190			continue
191		}
192
193		if needHTTP01Listener {
194			addr := *bindaddr.Addr
195			addr.Port = 80
196			log.Printf("Starting HTTP-01 ACME listener")
197			var lnHTTP01 *net.TCPListener
198			lnHTTP01, err = net.ListenTCP("tcp", &addr)
199			if err != nil {
200				log.Printf("error opening HTTP-01 ACME listener: %s", err)
201				pt.SmethodError(bindaddr.MethodName, "HTTP-01 ACME listener: "+err.Error())
202				continue
203			}
204			server := &http.Server{
205				Addr:    addr.String(),
206				Handler: certManager.HTTPHandler(nil),
207			}
208			go func() {
209				log.Fatal(server.Serve(lnHTTP01))
210			}()
211			listeners = append(listeners, lnHTTP01)
212			needHTTP01Listener = false
213		}
214
215		// We're not capable of listening on port 0 (i.e., an ephemeral port
216		// unknown in advance). The reason is that while the net/http package
217		// exposes ListenAndServe and ListenAndServeTLS, those functions never
218		// return, so there's no opportunity to find out what the port number
219		// is, in between the Listen and Serve steps.
220		// https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
221		if bindaddr.Addr.Port == 0 {
222			err := fmt.Errorf(
223				"cannot listen on port %d; configure a port using ServerTransportListenAddr",
224				bindaddr.Addr.Port)
225			log.Printf("error opening listener: %s", err)
226			pt.SmethodError(bindaddr.MethodName, err.Error())
227			continue
228		}
229
230		var transport *sf.Transport
231		args := pt.Args{}
232		if disableTLS {
233			args.Add("tls", "no")
234			transport = sf.NewSnowflakeServer(nil)
235		} else {
236			args.Add("tls", "yes")
237			for _, hostname := range acmeHostnames {
238				args.Add("hostname", hostname)
239			}
240			transport = sf.NewSnowflakeServer(certManager.GetCertificate)
241		}
242		ln, err := transport.Listen(bindaddr.Addr)
243		if err != nil {
244			log.Printf("error opening listener: %s", err)
245			pt.SmethodError(bindaddr.MethodName, err.Error())
246			continue
247		}
248		defer ln.Close()
249		go acceptLoop(ln)
250		pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
251		listeners = append(listeners, ln)
252	}
253	pt.SmethodsDone()
254
255	sigChan := make(chan os.Signal, 1)
256	signal.Notify(sigChan, syscall.SIGTERM)
257
258	if os.Getenv("TOR_PT_EXIT_ON_STDIN_CLOSE") == "1" {
259		// This environment variable means we should treat EOF on stdin
260		// just like SIGTERM: https://bugs.torproject.org/15435.
261		go func() {
262			if _, err := io.Copy(ioutil.Discard, os.Stdin); err != nil {
263				log.Printf("error copying os.Stdin to ioutil.Discard: %v", err)
264			}
265			log.Printf("synthesizing SIGTERM because of stdin close")
266			sigChan <- syscall.SIGTERM
267		}()
268	}
269
270	// Wait for a signal.
271	sig := <-sigChan
272
273	// Signal received, shut down.
274	log.Printf("caught signal %q, exiting", sig)
275	for _, ln := range listeners {
276		ln.Close()
277	}
278}
279