1package quic
2
3import (
4	"fmt"
5
6	"github.com/lucas-clemente/quic-go/internal/protocol"
7	"github.com/lucas-clemente/quic-go/internal/qerr"
8	"github.com/lucas-clemente/quic-go/internal/wire"
9
10	. "github.com/onsi/ginkgo"
11	. "github.com/onsi/gomega"
12)
13
14var _ = Describe("Connection ID Generator", func() {
15	var (
16		addedConnIDs       []protocol.ConnectionID
17		retiredConnIDs     []protocol.ConnectionID
18		removedConnIDs     []protocol.ConnectionID
19		replacedWithClosed map[string]packetHandler
20		queuedFrames       []wire.Frame
21		g                  *connIDGenerator
22	)
23	initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
24	initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe}
25
26	connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
27		return protocol.StatelessResetToken{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]}
28	}
29
30	BeforeEach(func() {
31		addedConnIDs = nil
32		retiredConnIDs = nil
33		removedConnIDs = nil
34		queuedFrames = nil
35		replacedWithClosed = make(map[string]packetHandler)
36		g = newConnIDGenerator(
37			initialConnID,
38			initialClientDestConnID,
39			func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
40			connIDToToken,
41			func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
42			func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
43			func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
44			func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
45			protocol.VersionDraft29,
46		)
47	})
48
49	It("issues new connection IDs", func() {
50		Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
51		Expect(retiredConnIDs).To(BeEmpty())
52		Expect(addedConnIDs).To(HaveLen(3))
53		for i := 0; i < len(addedConnIDs)-1; i++ {
54			Expect(addedConnIDs[i]).ToNot(Equal(addedConnIDs[i+1]))
55		}
56		Expect(queuedFrames).To(HaveLen(3))
57		for i := 0; i < 3; i++ {
58			f := queuedFrames[i]
59			Expect(f).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
60			nf := f.(*wire.NewConnectionIDFrame)
61			Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1))
62			Expect(nf.ConnectionID.Len()).To(Equal(7))
63			Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID)))
64		}
65	})
66
67	It("limits the number of connection IDs that it issues", func() {
68		Expect(g.SetMaxActiveConnIDs(9999999)).To(Succeed())
69		Expect(retiredConnIDs).To(BeEmpty())
70		Expect(addedConnIDs).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1))
71		Expect(queuedFrames).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1))
72	})
73
74	// SetMaxActiveConnIDs is called twice when we dialing a 0-RTT connection:
75	// once for the restored from the old connections, once when we receive the transport parameters
76	Context("dealing with 0-RTT", func() {
77		It("doesn't issue new connection IDs when SetMaxActiveConnIDs is called with the same value", func() {
78			Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
79			Expect(queuedFrames).To(HaveLen(3))
80			queuedFrames = nil
81			Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
82			Expect(queuedFrames).To(BeEmpty())
83		})
84
85		It("issues more connection IDs if the server allows a higher limit on the resumed connection", func() {
86			Expect(g.SetMaxActiveConnIDs(3)).To(Succeed())
87			Expect(queuedFrames).To(HaveLen(2))
88			queuedFrames = nil
89			Expect(g.SetMaxActiveConnIDs(6)).To(Succeed())
90			Expect(queuedFrames).To(HaveLen(3))
91		})
92
93		It("issues more connection IDs if the server allows a higher limit on the resumed connection, when connection IDs were retired in between", func() {
94			Expect(g.SetMaxActiveConnIDs(3)).To(Succeed())
95			Expect(queuedFrames).To(HaveLen(2))
96			queuedFrames = nil
97			g.Retire(1, protocol.ConnectionID{})
98			Expect(queuedFrames).To(HaveLen(1))
99			queuedFrames = nil
100			Expect(g.SetMaxActiveConnIDs(6)).To(Succeed())
101			Expect(queuedFrames).To(HaveLen(3))
102		})
103	})
104
105	It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() {
106		Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError(&qerr.TransportError{
107			ErrorCode:    qerr.ProtocolViolation,
108			ErrorMessage: "retired connection ID 1 (highest issued: 0)",
109		}))
110	})
111
112	It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() {
113		Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
114		Expect(queuedFrames).ToNot(BeEmpty())
115		Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
116		f := queuedFrames[0].(*wire.NewConnectionIDFrame)
117		Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(&qerr.TransportError{
118			ErrorCode:    qerr.ProtocolViolation,
119			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID),
120		}))
121	})
122
123	It("issues new connection IDs, when old ones are retired", func() {
124		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
125		queuedFrames = nil
126		Expect(retiredConnIDs).To(BeEmpty())
127		Expect(g.Retire(3, protocol.ConnectionID{})).To(Succeed())
128		Expect(queuedFrames).To(HaveLen(1))
129		Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
130		nf := queuedFrames[0].(*wire.NewConnectionIDFrame)
131		Expect(nf.SequenceNumber).To(BeEquivalentTo(5))
132		Expect(nf.ConnectionID.Len()).To(Equal(7))
133	})
134
135	It("retires the initial connection ID", func() {
136		Expect(g.Retire(0, protocol.ConnectionID{})).To(Succeed())
137		Expect(removedConnIDs).To(BeEmpty())
138		Expect(retiredConnIDs).To(HaveLen(1))
139		Expect(retiredConnIDs[0]).To(Equal(initialConnID))
140		Expect(addedConnIDs).To(BeEmpty())
141	})
142
143	It("handles duplicate retirements", func() {
144		Expect(g.SetMaxActiveConnIDs(11)).To(Succeed())
145		queuedFrames = nil
146		Expect(retiredConnIDs).To(BeEmpty())
147		Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed())
148		Expect(retiredConnIDs).To(HaveLen(1))
149		Expect(queuedFrames).To(HaveLen(1))
150		Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed())
151		Expect(retiredConnIDs).To(HaveLen(1))
152		Expect(queuedFrames).To(HaveLen(1))
153	})
154
155	It("retires the client's initial destination connection ID when the handshake completes", func() {
156		g.SetHandshakeComplete()
157		Expect(retiredConnIDs).To(HaveLen(1))
158		Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID))
159	})
160
161	It("removes all connection IDs", func() {
162		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
163		Expect(queuedFrames).To(HaveLen(4))
164		g.RemoveAll()
165		Expect(removedConnIDs).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
166		Expect(removedConnIDs).To(ContainElement(initialConnID))
167		Expect(removedConnIDs).To(ContainElement(initialClientDestConnID))
168		for _, f := range queuedFrames {
169			nf := f.(*wire.NewConnectionIDFrame)
170			Expect(removedConnIDs).To(ContainElement(nf.ConnectionID))
171		}
172	})
173
174	It("replaces with a closed session for all connection IDs", func() {
175		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
176		Expect(queuedFrames).To(HaveLen(4))
177		sess := NewMockPacketHandler(mockCtrl)
178		g.ReplaceWithClosed(sess)
179		Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
180		Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess))
181		Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
182		for _, f := range queuedFrames {
183			nf := f.(*wire.NewConnectionIDFrame)
184			Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess))
185		}
186	})
187})
188