1package http
2
3//go:generate go run github.com/v2fly/v2ray-core/v4/common/errors/errorgen
4
5import (
6	"bufio"
7	"bytes"
8	"context"
9	"io"
10	"net"
11	"net/http"
12	"strings"
13	"time"
14
15	"github.com/v2fly/v2ray-core/v4/common"
16	"github.com/v2fly/v2ray-core/v4/common/buf"
17)
18
19const (
20	// CRLF is the line ending in HTTP header
21	CRLF = "\r\n"
22
23	// ENDING is the double line ending between HTTP header and body.
24	ENDING = CRLF + CRLF
25
26	// max length of HTTP header. Safety precaution for DDoS attack.
27	maxHeaderLength = 8192
28)
29
30var (
31	ErrHeaderToLong = newError("Header too long.")
32
33	ErrHeaderMisMatch = newError("Header Mismatch.")
34)
35
36type Reader interface {
37	Read(io.Reader) (*buf.Buffer, error)
38}
39
40type Writer interface {
41	Write(io.Writer) error
42}
43
44type NoOpReader struct{}
45
46func (NoOpReader) Read(io.Reader) (*buf.Buffer, error) {
47	return nil, nil
48}
49
50type NoOpWriter struct{}
51
52func (NoOpWriter) Write(io.Writer) error {
53	return nil
54}
55
56type HeaderReader struct {
57	req            *http.Request
58	expectedHeader *RequestConfig
59}
60
61func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader {
62	h.expectedHeader = expectedHeader
63	return h
64}
65
66func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
67	buffer := buf.New()
68	totalBytes := int32(0)
69	endingDetected := false
70
71	var headerBuf bytes.Buffer
72
73	for totalBytes < maxHeaderLength {
74		_, err := buffer.ReadFrom(reader)
75		if err != nil {
76			buffer.Release()
77			return nil, err
78		}
79		if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 {
80			headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING))))
81			buffer.Advance(int32(n + len(ENDING)))
82			endingDetected = true
83			break
84		}
85		lenEnding := int32(len(ENDING))
86		if buffer.Len() >= lenEnding {
87			totalBytes += buffer.Len() - lenEnding
88			headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding))
89			leftover := buffer.BytesFrom(-lenEnding)
90			buffer.Clear()
91			copy(buffer.Extend(lenEnding), leftover)
92
93			if _, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes())), false); err != io.ErrUnexpectedEOF {
94				return nil, err
95			}
96		}
97	}
98
99	if !endingDetected {
100		buffer.Release()
101		return nil, ErrHeaderToLong
102	}
103
104	if h.expectedHeader == nil {
105		if buffer.IsEmpty() {
106			buffer.Release()
107			return nil, nil
108		}
109		return buffer, nil
110	}
111
112	// Parse the request
113	if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes())), false); err != nil {
114		return nil, err
115	} else { // nolint: golint
116		h.req = req
117	}
118
119	// Check req
120	path := h.req.URL.Path
121	hasThisURI := false
122	for _, u := range h.expectedHeader.Uri {
123		if u == path {
124			hasThisURI = true
125		}
126	}
127
128	if !hasThisURI {
129		return nil, ErrHeaderMisMatch
130	}
131
132	if buffer.IsEmpty() {
133		buffer.Release()
134		return nil, nil
135	}
136
137	return buffer, nil
138}
139
140type HeaderWriter struct {
141	header *buf.Buffer
142}
143
144func NewHeaderWriter(header *buf.Buffer) *HeaderWriter {
145	return &HeaderWriter{
146		header: header,
147	}
148}
149
150func (w *HeaderWriter) Write(writer io.Writer) error {
151	if w.header == nil {
152		return nil
153	}
154	err := buf.WriteAllBytes(writer, w.header.Bytes())
155	w.header.Release()
156	w.header = nil
157	return err
158}
159
160type Conn struct {
161	net.Conn
162
163	readBuffer          *buf.Buffer
164	oneTimeReader       Reader
165	oneTimeWriter       Writer
166	errorWriter         Writer
167	errorMismatchWriter Writer
168	errorTooLongWriter  Writer
169	errReason           error
170}
171
172func NewConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *Conn {
173	return &Conn{
174		Conn:                conn,
175		oneTimeReader:       reader,
176		oneTimeWriter:       writer,
177		errorWriter:         errorWriter,
178		errorMismatchWriter: errorMismatchWriter,
179		errorTooLongWriter:  errorTooLongWriter,
180	}
181}
182
183func (c *Conn) Read(b []byte) (int, error) {
184	if c.oneTimeReader != nil {
185		buffer, err := c.oneTimeReader.Read(c.Conn)
186		if err != nil {
187			c.errReason = err
188			return 0, err
189		}
190		c.readBuffer = buffer
191		c.oneTimeReader = nil
192	}
193
194	if !c.readBuffer.IsEmpty() {
195		nBytes, _ := c.readBuffer.Read(b)
196		if c.readBuffer.IsEmpty() {
197			c.readBuffer.Release()
198			c.readBuffer = nil
199		}
200		return nBytes, nil
201	}
202
203	return c.Conn.Read(b)
204}
205
206// Write implements io.Writer.
207func (c *Conn) Write(b []byte) (int, error) {
208	if c.oneTimeWriter != nil {
209		err := c.oneTimeWriter.Write(c.Conn)
210		c.oneTimeWriter = nil
211		if err != nil {
212			return 0, err
213		}
214	}
215
216	return c.Conn.Write(b)
217}
218
219// Close implements net.Conn.Close().
220func (c *Conn) Close() error {
221	if c.oneTimeWriter != nil && c.errorWriter != nil {
222		// Connection is being closed but header wasn't sent. This means the client request
223		// is probably not valid. Sending back a server error header in this case.
224
225		// Write response based on error reason
226		switch c.errReason {
227		case ErrHeaderMisMatch:
228			c.errorMismatchWriter.Write(c.Conn)
229		case ErrHeaderToLong:
230			c.errorTooLongWriter.Write(c.Conn)
231		default:
232			c.errorWriter.Write(c.Conn)
233		}
234	}
235
236	return c.Conn.Close()
237}
238
239func formResponseHeader(config *ResponseConfig) *HeaderWriter {
240	header := buf.New()
241	common.Must2(header.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " ")))
242	common.Must2(header.WriteString(CRLF))
243
244	headers := config.PickHeaders()
245	for _, h := range headers {
246		common.Must2(header.WriteString(h))
247		common.Must2(header.WriteString(CRLF))
248	}
249	if !config.HasHeader("Date") {
250		common.Must2(header.WriteString("Date: "))
251		common.Must2(header.WriteString(time.Now().Format(http.TimeFormat)))
252		common.Must2(header.WriteString(CRLF))
253	}
254	common.Must2(header.WriteString(CRLF))
255	return &HeaderWriter{
256		header: header,
257	}
258}
259
260type Authenticator struct {
261	config *Config
262}
263
264func (a Authenticator) GetClientWriter() *HeaderWriter {
265	header := buf.New()
266	config := a.config.Request
267	common.Must2(header.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickURI(), config.GetFullVersion()}, " ")))
268	common.Must2(header.WriteString(CRLF))
269
270	headers := config.PickHeaders()
271	for _, h := range headers {
272		common.Must2(header.WriteString(h))
273		common.Must2(header.WriteString(CRLF))
274	}
275	common.Must2(header.WriteString(CRLF))
276	return &HeaderWriter{
277		header: header,
278	}
279}
280
281func (a Authenticator) GetServerWriter() *HeaderWriter {
282	return formResponseHeader(a.config.Response)
283}
284
285func (a Authenticator) Client(conn net.Conn) net.Conn {
286	if a.config.Request == nil && a.config.Response == nil {
287		return conn
288	}
289	var reader Reader = NoOpReader{}
290	if a.config.Request != nil {
291		reader = new(HeaderReader)
292	}
293
294	var writer Writer = NoOpWriter{}
295	if a.config.Response != nil {
296		writer = a.GetClientWriter()
297	}
298	return NewConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{})
299}
300
301func (a Authenticator) Server(conn net.Conn) net.Conn {
302	if a.config.Request == nil && a.config.Response == nil {
303		return conn
304	}
305	return NewConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(),
306		formResponseHeader(resp400),
307		formResponseHeader(resp404),
308		formResponseHeader(resp400))
309}
310
311func NewAuthenticator(ctx context.Context, config *Config) (Authenticator, error) {
312	return Authenticator{
313		config: config,
314	}, nil
315}
316
317func init() {
318	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
319		return NewAuthenticator(ctx, config.(*Config))
320	}))
321}
322