1package quic
2
3import (
4	"fmt"
5	"io"
6
7	"github.com/lucas-clemente/quic-go/internal/protocol"
8	"github.com/lucas-clemente/quic-go/internal/qerr"
9	"github.com/lucas-clemente/quic-go/internal/utils"
10	"github.com/lucas-clemente/quic-go/internal/wire"
11)
12
13type cryptoStream interface {
14	// for receiving data
15	HandleCryptoFrame(*wire.CryptoFrame) error
16	GetCryptoData() []byte
17	Finish() error
18	// for sending data
19	io.Writer
20	HasData() bool
21	PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
22}
23
24type cryptoStreamImpl struct {
25	queue  *frameSorter
26	msgBuf []byte
27
28	highestOffset protocol.ByteCount
29	finished      bool
30
31	writeOffset protocol.ByteCount
32	writeBuf    []byte
33}
34
35func newCryptoStream() cryptoStream {
36	return &cryptoStreamImpl{queue: newFrameSorter()}
37}
38
39func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
40	highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
41	if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
42		return qerr.NewError(qerr.CryptoBufferExceeded, fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset))
43	}
44	if s.finished {
45		if highestOffset > s.highestOffset {
46			// reject crypto data received after this stream was already finished
47			return qerr.NewError(qerr.ProtocolViolation, "received crypto data after change of encryption level")
48		}
49		// ignore data with a smaller offset than the highest received
50		// could e.g. be a retransmission
51		return nil
52	}
53	s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
54	if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
55		return err
56	}
57	for {
58		_, data, _ := s.queue.Pop()
59		if data == nil {
60			return nil
61		}
62		s.msgBuf = append(s.msgBuf, data...)
63	}
64}
65
66// GetCryptoData retrieves data that was received in CRYPTO frames
67func (s *cryptoStreamImpl) GetCryptoData() []byte {
68	if len(s.msgBuf) < 4 {
69		return nil
70	}
71	msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
72	if len(s.msgBuf) < msgLen {
73		return nil
74	}
75	msg := make([]byte, msgLen)
76	copy(msg, s.msgBuf[:msgLen])
77	s.msgBuf = s.msgBuf[msgLen:]
78	return msg
79}
80
81func (s *cryptoStreamImpl) Finish() error {
82	if s.queue.HasMoreData() {
83		return qerr.NewError(qerr.ProtocolViolation, "encryption level changed, but crypto stream has more data to read")
84	}
85	s.finished = true
86	return nil
87}
88
89// Writes writes data that should be sent out in CRYPTO frames
90func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
91	s.writeBuf = append(s.writeBuf, p...)
92	return len(p), nil
93}
94
95func (s *cryptoStreamImpl) HasData() bool {
96	return len(s.writeBuf) > 0
97}
98
99func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
100	f := &wire.CryptoFrame{Offset: s.writeOffset}
101	n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
102	f.Data = s.writeBuf[:n]
103	s.writeBuf = s.writeBuf[n:]
104	s.writeOffset += n
105	return f
106}
107