1package quic 2 3import ( 4 "fmt" 5 6 "github.com/lucas-clemente/quic-go/internal/protocol" 7 "github.com/lucas-clemente/quic-go/internal/qerr" 8 "github.com/lucas-clemente/quic-go/internal/utils" 9 "github.com/lucas-clemente/quic-go/internal/wire" 10) 11 12type connIDGenerator struct { 13 connIDLen int 14 highestSeq uint64 15 16 activeSrcConnIDs map[uint64]protocol.ConnectionID 17 initialClientDestConnID protocol.ConnectionID 18 19 addConnectionID func(protocol.ConnectionID) 20 getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken 21 removeConnectionID func(protocol.ConnectionID) 22 retireConnectionID func(protocol.ConnectionID) 23 replaceWithClosed func(protocol.ConnectionID, packetHandler) 24 queueControlFrame func(wire.Frame) 25 26 version protocol.VersionNumber 27} 28 29func newConnIDGenerator( 30 initialConnectionID protocol.ConnectionID, 31 initialClientDestConnID protocol.ConnectionID, // nil for the client 32 addConnectionID func(protocol.ConnectionID), 33 getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, 34 removeConnectionID func(protocol.ConnectionID), 35 retireConnectionID func(protocol.ConnectionID), 36 replaceWithClosed func(protocol.ConnectionID, packetHandler), 37 queueControlFrame func(wire.Frame), 38 version protocol.VersionNumber, 39) *connIDGenerator { 40 m := &connIDGenerator{ 41 connIDLen: initialConnectionID.Len(), 42 activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), 43 addConnectionID: addConnectionID, 44 getStatelessResetToken: getStatelessResetToken, 45 removeConnectionID: removeConnectionID, 46 retireConnectionID: retireConnectionID, 47 replaceWithClosed: replaceWithClosed, 48 queueControlFrame: queueControlFrame, 49 version: version, 50 } 51 m.activeSrcConnIDs[0] = initialConnectionID 52 m.initialClientDestConnID = initialClientDestConnID 53 return m 54} 55 56func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { 57 if m.connIDLen == 0 { 58 return nil 59 } 60 // The active_connection_id_limit transport parameter is the number of 61 // connection IDs the peer will store. This limit includes the connection ID 62 // used during the handshake, and the one sent in the preferred_address 63 // transport parameter. 64 // We currently don't send the preferred_address transport parameter, 65 // so we can issue (limit - 1) connection IDs. 66 for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ { 67 if err := m.issueNewConnID(); err != nil { 68 return err 69 } 70 } 71 return nil 72} 73 74func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { 75 if seq > m.highestSeq { 76 return &qerr.TransportError{ 77 ErrorCode: qerr.ProtocolViolation, 78 ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), 79 } 80 } 81 connID, ok := m.activeSrcConnIDs[seq] 82 // We might already have deleted this connection ID, if this is a duplicate frame. 83 if !ok { 84 return nil 85 } 86 if connID.Equal(sentWithDestConnID) { 87 return &qerr.TransportError{ 88 ErrorCode: qerr.ProtocolViolation, 89 ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), 90 } 91 } 92 m.retireConnectionID(connID) 93 delete(m.activeSrcConnIDs, seq) 94 // Don't issue a replacement for the initial connection ID. 95 if seq == 0 { 96 return nil 97 } 98 return m.issueNewConnID() 99} 100 101func (m *connIDGenerator) issueNewConnID() error { 102 connID, err := protocol.GenerateConnectionID(m.connIDLen) 103 if err != nil { 104 return err 105 } 106 m.activeSrcConnIDs[m.highestSeq+1] = connID 107 m.addConnectionID(connID) 108 m.queueControlFrame(&wire.NewConnectionIDFrame{ 109 SequenceNumber: m.highestSeq + 1, 110 ConnectionID: connID, 111 StatelessResetToken: m.getStatelessResetToken(connID), 112 }) 113 m.highestSeq++ 114 return nil 115} 116 117func (m *connIDGenerator) SetHandshakeComplete() { 118 if m.initialClientDestConnID != nil { 119 m.retireConnectionID(m.initialClientDestConnID) 120 m.initialClientDestConnID = nil 121 } 122} 123 124func (m *connIDGenerator) RemoveAll() { 125 if m.initialClientDestConnID != nil { 126 m.removeConnectionID(m.initialClientDestConnID) 127 } 128 for _, connID := range m.activeSrcConnIDs { 129 m.removeConnectionID(connID) 130 } 131} 132 133func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { 134 if m.initialClientDestConnID != nil { 135 m.replaceWithClosed(m.initialClientDestConnID, handler) 136 } 137 for _, connID := range m.activeSrcConnIDs { 138 m.replaceWithClosed(connID, handler) 139 } 140} 141