1package http
2
3import (
4	"bufio"
5	"context"
6	"encoding/base64"
7	"io"
8	"net/http"
9	"strings"
10	"time"
11
12	"github.com/xtls/xray-core/transport/internet/stat"
13
14	"github.com/xtls/xray-core/common"
15	"github.com/xtls/xray-core/common/buf"
16	"github.com/xtls/xray-core/common/errors"
17	"github.com/xtls/xray-core/common/log"
18	"github.com/xtls/xray-core/common/net"
19	"github.com/xtls/xray-core/common/protocol"
20	http_proto "github.com/xtls/xray-core/common/protocol/http"
21	"github.com/xtls/xray-core/common/session"
22	"github.com/xtls/xray-core/common/signal"
23	"github.com/xtls/xray-core/common/task"
24	"github.com/xtls/xray-core/core"
25	"github.com/xtls/xray-core/features/policy"
26	"github.com/xtls/xray-core/features/routing"
27)
28
29// Server is an HTTP proxy server.
30type Server struct {
31	config        *ServerConfig
32	policyManager policy.Manager
33}
34
35// NewServer creates a new HTTP inbound handler.
36func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
37	v := core.MustFromContext(ctx)
38	s := &Server{
39		config:        config,
40		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
41	}
42
43	return s, nil
44}
45
46func (s *Server) policy() policy.Session {
47	config := s.config
48	p := s.policyManager.ForLevel(config.UserLevel)
49	if config.Timeout > 0 && config.UserLevel == 0 {
50		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
51	}
52	return p
53}
54
55// Network implements proxy.Inbound.
56func (*Server) Network() []net.Network {
57	return []net.Network{net.Network_TCP, net.Network_UNIX}
58}
59
60func isTimeout(err error) bool {
61	nerr, ok := errors.Cause(err).(net.Error)
62	return ok && nerr.Timeout()
63}
64
65func parseBasicAuth(auth string) (username, password string, ok bool) {
66	const prefix = "Basic "
67	if !strings.HasPrefix(auth, prefix) {
68		return
69	}
70	c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
71	if err != nil {
72		return
73	}
74	cs := string(c)
75	s := strings.IndexByte(cs, ':')
76	if s < 0 {
77		return
78	}
79	return cs[:s], cs[s+1:], true
80}
81
82type readerOnly struct {
83	io.Reader
84}
85
86func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
87	inbound := session.InboundFromContext(ctx)
88	if inbound != nil {
89		inbound.User = &protocol.MemoryUser{
90			Level: s.config.UserLevel,
91		}
92	}
93
94	reader := bufio.NewReaderSize(readerOnly{conn}, buf.Size)
95
96Start:
97	if err := conn.SetReadDeadline(time.Now().Add(s.policy().Timeouts.Handshake)); err != nil {
98		newError("failed to set read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
99	}
100
101	request, err := http.ReadRequest(reader)
102	if err != nil {
103		trace := newError("failed to read http request").Base(err)
104		if errors.Cause(err) != io.EOF && !isTimeout(errors.Cause(err)) {
105			trace.AtWarning()
106		}
107		return trace
108	}
109
110	if len(s.config.Accounts) > 0 {
111		user, pass, ok := parseBasicAuth(request.Header.Get("Proxy-Authorization"))
112		if !ok || !s.config.HasAccount(user, pass) {
113			return common.Error2(conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic realm=\"proxy\"\r\n\r\n")))
114		}
115		if inbound != nil {
116			inbound.User.Email = user
117		}
118	}
119
120	newError("request to Method [", request.Method, "] Host [", request.Host, "] with URL [", request.URL, "]").WriteToLog(session.ExportIDToError(ctx))
121	if err := conn.SetReadDeadline(time.Time{}); err != nil {
122		newError("failed to clear read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
123	}
124
125	defaultPort := net.Port(80)
126	if strings.EqualFold(request.URL.Scheme, "https") {
127		defaultPort = net.Port(443)
128	}
129	host := request.Host
130	if host == "" {
131		host = request.URL.Host
132	}
133	dest, err := http_proto.ParseHost(host, defaultPort)
134	if err != nil {
135		return newError("malformed proxy host: ", host).AtWarning().Base(err)
136	}
137	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
138		From:   conn.RemoteAddr(),
139		To:     request.URL,
140		Status: log.AccessAccepted,
141		Reason: "",
142	})
143
144	if strings.EqualFold(request.Method, "CONNECT") {
145		return s.handleConnect(ctx, request, reader, conn, dest, dispatcher, inbound)
146	}
147
148	keepAlive := (strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive")
149
150	err = s.handlePlainHTTP(ctx, request, conn, dest, dispatcher)
151	if err == errWaitAnother {
152		if keepAlive {
153			goto Start
154		}
155		err = nil
156	}
157
158	return err
159}
160
161func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
162	_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
163	if err != nil {
164		return newError("failed to write back OK response").Base(err)
165	}
166
167	plcy := s.policy()
168	ctx, cancel := context.WithCancel(ctx)
169	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
170
171	if inbound != nil {
172		inbound.Timer = timer
173	}
174
175	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
176	link, err := dispatcher.Dispatch(ctx, dest)
177	if err != nil {
178		return err
179	}
180
181	if reader.Buffered() > 0 {
182		payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
183		if err != nil {
184			return err
185		}
186		if err := link.Writer.WriteMultiBuffer(payload); err != nil {
187			return err
188		}
189		reader = nil
190	}
191
192	requestDone := func() error {
193		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
194
195		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
196	}
197
198	responseDone := func() error {
199		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
200
201		v2writer := buf.NewWriter(conn)
202		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
203			return err
204		}
205
206		return nil
207	}
208
209	closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
210	if err := task.Run(ctx, closeWriter, responseDone); err != nil {
211		common.Interrupt(link.Reader)
212		common.Interrupt(link.Writer)
213		return newError("connection ends").Base(err)
214	}
215
216	return nil
217}
218
219var errWaitAnother = newError("keep alive")
220
221func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher) error {
222	if !s.config.AllowTransparent && request.URL.Host == "" {
223		// RFC 2068 (HTTP/1.1) requires URL to be absolute URL in HTTP proxy.
224		response := &http.Response{
225			Status:        "Bad Request",
226			StatusCode:    400,
227			Proto:         "HTTP/1.1",
228			ProtoMajor:    1,
229			ProtoMinor:    1,
230			Header:        http.Header(make(map[string][]string)),
231			Body:          nil,
232			ContentLength: 0,
233			Close:         true,
234		}
235		response.Header.Set("Proxy-Connection", "close")
236		response.Header.Set("Connection", "close")
237		return response.Write(writer)
238	}
239
240	if len(request.URL.Host) > 0 {
241		request.Host = request.URL.Host
242	}
243	http_proto.RemoveHopByHopHeaders(request.Header)
244
245	// Prevent UA from being set to golang's default ones
246	if request.Header.Get("User-Agent") == "" {
247		request.Header.Set("User-Agent", "")
248	}
249
250	content := &session.Content{
251		Protocol: "http/1.1",
252	}
253
254	content.SetAttribute(":method", strings.ToUpper(request.Method))
255	content.SetAttribute(":path", request.URL.Path)
256	for key := range request.Header {
257		value := request.Header.Get(key)
258		content.SetAttribute(strings.ToLower(key), value)
259	}
260
261	ctx = session.ContextWithContent(ctx, content)
262
263	link, err := dispatcher.Dispatch(ctx, dest)
264	if err != nil {
265		return err
266	}
267
268	// Plain HTTP request is not a stream. The request always finishes before response. Hense request has to be closed later.
269	defer common.Close(link.Writer)
270	var result error = errWaitAnother
271
272	requestDone := func() error {
273		request.Header.Set("Connection", "close")
274
275		requestWriter := buf.NewBufferedWriter(link.Writer)
276		common.Must(requestWriter.SetBuffered(false))
277		if err := request.Write(requestWriter); err != nil {
278			return newError("failed to write whole request").Base(err).AtWarning()
279		}
280		return nil
281	}
282
283	responseDone := func() error {
284		responseReader := bufio.NewReaderSize(&buf.BufferedReader{Reader: link.Reader}, buf.Size)
285		response, err := http.ReadResponse(responseReader, request)
286		if err == nil {
287			http_proto.RemoveHopByHopHeaders(response.Header)
288			if response.ContentLength >= 0 {
289				response.Header.Set("Proxy-Connection", "keep-alive")
290				response.Header.Set("Connection", "keep-alive")
291				response.Header.Set("Keep-Alive", "timeout=4")
292				response.Close = false
293			} else {
294				response.Close = true
295				result = nil
296			}
297			defer response.Body.Close()
298		} else {
299			newError("failed to read response from ", request.Host).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
300			response = &http.Response{
301				Status:        "Service Unavailable",
302				StatusCode:    503,
303				Proto:         "HTTP/1.1",
304				ProtoMajor:    1,
305				ProtoMinor:    1,
306				Header:        http.Header(make(map[string][]string)),
307				Body:          nil,
308				ContentLength: 0,
309				Close:         true,
310			}
311			response.Header.Set("Connection", "close")
312			response.Header.Set("Proxy-Connection", "close")
313		}
314		if err := response.Write(writer); err != nil {
315			return newError("failed to write response").Base(err).AtWarning()
316		}
317		return nil
318	}
319
320	if err := task.Run(ctx, requestDone, responseDone); err != nil {
321		common.Interrupt(link.Reader)
322		common.Interrupt(link.Writer)
323		return newError("connection ends").Base(err)
324	}
325
326	return result
327}
328
329func init() {
330	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
331		return NewServer(ctx, config.(*ServerConfig))
332	}))
333}
334