1package quic
2
3import (
4	"bytes"
5	"context"
6	"errors"
7	"math/rand"
8	"time"
9
10	"github.com/golang/mock/gomock"
11	"github.com/lucas-clemente/quic-go/internal/protocol"
12	"github.com/lucas-clemente/quic-go/internal/wire"
13
14	. "github.com/onsi/ginkgo"
15	. "github.com/onsi/gomega"
16)
17
18type mockGenericStream struct {
19	num protocol.StreamNum
20
21	closed     bool
22	closeErr   error
23	sendWindow protocol.ByteCount
24}
25
26func (s *mockGenericStream) closeForShutdown(err error) {
27	s.closed = true
28	s.closeErr = err
29}
30
31func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) {
32	s.sendWindow = limit
33}
34
35var _ = Describe("Streams Map (incoming)", func() {
36	var (
37		m              *incomingItemsMap
38		newItemCounter int
39		mockSender     *MockStreamSender
40		maxNumStreams  uint64
41	)
42
43	// check that the frame can be serialized and deserialized
44	checkFrameSerialization := func(f wire.Frame) {
45		b := &bytes.Buffer{}
46		ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed())
47		frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT)
48		ExpectWithOffset(1, err).ToNot(HaveOccurred())
49		Expect(f).To(Equal(frame))
50	}
51
52	BeforeEach(func() { maxNumStreams = 5 })
53
54	JustBeforeEach(func() {
55		newItemCounter = 0
56		mockSender = NewMockStreamSender(mockCtrl)
57		m = newIncomingItemsMap(
58			func(num protocol.StreamNum) item {
59				newItemCounter++
60				return &mockGenericStream{num: num}
61			},
62			maxNumStreams,
63			mockSender.queueControlFrame,
64		)
65	})
66
67	It("opens all streams up to the id on GetOrOpenStream", func() {
68		_, err := m.GetOrOpenStream(4)
69		Expect(err).ToNot(HaveOccurred())
70		Expect(newItemCounter).To(Equal(4))
71	})
72
73	It("starts opening streams at the right position", func() {
74		// like the test above, but with 2 calls to GetOrOpenStream
75		_, err := m.GetOrOpenStream(2)
76		Expect(err).ToNot(HaveOccurred())
77		Expect(newItemCounter).To(Equal(2))
78		_, err = m.GetOrOpenStream(5)
79		Expect(err).ToNot(HaveOccurred())
80		Expect(newItemCounter).To(Equal(5))
81	})
82
83	It("accepts streams in the right order", func() {
84		_, err := m.GetOrOpenStream(2) // open streams 1 and 2
85		Expect(err).ToNot(HaveOccurred())
86		str, err := m.AcceptStream(context.Background())
87		Expect(err).ToNot(HaveOccurred())
88		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
89		str, err = m.AcceptStream(context.Background())
90		Expect(err).ToNot(HaveOccurred())
91		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
92	})
93
94	It("allows opening the maximum stream ID", func() {
95		str, err := m.GetOrOpenStream(1)
96		Expect(err).ToNot(HaveOccurred())
97		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
98	})
99
100	It("errors when trying to get a stream ID higher than the maximum", func() {
101		_, err := m.GetOrOpenStream(6)
102		Expect(err).To(HaveOccurred())
103		Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)"))
104	})
105
106	It("blocks AcceptStream until a new stream is available", func() {
107		strChan := make(chan item)
108		go func() {
109			defer GinkgoRecover()
110			str, err := m.AcceptStream(context.Background())
111			Expect(err).ToNot(HaveOccurred())
112			strChan <- str
113		}()
114		Consistently(strChan).ShouldNot(Receive())
115		str, err := m.GetOrOpenStream(1)
116		Expect(err).ToNot(HaveOccurred())
117		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
118		var acceptedStr item
119		Eventually(strChan).Should(Receive(&acceptedStr))
120		Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
121	})
122
123	It("unblocks AcceptStream when the context is canceled", func() {
124		ctx, cancel := context.WithCancel(context.Background())
125		done := make(chan struct{})
126		go func() {
127			defer GinkgoRecover()
128			_, err := m.AcceptStream(ctx)
129			Expect(err).To(MatchError("context canceled"))
130			close(done)
131		}()
132		Consistently(done).ShouldNot(BeClosed())
133		cancel()
134		Eventually(done).Should(BeClosed())
135	})
136
137	It("unblocks AcceptStream when it is closed", func() {
138		testErr := errors.New("test error")
139		done := make(chan struct{})
140		go func() {
141			defer GinkgoRecover()
142			_, err := m.AcceptStream(context.Background())
143			Expect(err).To(MatchError(testErr))
144			close(done)
145		}()
146		Consistently(done).ShouldNot(BeClosed())
147		m.CloseWithError(testErr)
148		Eventually(done).Should(BeClosed())
149	})
150
151	It("errors AcceptStream immediately if it is closed", func() {
152		testErr := errors.New("test error")
153		m.CloseWithError(testErr)
154		_, err := m.AcceptStream(context.Background())
155		Expect(err).To(MatchError(testErr))
156	})
157
158	It("closes all streams when CloseWithError is called", func() {
159		str1, err := m.GetOrOpenStream(1)
160		Expect(err).ToNot(HaveOccurred())
161		str2, err := m.GetOrOpenStream(3)
162		Expect(err).ToNot(HaveOccurred())
163		testErr := errors.New("test err")
164		m.CloseWithError(testErr)
165		Expect(str1.(*mockGenericStream).closed).To(BeTrue())
166		Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr))
167		Expect(str2.(*mockGenericStream).closed).To(BeTrue())
168		Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr))
169	})
170
171	It("deletes streams", func() {
172		mockSender.EXPECT().queueControlFrame(gomock.Any())
173		_, err := m.GetOrOpenStream(1)
174		Expect(err).ToNot(HaveOccurred())
175		str, err := m.AcceptStream(context.Background())
176		Expect(err).ToNot(HaveOccurred())
177		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
178		Expect(m.DeleteStream(1)).To(Succeed())
179		str, err = m.GetOrOpenStream(1)
180		Expect(err).ToNot(HaveOccurred())
181		Expect(str).To(BeNil())
182	})
183
184	It("waits until a stream is accepted before actually deleting it", func() {
185		_, err := m.GetOrOpenStream(2)
186		Expect(err).ToNot(HaveOccurred())
187		Expect(m.DeleteStream(2)).To(Succeed())
188		str, err := m.AcceptStream(context.Background())
189		Expect(err).ToNot(HaveOccurred())
190		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
191		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
192		mockSender.EXPECT().queueControlFrame(gomock.Any())
193		str, err = m.AcceptStream(context.Background())
194		Expect(err).ToNot(HaveOccurred())
195		Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
196	})
197
198	It("doesn't return a stream queued for deleting from GetOrOpenStream", func() {
199		str, err := m.GetOrOpenStream(1)
200		Expect(err).ToNot(HaveOccurred())
201		Expect(str).ToNot(BeNil())
202		Expect(m.DeleteStream(1)).To(Succeed())
203		str, err = m.GetOrOpenStream(1)
204		Expect(err).ToNot(HaveOccurred())
205		Expect(str).To(BeNil())
206		// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
207		mockSender.EXPECT().queueControlFrame(gomock.Any())
208		str, err = m.AcceptStream(context.Background())
209		Expect(err).ToNot(HaveOccurred())
210		Expect(str).ToNot(BeNil())
211	})
212
213	It("errors when deleting a non-existing stream", func() {
214		err := m.DeleteStream(1337)
215		Expect(err).To(HaveOccurred())
216		Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown incoming stream 1337"))
217	})
218
219	It("sends MAX_STREAMS frames when streams are deleted", func() {
220		// open a bunch of streams
221		_, err := m.GetOrOpenStream(5)
222		Expect(err).ToNot(HaveOccurred())
223		// accept all streams
224		for i := 0; i < 5; i++ {
225			_, err := m.AcceptStream(context.Background())
226			Expect(err).ToNot(HaveOccurred())
227		}
228		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
229			Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1)))
230			checkFrameSerialization(f)
231		})
232		Expect(m.DeleteStream(3)).To(Succeed())
233		mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
234			Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2)))
235			checkFrameSerialization(f)
236		})
237		Expect(m.DeleteStream(4)).To(Succeed())
238	})
239
240	Context("using high stream limits", func() {
241		BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 })
242
243		It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() {
244			// open a bunch of streams
245			_, err := m.GetOrOpenStream(5)
246			Expect(err).ToNot(HaveOccurred())
247			// accept all streams
248			for i := 0; i < 5; i++ {
249				_, err := m.AcceptStream(context.Background())
250				Expect(err).ToNot(HaveOccurred())
251			}
252			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
253				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1))
254				checkFrameSerialization(f)
255			})
256			Expect(m.DeleteStream(4)).To(Succeed())
257			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
258				Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount))
259				checkFrameSerialization(f)
260			})
261			Expect(m.DeleteStream(3)).To(Succeed())
262			// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
263			Expect(m.DeleteStream(2)).To(Succeed())
264			Expect(m.DeleteStream(1)).To(Succeed())
265		})
266	})
267
268	Context("randomized tests", func() {
269		const num = 1000
270
271		BeforeEach(func() { maxNumStreams = num })
272
273		It("opens and accepts streams", func() {
274			rand.Seed(GinkgoRandomSeed())
275			ids := make([]protocol.StreamNum, num)
276			for i := 0; i < num; i++ {
277				ids[i] = protocol.StreamNum(i + 1)
278			}
279			rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] })
280
281			const timeout = 5 * time.Second
282			done := make(chan struct{}, 2)
283			go func() {
284				defer GinkgoRecover()
285				ctx, cancel := context.WithTimeout(context.Background(), timeout)
286				defer cancel()
287				for i := 0; i < num; i++ {
288					_, err := m.AcceptStream(ctx)
289					Expect(err).ToNot(HaveOccurred())
290				}
291				done <- struct{}{}
292			}()
293
294			go func() {
295				defer GinkgoRecover()
296				for i := 0; i < num; i++ {
297					_, err := m.GetOrOpenStream(ids[i])
298					Expect(err).ToNot(HaveOccurred())
299				}
300				done <- struct{}{}
301			}()
302
303			Eventually(done, timeout*3/2).Should(Receive())
304			Eventually(done, timeout*3/2).Should(Receive())
305		})
306	})
307})
308