1package congestion
2
3import (
4	"time"
5
6	"github.com/lucas-clemente/quic-go/internal/protocol"
7	"github.com/lucas-clemente/quic-go/internal/utils"
8)
9
10const (
11	maxBurstBytes                                     = 3 * protocol.DefaultTCPMSS
12	renoBeta                       float32            = 0.7 // Reno backoff factor.
13	defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
14)
15
16type cubicSender struct {
17	hybridSlowStart HybridSlowStart
18	prr             PrrSender
19	rttStats        *RTTStats
20	stats           connectionStats
21	cubic           *Cubic
22
23	reno bool
24
25	// Track the largest packet that has been sent.
26	largestSentPacketNumber protocol.PacketNumber
27
28	// Track the largest packet that has been acked.
29	largestAckedPacketNumber protocol.PacketNumber
30
31	// Track the largest packet number outstanding when a CWND cutback occurs.
32	largestSentAtLastCutback protocol.PacketNumber
33
34	// Whether the last loss event caused us to exit slowstart.
35	// Used for stats collection of slowstartPacketsLost
36	lastCutbackExitedSlowstart bool
37
38	// When true, exit slow start with large cutback of congestion window.
39	slowStartLargeReduction bool
40
41	// Congestion window in packets.
42	congestionWindow protocol.ByteCount
43
44	// Minimum congestion window in packets.
45	minCongestionWindow protocol.ByteCount
46
47	// Maximum congestion window.
48	maxCongestionWindow protocol.ByteCount
49
50	// Slow start congestion window in bytes, aka ssthresh.
51	slowstartThreshold protocol.ByteCount
52
53	// Number of connections to simulate.
54	numConnections int
55
56	// ACK counter for the Reno implementation.
57	numAckedPackets uint64
58
59	initialCongestionWindow    protocol.ByteCount
60	initialMaxCongestionWindow protocol.ByteCount
61
62	minSlowStartExitWindow protocol.ByteCount
63}
64
65var _ SendAlgorithm = &cubicSender{}
66var _ SendAlgorithmWithDebugInfo = &cubicSender{}
67
68// NewCubicSender makes a new cubic sender
69func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
70	return &cubicSender{
71		rttStats:                   rttStats,
72		initialCongestionWindow:    initialCongestionWindow,
73		initialMaxCongestionWindow: initialMaxCongestionWindow,
74		congestionWindow:           initialCongestionWindow,
75		minCongestionWindow:        defaultMinimumCongestionWindow,
76		slowstartThreshold:         initialMaxCongestionWindow,
77		maxCongestionWindow:        initialMaxCongestionWindow,
78		numConnections:             defaultNumConnections,
79		cubic:                      NewCubic(clock),
80		reno:                       reno,
81	}
82}
83
84// TimeUntilSend returns when the next packet should be sent.
85func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
86	if c.InRecovery() {
87		// PRR is used when in recovery.
88		if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
89			return 0
90		}
91	}
92	delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
93	if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
94		delay = delay * 8 / 5
95	}
96	return delay
97}
98
99func (c *cubicSender) OnPacketSent(
100	sentTime time.Time,
101	bytesInFlight protocol.ByteCount,
102	packetNumber protocol.PacketNumber,
103	bytes protocol.ByteCount,
104	isRetransmittable bool,
105) {
106	if !isRetransmittable {
107		return
108	}
109	if c.InRecovery() {
110		// PRR is used when in recovery.
111		c.prr.OnPacketSent(bytes)
112	}
113	c.largestSentPacketNumber = packetNumber
114	c.hybridSlowStart.OnPacketSent(packetNumber)
115}
116
117func (c *cubicSender) InRecovery() bool {
118	return c.largestAckedPacketNumber <= c.largestSentAtLastCutback && c.largestAckedPacketNumber != 0
119}
120
121func (c *cubicSender) InSlowStart() bool {
122	return c.GetCongestionWindow() < c.GetSlowStartThreshold()
123}
124
125func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
126	return c.congestionWindow
127}
128
129func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
130	return c.slowstartThreshold
131}
132
133func (c *cubicSender) ExitSlowstart() {
134	c.slowstartThreshold = c.congestionWindow
135}
136
137func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
138	return c.slowstartThreshold
139}
140
141func (c *cubicSender) MaybeExitSlowStart() {
142	if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
143		c.ExitSlowstart()
144	}
145}
146
147func (c *cubicSender) OnPacketAcked(
148	ackedPacketNumber protocol.PacketNumber,
149	ackedBytes protocol.ByteCount,
150	priorInFlight protocol.ByteCount,
151	eventTime time.Time,
152) {
153	c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
154	if c.InRecovery() {
155		// PRR is used when in recovery.
156		c.prr.OnPacketAcked(ackedBytes)
157		return
158	}
159	c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
160	if c.InSlowStart() {
161		c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
162	}
163}
164
165func (c *cubicSender) OnPacketLost(
166	packetNumber protocol.PacketNumber,
167	lostBytes protocol.ByteCount,
168	priorInFlight protocol.ByteCount,
169) {
170	// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
171	// already sent should be treated as a single loss event, since it's expected.
172	if packetNumber <= c.largestSentAtLastCutback {
173		if c.lastCutbackExitedSlowstart {
174			c.stats.slowstartPacketsLost++
175			c.stats.slowstartBytesLost += lostBytes
176			if c.slowStartLargeReduction {
177				// Reduce congestion window by lost_bytes for every loss.
178				c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
179				c.slowstartThreshold = c.congestionWindow
180			}
181		}
182		return
183	}
184	c.lastCutbackExitedSlowstart = c.InSlowStart()
185	if c.InSlowStart() {
186		c.stats.slowstartPacketsLost++
187	}
188
189	c.prr.OnPacketLost(priorInFlight)
190
191	// TODO(chromium): Separate out all of slow start into a separate class.
192	if c.slowStartLargeReduction && c.InSlowStart() {
193		if c.congestionWindow >= 2*c.initialCongestionWindow {
194			c.minSlowStartExitWindow = c.congestionWindow / 2
195		}
196		c.congestionWindow -= protocol.DefaultTCPMSS
197	} else if c.reno {
198		c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
199	} else {
200		c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
201	}
202	if c.congestionWindow < c.minCongestionWindow {
203		c.congestionWindow = c.minCongestionWindow
204	}
205	c.slowstartThreshold = c.congestionWindow
206	c.largestSentAtLastCutback = c.largestSentPacketNumber
207	// reset packet count from congestion avoidance mode. We start
208	// counting again when we're out of recovery.
209	c.numAckedPackets = 0
210}
211
212func (c *cubicSender) RenoBeta() float32 {
213	// kNConnectionBeta is the backoff factor after loss for our N-connection
214	// emulation, which emulates the effective backoff of an ensemble of N
215	// TCP-Reno connections on a single loss event. The effective multiplier is
216	// computed as:
217	return (float32(c.numConnections) - 1. + renoBeta) / float32(c.numConnections)
218}
219
220// Called when we receive an ack. Normal TCP tracks how many packets one ack
221// represents, but quic has a separate ack for each packet.
222func (c *cubicSender) maybeIncreaseCwnd(
223	ackedPacketNumber protocol.PacketNumber,
224	ackedBytes protocol.ByteCount,
225	priorInFlight protocol.ByteCount,
226	eventTime time.Time,
227) {
228	// Do not increase the congestion window unless the sender is close to using
229	// the current window.
230	if !c.isCwndLimited(priorInFlight) {
231		c.cubic.OnApplicationLimited()
232		return
233	}
234	if c.congestionWindow >= c.maxCongestionWindow {
235		return
236	}
237	if c.InSlowStart() {
238		// TCP slow start, exponential growth, increase by one for each ACK.
239		c.congestionWindow += protocol.DefaultTCPMSS
240		return
241	}
242	// Congestion avoidance
243	if c.reno {
244		// Classic Reno congestion avoidance.
245		c.numAckedPackets++
246		// Divide by num_connections to smoothly increase the CWND at a faster
247		// rate than conventional Reno.
248		if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
249			c.congestionWindow += protocol.DefaultTCPMSS
250			c.numAckedPackets = 0
251		}
252	} else {
253		c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
254	}
255}
256
257func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
258	congestionWindow := c.GetCongestionWindow()
259	if bytesInFlight >= congestionWindow {
260		return true
261	}
262	availableBytes := congestionWindow - bytesInFlight
263	slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
264	return slowStartLimited || availableBytes <= maxBurstBytes
265}
266
267// BandwidthEstimate returns the current bandwidth estimate
268func (c *cubicSender) BandwidthEstimate() Bandwidth {
269	srtt := c.rttStats.SmoothedRTT()
270	if srtt == 0 {
271		// If we haven't measured an rtt, the bandwidth estimate is unknown.
272		return 0
273	}
274	return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
275}
276
277// HybridSlowStart returns the hybrid slow start instance for testing
278func (c *cubicSender) HybridSlowStart() *HybridSlowStart {
279	return &c.hybridSlowStart
280}
281
282// SetNumEmulatedConnections sets the number of emulated connections
283func (c *cubicSender) SetNumEmulatedConnections(n int) {
284	c.numConnections = utils.Max(n, 1)
285	c.cubic.SetNumConnections(c.numConnections)
286}
287
288// OnRetransmissionTimeout is called on an retransmission timeout
289func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
290	c.largestSentAtLastCutback = 0
291	if !packetsRetransmitted {
292		return
293	}
294	c.hybridSlowStart.Restart()
295	c.cubic.Reset()
296	c.slowstartThreshold = c.congestionWindow / 2
297	c.congestionWindow = c.minCongestionWindow
298}
299
300// OnConnectionMigration is called when the connection is migrated (?)
301func (c *cubicSender) OnConnectionMigration() {
302	c.hybridSlowStart.Restart()
303	c.prr = PrrSender{}
304	c.largestSentPacketNumber = 0
305	c.largestAckedPacketNumber = 0
306	c.largestSentAtLastCutback = 0
307	c.lastCutbackExitedSlowstart = false
308	c.cubic.Reset()
309	c.numAckedPackets = 0
310	c.congestionWindow = c.initialCongestionWindow
311	c.slowstartThreshold = c.initialMaxCongestionWindow
312	c.maxCongestionWindow = c.initialMaxCongestionWindow
313}
314
315// SetSlowStartLargeReduction allows enabling the SSLR experiment
316func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
317	c.slowStartLargeReduction = enabled
318}
319