1package quic
2
3import (
4	"fmt"
5	"io"
6
7	"github.com/lucas-clemente/quic-go/internal/protocol"
8	"github.com/lucas-clemente/quic-go/internal/qerr"
9	"github.com/lucas-clemente/quic-go/internal/utils"
10	"github.com/lucas-clemente/quic-go/internal/wire"
11)
12
13type cryptoStream interface {
14	// for receiving data
15	HandleCryptoFrame(*wire.CryptoFrame) error
16	GetCryptoData() []byte
17	Finish() error
18	// for sending data
19	io.Writer
20	HasData() bool
21	PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
22}
23
24type cryptoStreamImpl struct {
25	queue  *frameSorter
26	msgBuf []byte
27
28	highestOffset protocol.ByteCount
29	finished      bool
30
31	writeOffset protocol.ByteCount
32	writeBuf    []byte
33}
34
35func newCryptoStream() cryptoStream {
36	return &cryptoStreamImpl{queue: newFrameSorter()}
37}
38
39func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
40	highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
41	if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
42		return &qerr.TransportError{
43			ErrorCode:    qerr.CryptoBufferExceeded,
44			ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
45		}
46	}
47	if s.finished {
48		if highestOffset > s.highestOffset {
49			// reject crypto data received after this stream was already finished
50			return &qerr.TransportError{
51				ErrorCode:    qerr.ProtocolViolation,
52				ErrorMessage: "received crypto data after change of encryption level",
53			}
54		}
55		// ignore data with a smaller offset than the highest received
56		// could e.g. be a retransmission
57		return nil
58	}
59	s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
60	if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
61		return err
62	}
63	for {
64		_, data, _ := s.queue.Pop()
65		if data == nil {
66			return nil
67		}
68		s.msgBuf = append(s.msgBuf, data...)
69	}
70}
71
72// GetCryptoData retrieves data that was received in CRYPTO frames
73func (s *cryptoStreamImpl) GetCryptoData() []byte {
74	if len(s.msgBuf) < 4 {
75		return nil
76	}
77	msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
78	if len(s.msgBuf) < msgLen {
79		return nil
80	}
81	msg := make([]byte, msgLen)
82	copy(msg, s.msgBuf[:msgLen])
83	s.msgBuf = s.msgBuf[msgLen:]
84	return msg
85}
86
87func (s *cryptoStreamImpl) Finish() error {
88	if s.queue.HasMoreData() {
89		return &qerr.TransportError{
90			ErrorCode:    qerr.ProtocolViolation,
91			ErrorMessage: "encryption level changed, but crypto stream has more data to read",
92		}
93	}
94	s.finished = true
95	return nil
96}
97
98// Writes writes data that should be sent out in CRYPTO frames
99func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
100	s.writeBuf = append(s.writeBuf, p...)
101	return len(p), nil
102}
103
104func (s *cryptoStreamImpl) HasData() bool {
105	return len(s.writeBuf) > 0
106}
107
108func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
109	f := &wire.CryptoFrame{Offset: s.writeOffset}
110	n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
111	f.Data = s.writeBuf[:n]
112	s.writeBuf = s.writeBuf[n:]
113	s.writeOffset += n
114	return f
115}
116