1package main
2
3import (
4	"crypto/tls"
5	"errors"
6	"flag"
7	"fmt"
8	"io"
9	"log"
10	"net/http"
11	"os"
12	"strings"
13	"time"
14
15	"golang.org/x/sync/errgroup"
16
17	"github.com/lucas-clemente/quic-go"
18	"github.com/lucas-clemente/quic-go/http3"
19	"github.com/lucas-clemente/quic-go/internal/handshake"
20	"github.com/lucas-clemente/quic-go/internal/protocol"
21	"github.com/lucas-clemente/quic-go/interop/http09"
22	"github.com/lucas-clemente/quic-go/interop/utils"
23	"github.com/lucas-clemente/quic-go/qlog"
24)
25
26var errUnsupported = errors.New("unsupported test case")
27
28var tlsConf *tls.Config
29
30func main() {
31	logFile, err := os.Create("/logs/log.txt")
32	if err != nil {
33		fmt.Printf("Could not create log file: %s\n", err.Error())
34		os.Exit(1)
35	}
36	defer logFile.Close()
37	log.SetOutput(logFile)
38
39	keyLog, err := utils.GetSSLKeyLog()
40	if err != nil {
41		fmt.Printf("Could not create key log: %s\n", err.Error())
42		os.Exit(1)
43	}
44	if keyLog != nil {
45		defer keyLog.Close()
46	}
47
48	tlsConf = &tls.Config{
49		InsecureSkipVerify: true,
50		KeyLogWriter:       keyLog,
51	}
52	testcase := os.Getenv("TESTCASE")
53	if err := runTestcase(testcase); err != nil {
54		if err == errUnsupported {
55			fmt.Printf("unsupported test case: %s\n", testcase)
56			os.Exit(127)
57		}
58		fmt.Printf("Downloading files failed: %s\n", err.Error())
59		os.Exit(1)
60	}
61}
62
63func runTestcase(testcase string) error {
64	flag.Parse()
65	urls := flag.Args()
66
67	getLogWriter, err := utils.GetQLOGWriter()
68	if err != nil {
69		return err
70	}
71	quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)}
72
73	if testcase == "http3" {
74		r := &http3.RoundTripper{
75			TLSClientConfig: tlsConf,
76			QuicConfig:      quicConf,
77		}
78		defer r.Close()
79		return downloadFiles(r, urls, false)
80	}
81
82	r := &http09.RoundTripper{
83		TLSClientConfig: tlsConf,
84		QuicConfig:      quicConf,
85	}
86	defer r.Close()
87
88	switch testcase {
89	case "handshake", "transfer", "retry":
90	case "keyupdate":
91		handshake.KeyUpdateInterval = 100
92	case "chacha20":
93		tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256}
94	case "multiconnect":
95		return runMultiConnectTest(r, urls)
96	case "versionnegotiation":
97		return runVersionNegotiationTest(r, urls)
98	case "resumption":
99		return runResumptionTest(r, urls, false)
100	case "zerortt":
101		return runResumptionTest(r, urls, true)
102	default:
103		return errUnsupported
104	}
105
106	return downloadFiles(r, urls, false)
107}
108
109func runVersionNegotiationTest(r *http09.RoundTripper, urls []string) error {
110	if len(urls) != 1 {
111		return errors.New("expected at least 2 URLs")
112	}
113	protocol.SupportedVersions = []protocol.VersionNumber{0x1a2a3a4a}
114	err := downloadFile(r, urls[0], false)
115	if err == nil {
116		return errors.New("expected version negotiation to fail")
117	}
118	if !strings.Contains(err.Error(), "No compatible QUIC version found") {
119		return fmt.Errorf("expect version negotiation error, got: %s", err.Error())
120	}
121	return nil
122}
123
124func runMultiConnectTest(r *http09.RoundTripper, urls []string) error {
125	for _, url := range urls {
126		if err := downloadFile(r, url, false); err != nil {
127			return err
128		}
129		if err := r.Close(); err != nil {
130			return err
131		}
132	}
133	return nil
134}
135
136type sessionCache struct {
137	tls.ClientSessionCache
138	put chan<- struct{}
139}
140
141func newSessionCache(c tls.ClientSessionCache) (tls.ClientSessionCache, <-chan struct{}) {
142	put := make(chan struct{}, 100)
143	return &sessionCache{ClientSessionCache: c, put: put}, put
144}
145
146func (c *sessionCache) Put(key string, cs *tls.ClientSessionState) {
147	c.ClientSessionCache.Put(key, cs)
148	c.put <- struct{}{}
149}
150
151func runResumptionTest(r *http09.RoundTripper, urls []string, use0RTT bool) error {
152	if len(urls) < 2 {
153		return errors.New("expected at least 2 URLs")
154	}
155
156	var put <-chan struct{}
157	tlsConf.ClientSessionCache, put = newSessionCache(tls.NewLRUClientSessionCache(1))
158
159	// do the first transfer
160	if err := downloadFiles(r, urls[:1], false); err != nil {
161		return err
162	}
163
164	// wait for the session ticket to arrive
165	select {
166	case <-time.NewTimer(10 * time.Second).C:
167		return errors.New("expected to receive a session ticket within 10 seconds")
168	case <-put:
169	}
170
171	if err := r.Close(); err != nil {
172		return err
173	}
174
175	// reestablish the connection, using the session ticket that the server (hopefully provided)
176	defer r.Close()
177	return downloadFiles(r, urls[1:], use0RTT)
178}
179
180func downloadFiles(cl http.RoundTripper, urls []string, use0RTT bool) error {
181	var g errgroup.Group
182	for _, u := range urls {
183		url := u
184		g.Go(func() error {
185			return downloadFile(cl, url, use0RTT)
186		})
187	}
188	return g.Wait()
189}
190
191func downloadFile(cl http.RoundTripper, url string, use0RTT bool) error {
192	method := http.MethodGet
193	if use0RTT {
194		method = http09.MethodGet0RTT
195	}
196	req, err := http.NewRequest(method, url, nil)
197	if err != nil {
198		return err
199	}
200	rsp, err := cl.RoundTrip(req)
201	if err != nil {
202		return err
203	}
204	defer rsp.Body.Close()
205
206	file, err := os.Create("/downloads" + req.URL.Path)
207	if err != nil {
208		return err
209	}
210	defer file.Close()
211	_, err = io.Copy(file, rsp.Body)
212	return err
213}
214