1package buf
2
3import (
4	"io"
5	"net"
6	"os"
7	"syscall"
8	"time"
9
10	"github.com/xtls/xray-core/features/stats"
11	"github.com/xtls/xray-core/transport/internet/stat"
12)
13
14// Reader extends io.Reader with MultiBuffer.
15type Reader interface {
16	// ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
17	ReadMultiBuffer() (MultiBuffer, error)
18}
19
20// ErrReadTimeout is an error that happens with IO timeout.
21var ErrReadTimeout = newError("IO timeout")
22
23// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
24type TimeoutReader interface {
25	ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
26}
27
28// Writer extends io.Writer with MultiBuffer.
29type Writer interface {
30	// WriteMultiBuffer writes a MultiBuffer into underlying writer.
31	WriteMultiBuffer(MultiBuffer) error
32}
33
34// WriteAllBytes ensures all bytes are written into the given writer.
35func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
36	wc := 0
37	defer func() {
38		if c != nil {
39			c.Add(int64(wc))
40		}
41	}()
42
43	for len(payload) > 0 {
44		n, err := writer.Write(payload)
45		wc += n
46		if err != nil {
47			return err
48		}
49		payload = payload[n:]
50	}
51	return nil
52}
53
54func isPacketReader(reader io.Reader) bool {
55	_, ok := reader.(net.PacketConn)
56	return ok
57}
58
59// NewReader creates a new Reader.
60// The Reader instance doesn't take the ownership of reader.
61func NewReader(reader io.Reader) Reader {
62	if mr, ok := reader.(Reader); ok {
63		return mr
64	}
65
66	if isPacketReader(reader) {
67		return &PacketReader{
68			Reader: reader,
69		}
70	}
71
72	_, isFile := reader.(*os.File)
73	if !isFile && useReadv {
74		if sc, ok := reader.(syscall.Conn); ok {
75			rawConn, err := sc.SyscallConn()
76			if err != nil {
77				newError("failed to get sysconn").Base(err).WriteToLog()
78			} else {
79				var counter stats.Counter
80
81				if statConn, ok := reader.(*stat.CounterConnection); ok {
82					reader = statConn.Connection
83					counter = statConn.ReadCounter
84				}
85				return NewReadVReader(reader, rawConn, counter)
86			}
87		}
88	}
89
90	return &SingleReader{
91		Reader: reader,
92	}
93}
94
95// NewPacketReader creates a new PacketReader based on the given reader.
96func NewPacketReader(reader io.Reader) Reader {
97	if mr, ok := reader.(Reader); ok {
98		return mr
99	}
100
101	return &PacketReader{
102		Reader: reader,
103	}
104}
105
106func isPacketWriter(writer io.Writer) bool {
107	if _, ok := writer.(net.PacketConn); ok {
108		return true
109	}
110
111	// If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
112	if _, ok := writer.(syscall.Conn); !ok {
113		return true
114	}
115	return false
116}
117
118// NewWriter creates a new Writer.
119func NewWriter(writer io.Writer) Writer {
120	if mw, ok := writer.(Writer); ok {
121		return mw
122	}
123
124	iConn := writer
125	if statConn, ok := writer.(*stat.CounterConnection); ok {
126		iConn = statConn.Connection
127	}
128
129	if isPacketWriter(iConn) {
130		return &SequentialWriter{
131			Writer: writer,
132		}
133	}
134
135	var counter stats.Counter
136
137	if statConn, ok := writer.(*stat.CounterConnection); ok {
138		counter = statConn.WriteCounter
139	}
140	return &BufferToBytesWriter{
141		Writer:  iConn,
142		counter: counter,
143	}
144}
145