1package http09
2
3import (
4	"context"
5	"crypto/tls"
6	"errors"
7	"io/ioutil"
8	"log"
9	"net"
10	"net/http"
11	"strings"
12	"sync"
13
14	"golang.org/x/net/idna"
15
16	"github.com/lucas-clemente/quic-go"
17)
18
19// MethodGet0RTT allows a GET request to be sent using 0-RTT.
20// Note that 0-RTT data doesn't provide replay protection.
21const MethodGet0RTT = "GET_0RTT"
22
23// RoundTripper performs HTTP/0.9 roundtrips over QUIC.
24type RoundTripper struct {
25	mutex sync.Mutex
26
27	TLSClientConfig *tls.Config
28	QuicConfig      *quic.Config
29
30	clients map[string]*client
31}
32
33var _ http.RoundTripper = &RoundTripper{}
34
35// RoundTrip performs a HTTP/0.9 request.
36// It only supports GET requests.
37func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
38	if req.Method != http.MethodGet && req.Method != MethodGet0RTT {
39		return nil, errors.New("only GET requests supported")
40	}
41
42	log.Printf("Requesting %s.\n", req.URL)
43
44	r.mutex.Lock()
45	hostname := authorityAddr("https", hostnameFromRequest(req))
46	if r.clients == nil {
47		r.clients = make(map[string]*client)
48	}
49	c, ok := r.clients[hostname]
50	if !ok {
51		tlsConf := &tls.Config{}
52		if r.TLSClientConfig != nil {
53			tlsConf = r.TLSClientConfig.Clone()
54		}
55		tlsConf.NextProtos = []string{h09alpn}
56		c = &client{
57			hostname: hostname,
58			tlsConf:  tlsConf,
59			quicConf: r.QuicConfig,
60		}
61		r.clients[hostname] = c
62	}
63	r.mutex.Unlock()
64	return c.RoundTrip(req)
65}
66
67// Close closes the roundtripper.
68func (r *RoundTripper) Close() error {
69	r.mutex.Lock()
70	defer r.mutex.Unlock()
71
72	for id, c := range r.clients {
73		if err := c.Close(); err != nil {
74			return err
75		}
76		delete(r.clients, id)
77	}
78	return nil
79}
80
81type client struct {
82	hostname string
83	tlsConf  *tls.Config
84	quicConf *quic.Config
85
86	once    sync.Once
87	sess    quic.EarlySession
88	dialErr error
89}
90
91func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
92	c.once.Do(func() {
93		c.sess, c.dialErr = quic.DialAddrEarly(c.hostname, c.tlsConf, c.quicConf)
94	})
95	if c.dialErr != nil {
96		return nil, c.dialErr
97	}
98	if req.Method != MethodGet0RTT {
99		<-c.sess.HandshakeComplete().Done()
100	}
101	return c.doRequest(req)
102}
103
104func (c *client) doRequest(req *http.Request) (*http.Response, error) {
105	str, err := c.sess.OpenStreamSync(context.Background())
106	if err != nil {
107		return nil, err
108	}
109	cmd := "GET " + req.URL.Path + "\r\n"
110	if _, err := str.Write([]byte(cmd)); err != nil {
111		return nil, err
112	}
113	if err := str.Close(); err != nil {
114		return nil, err
115	}
116	rsp := &http.Response{
117		Proto:      "HTTP/0.9",
118		ProtoMajor: 0,
119		ProtoMinor: 9,
120		Request:    req,
121		Body:       ioutil.NopCloser(str),
122	}
123	return rsp, nil
124}
125
126func (c *client) Close() error {
127	if c.sess == nil {
128		return nil
129	}
130	return c.sess.CloseWithError(0, "")
131}
132
133func hostnameFromRequest(req *http.Request) string {
134	if req.URL != nil {
135		return req.URL.Host
136	}
137	return ""
138}
139
140// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
141// and returns a host:port. The port 443 is added if needed.
142func authorityAddr(scheme string, authority string) (addr string) {
143	host, port, err := net.SplitHostPort(authority)
144	if err != nil { // authority didn't have a port
145		port = "443"
146		if scheme == "http" {
147			port = "80"
148		}
149		host = authority
150	}
151	if a, err := idna.ToASCII(host); err == nil {
152		host = a
153	}
154	// IPv6 address literal, without a port:
155	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
156		return host + ":" + port
157	}
158	return net.JoinHostPort(host, port)
159}
160