1package tcp
2
3import (
4	"context"
5	"fmt"
6	"io"
7
8	"github.com/v2fly/v2ray-core/v4/common/buf"
9	"github.com/v2fly/v2ray-core/v4/common/net"
10	"github.com/v2fly/v2ray-core/v4/common/task"
11	"github.com/v2fly/v2ray-core/v4/transport/internet"
12	"github.com/v2fly/v2ray-core/v4/transport/pipe"
13)
14
15type Server struct {
16	Port         net.Port
17	MsgProcessor func(msg []byte) []byte
18	ShouldClose  bool
19	SendFirst    []byte
20	Listen       net.Address
21	listener     net.Listener
22}
23
24func (server *Server) Start() (net.Destination, error) {
25	return server.StartContext(context.Background(), nil)
26}
27
28func (server *Server) StartContext(ctx context.Context, sockopt *internet.SocketConfig) (net.Destination, error) {
29	listenerAddr := server.Listen
30	if listenerAddr == nil {
31		listenerAddr = net.LocalHostIP
32	}
33	listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
34		IP:   listenerAddr.IP(),
35		Port: int(server.Port),
36	}, sockopt)
37	if err != nil {
38		return net.Destination{}, err
39	}
40
41	localAddr := listener.Addr().(*net.TCPAddr)
42	server.Port = net.Port(localAddr.Port)
43	server.listener = listener
44	go server.acceptConnections(listener.(*net.TCPListener))
45
46	return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil
47}
48
49func (server *Server) acceptConnections(listener *net.TCPListener) {
50	for {
51		conn, err := listener.Accept()
52		if err != nil {
53			fmt.Printf("Failed accept TCP connection: %v\n", err)
54			return
55		}
56
57		go server.handleConnection(conn)
58	}
59}
60
61func (server *Server) handleConnection(conn net.Conn) {
62	if len(server.SendFirst) > 0 {
63		conn.Write(server.SendFirst)
64	}
65
66	pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
67	err := task.Run(context.Background(), func() error {
68		defer pWriter.Close()
69
70		for {
71			b := buf.New()
72			if _, err := b.ReadFrom(conn); err != nil {
73				if err == io.EOF {
74					return nil
75				}
76				return err
77			}
78			copy(b.Bytes(), server.MsgProcessor(b.Bytes()))
79			if err := pWriter.WriteMultiBuffer(buf.MultiBuffer{b}); err != nil {
80				return err
81			}
82		}
83	}, func() error {
84		defer pReader.Interrupt()
85
86		w := buf.NewWriter(conn)
87		for {
88			mb, err := pReader.ReadMultiBuffer()
89			if err != nil {
90				if err == io.EOF {
91					return nil
92				}
93				return err
94			}
95			if err := w.WriteMultiBuffer(mb); err != nil {
96				return err
97			}
98		}
99	})
100
101	if err != nil {
102		fmt.Println("failed to transfer data: ", err.Error())
103	}
104
105	conn.Close()
106}
107
108func (server *Server) Close() error {
109	return server.listener.Close()
110}
111