1package quic
2
3import (
4	"bytes"
5	"context"
6	"errors"
7	"math/rand"
8	"time"
9
10	"github.com/golang/mock/gomock"
11	"github.com/lucas-clemente/quic-go/internal/protocol"
12	"github.com/lucas-clemente/quic-go/internal/wire"
13
14	. "github.com/onsi/ginkgo"
15	. "github.com/onsi/gomega"
16)
17
18type mockGenericStream struct {
19	num protocol.StreamNum
20
21	closed   bool
22	closeErr error
23}
24
25func (s *mockGenericStream) closeForShutdown(err error) {
26	s.closed = true
27	s.closeErr = err
28}
29
30var _ = Describe("Streams Map (incoming)", func() {
31	var (
32		m              *incomingItemsMap
33		newItemCounter int
34		mockSender     *MockStreamSender
35		maxNumStreams  uint64
36	)
37
38	// check that the frame can be serialized and deserialized
39	checkFrameSerialization := func(f wire.Frame) {
40		b := &bytes.Buffer{}
41		ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed())
42		frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT)
43		ExpectWithOffset(1, err).ToNot(HaveOccurred())
44		Expect(f).To(Equal(frame))
45	}
46
47	BeforeEach(func() { maxNumStreams = 5 })
48
49	JustBeforeEach(func() {
50		newItemCounter = 0
51		mockSender = NewMockStreamSender(mockCtrl)
52		m = newIncomingItemsMap(
53			func(num protocol.StreamNum) item {
54				newItemCounter++
55				return &mockGenericStream{num: num}
56			},
57			maxNumStreams,
58			mockSender.queueControlFrame,
59		)
60	})
61
62	It("opens all streams up to the id on GetOrOpenStream", func() {
63		_, err := m.GetOrOpenStream(4)
64		Expect(err).ToNot(HaveOccurred())
65		Expect(newItemCounter).To(Equal(4))
66	})
67
68	It("starts opening streams at the right position", func() {
69		// like the test above, but with 2 calls to GetOrOpenStream
70		_, err := m.GetOrOpenStream(2)
71		Expect(err).ToNot(HaveOccurred())
72		Expect(newItemCounter).To(Equal(2))
73		_, err = m.GetOrOpenStream(5)
74		Expect(err).ToNot(HaveOccurred())
75		Expect(newItemCounter).To(Equal(5))
76	})
77
78	It("accepts streams in the right order", func() {
79		_, err := m.GetOrOpenStream(2) // open streams 1 and 2
80		Expect(err).ToNot(HaveOccurred())
81		str, err := m.AcceptStream(context.Background())
82		Expect(err).ToNot(HaveOccurred())
83		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
84		str, err = m.AcceptStream(context.Background())
85		Expect(err).ToNot(HaveOccurred())
86		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
87	})
88
89	It("allows opening the maximum stream ID", func() {
90		str, err := m.GetOrOpenStream(1)
91		Expect(err).ToNot(HaveOccurred())
92		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
93	})
94
95	It("errors when trying to get a stream ID higher than the maximum", func() {
96		_, err := m.GetOrOpenStream(6)
97		Expect(err).To(HaveOccurred())
98		Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)"))
99	})
100
101	It("blocks AcceptStream until a new stream is available", func() {
102		strChan := make(chan item)
103		go func() {
104			defer GinkgoRecover()
105			str, err := m.AcceptStream(context.Background())
106			Expect(err).ToNot(HaveOccurred())
107			strChan <- str
108		}()
109		Consistently(strChan).ShouldNot(Receive())
110		str, err := m.GetOrOpenStream(1)
111		Expect(err).ToNot(HaveOccurred())
112		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
113		var acceptedStr item
114		Eventually(strChan).Should(Receive(&acceptedStr))
115		Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
116	})
117
118	It("unblocks AcceptStream when the context is canceled", func() {
119		ctx, cancel := context.WithCancel(context.Background())
120		done := make(chan struct{})
121		go func() {
122			defer GinkgoRecover()
123			_, err := m.AcceptStream(ctx)
124			Expect(err).To(MatchError("context canceled"))
125			close(done)
126		}()
127		Consistently(done).ShouldNot(BeClosed())
128		cancel()
129		Eventually(done).Should(BeClosed())
130	})
131
132	It("unblocks AcceptStream when it is closed", func() {
133		testErr := errors.New("test error")
134		done := make(chan struct{})
135		go func() {
136			defer GinkgoRecover()
137			_, err := m.AcceptStream(context.Background())
138			Expect(err).To(MatchError(testErr))
139			close(done)
140		}()
141		Consistently(done).ShouldNot(BeClosed())
142		m.CloseWithError(testErr)
143		Eventually(done).Should(BeClosed())
144	})
145
146	It("errors AcceptStream immediately if it is closed", func() {
147		testErr := errors.New("test error")
148		m.CloseWithError(testErr)
149		_, err := m.AcceptStream(context.Background())
150		Expect(err).To(MatchError(testErr))
151	})
152
153	It("closes all streams when CloseWithError is called", func() {
154		str1, err := m.GetOrOpenStream(1)
155		Expect(err).ToNot(HaveOccurred())
156		str2, err := m.GetOrOpenStream(3)
157		Expect(err).ToNot(HaveOccurred())
158		testErr := errors.New("test err")
159		m.CloseWithError(testErr)
160		Expect(str1.(*mockGenericStream).closed).To(BeTrue())
161		Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr))
162		Expect(str2.(*mockGenericStream).closed).To(BeTrue())
163		Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr))
164	})
165
166	It("deletes streams", func() {
167		mockSender.EXPECT().queueControlFrame(gomock.Any())
168		_, err := m.GetOrOpenStream(1)
169		Expect(err).ToNot(HaveOccurred())
170		str, err := m.AcceptStream(context.Background())
171		Expect(err).ToNot(HaveOccurred())
172		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
173		Expect(m.DeleteStream(1)).To(Succeed())
174		str, err = m.GetOrOpenStream(1)
175		Expect(err).ToNot(HaveOccurred())
176		Expect(str).To(BeNil())
177	})
178
179	It("waits until a stream is accepted before actually deleting it", func() {
180		_, err := m.GetOrOpenStream(2)
181		Expect(err).ToNot(HaveOccurred())
182		Expect(m.DeleteStream(2)).To(Succeed())
183		str, err := m.AcceptStream(context.Background())
184		Expect(err).ToNot(HaveOccurred())
185		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
186		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
187		mockSender.EXPECT().queueControlFrame(gomock.Any())
188		str, err = m.AcceptStream(context.Background())
189		Expect(err).ToNot(HaveOccurred())
190		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
191	})
192
193	It("doesn't return a stream queued for deleting from GetOrOpenStream", func() {
194		str, err := m.GetOrOpenStream(1)
195		Expect(err).ToNot(HaveOccurred())
196		Expect(str).ToNot(BeNil())
197		Expect(m.DeleteStream(1)).To(Succeed())
198		str, err = m.GetOrOpenStream(1)
199		Expect(err).ToNot(HaveOccurred())
200		Expect(str).To(BeNil())
201		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
202		mockSender.EXPECT().queueControlFrame(gomock.Any())
203		str, err = m.AcceptStream(context.Background())
204		Expect(err).ToNot(HaveOccurred())
205		Expect(str).ToNot(BeNil())
206	})
207
208	It("errors when deleting a non-existing stream", func() {
209		err := m.DeleteStream(1337)
210		Expect(err).To(HaveOccurred())
211		Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown incoming stream 1337"))
212	})
213
214	It("sends MAX_STREAMS frames when streams are deleted", func() {
215		// open a bunch of streams
216		_, err := m.GetOrOpenStream(5)
217		Expect(err).ToNot(HaveOccurred())
218		// accept all streams
219		for i := 0; i < 5; i++ {
220			_, err := m.AcceptStream(context.Background())
221			Expect(err).ToNot(HaveOccurred())
222		}
223		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
224			Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1)))
225			checkFrameSerialization(f)
226		})
227		Expect(m.DeleteStream(3)).To(Succeed())
228		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
229			Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2)))
230			checkFrameSerialization(f)
231		})
232		Expect(m.DeleteStream(4)).To(Succeed())
233	})
234
235	Context("using high stream limits", func() {
236		BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 })
237
238		It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() {
239			// open a bunch of streams
240			_, err := m.GetOrOpenStream(5)
241			Expect(err).ToNot(HaveOccurred())
242			// accept all streams
243			for i := 0; i < 5; i++ {
244				_, err := m.AcceptStream(context.Background())
245				Expect(err).ToNot(HaveOccurred())
246			}
247			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
248				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1))
249				checkFrameSerialization(f)
250			})
251			Expect(m.DeleteStream(4)).To(Succeed())
252			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
253				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount))
254				checkFrameSerialization(f)
255			})
256			Expect(m.DeleteStream(3)).To(Succeed())
257			// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
258			Expect(m.DeleteStream(2)).To(Succeed())
259			Expect(m.DeleteStream(1)).To(Succeed())
260		})
261	})
262
263	Context("randomized tests", func() {
264		const num = 1000
265
266		BeforeEach(func() { maxNumStreams = num })
267
268		It("opens and accepts streams", func() {
269			rand.Seed(GinkgoRandomSeed())
270			ids := make([]protocol.StreamNum, num)
271			for i := 0; i < num; i++ {
272				ids[i] = protocol.StreamNum(i + 1)
273			}
274			rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
275
276			const timeout = 5 * time.Second
277			done := make(chan struct{}, 2)
278			go func() {
279				defer GinkgoRecover()
280				ctx, cancel := context.WithTimeout(context.Background(), timeout)
281				defer cancel()
282				for i := 0; i < num; i++ {
283					_, err := m.AcceptStream(ctx)
284					Expect(err).ToNot(HaveOccurred())
285				}
286				done <- struct{}{}
287			}()
288
289			go func() {
290				defer GinkgoRecover()
291				for i := 0; i < num; i++ {
292					_, err := m.GetOrOpenStream(ids[i])
293					Expect(err).ToNot(HaveOccurred())
294				}
295				done <- struct{}{}
296			}()
297
298			Eventually(done, timeout*3/2).Should(Receive())
299			Eventually(done, timeout*3/2).Should(Receive())
300		})
301	})
302})
303