1package http3
2
3import (
4	"crypto/tls"
5	"errors"
6	"fmt"
7	"io"
8	"net/http"
9	"strings"
10	"sync"
11
12	quic "github.com/lucas-clemente/quic-go"
13
14	"golang.org/x/net/http/httpguts"
15)
16
17type roundTripCloser interface {
18	http.RoundTripper
19	io.Closer
20}
21
22// RoundTripper implements the http.RoundTripper interface
23type RoundTripper struct {
24	mutex sync.Mutex
25
26	// DisableCompression, if true, prevents the Transport from
27	// requesting compression with an "Accept-Encoding: gzip"
28	// request header when the Request contains no existing
29	// Accept-Encoding value. If the Transport requests gzip on
30	// its own and gets a gzipped response, it's transparently
31	// decoded in the Response.Body. However, if the user
32	// explicitly requested gzip it is not automatically
33	// uncompressed.
34	DisableCompression bool
35
36	// TLSClientConfig specifies the TLS configuration to use with
37	// tls.Client. If nil, the default configuration is used.
38	TLSClientConfig *tls.Config
39
40	// QuicConfig is the quic.Config used for dialing new connections.
41	// If nil, reasonable default values will be used.
42	QuicConfig *quic.Config
43
44	// Enable support for HTTP/3 datagrams.
45	// If set to true, QuicConfig.EnableDatagram will be set.
46	// See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html.
47	EnableDatagrams bool
48
49	// Dial specifies an optional dial function for creating QUIC
50	// connections for requests.
51	// If Dial is nil, quic.DialAddrEarly will be used.
52	Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
53
54	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
55	// allowed in the server's response header.
56	// Zero means to use a default limit.
57	MaxResponseHeaderBytes int64
58
59	clients map[string]roundTripCloser
60}
61
62// RoundTripOpt are options for the Transport.RoundTripOpt method.
63type RoundTripOpt struct {
64	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
65	// If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn.
66	OnlyCachedConn bool
67	// SkipSchemeCheck controls whether we check if the scheme is https.
68	// This allows the use of different schemes, e.g. masque://target.example.com:443/.
69	SkipSchemeCheck bool
70}
71
72var _ roundTripCloser = &RoundTripper{}
73
74// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
75var ErrNoCachedConn = errors.New("http3: no cached connection was available")
76
77// RoundTripOpt is like RoundTrip, but takes options.
78func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
79	if req.URL == nil {
80		closeRequestBody(req)
81		return nil, errors.New("http3: nil Request.URL")
82	}
83	if req.URL.Host == "" {
84		closeRequestBody(req)
85		return nil, errors.New("http3: no Host in request URL")
86	}
87	if req.Header == nil {
88		closeRequestBody(req)
89		return nil, errors.New("http3: nil Request.Header")
90	}
91
92	if req.URL.Scheme == "https" {
93		for k, vv := range req.Header {
94			if !httpguts.ValidHeaderFieldName(k) {
95				return nil, fmt.Errorf("http3: invalid http header field name %q", k)
96			}
97			for _, v := range vv {
98				if !httpguts.ValidHeaderFieldValue(v) {
99					return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
100				}
101			}
102		}
103	} else if !opt.SkipSchemeCheck {
104		closeRequestBody(req)
105		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
106	}
107
108	if req.Method != "" && !validMethod(req.Method) {
109		closeRequestBody(req)
110		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
111	}
112
113	hostname := authorityAddr("https", hostnameFromRequest(req))
114	cl, err := r.getClient(hostname, opt.OnlyCachedConn)
115	if err != nil {
116		return nil, err
117	}
118	return cl.RoundTrip(req)
119}
120
121// RoundTrip does a round trip.
122func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
123	return r.RoundTripOpt(req, RoundTripOpt{})
124}
125
126func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
127	r.mutex.Lock()
128	defer r.mutex.Unlock()
129
130	if r.clients == nil {
131		r.clients = make(map[string]roundTripCloser)
132	}
133
134	client, ok := r.clients[hostname]
135	if !ok {
136		if onlyCached {
137			return nil, ErrNoCachedConn
138		}
139		var err error
140		client, err = newClient(
141			hostname,
142			r.TLSClientConfig,
143			&roundTripperOpts{
144				EnableDatagram:     r.EnableDatagrams,
145				DisableCompression: r.DisableCompression,
146				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
147			},
148			r.QuicConfig,
149			r.Dial,
150		)
151		if err != nil {
152			return nil, err
153		}
154		r.clients[hostname] = client
155	}
156	return client, nil
157}
158
159// Close closes the QUIC connections that this RoundTripper has used
160func (r *RoundTripper) Close() error {
161	r.mutex.Lock()
162	defer r.mutex.Unlock()
163	for _, client := range r.clients {
164		if err := client.Close(); err != nil {
165			return err
166		}
167	}
168	r.clients = nil
169	return nil
170}
171
172func closeRequestBody(req *http.Request) {
173	if req.Body != nil {
174		req.Body.Close()
175	}
176}
177
178func validMethod(method string) bool {
179	/*
180				     Method         = "OPTIONS"                ; Section 9.2
181		   		                    | "GET"                    ; Section 9.3
182		   		                    | "HEAD"                   ; Section 9.4
183		   		                    | "POST"                   ; Section 9.5
184		   		                    | "PUT"                    ; Section 9.6
185		   		                    | "DELETE"                 ; Section 9.7
186		   		                    | "TRACE"                  ; Section 9.8
187		   		                    | "CONNECT"                ; Section 9.9
188		   		                    | extension-method
189		   		   extension-method = token
190		   		     token          = 1*<any CHAR except CTLs or separators>
191	*/
192	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
193}
194
195// copied from net/http/http.go
196func isNotToken(r rune) bool {
197	return !httpguts.IsTokenRune(r)
198}
199