1// +build !confonly
2
3package trojan
4
5import (
6	"context"
7	"time"
8
9	core "github.com/v2fly/v2ray-core/v4"
10	"github.com/v2fly/v2ray-core/v4/common"
11	"github.com/v2fly/v2ray-core/v4/common/buf"
12	"github.com/v2fly/v2ray-core/v4/common/net"
13	"github.com/v2fly/v2ray-core/v4/common/protocol"
14	"github.com/v2fly/v2ray-core/v4/common/retry"
15	"github.com/v2fly/v2ray-core/v4/common/session"
16	"github.com/v2fly/v2ray-core/v4/common/signal"
17	"github.com/v2fly/v2ray-core/v4/common/task"
18	"github.com/v2fly/v2ray-core/v4/features/policy"
19	"github.com/v2fly/v2ray-core/v4/transport"
20	"github.com/v2fly/v2ray-core/v4/transport/internet"
21)
22
23// Client is a inbound handler for trojan protocol
24type Client struct {
25	serverPicker  protocol.ServerPicker
26	policyManager policy.Manager
27}
28
29// NewClient create a new trojan client.
30func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
31	serverList := protocol.NewServerList()
32	for _, rec := range config.Server {
33		s, err := protocol.NewServerSpecFromPB(rec)
34		if err != nil {
35			return nil, newError("failed to parse server spec").Base(err)
36		}
37		serverList.AddServer(s)
38	}
39	if serverList.Size() == 0 {
40		return nil, newError("0 server")
41	}
42
43	v := core.MustFromContext(ctx)
44	client := &Client{
45		serverPicker:  protocol.NewRoundRobinServerPicker(serverList),
46		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
47	}
48	return client, nil
49}
50
51// Process implements OutboundHandler.Process().
52func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
53	outbound := session.OutboundFromContext(ctx)
54	if outbound == nil || !outbound.Target.IsValid() {
55		return newError("target not specified")
56	}
57	destination := outbound.Target
58	network := destination.Network
59
60	var server *protocol.ServerSpec
61	var conn internet.Connection
62
63	err := retry.ExponentialBackoff(5, 100).On(func() error {
64		server = c.serverPicker.PickServer()
65		rawConn, err := dialer.Dial(ctx, server.Destination())
66		if err != nil {
67			return err
68		}
69
70		conn = rawConn
71		return nil
72	})
73	if err != nil {
74		return newError("failed to find an available destination").AtWarning().Base(err)
75	}
76	newError("tunneling request to ", destination, " via ", server.Destination()).WriteToLog(session.ExportIDToError(ctx))
77
78	defer conn.Close()
79
80	user := server.PickUser()
81	account, ok := user.Account.(*MemoryAccount)
82	if !ok {
83		return newError("user account is not valid")
84	}
85
86	sessionPolicy := c.policyManager.ForLevel(user.Level)
87	ctx, cancel := context.WithCancel(ctx)
88	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
89
90	postRequest := func() error {
91		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
92
93		var bodyWriter buf.Writer
94		bufferWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
95		connWriter := &ConnWriter{Writer: bufferWriter, Target: destination, Account: account}
96
97		if destination.Network == net.Network_UDP {
98			bodyWriter = &PacketWriter{Writer: connWriter, Target: destination}
99		} else {
100			bodyWriter = connWriter
101		}
102
103		// write some request payload to buffer
104		if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
105			return newError("failed to write A request payload").Base(err).AtWarning()
106		}
107
108		// Flush; bufferWriter.WriteMultiBufer now is bufferWriter.writer.WriteMultiBuffer
109		if err = bufferWriter.SetBuffered(false); err != nil {
110			return newError("failed to flush payload").Base(err).AtWarning()
111		}
112
113		if err = buf.Copy(link.Reader, bodyWriter, buf.UpdateActivity(timer)); err != nil {
114			return newError("failed to transfer request payload").Base(err).AtInfo()
115		}
116
117		return nil
118	}
119
120	getResponse := func() error {
121		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
122
123		var reader buf.Reader
124		if network == net.Network_UDP {
125			reader = &PacketReader{
126				Reader: conn,
127			}
128		} else {
129			reader = buf.NewReader(conn)
130		}
131		return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
132	}
133
134	var responseDoneAndCloseWriter = task.OnSuccess(getResponse, task.Close(link.Writer))
135	if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil {
136		return newError("connection ends").Base(err)
137	}
138
139	return nil
140}
141
142func init() {
143	common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
144		return NewClient(ctx, config.(*ClientConfig))
145	}))
146}
147