1package quic
2
3import (
4	"fmt"
5
6	"github.com/lucas-clemente/quic-go/internal/protocol"
7	"github.com/lucas-clemente/quic-go/internal/qerr"
8	"github.com/lucas-clemente/quic-go/internal/utils"
9	"github.com/lucas-clemente/quic-go/internal/wire"
10)
11
12type connIDManager struct {
13	queue utils.NewConnectionIDList
14
15	handshakeComplete         bool
16	activeSequenceNumber      uint64
17	highestRetired            uint64
18	activeConnectionID        protocol.ConnectionID
19	activeStatelessResetToken *protocol.StatelessResetToken
20
21	// We change the connection ID after sending on average
22	// protocol.PacketsPerConnectionID packets. The actual value is randomized
23	// hide the packet loss rate from on-path observers.
24	rand                   utils.Rand
25	packetsSinceLastChange uint32
26	packetsPerConnectionID uint32
27
28	addStatelessResetToken    func(protocol.StatelessResetToken)
29	removeStatelessResetToken func(protocol.StatelessResetToken)
30	queueControlFrame         func(wire.Frame)
31}
32
33func newConnIDManager(
34	initialDestConnID protocol.ConnectionID,
35	addStatelessResetToken func(protocol.StatelessResetToken),
36	removeStatelessResetToken func(protocol.StatelessResetToken),
37	queueControlFrame func(wire.Frame),
38) *connIDManager {
39	return &connIDManager{
40		activeConnectionID:        initialDestConnID,
41		addStatelessResetToken:    addStatelessResetToken,
42		removeStatelessResetToken: removeStatelessResetToken,
43		queueControlFrame:         queueControlFrame,
44	}
45}
46
47func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
48	return h.addConnectionID(1, connID, resetToken)
49}
50
51func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
52	if err := h.add(f); err != nil {
53		return err
54	}
55	if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
56		return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
57	}
58	return nil
59}
60
61func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
62	// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
63	// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
64	if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
65		h.queueControlFrame(&wire.RetireConnectionIDFrame{
66			SequenceNumber: f.SequenceNumber,
67		})
68		return nil
69	}
70
71	// Retire elements in the queue.
72	// Doesn't retire the active connection ID.
73	if f.RetirePriorTo > h.highestRetired {
74		var next *utils.NewConnectionIDElement
75		for el := h.queue.Front(); el != nil; el = next {
76			if el.Value.SequenceNumber >= f.RetirePriorTo {
77				break
78			}
79			next = el.Next()
80			h.queueControlFrame(&wire.RetireConnectionIDFrame{
81				SequenceNumber: el.Value.SequenceNumber,
82			})
83			h.queue.Remove(el)
84		}
85		h.highestRetired = f.RetirePriorTo
86	}
87
88	if f.SequenceNumber == h.activeSequenceNumber {
89		return nil
90	}
91
92	if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
93		return err
94	}
95
96	// Retire the active connection ID, if necessary.
97	if h.activeSequenceNumber < f.RetirePriorTo {
98		// The queue is guaranteed to have at least one element at this point.
99		h.updateConnectionID()
100	}
101	return nil
102}
103
104func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
105	// insert a new element at the end
106	if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
107		h.queue.PushBack(utils.NewConnectionID{
108			SequenceNumber:      seq,
109			ConnectionID:        connID,
110			StatelessResetToken: resetToken,
111		})
112		return nil
113	}
114	// insert a new element somewhere in the middle
115	for el := h.queue.Front(); el != nil; el = el.Next() {
116		if el.Value.SequenceNumber == seq {
117			if !el.Value.ConnectionID.Equal(connID) {
118				return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
119			}
120			if el.Value.StatelessResetToken != resetToken {
121				return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
122			}
123			break
124		}
125		if el.Value.SequenceNumber > seq {
126			h.queue.InsertBefore(utils.NewConnectionID{
127				SequenceNumber:      seq,
128				ConnectionID:        connID,
129				StatelessResetToken: resetToken,
130			}, el)
131			break
132		}
133	}
134	return nil
135}
136
137func (h *connIDManager) updateConnectionID() {
138	h.queueControlFrame(&wire.RetireConnectionIDFrame{
139		SequenceNumber: h.activeSequenceNumber,
140	})
141	h.highestRetired = utils.MaxUint64(h.highestRetired, h.activeSequenceNumber)
142	if h.activeStatelessResetToken != nil {
143		h.removeStatelessResetToken(*h.activeStatelessResetToken)
144	}
145
146	front := h.queue.Remove(h.queue.Front())
147	h.activeSequenceNumber = front.SequenceNumber
148	h.activeConnectionID = front.ConnectionID
149	h.activeStatelessResetToken = &front.StatelessResetToken
150	h.packetsSinceLastChange = 0
151	h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
152	h.addStatelessResetToken(*h.activeStatelessResetToken)
153}
154
155func (h *connIDManager) Close() {
156	if h.activeStatelessResetToken != nil {
157		h.removeStatelessResetToken(*h.activeStatelessResetToken)
158	}
159}
160
161// is called when the server performs a Retry
162// and when the server changes the connection ID in the first Initial sent
163func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
164	if h.activeSequenceNumber != 0 {
165		panic("expected first connection ID to have sequence number 0")
166	}
167	h.activeConnectionID = newConnID
168}
169
170// is called when the server provides a stateless reset token in the transport parameters
171func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
172	if h.activeSequenceNumber != 0 {
173		panic("expected first connection ID to have sequence number 0")
174	}
175	h.activeStatelessResetToken = &token
176	h.addStatelessResetToken(token)
177}
178
179func (h *connIDManager) SentPacket() {
180	h.packetsSinceLastChange++
181}
182
183func (h *connIDManager) shouldUpdateConnID() bool {
184	if !h.handshakeComplete {
185		return false
186	}
187	// initiate the first change as early as possible (after handshake completion)
188	if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
189		return true
190	}
191	// For later changes, only change if
192	// 1. The queue of connection IDs is filled more than 50%.
193	// 2. We sent at least PacketsPerConnectionID packets
194	return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
195		h.packetsSinceLastChange >= h.packetsPerConnectionID
196}
197
198func (h *connIDManager) Get() protocol.ConnectionID {
199	if h.shouldUpdateConnID() {
200		h.updateConnectionID()
201	}
202	return h.activeConnectionID
203}
204
205func (h *connIDManager) SetHandshakeComplete() {
206	h.handshakeComplete = true
207}
208