1package wire
2
3import (
4	"bytes"
5	"errors"
6	"sort"
7	"time"
8
9	"github.com/lucas-clemente/quic-go/internal/protocol"
10	"github.com/lucas-clemente/quic-go/internal/utils"
11	"github.com/lucas-clemente/quic-go/quicvarint"
12)
13
14var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
15
16// An AckFrame is an ACK frame
17type AckFrame struct {
18	AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
19	DelayTime time.Duration
20
21	ECT0, ECT1, ECNCE uint64
22}
23
24// parseAckFrame reads an ACK frame
25func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) {
26	typeByte, err := r.ReadByte()
27	if err != nil {
28		return nil, err
29	}
30	ecn := typeByte&0x1 > 0
31
32	frame := &AckFrame{}
33
34	la, err := quicvarint.Read(r)
35	if err != nil {
36		return nil, err
37	}
38	largestAcked := protocol.PacketNumber(la)
39	delay, err := quicvarint.Read(r)
40	if err != nil {
41		return nil, err
42	}
43
44	delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
45	if delayTime < 0 {
46		// If the delay time overflows, set it to the maximum encodable value.
47		delayTime = utils.InfDuration
48	}
49	frame.DelayTime = delayTime
50
51	numBlocks, err := quicvarint.Read(r)
52	if err != nil {
53		return nil, err
54	}
55
56	// read the first ACK range
57	ab, err := quicvarint.Read(r)
58	if err != nil {
59		return nil, err
60	}
61	ackBlock := protocol.PacketNumber(ab)
62	if ackBlock > largestAcked {
63		return nil, errors.New("invalid first ACK range")
64	}
65	smallest := largestAcked - ackBlock
66
67	// read all the other ACK ranges
68	frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
69	for i := uint64(0); i < numBlocks; i++ {
70		g, err := quicvarint.Read(r)
71		if err != nil {
72			return nil, err
73		}
74		gap := protocol.PacketNumber(g)
75		if smallest < gap+2 {
76			return nil, errInvalidAckRanges
77		}
78		largest := smallest - gap - 2
79
80		ab, err := quicvarint.Read(r)
81		if err != nil {
82			return nil, err
83		}
84		ackBlock := protocol.PacketNumber(ab)
85
86		if ackBlock > largest {
87			return nil, errInvalidAckRanges
88		}
89		smallest = largest - ackBlock
90		frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
91	}
92
93	if !frame.validateAckRanges() {
94		return nil, errInvalidAckRanges
95	}
96
97	// parse (and skip) the ECN section
98	if ecn {
99		for i := 0; i < 3; i++ {
100			if _, err := quicvarint.Read(r); err != nil {
101				return nil, err
102			}
103		}
104	}
105
106	return frame, nil
107}
108
109// Write writes an ACK frame.
110func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
111	hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
112	if hasECN {
113		b.WriteByte(0x3)
114	} else {
115		b.WriteByte(0x2)
116	}
117	quicvarint.Write(b, uint64(f.LargestAcked()))
118	quicvarint.Write(b, encodeAckDelay(f.DelayTime))
119
120	numRanges := f.numEncodableAckRanges()
121	quicvarint.Write(b, uint64(numRanges-1))
122
123	// write the first range
124	_, firstRange := f.encodeAckRange(0)
125	quicvarint.Write(b, firstRange)
126
127	// write all the other range
128	for i := 1; i < numRanges; i++ {
129		gap, len := f.encodeAckRange(i)
130		quicvarint.Write(b, gap)
131		quicvarint.Write(b, len)
132	}
133
134	if hasECN {
135		quicvarint.Write(b, f.ECT0)
136		quicvarint.Write(b, f.ECT1)
137		quicvarint.Write(b, f.ECNCE)
138	}
139	return nil
140}
141
142// Length of a written frame
143func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
144	largestAcked := f.AckRanges[0].Largest
145	numRanges := f.numEncodableAckRanges()
146
147	length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime))
148
149	length += quicvarint.Len(uint64(numRanges - 1))
150	lowestInFirstRange := f.AckRanges[0].Smallest
151	length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange))
152
153	for i := 1; i < numRanges; i++ {
154		gap, len := f.encodeAckRange(i)
155		length += quicvarint.Len(gap)
156		length += quicvarint.Len(len)
157	}
158	if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 {
159		length += quicvarint.Len(f.ECT0)
160		length += quicvarint.Len(f.ECT1)
161		length += quicvarint.Len(f.ECNCE)
162	}
163	return length
164}
165
166// gets the number of ACK ranges that can be encoded
167// such that the resulting frame is smaller than the maximum ACK frame size
168func (f *AckFrame) numEncodableAckRanges() int {
169	length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime))
170	length += 2 // assume that the number of ranges will consume 2 bytes
171	for i := 1; i < len(f.AckRanges); i++ {
172		gap, len := f.encodeAckRange(i)
173		rangeLen := quicvarint.Len(gap) + quicvarint.Len(len)
174		if length+rangeLen > protocol.MaxAckFrameSize {
175			// Writing range i would exceed the MaxAckFrameSize.
176			// So encode one range less than that.
177			return i - 1
178		}
179		length += rangeLen
180	}
181	return len(f.AckRanges)
182}
183
184func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) {
185	if i == 0 {
186		return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest)
187	}
188	return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2),
189		uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest)
190}
191
192// HasMissingRanges returns if this frame reports any missing packets
193func (f *AckFrame) HasMissingRanges() bool {
194	return len(f.AckRanges) > 1
195}
196
197func (f *AckFrame) validateAckRanges() bool {
198	if len(f.AckRanges) == 0 {
199		return false
200	}
201
202	// check the validity of every single ACK range
203	for _, ackRange := range f.AckRanges {
204		if ackRange.Smallest > ackRange.Largest {
205			return false
206		}
207	}
208
209	// check the consistency for ACK with multiple NACK ranges
210	for i, ackRange := range f.AckRanges {
211		if i == 0 {
212			continue
213		}
214		lastAckRange := f.AckRanges[i-1]
215		if lastAckRange.Smallest <= ackRange.Smallest {
216			return false
217		}
218		if lastAckRange.Smallest <= ackRange.Largest+1 {
219			return false
220		}
221	}
222
223	return true
224}
225
226// LargestAcked is the largest acked packet number
227func (f *AckFrame) LargestAcked() protocol.PacketNumber {
228	return f.AckRanges[0].Largest
229}
230
231// LowestAcked is the lowest acked packet number
232func (f *AckFrame) LowestAcked() protocol.PacketNumber {
233	return f.AckRanges[len(f.AckRanges)-1].Smallest
234}
235
236// AcksPacket determines if this ACK frame acks a certain packet number
237func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
238	if p < f.LowestAcked() || p > f.LargestAcked() {
239		return false
240	}
241
242	i := sort.Search(len(f.AckRanges), func(i int) bool {
243		return p >= f.AckRanges[i].Smallest
244	})
245	// i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked
246	return p <= f.AckRanges[i].Largest
247}
248
249func encodeAckDelay(delay time.Duration) uint64 {
250	return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
251}
252