1package quic
2
3import (
4	"math/rand"
5	"time"
6
7	"github.com/lucas-clemente/quic-go/internal/protocol"
8
9	"github.com/lucas-clemente/quic-go/internal/utils"
10
11	. "github.com/onsi/ginkgo"
12	. "github.com/onsi/gomega"
13)
14
15var _ = Describe("MTU Discoverer", func() {
16	const (
17		rtt                         = 100 * time.Millisecond
18		startMTU protocol.ByteCount = 1000
19		maxMTU   protocol.ByteCount = 2000
20	)
21
22	var (
23		d             mtuDiscoverer
24		rttStats      *utils.RTTStats
25		now           time.Time
26		discoveredMTU protocol.ByteCount
27	)
28
29	BeforeEach(func() {
30		rttStats = &utils.RTTStats{}
31		rttStats.SetInitialRTT(rtt)
32		Expect(rttStats.SmoothedRTT()).To(Equal(rtt))
33		d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s })
34		now = time.Now()
35		_ = discoveredMTU
36	})
37
38	It("only allows a probe 5 RTTs after the handshake completes", func() {
39		Expect(d.ShouldSendProbe(now)).To(BeFalse())
40		Expect(d.ShouldSendProbe(now.Add(rtt * 9 / 2))).To(BeFalse())
41		Expect(d.NextProbeTime()).To(BeTemporally("~", now.Add(5*rtt), scaleDuration(20*time.Millisecond)))
42		Expect(d.ShouldSendProbe(now.Add(rtt * 5))).To(BeTrue())
43	})
44
45	It("doesn't allow a probe if another probe is still in flight", func() {
46		ping, _ := d.GetPing()
47		Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse())
48		Expect(d.NextProbeTime()).To(BeZero())
49		ping.OnLost(ping.Frame)
50		Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue())
51		Expect(d.NextProbeTime()).ToNot(BeZero())
52	})
53
54	It("tries a lower size when a probe is lost", func() {
55		ping, size := d.GetPing()
56		Expect(size).To(Equal(protocol.ByteCount(1500)))
57		ping.OnLost(ping.Frame)
58		_, size = d.GetPing()
59		Expect(size).To(Equal(protocol.ByteCount(1250)))
60	})
61
62	It("tries a higher size and calls the callback when a probe is acknowledged", func() {
63		ping, size := d.GetPing()
64		Expect(size).To(Equal(protocol.ByteCount(1500)))
65		ping.OnAcked(ping.Frame)
66		Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500)))
67		_, size = d.GetPing()
68		Expect(size).To(Equal(protocol.ByteCount(1750)))
69	})
70
71	It("stops discovery after getting close enough to the MTU", func() {
72		var sizes []protocol.ByteCount
73		t := now.Add(5 * rtt)
74		for d.ShouldSendProbe(t) {
75			ping, size := d.GetPing()
76			ping.OnAcked(ping.Frame)
77			sizes = append(sizes, size)
78			t = t.Add(5 * rtt)
79		}
80		Expect(sizes).To(Equal([]protocol.ByteCount{1500, 1750, 1875, 1937, 1968, 1984}))
81		Expect(d.ShouldSendProbe(t.Add(10 * rtt))).To(BeFalse())
82		Expect(d.NextProbeTime()).To(BeZero())
83	})
84
85	It("finds the MTU", func() {
86		const rep = 3000
87		var maxDiff protocol.ByteCount
88		for i := 0; i < rep; i++ {
89			max := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1
90			currentMTU := startMTU
91			d := newMTUDiscoverer(rttStats, startMTU, max, func(s protocol.ByteCount) { currentMTU = s })
92			now := time.Now()
93			realMTU := protocol.ByteCount(rand.Intn(int(max-startMTU))) + startMTU
94			t := now.Add(mtuProbeDelay * rtt)
95			var count int
96			for d.ShouldSendProbe(t) {
97				if count > 25 {
98					Fail("too many iterations")
99				}
100				count++
101
102				ping, size := d.GetPing()
103				if size <= realMTU {
104					ping.OnAcked(ping.Frame)
105				} else {
106					ping.OnLost(ping.Frame)
107				}
108				t = t.Add(mtuProbeDelay * rtt)
109			}
110			diff := realMTU - currentMTU
111			Expect(diff).To(BeNumerically(">=", 0))
112			maxDiff = utils.MaxByteCount(maxDiff, diff)
113		}
114		Expect(maxDiff).To(BeEquivalentTo(maxMTUDiff))
115	})
116})
117