1package quic 2 3import ( 4 "bytes" 5 "context" 6 "errors" 7 "math/rand" 8 "time" 9 10 "github.com/golang/mock/gomock" 11 "github.com/lucas-clemente/quic-go/internal/protocol" 12 "github.com/lucas-clemente/quic-go/internal/wire" 13 14 . "github.com/onsi/ginkgo" 15 . "github.com/onsi/gomega" 16) 17 18type mockGenericStream struct { 19 num protocol.StreamNum 20 21 closed bool 22 closeErr error 23 sendWindow protocol.ByteCount 24} 25 26func (s *mockGenericStream) closeForShutdown(err error) { 27 s.closed = true 28 s.closeErr = err 29} 30 31func (s *mockGenericStream) updateSendWindow(limit protocol.ByteCount) { 32 s.sendWindow = limit 33} 34 35var _ = Describe("Streams Map (incoming)", func() { 36 var ( 37 m *incomingItemsMap 38 newItemCounter int 39 mockSender *MockStreamSender 40 maxNumStreams uint64 41 ) 42 43 // check that the frame can be serialized and deserialized 44 checkFrameSerialization := func(f wire.Frame) { 45 b := &bytes.Buffer{} 46 ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) 47 frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) 48 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 49 Expect(f).To(Equal(frame)) 50 } 51 52 BeforeEach(func() { maxNumStreams = 5 }) 53 54 JustBeforeEach(func() { 55 newItemCounter = 0 56 mockSender = NewMockStreamSender(mockCtrl) 57 m = newIncomingItemsMap( 58 func(num protocol.StreamNum) item { 59 newItemCounter++ 60 return &mockGenericStream{num: num} 61 }, 62 maxNumStreams, 63 mockSender.queueControlFrame, 64 ) 65 }) 66 67 It("opens all streams up to the id on GetOrOpenStream", func() { 68 _, err := m.GetOrOpenStream(4) 69 Expect(err).ToNot(HaveOccurred()) 70 Expect(newItemCounter).To(Equal(4)) 71 }) 72 73 It("starts opening streams at the right position", func() { 74 // like the test above, but with 2 calls to GetOrOpenStream 75 _, err := m.GetOrOpenStream(2) 76 Expect(err).ToNot(HaveOccurred()) 77 Expect(newItemCounter).To(Equal(2)) 78 _, err = m.GetOrOpenStream(5) 79 Expect(err).ToNot(HaveOccurred()) 80 Expect(newItemCounter).To(Equal(5)) 81 }) 82 83 It("accepts streams in the right order", func() { 84 _, err := m.GetOrOpenStream(2) // open streams 1 and 2 85 Expect(err).ToNot(HaveOccurred()) 86 str, err := m.AcceptStream(context.Background()) 87 Expect(err).ToNot(HaveOccurred()) 88 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 89 str, err = m.AcceptStream(context.Background()) 90 Expect(err).ToNot(HaveOccurred()) 91 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) 92 }) 93 94 It("allows opening the maximum stream ID", func() { 95 str, err := m.GetOrOpenStream(1) 96 Expect(err).ToNot(HaveOccurred()) 97 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 98 }) 99 100 It("errors when trying to get a stream ID higher than the maximum", func() { 101 _, err := m.GetOrOpenStream(6) 102 Expect(err).To(HaveOccurred()) 103 Expect(err.(streamError).TestError()).To(MatchError("peer tried to open stream 6 (current limit: 5)")) 104 }) 105 106 It("blocks AcceptStream until a new stream is available", func() { 107 strChan := make(chan item) 108 go func() { 109 defer GinkgoRecover() 110 str, err := m.AcceptStream(context.Background()) 111 Expect(err).ToNot(HaveOccurred()) 112 strChan <- str 113 }() 114 Consistently(strChan).ShouldNot(Receive()) 115 str, err := m.GetOrOpenStream(1) 116 Expect(err).ToNot(HaveOccurred()) 117 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 118 var acceptedStr item 119 Eventually(strChan).Should(Receive(&acceptedStr)) 120 Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 121 }) 122 123 It("unblocks AcceptStream when the context is canceled", func() { 124 ctx, cancel := context.WithCancel(context.Background()) 125 done := make(chan struct{}) 126 go func() { 127 defer GinkgoRecover() 128 _, err := m.AcceptStream(ctx) 129 Expect(err).To(MatchError("context canceled")) 130 close(done) 131 }() 132 Consistently(done).ShouldNot(BeClosed()) 133 cancel() 134 Eventually(done).Should(BeClosed()) 135 }) 136 137 It("unblocks AcceptStream when it is closed", func() { 138 testErr := errors.New("test error") 139 done := make(chan struct{}) 140 go func() { 141 defer GinkgoRecover() 142 _, err := m.AcceptStream(context.Background()) 143 Expect(err).To(MatchError(testErr)) 144 close(done) 145 }() 146 Consistently(done).ShouldNot(BeClosed()) 147 m.CloseWithError(testErr) 148 Eventually(done).Should(BeClosed()) 149 }) 150 151 It("errors AcceptStream immediately if it is closed", func() { 152 testErr := errors.New("test error") 153 m.CloseWithError(testErr) 154 _, err := m.AcceptStream(context.Background()) 155 Expect(err).To(MatchError(testErr)) 156 }) 157 158 It("closes all streams when CloseWithError is called", func() { 159 str1, err := m.GetOrOpenStream(1) 160 Expect(err).ToNot(HaveOccurred()) 161 str2, err := m.GetOrOpenStream(3) 162 Expect(err).ToNot(HaveOccurred()) 163 testErr := errors.New("test err") 164 m.CloseWithError(testErr) 165 Expect(str1.(*mockGenericStream).closed).To(BeTrue()) 166 Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) 167 Expect(str2.(*mockGenericStream).closed).To(BeTrue()) 168 Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) 169 }) 170 171 It("deletes streams", func() { 172 mockSender.EXPECT().queueControlFrame(gomock.Any()) 173 _, err := m.GetOrOpenStream(1) 174 Expect(err).ToNot(HaveOccurred()) 175 str, err := m.AcceptStream(context.Background()) 176 Expect(err).ToNot(HaveOccurred()) 177 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 178 Expect(m.DeleteStream(1)).To(Succeed()) 179 str, err = m.GetOrOpenStream(1) 180 Expect(err).ToNot(HaveOccurred()) 181 Expect(str).To(BeNil()) 182 }) 183 184 It("waits until a stream is accepted before actually deleting it", func() { 185 _, err := m.GetOrOpenStream(2) 186 Expect(err).ToNot(HaveOccurred()) 187 Expect(m.DeleteStream(2)).To(Succeed()) 188 str, err := m.AcceptStream(context.Background()) 189 Expect(err).ToNot(HaveOccurred()) 190 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) 191 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 192 mockSender.EXPECT().queueControlFrame(gomock.Any()) 193 str, err = m.AcceptStream(context.Background()) 194 Expect(err).ToNot(HaveOccurred()) 195 Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) 196 }) 197 198 It("doesn't return a stream queued for deleting from GetOrOpenStream", func() { 199 str, err := m.GetOrOpenStream(1) 200 Expect(err).ToNot(HaveOccurred()) 201 Expect(str).ToNot(BeNil()) 202 Expect(m.DeleteStream(1)).To(Succeed()) 203 str, err = m.GetOrOpenStream(1) 204 Expect(err).ToNot(HaveOccurred()) 205 Expect(str).To(BeNil()) 206 // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued 207 mockSender.EXPECT().queueControlFrame(gomock.Any()) 208 str, err = m.AcceptStream(context.Background()) 209 Expect(err).ToNot(HaveOccurred()) 210 Expect(str).ToNot(BeNil()) 211 }) 212 213 It("errors when deleting a non-existing stream", func() { 214 err := m.DeleteStream(1337) 215 Expect(err).To(HaveOccurred()) 216 Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown incoming stream 1337")) 217 }) 218 219 It("sends MAX_STREAMS frames when streams are deleted", func() { 220 // open a bunch of streams 221 _, err := m.GetOrOpenStream(5) 222 Expect(err).ToNot(HaveOccurred()) 223 // accept all streams 224 for i := 0; i < 5; i++ { 225 _, err := m.AcceptStream(context.Background()) 226 Expect(err).ToNot(HaveOccurred()) 227 } 228 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 229 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) 230 checkFrameSerialization(f) 231 }) 232 Expect(m.DeleteStream(3)).To(Succeed()) 233 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 234 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) 235 checkFrameSerialization(f) 236 }) 237 Expect(m.DeleteStream(4)).To(Succeed()) 238 }) 239 240 Context("using high stream limits", func() { 241 BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) 242 243 It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { 244 // open a bunch of streams 245 _, err := m.GetOrOpenStream(5) 246 Expect(err).ToNot(HaveOccurred()) 247 // accept all streams 248 for i := 0; i < 5; i++ { 249 _, err := m.AcceptStream(context.Background()) 250 Expect(err).ToNot(HaveOccurred()) 251 } 252 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 253 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) 254 checkFrameSerialization(f) 255 }) 256 Expect(m.DeleteStream(4)).To(Succeed()) 257 mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { 258 Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) 259 checkFrameSerialization(f) 260 }) 261 Expect(m.DeleteStream(3)).To(Succeed()) 262 // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent 263 Expect(m.DeleteStream(2)).To(Succeed()) 264 Expect(m.DeleteStream(1)).To(Succeed()) 265 }) 266 }) 267 268 Context("randomized tests", func() { 269 const num = 1000 270 271 BeforeEach(func() { maxNumStreams = num }) 272 273 It("opens and accepts streams", func() { 274 rand.Seed(GinkgoRandomSeed()) 275 ids := make([]protocol.StreamNum, num) 276 for i := 0; i < num; i++ { 277 ids[i] = protocol.StreamNum(i + 1) 278 } 279 rand.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) 280 281 const timeout = 5 * time.Second 282 done := make(chan struct{}, 2) 283 go func() { 284 defer GinkgoRecover() 285 ctx, cancel := context.WithTimeout(context.Background(), timeout) 286 defer cancel() 287 for i := 0; i < num; i++ { 288 _, err := m.AcceptStream(ctx) 289 Expect(err).ToNot(HaveOccurred()) 290 } 291 done <- struct{}{} 292 }() 293 294 go func() { 295 defer GinkgoRecover() 296 for i := 0; i < num; i++ { 297 _, err := m.GetOrOpenStream(ids[i]) 298 Expect(err).ToNot(HaveOccurred()) 299 } 300 done <- struct{}{} 301 }() 302 303 Eventually(done, timeout*3/2).Should(Receive()) 304 Eventually(done, timeout*3/2).Should(Receive()) 305 }) 306 }) 307}) 308