1package quic
2
3import (
4	"fmt"
5
6	"github.com/lucas-clemente/quic-go/internal/protocol"
7	"github.com/lucas-clemente/quic-go/internal/qerr"
8	"github.com/lucas-clemente/quic-go/internal/utils"
9	"github.com/lucas-clemente/quic-go/internal/wire"
10)
11
12type connIDGenerator struct {
13	connIDLen  int
14	highestSeq uint64
15
16	activeSrcConnIDs        map[uint64]protocol.ConnectionID
17	initialClientDestConnID protocol.ConnectionID
18
19	addConnectionID        func(protocol.ConnectionID)
20	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
21	removeConnectionID     func(protocol.ConnectionID)
22	retireConnectionID     func(protocol.ConnectionID)
23	replaceWithClosed      func(protocol.ConnectionID, packetHandler)
24	queueControlFrame      func(wire.Frame)
25
26	version protocol.VersionNumber
27}
28
29func newConnIDGenerator(
30	initialConnectionID protocol.ConnectionID,
31	initialClientDestConnID protocol.ConnectionID, // nil for the client
32	addConnectionID func(protocol.ConnectionID),
33	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
34	removeConnectionID func(protocol.ConnectionID),
35	retireConnectionID func(protocol.ConnectionID),
36	replaceWithClosed func(protocol.ConnectionID, packetHandler),
37	queueControlFrame func(wire.Frame),
38	version protocol.VersionNumber,
39) *connIDGenerator {
40	m := &connIDGenerator{
41		connIDLen:              initialConnectionID.Len(),
42		activeSrcConnIDs:       make(map[uint64]protocol.ConnectionID),
43		addConnectionID:        addConnectionID,
44		getStatelessResetToken: getStatelessResetToken,
45		removeConnectionID:     removeConnectionID,
46		retireConnectionID:     retireConnectionID,
47		replaceWithClosed:      replaceWithClosed,
48		queueControlFrame:      queueControlFrame,
49		version:                version,
50	}
51	m.activeSrcConnIDs[0] = initialConnectionID
52	m.initialClientDestConnID = initialClientDestConnID
53	return m
54}
55
56func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
57	if m.connIDLen == 0 {
58		return nil
59	}
60	// The active_connection_id_limit transport parameter is the number of
61	// connection IDs the peer will store. This limit includes the connection ID
62	// used during the handshake, and the one sent in the preferred_address
63	// transport parameter.
64	// We currently don't send the preferred_address transport parameter,
65	// so we can issue (limit - 1) connection IDs.
66	for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ {
67		if err := m.issueNewConnID(); err != nil {
68			return err
69		}
70	}
71	return nil
72}
73
74func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
75	if seq > m.highestSeq {
76		return &qerr.TransportError{
77			ErrorCode:    qerr.ProtocolViolation,
78			ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
79		}
80	}
81	connID, ok := m.activeSrcConnIDs[seq]
82	// We might already have deleted this connection ID, if this is a duplicate frame.
83	if !ok {
84		return nil
85	}
86	if connID.Equal(sentWithDestConnID) {
87		return &qerr.TransportError{
88			ErrorCode:    qerr.ProtocolViolation,
89			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
90		}
91	}
92	m.retireConnectionID(connID)
93	delete(m.activeSrcConnIDs, seq)
94	// Don't issue a replacement for the initial connection ID.
95	if seq == 0 {
96		return nil
97	}
98	return m.issueNewConnID()
99}
100
101func (m *connIDGenerator) issueNewConnID() error {
102	connID, err := protocol.GenerateConnectionID(m.connIDLen)
103	if err != nil {
104		return err
105	}
106	m.activeSrcConnIDs[m.highestSeq+1] = connID
107	m.addConnectionID(connID)
108	m.queueControlFrame(&wire.NewConnectionIDFrame{
109		SequenceNumber:      m.highestSeq + 1,
110		ConnectionID:        connID,
111		StatelessResetToken: m.getStatelessResetToken(connID),
112	})
113	m.highestSeq++
114	return nil
115}
116
117func (m *connIDGenerator) SetHandshakeComplete() {
118	if m.initialClientDestConnID != nil {
119		m.retireConnectionID(m.initialClientDestConnID)
120		m.initialClientDestConnID = nil
121	}
122}
123
124func (m *connIDGenerator) RemoveAll() {
125	if m.initialClientDestConnID != nil {
126		m.removeConnectionID(m.initialClientDestConnID)
127	}
128	for _, connID := range m.activeSrcConnIDs {
129		m.removeConnectionID(connID)
130	}
131}
132
133func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
134	if m.initialClientDestConnID != nil {
135		m.replaceWithClosed(m.initialClientDestConnID, handler)
136	}
137	for _, connID := range m.activeSrcConnIDs {
138		m.replaceWithClosed(connID, handler)
139	}
140}
141