1package crypto
2
3import (
4	"crypto/cipher"
5	"io"
6	"math/rand"
7
8	"github.com/v2fly/v2ray-core/v4/common"
9	"github.com/v2fly/v2ray-core/v4/common/buf"
10	"github.com/v2fly/v2ray-core/v4/common/bytespool"
11	"github.com/v2fly/v2ray-core/v4/common/errors"
12	"github.com/v2fly/v2ray-core/v4/common/protocol"
13)
14
15type BytesGenerator func() []byte
16
17func GenerateEmptyBytes() BytesGenerator {
18	var b [1]byte
19	return func() []byte {
20		return b[:0]
21	}
22}
23
24func GenerateStaticBytes(content []byte) BytesGenerator {
25	return func() []byte {
26		return content
27	}
28}
29
30func GenerateIncreasingNonce(nonce []byte) BytesGenerator {
31	c := append([]byte(nil), nonce...)
32	return func() []byte {
33		for i := range c {
34			c[i]++
35			if c[i] != 0 {
36				break
37			}
38		}
39		return c
40	}
41}
42
43func GenerateInitialAEADNonce() BytesGenerator {
44	return GenerateIncreasingNonce([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF})
45}
46
47type Authenticator interface {
48	NonceSize() int
49	Overhead() int
50	Open(dst, cipherText []byte) ([]byte, error)
51	Seal(dst, plainText []byte) ([]byte, error)
52}
53
54type AEADAuthenticator struct {
55	cipher.AEAD
56	NonceGenerator          BytesGenerator
57	AdditionalDataGenerator BytesGenerator
58}
59
60func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
61	iv := v.NonceGenerator()
62	if len(iv) != v.AEAD.NonceSize() {
63		return nil, newError("invalid AEAD nonce size: ", len(iv))
64	}
65
66	var additionalData []byte
67	if v.AdditionalDataGenerator != nil {
68		additionalData = v.AdditionalDataGenerator()
69	}
70	return v.AEAD.Open(dst, iv, cipherText, additionalData)
71}
72
73func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
74	iv := v.NonceGenerator()
75	if len(iv) != v.AEAD.NonceSize() {
76		return nil, newError("invalid AEAD nonce size: ", len(iv))
77	}
78
79	var additionalData []byte
80	if v.AdditionalDataGenerator != nil {
81		additionalData = v.AdditionalDataGenerator()
82	}
83	return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
84}
85
86type AuthenticationReader struct {
87	auth         Authenticator
88	reader       *buf.BufferedReader
89	sizeParser   ChunkSizeDecoder
90	sizeBytes    []byte
91	transferType protocol.TransferType
92	padding      PaddingLengthGenerator
93	size         uint16
94	paddingLen   uint16
95	hasSize      bool
96	done         bool
97}
98
99func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType, paddingLen PaddingLengthGenerator) *AuthenticationReader {
100	r := &AuthenticationReader{
101		auth:         auth,
102		sizeParser:   sizeParser,
103		transferType: transferType,
104		padding:      paddingLen,
105		sizeBytes:    make([]byte, sizeParser.SizeBytes()),
106	}
107	if breader, ok := reader.(*buf.BufferedReader); ok {
108		r.reader = breader
109	} else {
110		r.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)}
111	}
112	return r
113}
114
115func (r *AuthenticationReader) readSize() (uint16, uint16, error) {
116	if r.hasSize {
117		r.hasSize = false
118		return r.size, r.paddingLen, nil
119	}
120	if _, err := io.ReadFull(r.reader, r.sizeBytes); err != nil {
121		return 0, 0, err
122	}
123	var padding uint16
124	if r.padding != nil {
125		padding = r.padding.NextPaddingLen()
126	}
127	size, err := r.sizeParser.Decode(r.sizeBytes)
128	return size, padding, err
129}
130
131var errSoft = newError("waiting for more data")
132
133func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
134	b := buf.New()
135	if _, err := b.ReadFullFrom(r.reader, size); err != nil {
136		b.Release()
137		return nil, err
138	}
139	size -= padding
140	rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size))
141	if err != nil {
142		b.Release()
143		return nil, err
144	}
145	b.Resize(0, int32(len(rb)))
146	return b, nil
147}
148
149func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) error {
150	if soft && r.reader.BufferedBytes() < r.sizeParser.SizeBytes() {
151		return errSoft
152	}
153
154	if r.done {
155		return io.EOF
156	}
157
158	size, padding, err := r.readSize()
159	if err != nil {
160		return err
161	}
162
163	if size == uint16(r.auth.Overhead())+padding {
164		r.done = true
165		return io.EOF
166	}
167
168	if soft && int32(size) > r.reader.BufferedBytes() {
169		r.size = size
170		r.paddingLen = padding
171		r.hasSize = true
172		return errSoft
173	}
174
175	if size <= buf.Size {
176		b, err := r.readBuffer(int32(size), int32(padding))
177		if err != nil {
178			return nil
179		}
180		*mb = append(*mb, b)
181		return nil
182	}
183
184	payload := bytespool.Alloc(int32(size))
185	defer bytespool.Free(payload)
186
187	if _, err := io.ReadFull(r.reader, payload[:size]); err != nil {
188		return err
189	}
190
191	size -= padding
192
193	rb, err := r.auth.Open(payload[:0], payload[:size])
194	if err != nil {
195		return err
196	}
197
198	*mb = buf.MergeBytes(*mb, rb)
199	return nil
200}
201
202func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
203	const readSize = 16
204	mb := make(buf.MultiBuffer, 0, readSize)
205	if err := r.readInternal(false, &mb); err != nil {
206		buf.ReleaseMulti(mb)
207		return nil, err
208	}
209
210	for i := 1; i < readSize; i++ {
211		err := r.readInternal(true, &mb)
212		if err == errSoft || err == io.EOF {
213			break
214		}
215		if err != nil {
216			buf.ReleaseMulti(mb)
217			return nil, err
218		}
219	}
220
221	return mb, nil
222}
223
224type AuthenticationWriter struct {
225	auth         Authenticator
226	writer       buf.Writer
227	sizeParser   ChunkSizeEncoder
228	transferType protocol.TransferType
229	padding      PaddingLengthGenerator
230}
231
232func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType, padding PaddingLengthGenerator) *AuthenticationWriter {
233	w := &AuthenticationWriter{
234		auth:         auth,
235		writer:       buf.NewWriter(writer),
236		sizeParser:   sizeParser,
237		transferType: transferType,
238	}
239	if padding != nil {
240		w.padding = padding
241	}
242	return w
243}
244
245func (w *AuthenticationWriter) seal(b []byte) (*buf.Buffer, error) {
246	encryptedSize := int32(len(b) + w.auth.Overhead())
247	var paddingSize int32
248	if w.padding != nil {
249		paddingSize = int32(w.padding.NextPaddingLen())
250	}
251
252	sizeBytes := w.sizeParser.SizeBytes()
253	totalSize := sizeBytes + encryptedSize + paddingSize
254	if totalSize > buf.Size {
255		return nil, newError("size too large: ", totalSize)
256	}
257
258	eb := buf.New()
259	w.sizeParser.Encode(uint16(encryptedSize+paddingSize), eb.Extend(sizeBytes))
260	if _, err := w.auth.Seal(eb.Extend(encryptedSize)[:0], b); err != nil {
261		eb.Release()
262		return nil, err
263	}
264	if paddingSize > 0 {
265		// With size of the chunk and padding length encrypted, the content of padding doesn't matter much.
266		paddingBytes := eb.Extend(paddingSize)
267		common.Must2(rand.Read(paddingBytes))
268	}
269
270	return eb, nil
271}
272
273func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
274	defer buf.ReleaseMulti(mb)
275
276	var maxPadding int32
277	if w.padding != nil {
278		maxPadding = int32(w.padding.MaxPaddingLen())
279	}
280
281	payloadSize := buf.Size - int32(w.auth.Overhead()) - w.sizeParser.SizeBytes() - maxPadding
282	if len(mb)+10 > 64*1024*1024 {
283		return errors.New("value too large")
284	}
285	sliceSize := len(mb) + 10
286	mb2Write := make(buf.MultiBuffer, 0, sliceSize)
287
288	temp := buf.New()
289	defer temp.Release()
290
291	rawBytes := temp.Extend(payloadSize)
292
293	for {
294		nb, nBytes := buf.SplitBytes(mb, rawBytes)
295		mb = nb
296
297		eb, err := w.seal(rawBytes[:nBytes])
298
299		if err != nil {
300			buf.ReleaseMulti(mb2Write)
301			return err
302		}
303		mb2Write = append(mb2Write, eb)
304		if mb.IsEmpty() {
305			break
306		}
307	}
308
309	return w.writer.WriteMultiBuffer(mb2Write)
310}
311
312func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
313	defer buf.ReleaseMulti(mb)
314
315	if len(mb)+1 > 64*1024*1024 {
316		return errors.New("value too large")
317	}
318	sliceSize := len(mb) + 1
319	mb2Write := make(buf.MultiBuffer, 0, sliceSize)
320
321	for _, b := range mb {
322		if b.IsEmpty() {
323			continue
324		}
325
326		eb, err := w.seal(b.Bytes())
327		if err != nil {
328			continue
329		}
330
331		mb2Write = append(mb2Write, eb)
332	}
333
334	if mb2Write.IsEmpty() {
335		return nil
336	}
337
338	return w.writer.WriteMultiBuffer(mb2Write)
339}
340
341// WriteMultiBuffer implements buf.Writer.
342func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
343	if mb.IsEmpty() {
344		eb, err := w.seal([]byte{})
345		common.Must(err)
346		return w.writer.WriteMultiBuffer(buf.MultiBuffer{eb})
347	}
348
349	if w.transferType == protocol.TransferTypeStream {
350		return w.writeStream(mb)
351	}
352
353	return w.writePacket(mb)
354}
355