1package buf
2
3import (
4	"io"
5
6	"github.com/v2fly/v2ray-core/v4/common"
7	"github.com/v2fly/v2ray-core/v4/common/errors"
8)
9
10func readOneUDP(r io.Reader) (*Buffer, error) {
11	b := New()
12	for i := 0; i < 64; i++ {
13		_, err := b.ReadFrom(r)
14		if !b.IsEmpty() {
15			return b, nil
16		}
17		if err != nil {
18			b.Release()
19			return nil, err
20		}
21	}
22
23	b.Release()
24	return nil, newError("Reader returns too many empty payloads.")
25}
26
27// ReadBuffer reads a Buffer from the given reader.
28func ReadBuffer(r io.Reader) (*Buffer, error) {
29	b := New()
30	n, err := b.ReadFrom(r)
31	if n > 0 {
32		return b, err
33	}
34	b.Release()
35	return nil, err
36}
37
38// BufferedReader is a Reader that keeps its internal buffer.
39type BufferedReader struct {
40	// Reader is the underlying reader to be read from
41	Reader Reader
42	// Buffer is the internal buffer to be read from first
43	Buffer MultiBuffer
44	// Spliter is a function to read bytes from MultiBuffer
45	Spliter func(MultiBuffer, []byte) (MultiBuffer, int)
46}
47
48// BufferedBytes returns the number of bytes that is cached in this reader.
49func (r *BufferedReader) BufferedBytes() int32 {
50	return r.Buffer.Len()
51}
52
53// ReadByte implements io.ByteReader.
54func (r *BufferedReader) ReadByte() (byte, error) {
55	var b [1]byte
56	_, err := r.Read(b[:])
57	return b[0], err
58}
59
60// Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
61func (r *BufferedReader) Read(b []byte) (int, error) {
62	spliter := r.Spliter
63	if spliter == nil {
64		spliter = SplitBytes
65	}
66
67	if !r.Buffer.IsEmpty() {
68		buffer, nBytes := spliter(r.Buffer, b)
69		r.Buffer = buffer
70		if r.Buffer.IsEmpty() {
71			r.Buffer = nil
72		}
73		return nBytes, nil
74	}
75
76	mb, err := r.Reader.ReadMultiBuffer()
77	if err != nil {
78		return 0, err
79	}
80
81	mb, nBytes := spliter(mb, b)
82	if !mb.IsEmpty() {
83		r.Buffer = mb
84	}
85	return nBytes, nil
86}
87
88// ReadMultiBuffer implements Reader.
89func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
90	if !r.Buffer.IsEmpty() {
91		mb := r.Buffer
92		r.Buffer = nil
93		return mb, nil
94	}
95
96	return r.Reader.ReadMultiBuffer()
97}
98
99// ReadAtMost returns a MultiBuffer with at most size.
100func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) {
101	if r.Buffer.IsEmpty() {
102		mb, err := r.Reader.ReadMultiBuffer()
103		if mb.IsEmpty() && err != nil {
104			return nil, err
105		}
106		r.Buffer = mb
107	}
108
109	rb, mb := SplitSize(r.Buffer, size)
110	r.Buffer = rb
111	if r.Buffer.IsEmpty() {
112		r.Buffer = nil
113	}
114	return mb, nil
115}
116
117func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
118	mbWriter := NewWriter(writer)
119	var sc SizeCounter
120	if r.Buffer != nil {
121		sc.Size = int64(r.Buffer.Len())
122		if err := mbWriter.WriteMultiBuffer(r.Buffer); err != nil {
123			return 0, err
124		}
125		r.Buffer = nil
126	}
127
128	err := Copy(r.Reader, mbWriter, CountSize(&sc))
129	return sc.Size, err
130}
131
132// WriteTo implements io.WriterTo.
133func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
134	nBytes, err := r.writeToInternal(writer)
135	if errors.Cause(err) == io.EOF {
136		return nBytes, nil
137	}
138	return nBytes, err
139}
140
141// Interrupt implements common.Interruptible.
142func (r *BufferedReader) Interrupt() {
143	common.Interrupt(r.Reader)
144}
145
146// Close implements io.Closer.
147func (r *BufferedReader) Close() error {
148	return common.Close(r.Reader)
149}
150
151// SingleReader is a Reader that read one Buffer every time.
152type SingleReader struct {
153	io.Reader
154}
155
156// ReadMultiBuffer implements Reader.
157func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) {
158	b, err := ReadBuffer(r.Reader)
159	return MultiBuffer{b}, err
160}
161
162// PacketReader is a Reader that read one Buffer every time.
163type PacketReader struct {
164	io.Reader
165}
166
167// ReadMultiBuffer implements Reader.
168func (r *PacketReader) ReadMultiBuffer() (MultiBuffer, error) {
169	b, err := readOneUDP(r.Reader)
170	if err != nil {
171		return nil, err
172	}
173	return MultiBuffer{b}, nil
174}
175