1package cmux
2
3import (
4	"bufio"
5	"io"
6	"io/ioutil"
7	"net/http"
8	"strings"
9
10	"golang.org/x/net/http2"
11	"golang.org/x/net/http2/hpack"
12)
13
14// Any is a Matcher that matches any connection.
15func Any() Matcher {
16	return func(r io.Reader) bool { return true }
17}
18
19// PrefixMatcher returns a matcher that matches a connection if it
20// starts with any of the strings in strs.
21func PrefixMatcher(strs ...string) Matcher {
22	pt := newPatriciaTreeString(strs...)
23	return pt.matchPrefix
24}
25
26var defaultHTTPMethods = []string{
27	"OPTIONS",
28	"GET",
29	"HEAD",
30	"POST",
31	"PUT",
32	"DELETE",
33	"TRACE",
34	"CONNECT",
35}
36
37// HTTP1Fast only matches the methods in the HTTP request.
38//
39// This matcher is very optimistic: if it returns true, it does not mean that
40// the request is a valid HTTP response. If you want a correct but slower HTTP1
41// matcher, use HTTP1 instead.
42func HTTP1Fast(extMethods ...string) Matcher {
43	return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...)
44}
45
46const maxHTTPRead = 4096
47
48// HTTP1 parses the first line or upto 4096 bytes of the request to see if
49// the conection contains an HTTP request.
50func HTTP1() Matcher {
51	return func(r io.Reader) bool {
52		br := bufio.NewReader(&io.LimitedReader{R: r, N: maxHTTPRead})
53		l, part, err := br.ReadLine()
54		if err != nil || part {
55			return false
56		}
57
58		_, _, proto, ok := parseRequestLine(string(l))
59		if !ok {
60			return false
61		}
62
63		v, _, ok := http.ParseHTTPVersion(proto)
64		return ok && v == 1
65	}
66}
67
68// grabbed from net/http.
69func parseRequestLine(line string) (method, uri, proto string, ok bool) {
70	s1 := strings.Index(line, " ")
71	s2 := strings.Index(line[s1+1:], " ")
72	if s1 < 0 || s2 < 0 {
73		return
74	}
75	s2 += s1 + 1
76	return line[:s1], line[s1+1 : s2], line[s2+1:], true
77}
78
79// HTTP2 parses the frame header of the first frame to detect whether the
80// connection is an HTTP2 connection.
81func HTTP2() Matcher {
82	return hasHTTP2Preface
83}
84
85// HTTP1HeaderField returns a matcher matching the header fields of the first
86// request of an HTTP 1 connection.
87func HTTP1HeaderField(name, value string) Matcher {
88	return func(r io.Reader) bool {
89		return matchHTTP1Field(r, name, value)
90	}
91}
92
93// HTTP2HeaderField resturns a matcher matching the header fields of the first
94// headers frame.
95func HTTP2HeaderField(name, value string) Matcher {
96	return func(r io.Reader) bool {
97		return matchHTTP2Field(r, name, value)
98	}
99}
100
101func hasHTTP2Preface(r io.Reader) bool {
102	var b [len(http2.ClientPreface)]byte
103	if _, err := io.ReadFull(r, b[:]); err != nil {
104		return false
105	}
106
107	return string(b[:]) == http2.ClientPreface
108}
109
110func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
111	req, err := http.ReadRequest(bufio.NewReader(r))
112	if err != nil {
113		return false
114	}
115
116	return req.Header.Get(name) == value
117}
118
119func matchHTTP2Field(r io.Reader, name, value string) (matched bool) {
120	if !hasHTTP2Preface(r) {
121		return false
122	}
123
124	framer := http2.NewFramer(ioutil.Discard, r)
125	hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
126		if hf.Name == name && hf.Value == value {
127			matched = true
128		}
129	})
130	for {
131		f, err := framer.ReadFrame()
132		if err != nil {
133			return false
134		}
135
136		switch f := f.(type) {
137		case *http2.HeadersFrame:
138			if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
139				return false
140			}
141			if matched {
142				return true
143			}
144
145			if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 {
146				return false
147			}
148		}
149	}
150}
151