1package quic
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"math/rand"
8	"sort"
9	"sync"
10	"time"
11
12	"github.com/golang/mock/gomock"
13	"github.com/lucas-clemente/quic-go/internal/protocol"
14	"github.com/lucas-clemente/quic-go/internal/wire"
15	. "github.com/onsi/ginkgo"
16	. "github.com/onsi/gomega"
17)
18
19var _ = Describe("Streams Map (outgoing)", func() {
20	var (
21		m          *outgoingItemsMap
22		newItem    func(num protocol.StreamNum) item
23		mockSender *MockStreamSender
24	)
25
26	// waitForEnqueued waits until there are n go routines waiting on OpenStreamSync()
27	waitForEnqueued := func(n int) {
28		Eventually(func() int {
29			m.mutex.Lock()
30			defer m.mutex.Unlock()
31			return len(m.openQueue)
32		}, 50*time.Millisecond, 100*time.Microsecond).Should(Equal(n))
33	}
34
35	BeforeEach(func() {
36		newItem = func(num protocol.StreamNum) item {
37			return &mockGenericStream{num: num}
38		}
39		mockSender = NewMockStreamSender(mockCtrl)
40		m = newOutgoingItemsMap(newItem, mockSender.queueControlFrame)
41	})
42
43	Context("no stream ID limit", func() {
44		BeforeEach(func() {
45			m.SetMaxStream(0xffffffff)
46		})
47
48		It("opens streams", func() {
49			str, err := m.OpenStream()
50			Expect(err).ToNot(HaveOccurred())
51			Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
52			str, err = m.OpenStream()
53			Expect(err).ToNot(HaveOccurred())
54			Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
55		})
56
57		It("doesn't open streams after it has been closed", func() {
58			testErr := errors.New("close")
59			m.CloseWithError(testErr)
60			_, err := m.OpenStream()
61			Expect(err).To(MatchError(testErr))
62		})
63
64		It("gets streams", func() {
65			_, err := m.OpenStream()
66			Expect(err).ToNot(HaveOccurred())
67			str, err := m.GetStream(1)
68			Expect(err).ToNot(HaveOccurred())
69			Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
70		})
71
72		It("errors when trying to get a stream that has not yet been opened", func() {
73			_, err := m.GetStream(1)
74			Expect(err).To(HaveOccurred())
75			Expect(err.(streamError).TestError()).To(MatchError("peer attempted to open stream 1"))
76		})
77
78		It("deletes streams", func() {
79			_, err := m.OpenStream()
80			Expect(err).ToNot(HaveOccurred())
81			Expect(m.DeleteStream(1)).To(Succeed())
82			Expect(err).ToNot(HaveOccurred())
83			str, err := m.GetStream(1)
84			Expect(err).ToNot(HaveOccurred())
85			Expect(str).To(BeNil())
86		})
87
88		It("errors when deleting a non-existing stream", func() {
89			err := m.DeleteStream(1337)
90			Expect(err).To(HaveOccurred())
91			Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1337"))
92		})
93
94		It("errors when deleting a stream twice", func() {
95			_, err := m.OpenStream() // opens firstNewStream
96			Expect(err).ToNot(HaveOccurred())
97			Expect(m.DeleteStream(1)).To(Succeed())
98			err = m.DeleteStream(1)
99			Expect(err).To(HaveOccurred())
100			Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1"))
101		})
102
103		It("closes all streams when CloseWithError is called", func() {
104			str1, err := m.OpenStream()
105			Expect(err).ToNot(HaveOccurred())
106			str2, err := m.OpenStream()
107			Expect(err).ToNot(HaveOccurred())
108			testErr := errors.New("test err")
109			m.CloseWithError(testErr)
110			Expect(str1.(*mockGenericStream).closed).To(BeTrue())
111			Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr))
112			Expect(str2.(*mockGenericStream).closed).To(BeTrue())
113			Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr))
114		})
115
116		It("updates the send window", func() {
117			str1, err := m.OpenStream()
118			Expect(err).ToNot(HaveOccurred())
119			str2, err := m.OpenStream()
120			Expect(err).ToNot(HaveOccurred())
121			m.UpdateSendWindow(1337)
122			Expect(str1.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337))
123			Expect(str2.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337))
124		})
125	})
126
127	Context("with stream ID limits", func() {
128		It("errors when no stream can be opened immediately", func() {
129			mockSender.EXPECT().queueControlFrame(gomock.Any())
130			_, err := m.OpenStream()
131			expectTooManyStreamsError(err)
132		})
133
134		It("returns immediately when called with a canceled context", func() {
135			ctx, cancel := context.WithCancel(context.Background())
136			cancel()
137			_, err := m.OpenStreamSync(ctx)
138			Expect(err).To(MatchError("context canceled"))
139		})
140
141		It("blocks until a stream can be opened synchronously", func() {
142			mockSender.EXPECT().queueControlFrame(gomock.Any())
143			done := make(chan struct{})
144			go func() {
145				defer GinkgoRecover()
146				str, err := m.OpenStreamSync(context.Background())
147				Expect(err).ToNot(HaveOccurred())
148				Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
149				close(done)
150			}()
151			waitForEnqueued(1)
152
153			m.SetMaxStream(1)
154			Eventually(done).Should(BeClosed())
155		})
156
157		It("unblocks when the context is canceled", func() {
158			mockSender.EXPECT().queueControlFrame(gomock.Any())
159			ctx, cancel := context.WithCancel(context.Background())
160			done := make(chan struct{})
161			go func() {
162				defer GinkgoRecover()
163				_, err := m.OpenStreamSync(ctx)
164				Expect(err).To(MatchError("context canceled"))
165				close(done)
166			}()
167			waitForEnqueued(1)
168
169			cancel()
170			Eventually(done).Should(BeClosed())
171
172			// make sure that the next stream opened is stream 1
173			m.SetMaxStream(1000)
174			str, err := m.OpenStream()
175			Expect(err).ToNot(HaveOccurred())
176			Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
177		})
178
179		It("opens streams in the right order", func() {
180			mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
181			done1 := make(chan struct{})
182			go func() {
183				defer GinkgoRecover()
184				str, err := m.OpenStreamSync(context.Background())
185				Expect(err).ToNot(HaveOccurred())
186				Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
187				close(done1)
188			}()
189			waitForEnqueued(1)
190
191			done2 := make(chan struct{})
192			go func() {
193				defer GinkgoRecover()
194				str, err := m.OpenStreamSync(context.Background())
195				Expect(err).ToNot(HaveOccurred())
196				Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
197				close(done2)
198			}()
199			waitForEnqueued(2)
200
201			m.SetMaxStream(1)
202			Eventually(done1).Should(BeClosed())
203			Consistently(done2).ShouldNot(BeClosed())
204			m.SetMaxStream(2)
205			Eventually(done2).Should(BeClosed())
206		})
207
208		It("opens streams in the right order, when one of the contexts is canceled", func() {
209			mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
210			done1 := make(chan struct{})
211			go func() {
212				defer GinkgoRecover()
213				str, err := m.OpenStreamSync(context.Background())
214				Expect(err).ToNot(HaveOccurred())
215				Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
216				close(done1)
217			}()
218			waitForEnqueued(1)
219
220			done2 := make(chan struct{})
221			ctx, cancel := context.WithCancel(context.Background())
222			go func() {
223				defer GinkgoRecover()
224				_, err := m.OpenStreamSync(ctx)
225				Expect(err).To(MatchError(context.Canceled))
226				close(done2)
227			}()
228			waitForEnqueued(2)
229
230			done3 := make(chan struct{})
231			go func() {
232				defer GinkgoRecover()
233				str, err := m.OpenStreamSync(context.Background())
234				Expect(err).ToNot(HaveOccurred())
235				Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
236				close(done3)
237			}()
238			waitForEnqueued(3)
239
240			cancel()
241			Eventually(done2).Should(BeClosed())
242			m.SetMaxStream(1000)
243			Eventually(done1).Should(BeClosed())
244			Eventually(done3).Should(BeClosed())
245		})
246
247		It("unblocks multiple OpenStreamSync calls at the same time", func() {
248			mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
249			done := make(chan struct{})
250			go func() {
251				defer GinkgoRecover()
252				_, err := m.OpenStreamSync(context.Background())
253				Expect(err).ToNot(HaveOccurred())
254				done <- struct{}{}
255			}()
256			go func() {
257				defer GinkgoRecover()
258				_, err := m.OpenStreamSync(context.Background())
259				Expect(err).ToNot(HaveOccurred())
260				done <- struct{}{}
261			}()
262			waitForEnqueued(2)
263			go func() {
264				defer GinkgoRecover()
265				_, err := m.OpenStreamSync(context.Background())
266				Expect(err).To(MatchError("test done"))
267				done <- struct{}{}
268			}()
269			waitForEnqueued(3)
270
271			m.SetMaxStream(2)
272			Eventually(done).Should(Receive())
273			Eventually(done).Should(Receive())
274			Consistently(done).ShouldNot(Receive())
275
276			m.CloseWithError(errors.New("test done"))
277			Eventually(done).Should(Receive())
278		})
279
280		It("returns an error for OpenStream while an OpenStreamSync call is blocking", func() {
281			mockSender.EXPECT().queueControlFrame(gomock.Any()).MaxTimes(2)
282			openedSync := make(chan struct{})
283			go func() {
284				defer GinkgoRecover()
285				str, err := m.OpenStreamSync(context.Background())
286				Expect(err).ToNot(HaveOccurred())
287				Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
288				close(openedSync)
289			}()
290			waitForEnqueued(1)
291
292			start := make(chan struct{})
293			openend := make(chan struct{})
294			go func() {
295				defer GinkgoRecover()
296				var hasStarted bool
297				for {
298					str, err := m.OpenStream()
299					if err == nil {
300						Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
301						close(openend)
302						return
303					}
304					expectTooManyStreamsError(err)
305					if !hasStarted {
306						close(start)
307						hasStarted = true
308					}
309				}
310			}()
311
312			Eventually(start).Should(BeClosed())
313			m.SetMaxStream(1)
314			Eventually(openedSync).Should(BeClosed())
315			Consistently(openend).ShouldNot(BeClosed())
316			m.SetMaxStream(2)
317			Eventually(openend).Should(BeClosed())
318		})
319
320		It("stops opening synchronously when it is closed", func() {
321			mockSender.EXPECT().queueControlFrame(gomock.Any())
322			testErr := errors.New("test error")
323			done := make(chan struct{})
324			go func() {
325				defer GinkgoRecover()
326				_, err := m.OpenStreamSync(context.Background())
327				Expect(err).To(MatchError(testErr))
328				close(done)
329			}()
330
331			Consistently(done).ShouldNot(BeClosed())
332			m.CloseWithError(testErr)
333			Eventually(done).Should(BeClosed())
334		})
335
336		It("doesn't reduce the stream limit", func() {
337			m.SetMaxStream(2)
338			m.SetMaxStream(1)
339			_, err := m.OpenStream()
340			Expect(err).ToNot(HaveOccurred())
341			str, err := m.OpenStream()
342			Expect(err).ToNot(HaveOccurred())
343			Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
344		})
345
346		It("queues a STREAMS_BLOCKED frame if no stream can be opened", func() {
347			m.SetMaxStream(6)
348			// open the 6 allowed streams
349			for i := 0; i < 6; i++ {
350				_, err := m.OpenStream()
351				Expect(err).ToNot(HaveOccurred())
352			}
353
354			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
355				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(6))
356			})
357			_, err := m.OpenStream()
358			Expect(err).To(HaveOccurred())
359			Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error()))
360		})
361
362		It("only sends one STREAMS_BLOCKED frame for one stream ID", func() {
363			m.SetMaxStream(1)
364			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
365				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1))
366			})
367			_, err := m.OpenStream()
368			Expect(err).ToNot(HaveOccurred())
369			// try to open a stream twice, but expect only one STREAMS_BLOCKED to be sent
370			_, err = m.OpenStream()
371			expectTooManyStreamsError(err)
372			_, err = m.OpenStream()
373			expectTooManyStreamsError(err)
374		})
375
376		It("queues a STREAMS_BLOCKED frame when there more streams waiting for OpenStreamSync than MAX_STREAMS allows", func() {
377			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
378				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(0))
379			})
380			done := make(chan struct{}, 2)
381			go func() {
382				defer GinkgoRecover()
383				_, err := m.OpenStreamSync(context.Background())
384				Expect(err).ToNot(HaveOccurred())
385				done <- struct{}{}
386			}()
387			go func() {
388				defer GinkgoRecover()
389				_, err := m.OpenStreamSync(context.Background())
390				Expect(err).ToNot(HaveOccurred())
391				done <- struct{}{}
392			}()
393			waitForEnqueued(2)
394
395			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
396				Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(1))
397			})
398			m.SetMaxStream(1)
399			Eventually(done).Should(Receive())
400			Consistently(done).ShouldNot(Receive())
401			m.SetMaxStream(2)
402			Eventually(done).Should(Receive())
403		})
404	})
405
406	Context("randomized tests", func() {
407		It("opens streams", func() {
408			rand.Seed(GinkgoRandomSeed())
409			const n = 100
410			fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n)
411
412			var blockedAt []protocol.StreamNum
413			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
414				blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit)
415			}).AnyTimes()
416			done := make(map[int]chan struct{})
417			for i := 1; i <= n; i++ {
418				c := make(chan struct{})
419				done[i] = c
420
421				go func(doneChan chan struct{}, id protocol.StreamNum) {
422					defer GinkgoRecover()
423					defer close(doneChan)
424					str, err := m.OpenStreamSync(context.Background())
425					Expect(err).ToNot(HaveOccurred())
426					Expect(str.(*mockGenericStream).num).To(Equal(id))
427				}(c, protocol.StreamNum(i))
428				waitForEnqueued(i)
429			}
430
431			var limit int
432			limits := []protocol.StreamNum{0}
433			for limit < n {
434				limit += rand.Intn(n/5) + 1
435				if limit <= n {
436					limits = append(limits, protocol.StreamNum(limit))
437				}
438				fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit)
439				m.SetMaxStream(protocol.StreamNum(limit))
440				for i := 1; i <= n; i++ {
441					if i <= limit {
442						Eventually(done[i]).Should(BeClosed())
443					} else {
444						Expect(done[i]).ToNot(BeClosed())
445					}
446				}
447				str, err := m.OpenStream()
448				if limit <= n {
449					Expect(err).To(HaveOccurred())
450					Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error()))
451				} else {
452					Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(n + 1)))
453				}
454			}
455			Expect(blockedAt).To(Equal(limits))
456		})
457
458		It("opens streams, when some of them are getting canceled", func() {
459			rand.Seed(GinkgoRandomSeed())
460			const n = 100
461			fmt.Fprintf(GinkgoWriter, "Opening %d streams concurrently.\n", n)
462
463			var blockedAt []protocol.StreamNum
464			mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
465				blockedAt = append(blockedAt, f.(*wire.StreamsBlockedFrame).StreamLimit)
466			}).AnyTimes()
467
468			ctx, cancel := context.WithCancel(context.Background())
469			streamsToCancel := make(map[protocol.StreamNum]struct{}) // used as a set
470			for i := 0; i < 10; i++ {
471				id := protocol.StreamNum(rand.Intn(n) + 1)
472				fmt.Fprintf(GinkgoWriter, "Canceling stream %d.\n", id)
473				streamsToCancel[id] = struct{}{}
474			}
475
476			streamWillBeCanceled := func(id protocol.StreamNum) bool {
477				_, ok := streamsToCancel[id]
478				return ok
479			}
480
481			var streamIDs []int
482			var mutex sync.Mutex
483			done := make(map[int]chan struct{})
484			for i := 1; i <= n; i++ {
485				c := make(chan struct{})
486				done[i] = c
487
488				go func(doneChan chan struct{}, id protocol.StreamNum) {
489					defer GinkgoRecover()
490					defer close(doneChan)
491					cont := context.Background()
492					if streamWillBeCanceled(id) {
493						cont = ctx
494					}
495					str, err := m.OpenStreamSync(cont)
496					if streamWillBeCanceled(id) {
497						Expect(err).To(MatchError(context.Canceled))
498						return
499					}
500					Expect(err).ToNot(HaveOccurred())
501					mutex.Lock()
502					streamIDs = append(streamIDs, int(str.(*mockGenericStream).num))
503					mutex.Unlock()
504				}(c, protocol.StreamNum(i))
505				waitForEnqueued(i)
506			}
507
508			cancel()
509			for id := range streamsToCancel {
510				Eventually(done[int(id)]).Should(BeClosed())
511			}
512			var limit int
513			numStreams := n - len(streamsToCancel)
514			var limits []protocol.StreamNum
515			for limit < numStreams {
516				limits = append(limits, protocol.StreamNum(limit))
517				limit += rand.Intn(n/5) + 1
518				fmt.Fprintf(GinkgoWriter, "Setting stream limit to %d.\n", limit)
519				m.SetMaxStream(protocol.StreamNum(limit))
520				l := limit
521				if l > numStreams {
522					l = numStreams
523				}
524				Eventually(func() int {
525					mutex.Lock()
526					defer mutex.Unlock()
527					return len(streamIDs)
528				}).Should(Equal(l))
529				// check that all stream IDs were used
530				Expect(streamIDs).To(HaveLen(l))
531				sort.Ints(streamIDs)
532				for i := 0; i < l; i++ {
533					Expect(streamIDs[i]).To(Equal(i + 1))
534				}
535			}
536			Expect(blockedAt).To(Equal(limits))
537		})
538	})
539})
540