1package quic 2 3import ( 4 "bytes" 5 "context" 6 "errors" 7 "math/rand" 8 "time" 9 10 "github.com/golang/mock/gomock" 11 "github.com/lucas-clemente/quic-go/internal/protocol" 12 "github.com/lucas-clemente/quic-go/internal/wire" 13 14 . "github.com/onsi/ginkgo" 15 . "github.com/onsi/gomega" 16) 17 18type mockGenericStream struct { 19 num protocol.StreamNum 20 21 closed bool 22 closeErr error 23} 24 25func (s *mockGenericStream) closeForShutdown(err error) { 26 s.closed = true 27 s.closeErr = err 28} 29 30var _ = Describe("Streams Map (incoming)", func() { 31 var ( 32 m *incomingItemsMap 33 newItemCounter int 34 mockSender *MockStreamSender 35 maxNumStreams uint64 36 ) 37 38 // check that the frame can be serialized and deserialized 39 checkFrameSerialization := func(f wire.Frame) { 40 b := &bytes.Buffer{} 41 ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) 42 frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) 43 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 44 Expect(f).To(Equal(frame)) 45 } 46 47 BeforeEach(func() { maxNumStreams = 5 }) 48 49 JustBeforeEach(func() { 50 newItemCounter = 0 51 mockSender = NewMockStreamSender(mockCtrl) 52 m = newIncomingItemsMap( 53 func(num protocol.StreamNum) item { 54 newItemCounter++ 55 return &mockGenericStream{num: num} 56 }, 57 maxNumStreams, 58 mockSender.queueControlFrame, 59 ) 60 }) 61 62 It("opens all streams up to the id on GetOrOpenStream", func() { 63 _, err := m.GetOrOpenStream(4) 64 Expect(err).ToNot(HaveOccurred()) 65 Expect(newItemCounter).To(Equal(4)) 66 }) 67 68 It("starts opening streams at the right position", func() { 69 // like the test above, but with 2 calls to GetOrOpenStream 70 _, err := m.GetOrOpenStream(2) 71 Expect(err).ToNot(HaveOccurred()) 72 Expect(newItemCounter).To(Equal(2)) 73 _, err = m.GetOrOpenStream(5) 74 Expect(err).ToNot(HaveOccurred()) 75 Expect(newItemCounter).To(Equal(5)) 76 }) 77 78 It("accepts streams in the right order", func() { 79 _, err := m.GetOrOpenStream(2) // open streams 1 and 2 80 Expect(err).ToNot(HaveOccurred()) 81 str, err := m.AcceptStream(context.Background()) 82 Expect(err).ToNot(HaveOccurred()) 83 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 84 str, err = m.AcceptStream(context.Background()) 85 Expect(err).ToNot(HaveOccurred()) 86 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) 87 }) 88 89 It("allows opening the maximum stream ID", func() { 90 str, err := m.GetOrOpenStream(1) 91 Expect(err).ToNot(HaveOccurred()) 92 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 93 }) 94 95 It("errors when trying to get a stream ID higher than the maximum", func() { 96 _, err := m.GetOrOpenStream(6) 97 Expect(err).To(HaveOccurred()) 98 Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) 99 }) 100 101 It("blocks AcceptStream until a new stream is available", func() { 102 strChan := make(chan item) 103 go func() { 104 defer GinkgoRecover() 105 str, err := m.AcceptStream(context.Background()) 106 Expect(err).ToNot(HaveOccurred()) 107 strChan <- str 108 }() 109 Consistently(strChan).ShouldNot(Receive()) 110 str, err := m.GetOrOpenStream(1) 111 Expect(err).ToNot(HaveOccurred()) 112 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 113 var acceptedStr item 114 Eventually(strChan).Should(Receive(&acceptedStr)) 115 Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 116 }) 117 118 It("unblocks AcceptStream when the context is canceled", func() { 119 ctx, cancel := context.WithCancel(context.Background()) 120 done := make(chan struct{}) 121 go func() { 122 defer GinkgoRecover() 123 _, err := m.AcceptStream(ctx) 124 Expect(err).To(MatchError("context canceled")) 125 close(done) 126 }() 127 Consistently(done).ShouldNot(BeClosed()) 128 cancel() 129 Eventually(done).Should(BeClosed()) 130 }) 131 132 It("unblocks AcceptStream when it is closed", func() { 133 testErr := errors.New("test error") 134 done := make(chan struct{}) 135 go func() { 136 defer GinkgoRecover() 137 _, err := m.AcceptStream(context.Background()) 138 Expect(err).To(MatchError(testErr)) 139 close(done) 140 }() 141 Consistently(done).ShouldNot(BeClosed()) 142 m.CloseWithError(testErr) 143 Eventually(done).Should(BeClosed()) 144 }) 145 146 It("errors AcceptStream immediately if it is closed", func() { 147 testErr := errors.New("test error") 148 m.CloseWithError(testErr) 149 _, err := m.AcceptStream(context.Background()) 150 Expect(err).To(MatchError(testErr)) 151 }) 152 153 It("closes all streams when CloseWithError is called", func() { 154 str1, err := m.GetOrOpenStream(1) 155 Expect(err).ToNot(HaveOccurred()) 156 str2, err := m.GetOrOpenStream(3) 157 Expect(err).ToNot(HaveOccurred()) 158 testErr := errors.New("test err") 159 m.CloseWithError(testErr) 160 Expect(str1.(*mockGenericStream).closed).To(BeTrue()) 161 Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) 162 Expect(str2.(*mockGenericStream).closed).To(BeTrue()) 163 Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) 164 }) 165 166 It("deletes streams", func() { 167 mockSender.EXPECT().queueControlFrame(gomock.Any()) 168 _, err := m.GetOrOpenStream(1) 169 Expect(err).ToNot(HaveOccurred()) 170 str, err := m.AcceptStream(context.Background()) 171 Expect(err).ToNot(HaveOccurred()) 172 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 173 Expect(m.DeleteStream(1)).To(Succeed()) 174 str, err = m.GetOrOpenStream(1) 175 Expect(err).ToNot(HaveOccurred()) 176 Expect(str).To(BeNil()) 177 }) 178 179 It("waits until a stream is accepted before actually deleting it", func() { 180 _, err := m.GetOrOpenStream(2) 181 Expect(err).ToNot(HaveOccurred()) 182 Expect(m.DeleteStream(2)).To(Succeed()) 183 str, err := m.AcceptStream(context.Background()) 184 Expect(err).ToNot(HaveOccurred()) 185 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 186 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 187 mockSender.EXPECT().queueControlFrame(gomock.Any()) 188 str, err = m.AcceptStream(context.Background()) 189 Expect(err).ToNot(HaveOccurred()) 190 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) 191 }) 192 193 It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { 194 str, err := m.GetOrOpenStream(1) 195 Expect(err).ToNot(HaveOccurred()) 196 Expect(str).ToNot(BeNil()) 197 Expect(m.DeleteStream(1)).To(Succeed()) 198 str, err = m.GetOrOpenStream(1) 199 Expect(err).ToNot(HaveOccurred()) 200 Expect(str).To(BeNil()) 201 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 202 mockSender.EXPECT().queueControlFrame(gomock.Any()) 203 str, err = m.AcceptStream(context.Background()) 204 Expect(err).ToNot(HaveOccurred()) 205 Expect(str).ToNot(BeNil()) 206 }) 207 208 It("errors when deleting a non-existing stream", func() { 209 err := m.DeleteStream(1337) 210 Expect(err).To(HaveOccurred()) 211 Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown incoming stream 1337")) 212 }) 213 214 It("sends MAX_STREAMS frames when streams are deleted", func() { 215 // open a bunch of streams 216 _, err := m.GetOrOpenStream(5) 217 Expect(err).ToNot(HaveOccurred()) 218 // accept all streams 219 for i := 0; i < 5; i++ { 220 _, err := m.AcceptStream(context.Background()) 221 Expect(err).ToNot(HaveOccurred()) 222 } 223 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 224 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) 225 checkFrameSerialization(f) 226 }) 227 Expect(m.DeleteStream(3)).To(Succeed()) 228 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 229 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) 230 checkFrameSerialization(f) 231 }) 232 Expect(m.DeleteStream(4)).To(Succeed()) 233 }) 234 235 Context("using high stream limits", func() { 236 BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) 237 238 It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { 239 // open a bunch of streams 240 _, err := m.GetOrOpenStream(5) 241 Expect(err).ToNot(HaveOccurred()) 242 // accept all streams 243 for i := 0; i < 5; i++ { 244 _, err := m.AcceptStream(context.Background()) 245 Expect(err).ToNot(HaveOccurred()) 246 } 247 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 248 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) 249 checkFrameSerialization(f) 250 }) 251 Expect(m.DeleteStream(4)).To(Succeed()) 252 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 253 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) 254 checkFrameSerialization(f) 255 }) 256 Expect(m.DeleteStream(3)).To(Succeed()) 257 // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent 258 Expect(m.DeleteStream(2)).To(Succeed()) 259 Expect(m.DeleteStream(1)).To(Succeed()) 260 }) 261 }) 262 263 Context("randomized tests", func() { 264 const num = 1000 265 266 BeforeEach(func() { maxNumStreams = num }) 267 268 It("opens and accepts streams", func() { 269 rand.Seed(GinkgoRandomSeed()) 270 ids := make([]protocol.StreamNum, num) 271 for i := 0; i < num; i++ { 272 ids[i] = protocol.StreamNum(i + 1) 273 } 274 rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) 275 276 const timeout = 5 * time.Second 277 done := make(chan struct{}, 2) 278 go func() { 279 defer GinkgoRecover() 280 ctx, cancel := context.WithTimeout(context.Background(), timeout) 281 defer cancel() 282 for i := 0; i < num; i++ { 283 _, err := m.AcceptStream(ctx) 284 Expect(err).ToNot(HaveOccurred()) 285 } 286 done <- struct{}{} 287 }() 288 289 go func() { 290 defer GinkgoRecover() 291 for i := 0; i < num; i++ { 292 _, err := m.GetOrOpenStream(ids[i]) 293 Expect(err).ToNot(HaveOccurred()) 294 } 295 done <- struct{}{} 296 }() 297 298 Eventually(done, timeout*3/2).Should(Receive()) 299 Eventually(done, timeout*3/2).Should(Receive()) 300 }) 301 }) 302}) 303