1package http3
2
3import (
4	"crypto/tls"
5	"errors"
6	"fmt"
7	"io"
8	"net/http"
9	"strings"
10	"sync"
11
12	quic "github.com/lucas-clemente/quic-go"
13
14	"golang.org/x/net/http/httpguts"
15)
16
17type roundTripCloser interface {
18	http.RoundTripper
19	io.Closer
20}
21
22// RoundTripper implements the http.RoundTripper interface
23type RoundTripper struct {
24	mutex sync.Mutex
25
26	// DisableCompression, if true, prevents the Transport from
27	// requesting compression with an "Accept-Encoding: gzip"
28	// request header when the Request contains no existing
29	// Accept-Encoding value. If the Transport requests gzip on
30	// its own and gets a gzipped response, it's transparently
31	// decoded in the Response.Body. However, if the user
32	// explicitly requested gzip it is not automatically
33	// uncompressed.
34	DisableCompression bool
35
36	// TLSClientConfig specifies the TLS configuration to use with
37	// tls.Client. If nil, the default configuration is used.
38	TLSClientConfig *tls.Config
39
40	// QuicConfig is the quic.Config used for dialing new connections.
41	// If nil, reasonable default values will be used.
42	QuicConfig *quic.Config
43
44	// Dial specifies an optional dial function for creating QUIC
45	// connections for requests.
46	// If Dial is nil, quic.DialAddr will be used.
47	Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
48
49	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
50	// allowed in the server's response header.
51	// Zero means to use a default limit.
52	MaxResponseHeaderBytes int64
53
54	clients map[string]roundTripCloser
55}
56
57// RoundTripOpt are options for the Transport.RoundTripOpt method.
58type RoundTripOpt struct {
59	// OnlyCachedConn controls whether the RoundTripper may
60	// create a new QUIC connection. If set true and
61	// no cached connection is available, RoundTrip
62	// will return ErrNoCachedConn.
63	OnlyCachedConn bool
64}
65
66var _ roundTripCloser = &RoundTripper{}
67
68// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
69var ErrNoCachedConn = errors.New("http3: no cached connection was available")
70
71// RoundTripOpt is like RoundTrip, but takes options.
72func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
73	if req.URL == nil {
74		closeRequestBody(req)
75		return nil, errors.New("http3: nil Request.URL")
76	}
77	if req.URL.Host == "" {
78		closeRequestBody(req)
79		return nil, errors.New("http3: no Host in request URL")
80	}
81	if req.Header == nil {
82		closeRequestBody(req)
83		return nil, errors.New("http3: nil Request.Header")
84	}
85
86	if req.URL.Scheme == "https" {
87		for k, vv := range req.Header {
88			if !httpguts.ValidHeaderFieldName(k) {
89				return nil, fmt.Errorf("http3: invalid http header field name %q", k)
90			}
91			for _, v := range vv {
92				if !httpguts.ValidHeaderFieldValue(v) {
93					return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
94				}
95			}
96		}
97	} else {
98		closeRequestBody(req)
99		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
100	}
101
102	if req.Method != "" && !validMethod(req.Method) {
103		closeRequestBody(req)
104		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
105	}
106
107	hostname := authorityAddr("https", hostnameFromRequest(req))
108	cl, err := r.getClient(hostname, opt.OnlyCachedConn)
109	if err != nil {
110		return nil, err
111	}
112	return cl.RoundTrip(req)
113}
114
115// RoundTrip does a round trip.
116func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
117	return r.RoundTripOpt(req, RoundTripOpt{})
118}
119
120func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
121	r.mutex.Lock()
122	defer r.mutex.Unlock()
123
124	if r.clients == nil {
125		r.clients = make(map[string]roundTripCloser)
126	}
127
128	client, ok := r.clients[hostname]
129	if !ok {
130		if onlyCached {
131			return nil, ErrNoCachedConn
132		}
133		var err error
134		client, err = newClient(
135			hostname,
136			r.TLSClientConfig,
137			&roundTripperOpts{
138				DisableCompression: r.DisableCompression,
139				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
140			},
141			r.QuicConfig,
142			r.Dial,
143		)
144		if err != nil {
145			return nil, err
146		}
147		r.clients[hostname] = client
148	}
149	return client, nil
150}
151
152// Close closes the QUIC connections that this RoundTripper has used
153func (r *RoundTripper) Close() error {
154	r.mutex.Lock()
155	defer r.mutex.Unlock()
156	for _, client := range r.clients {
157		if err := client.Close(); err != nil {
158			return err
159		}
160	}
161	r.clients = nil
162	return nil
163}
164
165func closeRequestBody(req *http.Request) {
166	if req.Body != nil {
167		req.Body.Close()
168	}
169}
170
171func validMethod(method string) bool {
172	/*
173				     Method         = "OPTIONS"                ; Section 9.2
174		   		                    | "GET"                    ; Section 9.3
175		   		                    | "HEAD"                   ; Section 9.4
176		   		                    | "POST"                   ; Section 9.5
177		   		                    | "PUT"                    ; Section 9.6
178		   		                    | "DELETE"                 ; Section 9.7
179		   		                    | "TRACE"                  ; Section 9.8
180		   		                    | "CONNECT"                ; Section 9.9
181		   		                    | extension-method
182		   		   extension-method = token
183		   		     token          = 1*<any CHAR except CTLs or separators>
184	*/
185	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
186}
187
188// copied from net/http/http.go
189func isNotToken(r rune) bool {
190	return !httpguts.IsTokenRune(r)
191}
192