1package quic
2
3import (
4	"context"
5	"fmt"
6	"sync"
7	"time"
8
9	"github.com/lucas-clemente/quic-go/internal/ackhandler"
10
11	"github.com/lucas-clemente/quic-go/internal/flowcontrol"
12	"github.com/lucas-clemente/quic-go/internal/protocol"
13	"github.com/lucas-clemente/quic-go/internal/utils"
14	"github.com/lucas-clemente/quic-go/internal/wire"
15)
16
17type sendStreamI interface {
18	SendStream
19	handleStopSendingFrame(*wire.StopSendingFrame)
20	hasData() bool
21	popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool)
22	closeForShutdown(error)
23	handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
24}
25
26type sendStream struct {
27	mutex sync.Mutex
28
29	numOutstandingFrames int64
30	retransmissionQueue  []*wire.StreamFrame
31
32	ctx       context.Context
33	ctxCancel context.CancelFunc
34
35	streamID protocol.StreamID
36	sender   streamSender
37
38	writeOffset protocol.ByteCount
39
40	cancelWriteErr      error
41	closeForShutdownErr error
42
43	closedForShutdown bool // set when CloseForShutdown() is called
44	finishedWriting   bool // set once Close() is called
45	canceledWrite     bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
46	finSent           bool // set when a STREAM_FRAME with FIN bit has been sent
47	completed         bool // set when this stream has been reported to the streamSender as completed
48
49	dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
50	nextFrame      *wire.StreamFrame
51
52	writeChan chan struct{}
53	deadline  time.Time
54
55	flowController flowcontrol.StreamFlowController
56
57	version protocol.VersionNumber
58}
59
60var (
61	_ SendStream  = &sendStream{}
62	_ sendStreamI = &sendStream{}
63)
64
65func newSendStream(
66	streamID protocol.StreamID,
67	sender streamSender,
68	flowController flowcontrol.StreamFlowController,
69	version protocol.VersionNumber,
70) *sendStream {
71	s := &sendStream{
72		streamID:       streamID,
73		sender:         sender,
74		flowController: flowController,
75		writeChan:      make(chan struct{}, 1),
76		version:        version,
77	}
78	s.ctx, s.ctxCancel = context.WithCancel(context.Background())
79	return s
80}
81
82func (s *sendStream) StreamID() protocol.StreamID {
83	return s.streamID // same for receiveStream and sendStream
84}
85
86func (s *sendStream) Write(p []byte) (int, error) {
87	s.mutex.Lock()
88	defer s.mutex.Unlock()
89
90	if s.finishedWriting {
91		return 0, fmt.Errorf("write on closed stream %d", s.streamID)
92	}
93	if s.canceledWrite {
94		return 0, s.cancelWriteErr
95	}
96	if s.closeForShutdownErr != nil {
97		return 0, s.closeForShutdownErr
98	}
99	if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
100		return 0, errDeadline
101	}
102	if len(p) == 0 {
103		return 0, nil
104	}
105
106	s.dataForWriting = p
107
108	var (
109		deadlineTimer  *utils.Timer
110		bytesWritten   int
111		notifiedSender bool
112	)
113	for {
114		var copied bool
115		var deadline time.Time
116		// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
117		// which can the be popped the next time we assemble a packet.
118		// This allows us to return Write() when all data but x bytes have been sent out.
119		// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
120		// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
121		if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
122			if s.nextFrame == nil {
123				f := wire.GetStreamFrame()
124				f.Offset = s.writeOffset
125				f.StreamID = s.streamID
126				f.DataLenPresent = true
127				f.Data = f.Data[:len(s.dataForWriting)]
128				copy(f.Data, s.dataForWriting)
129				s.nextFrame = f
130			} else {
131				l := len(s.nextFrame.Data)
132				s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
133				copy(s.nextFrame.Data[l:], s.dataForWriting)
134			}
135			s.dataForWriting = nil
136			bytesWritten = len(p)
137			copied = true
138		} else {
139			bytesWritten = len(p) - len(s.dataForWriting)
140			deadline = s.deadline
141			if !deadline.IsZero() {
142				if !time.Now().Before(deadline) {
143					s.dataForWriting = nil
144					return bytesWritten, errDeadline
145				}
146				if deadlineTimer == nil {
147					deadlineTimer = utils.NewTimer()
148					defer deadlineTimer.Stop()
149				}
150				deadlineTimer.Reset(deadline)
151			}
152			if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
153				break
154			}
155		}
156
157		s.mutex.Unlock()
158		if !notifiedSender {
159			s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
160			notifiedSender = true
161		}
162		if copied {
163			s.mutex.Lock()
164			break
165		}
166		if deadline.IsZero() {
167			<-s.writeChan
168		} else {
169			select {
170			case <-s.writeChan:
171			case <-deadlineTimer.Chan():
172				deadlineTimer.SetRead()
173			}
174		}
175		s.mutex.Lock()
176	}
177
178	if bytesWritten == len(p) {
179		return bytesWritten, nil
180	}
181	if s.closeForShutdownErr != nil {
182		return bytesWritten, s.closeForShutdownErr
183	} else if s.cancelWriteErr != nil {
184		return bytesWritten, s.cancelWriteErr
185	}
186	return bytesWritten, nil
187}
188
189func (s *sendStream) canBufferStreamFrame() bool {
190	var l protocol.ByteCount
191	if s.nextFrame != nil {
192		l = s.nextFrame.DataLen()
193	}
194	return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxReceivePacketSize
195}
196
197// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
198// maxBytes is the maximum length this frame (including frame header) will have.
199func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) {
200	s.mutex.Lock()
201	f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes)
202	if f != nil {
203		s.numOutstandingFrames++
204	}
205	s.mutex.Unlock()
206
207	if f == nil {
208		return nil, hasMoreData
209	}
210	return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData
211}
212
213func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
214	if s.canceledWrite || s.closeForShutdownErr != nil {
215		return nil, false
216	}
217
218	if len(s.retransmissionQueue) > 0 {
219		f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes)
220		if f != nil || hasMoreRetransmissions {
221			if f == nil {
222				return nil, true
223			}
224			// We always claim that we have more data to send.
225			// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
226			return f, true
227		}
228	}
229
230	if len(s.dataForWriting) == 0 && s.nextFrame == nil {
231		if s.finishedWriting && !s.finSent {
232			s.finSent = true
233			return &wire.StreamFrame{
234				StreamID:       s.streamID,
235				Offset:         s.writeOffset,
236				DataLenPresent: true,
237				Fin:            true,
238			}, false
239		}
240		return nil, false
241	}
242
243	sendWindow := s.flowController.SendWindowSize()
244	if sendWindow == 0 {
245		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
246			s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
247				StreamID:          s.streamID,
248				MaximumStreamData: offset,
249			})
250			return nil, false
251		}
252		return nil, true
253	}
254
255	f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow)
256	if dataLen := f.DataLen(); dataLen > 0 {
257		s.writeOffset += f.DataLen()
258		s.flowController.AddBytesSent(f.DataLen())
259	}
260	f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
261	if f.Fin {
262		s.finSent = true
263	}
264	return f, hasMoreData
265}
266
267func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) {
268	if s.nextFrame != nil {
269		nextFrame := s.nextFrame
270		s.nextFrame = nil
271
272		maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version))
273		if nextFrame.DataLen() > maxDataLen {
274			s.nextFrame = wire.GetStreamFrame()
275			s.nextFrame.StreamID = s.streamID
276			s.nextFrame.Offset = s.writeOffset + maxDataLen
277			s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
278			s.nextFrame.DataLenPresent = true
279			copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
280			nextFrame.Data = nextFrame.Data[:maxDataLen]
281		} else {
282			s.signalWrite()
283		}
284		return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
285	}
286
287	f := wire.GetStreamFrame()
288	f.Fin = false
289	f.StreamID = s.streamID
290	f.Offset = s.writeOffset
291	f.DataLenPresent = true
292	f.Data = f.Data[:0]
293
294	hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow)
295	if len(f.Data) == 0 && !f.Fin {
296		f.PutBack()
297		return nil, hasMoreData
298	}
299	return f, hasMoreData
300}
301
302func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool {
303	maxDataLen := f.MaxDataLen(maxBytes, s.version)
304	if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
305		return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
306	}
307	s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow))
308
309	return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
310}
311
312func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) {
313	f := s.retransmissionQueue[0]
314	newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version)
315	if needsSplit {
316		return newFrame, true
317	}
318	s.retransmissionQueue = s.retransmissionQueue[1:]
319	return f, len(s.retransmissionQueue) > 0
320}
321
322func (s *sendStream) hasData() bool {
323	s.mutex.Lock()
324	hasData := len(s.dataForWriting) > 0
325	s.mutex.Unlock()
326	return hasData
327}
328
329func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
330	if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
331		f.Data = f.Data[:len(s.dataForWriting)]
332		copy(f.Data, s.dataForWriting)
333		s.dataForWriting = nil
334		s.signalWrite()
335		return
336	}
337	f.Data = f.Data[:maxBytes]
338	copy(f.Data, s.dataForWriting)
339	s.dataForWriting = s.dataForWriting[maxBytes:]
340	if s.canBufferStreamFrame() {
341		s.signalWrite()
342	}
343}
344
345func (s *sendStream) frameAcked(f wire.Frame) {
346	f.(*wire.StreamFrame).PutBack()
347
348	s.mutex.Lock()
349	s.numOutstandingFrames--
350	if s.numOutstandingFrames < 0 {
351		panic("numOutStandingFrames negative")
352	}
353	newlyCompleted := s.isNewlyCompleted()
354	s.mutex.Unlock()
355
356	if newlyCompleted {
357		s.sender.onStreamCompleted(s.streamID)
358	}
359}
360
361func (s *sendStream) isNewlyCompleted() bool {
362	completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
363	if completed && !s.completed {
364		s.completed = true
365		return true
366	}
367	return false
368}
369
370func (s *sendStream) queueRetransmission(f wire.Frame) {
371	sf := f.(*wire.StreamFrame)
372	sf.DataLenPresent = true
373	s.mutex.Lock()
374	s.retransmissionQueue = append(s.retransmissionQueue, sf)
375	s.numOutstandingFrames--
376	if s.numOutstandingFrames < 0 {
377		panic("numOutStandingFrames negative")
378	}
379	s.mutex.Unlock()
380
381	s.sender.onHasStreamData(s.streamID)
382}
383
384func (s *sendStream) Close() error {
385	s.mutex.Lock()
386	if s.closedForShutdown {
387		s.mutex.Unlock()
388		return nil
389	}
390	if s.canceledWrite {
391		s.mutex.Unlock()
392		return fmt.Errorf("close called for canceled stream %d", s.streamID)
393	}
394	s.ctxCancel()
395	s.finishedWriting = true
396	s.mutex.Unlock()
397
398	s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
399	return nil
400}
401
402func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) {
403	s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
404}
405
406// must be called after locking the mutex
407func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) {
408	s.mutex.Lock()
409	if s.canceledWrite {
410		s.mutex.Unlock()
411		return
412	}
413	s.ctxCancel()
414	s.canceledWrite = true
415	s.cancelWriteErr = writeErr
416	newlyCompleted := s.isNewlyCompleted()
417	s.mutex.Unlock()
418
419	s.signalWrite()
420	s.sender.queueControlFrame(&wire.ResetStreamFrame{
421		StreamID:  s.streamID,
422		FinalSize: s.writeOffset,
423		ErrorCode: errorCode,
424	})
425	if newlyCompleted {
426		s.sender.onStreamCompleted(s.streamID)
427	}
428}
429
430func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
431	s.mutex.Lock()
432	hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
433	s.mutex.Unlock()
434
435	s.flowController.UpdateSendWindow(frame.MaximumStreamData)
436	if hasStreamData {
437		s.sender.onHasStreamData(s.streamID)
438	}
439}
440
441func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
442	writeErr := streamCanceledError{
443		errorCode: frame.ErrorCode,
444		error:     fmt.Errorf("stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
445	}
446	s.cancelWriteImpl(frame.ErrorCode, writeErr)
447}
448
449func (s *sendStream) Context() context.Context {
450	return s.ctx
451}
452
453func (s *sendStream) SetWriteDeadline(t time.Time) error {
454	s.mutex.Lock()
455	s.deadline = t
456	s.mutex.Unlock()
457	s.signalWrite()
458	return nil
459}
460
461// CloseForShutdown closes a stream abruptly.
462// It makes Write unblock (and return the error) immediately.
463// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
464func (s *sendStream) closeForShutdown(err error) {
465	s.mutex.Lock()
466	s.ctxCancel()
467	s.closedForShutdown = true
468	s.closeForShutdownErr = err
469	s.mutex.Unlock()
470	s.signalWrite()
471}
472
473// signalWrite performs a non-blocking send on the writeChan
474func (s *sendStream) signalWrite() {
475	select {
476	case s.writeChan <- struct{}{}:
477	default:
478	}
479}
480