1package buf
2
3import (
4	"io"
5	"net"
6	"sync"
7
8	"github.com/v2fly/v2ray-core/v4/common"
9	"github.com/v2fly/v2ray-core/v4/common/errors"
10)
11
12// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
13type BufferToBytesWriter struct {
14	io.Writer
15
16	cache [][]byte
17}
18
19// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
20func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
21	defer ReleaseMulti(mb)
22
23	size := mb.Len()
24	if size == 0 {
25		return nil
26	}
27
28	if len(mb) == 1 {
29		return WriteAllBytes(w.Writer, mb[0].Bytes())
30	}
31
32	if cap(w.cache) < len(mb) {
33		w.cache = make([][]byte, 0, len(mb))
34	}
35
36	bs := w.cache
37	for _, b := range mb {
38		bs = append(bs, b.Bytes())
39	}
40
41	defer func() {
42		for idx := range bs {
43			bs[idx] = nil
44		}
45	}()
46
47	nb := net.Buffers(bs)
48
49	for size > 0 {
50		n, err := nb.WriteTo(w.Writer)
51		if err != nil {
52			return err
53		}
54		size -= int32(n)
55	}
56
57	return nil
58}
59
60// ReadFrom implements io.ReaderFrom.
61func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
62	var sc SizeCounter
63	err := Copy(NewReader(reader), w, CountSize(&sc))
64	return sc.Size, err
65}
66
67// BufferedWriter is a Writer with internal buffer.
68type BufferedWriter struct {
69	sync.Mutex
70	writer   Writer
71	buffer   *Buffer
72	buffered bool
73}
74
75// NewBufferedWriter creates a new BufferedWriter.
76func NewBufferedWriter(writer Writer) *BufferedWriter {
77	return &BufferedWriter{
78		writer:   writer,
79		buffer:   New(),
80		buffered: true,
81	}
82}
83
84// WriteByte implements io.ByteWriter.
85func (w *BufferedWriter) WriteByte(c byte) error {
86	return common.Error2(w.Write([]byte{c}))
87}
88
89// Write implements io.Writer.
90func (w *BufferedWriter) Write(b []byte) (int, error) {
91	if len(b) == 0 {
92		return 0, nil
93	}
94
95	w.Lock()
96	defer w.Unlock()
97
98	if !w.buffered {
99		if writer, ok := w.writer.(io.Writer); ok {
100			return writer.Write(b)
101		}
102	}
103
104	totalBytes := 0
105	for len(b) > 0 {
106		if w.buffer == nil {
107			w.buffer = New()
108		}
109
110		nBytes, err := w.buffer.Write(b)
111		totalBytes += nBytes
112		if err != nil {
113			return totalBytes, err
114		}
115		if !w.buffered || w.buffer.IsFull() {
116			if err := w.flushInternal(); err != nil {
117				return totalBytes, err
118			}
119		}
120		b = b[nBytes:]
121	}
122
123	return totalBytes, nil
124}
125
126// WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
127func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
128	if b.IsEmpty() {
129		return nil
130	}
131
132	w.Lock()
133	defer w.Unlock()
134
135	if !w.buffered {
136		return w.writer.WriteMultiBuffer(b)
137	}
138
139	reader := MultiBufferContainer{
140		MultiBuffer: b,
141	}
142	defer reader.Close()
143
144	for !reader.MultiBuffer.IsEmpty() {
145		if w.buffer == nil {
146			w.buffer = New()
147		}
148		common.Must2(w.buffer.ReadFrom(&reader))
149		if w.buffer.IsFull() {
150			if err := w.flushInternal(); err != nil {
151				return err
152			}
153		}
154	}
155
156	return nil
157}
158
159// Flush flushes buffered content into underlying writer.
160func (w *BufferedWriter) Flush() error {
161	w.Lock()
162	defer w.Unlock()
163
164	return w.flushInternal()
165}
166
167func (w *BufferedWriter) flushInternal() error {
168	if w.buffer.IsEmpty() {
169		return nil
170	}
171
172	b := w.buffer
173	w.buffer = nil
174
175	if writer, ok := w.writer.(io.Writer); ok {
176		err := WriteAllBytes(writer, b.Bytes())
177		b.Release()
178		return err
179	}
180
181	return w.writer.WriteMultiBuffer(MultiBuffer{b})
182}
183
184// SetBuffered sets whether the internal buffer is used. If set to false, Flush() will be called to clear the buffer.
185func (w *BufferedWriter) SetBuffered(f bool) error {
186	w.Lock()
187	defer w.Unlock()
188
189	w.buffered = f
190	if !f {
191		return w.flushInternal()
192	}
193	return nil
194}
195
196// ReadFrom implements io.ReaderFrom.
197func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
198	if err := w.SetBuffered(false); err != nil {
199		return 0, err
200	}
201
202	var sc SizeCounter
203	err := Copy(NewReader(reader), w, CountSize(&sc))
204	return sc.Size, err
205}
206
207// Close implements io.Closable.
208func (w *BufferedWriter) Close() error {
209	if err := w.Flush(); err != nil {
210		return err
211	}
212	return common.Close(w.writer)
213}
214
215// SequentialWriter is a Writer that writes MultiBuffer sequentially into the underlying io.Writer.
216type SequentialWriter struct {
217	io.Writer
218}
219
220// WriteMultiBuffer implements Writer.
221func (w *SequentialWriter) WriteMultiBuffer(mb MultiBuffer) error {
222	mb, err := WriteMultiBuffer(w.Writer, mb)
223	ReleaseMulti(mb)
224	return err
225}
226
227type noOpWriter byte
228
229func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
230	ReleaseMulti(b)
231	return nil
232}
233
234func (noOpWriter) Write(b []byte) (int, error) {
235	return len(b), nil
236}
237
238func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
239	b := New()
240	defer b.Release()
241
242	totalBytes := int64(0)
243	for {
244		b.Clear()
245		_, err := b.ReadFrom(reader)
246		totalBytes += int64(b.Len())
247		if err != nil {
248			if errors.Cause(err) == io.EOF {
249				return totalBytes, nil
250			}
251			return totalBytes, err
252		}
253	}
254}
255
256var (
257	// Discard is a Writer that swallows all contents written in.
258	Discard Writer = noOpWriter(0)
259
260	// DiscardBytes is an io.Writer that swallows all contents written in.
261	DiscardBytes io.Writer = noOpWriter(0)
262)
263