1package quic 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 9 "github.com/lucas-clemente/quic-go/internal/flowcontrol" 10 "github.com/lucas-clemente/quic-go/internal/protocol" 11 "github.com/lucas-clemente/quic-go/internal/qerr" 12 "github.com/lucas-clemente/quic-go/internal/wire" 13) 14 15type streamError struct { 16 message string 17 nums []protocol.StreamNum 18} 19 20func (e streamError) Error() string { 21 return e.message 22} 23 24func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { 25 strError, ok := err.(streamError) 26 if !ok { 27 return err 28 } 29 ids := make([]interface{}, len(strError.nums)) 30 for i, num := range strError.nums { 31 ids[i] = num.StreamID(stype, pers) 32 } 33 return fmt.Errorf(strError.Error(), ids...) 34} 35 36type streamOpenErr struct{ error } 37 38var _ net.Error = &streamOpenErr{} 39 40func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } 41func (streamOpenErr) Timeout() bool { return false } 42 43// errTooManyOpenStreams is used internally by the outgoing streams maps. 44var errTooManyOpenStreams = errors.New("too many open streams") 45 46type streamsMap struct { 47 perspective protocol.Perspective 48 49 sender streamSender 50 newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController 51 52 outgoingBidiStreams *outgoingBidiStreamsMap 53 outgoingUniStreams *outgoingUniStreamsMap 54 incomingBidiStreams *incomingBidiStreamsMap 55 incomingUniStreams *incomingUniStreamsMap 56} 57 58var _ streamManager = &streamsMap{} 59 60func newStreamsMap( 61 sender streamSender, 62 newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, 63 maxIncomingBidiStreams uint64, 64 maxIncomingUniStreams uint64, 65 perspective protocol.Perspective, 66 version protocol.VersionNumber, 67) streamManager { 68 m := &streamsMap{ 69 perspective: perspective, 70 newFlowController: newFlowController, 71 sender: sender, 72 } 73 m.outgoingBidiStreams = newOutgoingBidiStreamsMap( 74 func(num protocol.StreamNum) streamI { 75 id := num.StreamID(protocol.StreamTypeBidi, perspective) 76 return newStream(id, m.sender, m.newFlowController(id), version) 77 }, 78 sender.queueControlFrame, 79 ) 80 m.incomingBidiStreams = newIncomingBidiStreamsMap( 81 func(num protocol.StreamNum) streamI { 82 id := num.StreamID(protocol.StreamTypeBidi, perspective.Opposite()) 83 return newStream(id, m.sender, m.newFlowController(id), version) 84 }, 85 maxIncomingBidiStreams, 86 sender.queueControlFrame, 87 ) 88 m.outgoingUniStreams = newOutgoingUniStreamsMap( 89 func(num protocol.StreamNum) sendStreamI { 90 id := num.StreamID(protocol.StreamTypeUni, perspective) 91 return newSendStream(id, m.sender, m.newFlowController(id), version) 92 }, 93 sender.queueControlFrame, 94 ) 95 m.incomingUniStreams = newIncomingUniStreamsMap( 96 func(num protocol.StreamNum) receiveStreamI { 97 id := num.StreamID(protocol.StreamTypeUni, perspective.Opposite()) 98 return newReceiveStream(id, m.sender, m.newFlowController(id), version) 99 }, 100 maxIncomingUniStreams, 101 sender.queueControlFrame, 102 ) 103 return m 104} 105 106func (m *streamsMap) OpenStream() (Stream, error) { 107 str, err := m.outgoingBidiStreams.OpenStream() 108 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 109} 110 111func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { 112 str, err := m.outgoingBidiStreams.OpenStreamSync(ctx) 113 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 114} 115 116func (m *streamsMap) OpenUniStream() (SendStream, error) { 117 str, err := m.outgoingUniStreams.OpenStream() 118 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 119} 120 121func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { 122 str, err := m.outgoingUniStreams.OpenStreamSync(ctx) 123 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 124} 125 126func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { 127 str, err := m.incomingBidiStreams.AcceptStream(ctx) 128 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) 129} 130 131func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { 132 str, err := m.incomingUniStreams.AcceptStream(ctx) 133 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) 134} 135 136func (m *streamsMap) DeleteStream(id protocol.StreamID) error { 137 num := id.StreamNum() 138 switch id.Type() { 139 case protocol.StreamTypeUni: 140 if id.InitiatedBy() == m.perspective { 141 return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) 142 } 143 return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) 144 case protocol.StreamTypeBidi: 145 if id.InitiatedBy() == m.perspective { 146 return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) 147 } 148 return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) 149 } 150 panic("") 151} 152 153func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { 154 str, err := m.getOrOpenReceiveStream(id) 155 if err != nil { 156 return nil, qerr.NewError(qerr.StreamStateError, err.Error()) 157 } 158 return str, nil 159} 160 161func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { 162 num := id.StreamNum() 163 switch id.Type() { 164 case protocol.StreamTypeUni: 165 if id.InitiatedBy() == m.perspective { 166 // an outgoing unidirectional stream is a send stream, not a receive stream 167 return nil, fmt.Errorf("peer attempted to open receive stream %d", id) 168 } 169 str, err := m.incomingUniStreams.GetOrOpenStream(num) 170 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 171 case protocol.StreamTypeBidi: 172 var str receiveStreamI 173 var err error 174 if id.InitiatedBy() == m.perspective { 175 str, err = m.outgoingBidiStreams.GetStream(num) 176 } else { 177 str, err = m.incomingBidiStreams.GetOrOpenStream(num) 178 } 179 return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) 180 } 181 panic("") 182} 183 184func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { 185 str, err := m.getOrOpenSendStream(id) 186 if err != nil { 187 return nil, qerr.NewError(qerr.StreamStateError, err.Error()) 188 } 189 return str, nil 190} 191 192func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { 193 num := id.StreamNum() 194 switch id.Type() { 195 case protocol.StreamTypeUni: 196 if id.InitiatedBy() == m.perspective { 197 str, err := m.outgoingUniStreams.GetStream(num) 198 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 199 } 200 // an incoming unidirectional stream is a receive stream, not a send stream 201 return nil, fmt.Errorf("peer attempted to open send stream %d", id) 202 case protocol.StreamTypeBidi: 203 var str sendStreamI 204 var err error 205 if id.InitiatedBy() == m.perspective { 206 str, err = m.outgoingBidiStreams.GetStream(num) 207 } else { 208 str, err = m.incomingBidiStreams.GetOrOpenStream(num) 209 } 210 return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) 211 } 212 panic("") 213} 214 215func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error { 216 switch f.Type { 217 case protocol.StreamTypeUni: 218 m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) 219 case protocol.StreamTypeBidi: 220 m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) 221 } 222 return nil 223} 224 225func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) error { 226 if p.MaxBidiStreamNum > protocol.MaxStreamCount || 227 p.MaxUniStreamNum > protocol.MaxStreamCount { 228 return qerr.StreamLimitError 229 } 230 // Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open. 231 m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) 232 m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) 233 return nil 234} 235 236func (m *streamsMap) CloseWithError(err error) { 237 m.outgoingBidiStreams.CloseWithError(err) 238 m.outgoingUniStreams.CloseWithError(err) 239 m.incomingBidiStreams.CloseWithError(err) 240 m.incomingUniStreams.CloseWithError(err) 241} 242