1package congestion
2
3import (
4	"fmt"
5	"time"
6
7	"github.com/lucas-clemente/quic-go/internal/protocol"
8	"github.com/lucas-clemente/quic-go/internal/utils"
9	"github.com/lucas-clemente/quic-go/logging"
10)
11
12const (
13	// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
14	// Used in QUIC for congestion window computations in bytes.
15	initialMaxDatagramSize     = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
16	maxBurstPackets            = 3
17	renoBeta                   = 0.7 // Reno backoff factor.
18	minCongestionWindowPackets = 2
19	initialCongestionWindow    = 32
20)
21
22type cubicSender struct {
23	hybridSlowStart HybridSlowStart
24	rttStats        *utils.RTTStats
25	cubic           *Cubic
26	pacer           *pacer
27	clock           Clock
28
29	reno bool
30
31	// Track the largest packet that has been sent.
32	largestSentPacketNumber protocol.PacketNumber
33
34	// Track the largest packet that has been acked.
35	largestAckedPacketNumber protocol.PacketNumber
36
37	// Track the largest packet number outstanding when a CWND cutback occurs.
38	largestSentAtLastCutback protocol.PacketNumber
39
40	// Whether the last loss event caused us to exit slowstart.
41	// Used for stats collection of slowstartPacketsLost
42	lastCutbackExitedSlowstart bool
43
44	// Congestion window in packets.
45	congestionWindow protocol.ByteCount
46
47	// Slow start congestion window in bytes, aka ssthresh.
48	slowStartThreshold protocol.ByteCount
49
50	// ACK counter for the Reno implementation.
51	numAckedPackets uint64
52
53	initialCongestionWindow    protocol.ByteCount
54	initialMaxCongestionWindow protocol.ByteCount
55
56	maxDatagramSize protocol.ByteCount
57
58	lastState logging.CongestionState
59	tracer    logging.ConnectionTracer
60}
61
62var (
63	_ SendAlgorithm               = &cubicSender{}
64	_ SendAlgorithmWithDebugInfos = &cubicSender{}
65)
66
67// NewCubicSender makes a new cubic sender
68func NewCubicSender(
69	clock Clock,
70	rttStats *utils.RTTStats,
71	initialMaxDatagramSize protocol.ByteCount,
72	reno bool,
73	tracer logging.ConnectionTracer,
74) *cubicSender {
75	return newCubicSender(
76		clock,
77		rttStats,
78		reno,
79		initialMaxDatagramSize,
80		initialCongestionWindow*initialMaxDatagramSize,
81		protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
82		tracer,
83	)
84}
85
86func newCubicSender(
87	clock Clock,
88	rttStats *utils.RTTStats,
89	reno bool,
90	initialMaxDatagramSize,
91	initialCongestionWindow,
92	initialMaxCongestionWindow protocol.ByteCount,
93	tracer logging.ConnectionTracer,
94) *cubicSender {
95	c := &cubicSender{
96		rttStats:                   rttStats,
97		largestSentPacketNumber:    protocol.InvalidPacketNumber,
98		largestAckedPacketNumber:   protocol.InvalidPacketNumber,
99		largestSentAtLastCutback:   protocol.InvalidPacketNumber,
100		initialCongestionWindow:    initialCongestionWindow,
101		initialMaxCongestionWindow: initialMaxCongestionWindow,
102		congestionWindow:           initialCongestionWindow,
103		slowStartThreshold:         protocol.MaxByteCount,
104		cubic:                      NewCubic(clock),
105		clock:                      clock,
106		reno:                       reno,
107		tracer:                     tracer,
108		maxDatagramSize:            initialMaxDatagramSize,
109	}
110	c.pacer = newPacer(c.BandwidthEstimate)
111	if c.tracer != nil {
112		c.lastState = logging.CongestionStateSlowStart
113		c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
114	}
115	return c
116}
117
118// TimeUntilSend returns when the next packet should be sent.
119func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
120	return c.pacer.TimeUntilSend()
121}
122
123func (c *cubicSender) HasPacingBudget() bool {
124	return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize
125}
126
127func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
128	return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
129}
130
131func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
132	return c.maxDatagramSize * minCongestionWindowPackets
133}
134
135func (c *cubicSender) OnPacketSent(
136	sentTime time.Time,
137	_ protocol.ByteCount,
138	packetNumber protocol.PacketNumber,
139	bytes protocol.ByteCount,
140	isRetransmittable bool,
141) {
142	c.pacer.SentPacket(sentTime, bytes)
143	if !isRetransmittable {
144		return
145	}
146	c.largestSentPacketNumber = packetNumber
147	c.hybridSlowStart.OnPacketSent(packetNumber)
148}
149
150func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
151	return bytesInFlight < c.GetCongestionWindow()
152}
153
154func (c *cubicSender) InRecovery() bool {
155	return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
156}
157
158func (c *cubicSender) InSlowStart() bool {
159	return c.GetCongestionWindow() < c.slowStartThreshold
160}
161
162func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
163	return c.congestionWindow
164}
165
166func (c *cubicSender) MaybeExitSlowStart() {
167	if c.InSlowStart() &&
168		c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
169		// exit slow start
170		c.slowStartThreshold = c.congestionWindow
171		c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
172	}
173}
174
175func (c *cubicSender) OnPacketAcked(
176	ackedPacketNumber protocol.PacketNumber,
177	ackedBytes protocol.ByteCount,
178	priorInFlight protocol.ByteCount,
179	eventTime time.Time,
180) {
181	c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
182	if c.InRecovery() {
183		return
184	}
185	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
186	if c.InSlowStart() {
187		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
188	}
189}
190
191func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
192	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
193	// already sent should be treated as a single loss event, since it's expected.
194	if packetNumber <= c.largestSentAtLastCutback {
195		return
196	}
197	c.lastCutbackExitedSlowstart = c.InSlowStart()
198	c.maybeTraceStateChange(logging.CongestionStateRecovery)
199
200	if c.reno {
201		c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
202	} else {
203		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
204	}
205	if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
206		c.congestionWindow = minCwnd
207	}
208	c.slowStartThreshold = c.congestionWindow
209	c.largestSentAtLastCutback = c.largestSentPacketNumber
210	// reset packet count from congestion avoidance mode. We start
211	// counting again when we're out of recovery.
212	c.numAckedPackets = 0
213}
214
215// Called when we receive an ack. Normal TCP tracks how many packets one ack
216// represents, but quic has a separate ack for each packet.
217func (c *cubicSender) maybeIncreaseCwnd(
218	_ protocol.PacketNumber,
219	ackedBytes protocol.ByteCount,
220	priorInFlight protocol.ByteCount,
221	eventTime time.Time,
222) {
223	// Do not increase the congestion window unless the sender is close to using
224	// the current window.
225	if !c.isCwndLimited(priorInFlight) {
226		c.cubic.OnApplicationLimited()
227		c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
228		return
229	}
230	if c.congestionWindow >= c.maxCongestionWindow() {
231		return
232	}
233	if c.InSlowStart() {
234		// TCP slow start, exponential growth, increase by one for each ACK.
235		c.congestionWindow += c.maxDatagramSize
236		c.maybeTraceStateChange(logging.CongestionStateSlowStart)
237		return
238	}
239	// Congestion avoidance
240	c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
241	if c.reno {
242		// Classic Reno congestion avoidance.
243		c.numAckedPackets++
244		if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
245			c.congestionWindow += c.maxDatagramSize
246			c.numAckedPackets = 0
247		}
248	} else {
249		c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
250	}
251}
252
253func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
254	congestionWindow := c.GetCongestionWindow()
255	if bytesInFlight >= congestionWindow {
256		return true
257	}
258	availableBytes := congestionWindow - bytesInFlight
259	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
260	return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
261}
262
263// BandwidthEstimate returns the current bandwidth estimate
264func (c *cubicSender) BandwidthEstimate() Bandwidth {
265	srtt := c.rttStats.SmoothedRTT()
266	if srtt == 0 {
267		// If we haven't measured an rtt, the bandwidth estimate is unknown.
268		return infBandwidth
269	}
270	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
271}
272
273// OnRetransmissionTimeout is called on an retransmission timeout
274func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
275	c.largestSentAtLastCutback = protocol.InvalidPacketNumber
276	if !packetsRetransmitted {
277		return
278	}
279	c.hybridSlowStart.Restart()
280	c.cubic.Reset()
281	c.slowStartThreshold = c.congestionWindow / 2
282	c.congestionWindow = c.minCongestionWindow()
283}
284
285// OnConnectionMigration is called when the connection is migrated (?)
286func (c *cubicSender) OnConnectionMigration() {
287	c.hybridSlowStart.Restart()
288	c.largestSentPacketNumber = protocol.InvalidPacketNumber
289	c.largestAckedPacketNumber = protocol.InvalidPacketNumber
290	c.largestSentAtLastCutback = protocol.InvalidPacketNumber
291	c.lastCutbackExitedSlowstart = false
292	c.cubic.Reset()
293	c.numAckedPackets = 0
294	c.congestionWindow = c.initialCongestionWindow
295	c.slowStartThreshold = c.initialMaxCongestionWindow
296}
297
298func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
299	if c.tracer == nil || new == c.lastState {
300		return
301	}
302	c.tracer.UpdatedCongestionState(new)
303	c.lastState = new
304}
305
306func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
307	if s < c.maxDatagramSize {
308		panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
309	}
310	cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
311	c.maxDatagramSize = s
312	if cwndIsMinCwnd {
313		c.congestionWindow = c.minCongestionWindow()
314	}
315	c.pacer.SetMaxDatagramSize(s)
316}
317