1package quic
2
3import (
4	"errors"
5
6	"github.com/lucas-clemente/quic-go/internal/protocol"
7	"github.com/lucas-clemente/quic-go/internal/utils"
8)
9
10type frameSorterEntry struct {
11	Data   []byte
12	DoneCb func()
13}
14
15type frameSorter struct {
16	queue   map[protocol.ByteCount]frameSorterEntry
17	readPos protocol.ByteCount
18	gaps    *utils.ByteIntervalList
19}
20
21var errDuplicateStreamData = errors.New("duplicate stream data")
22
23func newFrameSorter() *frameSorter {
24	s := frameSorter{
25		gaps:  utils.NewByteIntervalList(),
26		queue: make(map[protocol.ByteCount]frameSorterEntry),
27	}
28	s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
29	return &s
30}
31
32func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
33	err := s.push(data, offset, doneCb)
34	if err == errDuplicateStreamData {
35		if doneCb != nil {
36			doneCb()
37		}
38		return nil
39	}
40	return err
41}
42
43func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
44	if len(data) == 0 {
45		return errDuplicateStreamData
46	}
47
48	start := offset
49	end := offset + protocol.ByteCount(len(data))
50
51	if end <= s.gaps.Front().Value.Start {
52		return errDuplicateStreamData
53	}
54
55	startGap, startsInGap := s.findStartGap(start)
56	endGap, endsInGap := s.findEndGap(startGap, end)
57
58	startGapEqualsEndGap := startGap == endGap
59
60	if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
61		(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
62		return errDuplicateStreamData
63	}
64
65	startGapNext := startGap.Next()
66	startGapEnd := startGap.Value.End // save it, in case startGap is modified
67	endGapStart := endGap.Value.Start // save it, in case endGap is modified
68	endGapEnd := endGap.Value.End     // save it, in case endGap is modified
69	var adjustedStartGapEnd bool
70	var wasCut bool
71
72	pos := start
73	var hasReplacedAtLeastOne bool
74	for {
75		oldEntry, ok := s.queue[pos]
76		if !ok {
77			break
78		}
79		oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
80		if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
81			// The existing frame is shorter than the new frame. Replace it.
82			delete(s.queue, pos)
83			pos += oldEntryLen
84			hasReplacedAtLeastOne = true
85			if oldEntry.DoneCb != nil {
86				oldEntry.DoneCb()
87			}
88		} else {
89			if !hasReplacedAtLeastOne {
90				return errDuplicateStreamData
91			}
92			// The existing frame is longer than the new frame.
93			// Cut the new frame such that the end aligns with the start of the existing frame.
94			data = data[:pos-start]
95			end = pos
96			wasCut = true
97			break
98		}
99	}
100
101	if !startsInGap && !hasReplacedAtLeastOne {
102		// cut the frame, such that it starts at the start of the gap
103		data = data[startGap.Value.Start-start:]
104		start = startGap.Value.Start
105		wasCut = true
106	}
107	if start <= startGap.Value.Start {
108		if end >= startGap.Value.End {
109			// The frame covers the whole startGap. Delete the gap.
110			s.gaps.Remove(startGap)
111		} else {
112			startGap.Value.Start = end
113		}
114	} else if !hasReplacedAtLeastOne {
115		startGap.Value.End = start
116		adjustedStartGapEnd = true
117	}
118
119	if !startGapEqualsEndGap {
120		s.deleteConsecutive(startGapEnd)
121		var nextGap *utils.ByteIntervalElement
122		for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
123			nextGap = gap.Next()
124			s.deleteConsecutive(gap.Value.End)
125			s.gaps.Remove(gap)
126		}
127	}
128
129	if !endsInGap && start != endGapEnd && end > endGapEnd {
130		// cut the frame, such that it ends at the end of the gap
131		data = data[:endGapEnd-start]
132		end = endGapEnd
133		wasCut = true
134	}
135	if end == endGapEnd {
136		if !startGapEqualsEndGap {
137			// The frame covers the whole endGap. Delete the gap.
138			s.gaps.Remove(endGap)
139		}
140	} else {
141		if startGapEqualsEndGap && adjustedStartGapEnd {
142			// The frame split the existing gap into two.
143			s.gaps.InsertAfter(utils.ByteInterval{Start: end, End: startGapEnd}, startGap)
144		} else if !startGapEqualsEndGap {
145			endGap.Value.Start = end
146		}
147	}
148
149	if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
150		newData := make([]byte, len(data))
151		copy(newData, data)
152		data = newData
153		if doneCb != nil {
154			doneCb()
155			doneCb = nil
156		}
157	}
158
159	if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
160		return errors.New("too many gaps in received data")
161	}
162
163	s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
164	return nil
165}
166
167func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) {
168	for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
169		if offset >= gap.Value.Start && offset <= gap.Value.End {
170			return gap, true
171		}
172		if offset < gap.Value.Start {
173			return gap, false
174		}
175	}
176	panic("no gap found")
177}
178
179func (s *frameSorter) findEndGap(startGap *utils.ByteIntervalElement, offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) {
180	for gap := startGap; gap != nil; gap = gap.Next() {
181		if offset >= gap.Value.Start && offset < gap.Value.End {
182			return gap, true
183		}
184		if offset < gap.Value.Start {
185			return gap.Prev(), false
186		}
187	}
188	panic("no gap found")
189}
190
191// deleteConsecutive deletes consecutive frames from the queue, starting at pos
192func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
193	for {
194		oldEntry, ok := s.queue[pos]
195		if !ok {
196			break
197		}
198		oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
199		delete(s.queue, pos)
200		if oldEntry.DoneCb != nil {
201			oldEntry.DoneCb()
202		}
203		pos += oldEntryLen
204	}
205}
206
207func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
208	entry, ok := s.queue[s.readPos]
209	if !ok {
210		return s.readPos, nil, nil
211	}
212	delete(s.queue, s.readPos)
213	offset := s.readPos
214	s.readPos += protocol.ByteCount(len(entry.Data))
215	if s.gaps.Front().Value.End <= s.readPos {
216		panic("frame sorter BUG: read position higher than a gap")
217	}
218	return offset, entry.Data, entry.DoneCb
219}
220
221// HasMoreData says if there is any more data queued at *any* offset.
222func (s *frameSorter) HasMoreData() bool {
223	return len(s.queue) > 0
224}
225