1package quic 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "sync" 9 10 "github.com/lucas-clemente/quic-go/internal/flowcontrol" 11 "github.com/lucas-clemente/quic-go/internal/protocol" 12 "github.com/lucas-clemente/quic-go/internal/qerr" 13 "github.com/lucas-clemente/quic-go/internal/wire" 14) 15 16type streamError struct { 17 message string 18 nums []protocol.StreamNum 19} 20 21func (e streamError) Error() string { 22 return e.message 23} 24 25func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { 26 strError, ok := err.(streamError) 27 if !ok { 28 return err 29 } 30 ids := make([]interface{}, len(strError.nums)) 31 for i, num := range strError.nums { 32 ids[i] = num.StreamID(stype, pers) 33 } 34 return fmt.Errorf(strError.Error(), ids...) 35} 36 37type streamOpenErr struct{ error } 38 39var _ net.Error = &streamOpenErr{} 40 41func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } 42func (streamOpenErr) Timeout() bool { return false } 43 44// errTooManyOpenStreams is used internally by the outgoing streams maps. 45var errTooManyOpenStreams = errors.New("too many open streams") 46 47type streamsMap struct { 48 perspective protocol.Perspective 49 version protocol.VersionNumber 50 51 maxIncomingBidiStreams uint64 52 maxIncomingUniStreams uint64 53 54 sender streamSender 55 newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController 56 57 mutex sync.Mutex 58 outgoingBidiStreams *outgoingBidiStreamsMap 59 outgoingUniStreams *outgoingUniStreamsMap 60 incomingBidiStreams *incomingBidiStreamsMap 61 incomingUniStreams *incomingUniStreamsMap 62 reset bool 63} 64 65var _ streamManager = &streamsMap{} 66 67func newStreamsMap( 68 sender streamSender, 69 newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, 70 maxIncomingBidiStreams uint64, 71 maxIncomingUniStreams uint64, 72 perspective protocol.Perspective, 73 version protocol.VersionNumber, 74) streamManager { 75 m := &streamsMap{ 76 perspective: perspective, 77 newFlowController: newFlowController, 78 maxIncomingBidiStreams: maxIncomingBidiStreams, 79 maxIncomingUniStreams: maxIncomingUniStreams, 80 sender: sender, 81 version: version, 82 } 83 m.initMaps() 84 return m 85} 86 87func (m *streamsMap) initMaps() { 88 m.outgoingBidiStreams = newOutgoingBidiStreamsMap( 89 func(num protocol.StreamNum) streamI { 90 id := num.StreamID(protocol.StreamTypeBidi, m.perspective) 91 return newStream(id, m.sender, m.newFlowController(id), m.version) 92 }, 93 m.sender.queueControlFrame, 94 ) 95 m.incomingBidiStreams = newIncomingBidiStreamsMap( 96 func(num protocol.StreamNum) streamI { 97 id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) 98 return newStream(id, m.sender, m.newFlowController(id), m.version) 99 }, 100 m.maxIncomingBidiStreams, 101 m.sender.queueControlFrame, 102 ) 103 m.outgoingUniStreams = newOutgoingUniStreamsMap( 104 func(num protocol.StreamNum) sendStreamI { 105 id := num.StreamID(protocol.StreamTypeUni, m.perspective) 106 return newSendStream(id, m.sender, m.newFlowController(id), m.version) 107 }, 108 m.sender.queueControlFrame, 109 ) 110 m.incomingUniStreams = newIncomingUniStreamsMap( 111 func(num protocol.StreamNum) receiveStreamI { 112 id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) 113 return newReceiveStream(id, m.sender, m.newFlowController(id), m.version) 114 }, 115 m.maxIncomingUniStreams, 116 m.sender.queueControlFrame, 117 ) 118} 119 120func (m *streamsMap) OpenStream() (Stream, error) { 121 m.mutex.Lock() 122 reset := m.reset 123 mm := m.outgoingBidiStreams 124 m.mutex.Unlock() 125 if reset { 126 return nil, Err0RTTRejected 127 } 128 str, err := mm.OpenStream() 129 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 130} 131 132func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { 133 m.mutex.Lock() 134 reset := m.reset 135 mm := m.outgoingBidiStreams 136 m.mutex.Unlock() 137 if reset { 138 return nil, Err0RTTRejected 139 } 140 str, err := mm.OpenStreamSync(ctx) 141 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 142} 143 144func (m *streamsMap) OpenUniStream() (SendStream, error) { 145 m.mutex.Lock() 146 reset := m.reset 147 mm := m.outgoingUniStreams 148 m.mutex.Unlock() 149 if reset { 150 return nil, Err0RTTRejected 151 } 152 str, err := mm.OpenStream() 153 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 154} 155 156func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { 157 m.mutex.Lock() 158 reset := m.reset 159 mm := m.outgoingUniStreams 160 m.mutex.Unlock() 161 if reset { 162 return nil, Err0RTTRejected 163 } 164 str, err := mm.OpenStreamSync(ctx) 165 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 166} 167 168func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { 169 m.mutex.Lock() 170 reset := m.reset 171 mm := m.incomingBidiStreams 172 m.mutex.Unlock() 173 if reset { 174 return nil, Err0RTTRejected 175 } 176 str, err := mm.AcceptStream(ctx) 177 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) 178} 179 180func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { 181 m.mutex.Lock() 182 reset := m.reset 183 mm := m.incomingUniStreams 184 m.mutex.Unlock() 185 if reset { 186 return nil, Err0RTTRejected 187 } 188 str, err := mm.AcceptStream(ctx) 189 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) 190} 191 192func (m *streamsMap) DeleteStream(id protocol.StreamID) error { 193 num := id.StreamNum() 194 switch id.Type() { 195 case protocol.StreamTypeUni: 196 if id.InitiatedBy() == m.perspective { 197 return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) 198 } 199 return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) 200 case protocol.StreamTypeBidi: 201 if id.InitiatedBy() == m.perspective { 202 return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) 203 } 204 return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) 205 } 206 panic("") 207} 208 209func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { 210 str, err := m.getOrOpenReceiveStream(id) 211 if err != nil { 212 return nil, &qerr.TransportError{ 213 ErrorCode: qerr.StreamStateError, 214 ErrorMessage: err.Error(), 215 } 216 } 217 return str, nil 218} 219 220func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { 221 num := id.StreamNum() 222 switch id.Type() { 223 case protocol.StreamTypeUni: 224 if id.InitiatedBy() == m.perspective { 225 // an outgoing unidirectional stream is a send stream, not a receive stream 226 return nil, fmt.Errorf("peer attempted to open receive stream %d", id) 227 } 228 str, err := m.incomingUniStreams.GetOrOpenStream(num) 229 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 230 case protocol.StreamTypeBidi: 231 var str receiveStreamI 232 var err error 233 if id.InitiatedBy() == m.perspective { 234 str, err = m.outgoingBidiStreams.GetStream(num) 235 } else { 236 str, err = m.incomingBidiStreams.GetOrOpenStream(num) 237 } 238 return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) 239 } 240 panic("") 241} 242 243func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { 244 str, err := m.getOrOpenSendStream(id) 245 if err != nil { 246 return nil, &qerr.TransportError{ 247 ErrorCode: qerr.StreamStateError, 248 ErrorMessage: err.Error(), 249 } 250 } 251 return str, nil 252} 253 254func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { 255 num := id.StreamNum() 256 switch id.Type() { 257 case protocol.StreamTypeUni: 258 if id.InitiatedBy() == m.perspective { 259 str, err := m.outgoingUniStreams.GetStream(num) 260 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 261 } 262 // an incoming unidirectional stream is a receive stream, not a send stream 263 return nil, fmt.Errorf("peer attempted to open send stream %d", id) 264 case protocol.StreamTypeBidi: 265 var str sendStreamI 266 var err error 267 if id.InitiatedBy() == m.perspective { 268 str, err = m.outgoingBidiStreams.GetStream(num) 269 } else { 270 str, err = m.incomingBidiStreams.GetOrOpenStream(num) 271 } 272 return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) 273 } 274 panic("") 275} 276 277func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { 278 switch f.Type { 279 case protocol.StreamTypeUni: 280 m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) 281 case protocol.StreamTypeBidi: 282 m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) 283 } 284} 285 286func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { 287 m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) 288 m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) 289 m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) 290 m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) 291} 292 293func (m *streamsMap) CloseWithError(err error) { 294 m.outgoingBidiStreams.CloseWithError(err) 295 m.outgoingUniStreams.CloseWithError(err) 296 m.incomingBidiStreams.CloseWithError(err) 297 m.incomingUniStreams.CloseWithError(err) 298} 299 300// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are 301// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. 302// 2. reset to their initial state, such that we can immediately process new incoming stream data. 303// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, 304// until UseResetMaps() has been called. 305func (m *streamsMap) ResetFor0RTT() { 306 m.mutex.Lock() 307 defer m.mutex.Unlock() 308 m.reset = true 309 m.CloseWithError(Err0RTTRejected) 310 m.initMaps() 311} 312 313func (m *streamsMap) UseResetMaps() { 314 m.mutex.Lock() 315 m.reset = false 316 m.mutex.Unlock() 317} 318