1package quic
2
3import (
4	"errors"
5	"io"
6	"os"
7	"strconv"
8	"time"
9
10	"github.com/lucas-clemente/quic-go/internal/mocks"
11	"github.com/lucas-clemente/quic-go/internal/protocol"
12	"github.com/lucas-clemente/quic-go/internal/wire"
13	. "github.com/onsi/ginkgo"
14	. "github.com/onsi/gomega"
15	"github.com/onsi/gomega/gbytes"
16)
17
18// in the tests for the stream deadlines we set a deadline
19// and wait to make an assertion when Read / Write was unblocked
20// on the CIs, the timing is a lot less precise, so scale every duration by this factor
21func scaleDuration(t time.Duration) time.Duration {
22	scaleFactor := 1
23	if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
24		scaleFactor = f
25	}
26	Expect(scaleFactor).ToNot(BeZero())
27	return time.Duration(scaleFactor) * t
28}
29
30var _ = Describe("Stream", func() {
31	const streamID protocol.StreamID = 1337
32
33	var (
34		str            *stream
35		strWithTimeout io.ReadWriter // str wrapped with gbytes.Timeout{Reader,Writer}
36		mockFC         *mocks.MockStreamFlowController
37		mockSender     *MockStreamSender
38	)
39
40	BeforeEach(func() {
41		mockSender = NewMockStreamSender(mockCtrl)
42		mockFC = mocks.NewMockStreamFlowController(mockCtrl)
43		str = newStream(streamID, mockSender, mockFC, protocol.VersionWhatever)
44
45		timeout := scaleDuration(250 * time.Millisecond)
46		strWithTimeout = struct {
47			io.Reader
48			io.Writer
49		}{
50			gbytes.TimeoutReader(str, timeout),
51			gbytes.TimeoutWriter(str, timeout),
52		}
53	})
54
55	It("gets stream id", func() {
56		Expect(str.StreamID()).To(Equal(protocol.StreamID(1337)))
57	})
58
59	Context("deadlines", func() {
60		It("sets a write deadline, when SetDeadline is called", func() {
61			str.SetDeadline(time.Now().Add(-time.Second))
62			n, err := strWithTimeout.Write([]byte("foobar"))
63			Expect(err).To(MatchError(errDeadline))
64			Expect(n).To(BeZero())
65		})
66
67		It("sets a read deadline, when SetDeadline is called", func() {
68			mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
69			f := &wire.StreamFrame{Data: []byte("foobar")}
70			err := str.handleStreamFrame(f)
71			Expect(err).ToNot(HaveOccurred())
72			str.SetDeadline(time.Now().Add(-time.Second))
73			b := make([]byte, 6)
74			n, err := strWithTimeout.Read(b)
75			Expect(err).To(MatchError(errDeadline))
76			Expect(n).To(BeZero())
77		})
78	})
79
80	Context("completing", func() {
81		It("is not completed when only the receive side is completed", func() {
82			// don't EXPECT a call to mockSender.onStreamCompleted()
83			str.receiveStream.sender.onStreamCompleted(streamID)
84		})
85
86		It("is not completed when only the send side is completed", func() {
87			// don't EXPECT a call to mockSender.onStreamCompleted()
88			str.sendStream.sender.onStreamCompleted(streamID)
89		})
90
91		It("is completed when both sides are completed", func() {
92			mockSender.EXPECT().onStreamCompleted(streamID)
93			str.sendStream.sender.onStreamCompleted(streamID)
94			str.receiveStream.sender.onStreamCompleted(streamID)
95		})
96	})
97})
98
99var _ = Describe("Deadline Error", func() {
100	It("is a net.Error that wraps os.ErrDeadlineError", func() {
101		err := deadlineError{}
102		Expect(err.Temporary()).To(BeTrue())
103		Expect(err.Timeout()).To(BeTrue())
104		Expect(errors.Is(err, os.ErrDeadlineExceeded)).To(BeTrue())
105		Expect(errors.Unwrap(err)).To(Equal(os.ErrDeadlineExceeded))
106	})
107})
108