1package wire 2 3import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 9 "github.com/lucas-clemente/quic-go/internal/protocol" 10 "github.com/lucas-clemente/quic-go/internal/utils" 11 "github.com/lucas-clemente/quic-go/quicvarint" 12) 13 14// ErrInvalidReservedBits is returned when the reserved bits are incorrect. 15// When this error is returned, parsing continues, and an ExtendedHeader is returned. 16// This is necessary because we need to decrypt the packet in that case, 17// in order to avoid a timing side-channel. 18var ErrInvalidReservedBits = errors.New("invalid reserved bits") 19 20// ExtendedHeader is the header of a QUIC packet. 21type ExtendedHeader struct { 22 Header 23 24 typeByte byte 25 26 KeyPhase protocol.KeyPhaseBit 27 28 PacketNumberLen protocol.PacketNumberLen 29 PacketNumber protocol.PacketNumber 30 31 parsedLen protocol.ByteCount 32} 33 34func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { 35 startLen := b.Len() 36 // read the (now unencrypted) first byte 37 var err error 38 h.typeByte, err = b.ReadByte() 39 if err != nil { 40 return false, err 41 } 42 if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { 43 return false, err 44 } 45 var reservedBitsValid bool 46 if h.IsLongHeader { 47 reservedBitsValid, err = h.parseLongHeader(b, v) 48 } else { 49 reservedBitsValid, err = h.parseShortHeader(b, v) 50 } 51 if err != nil { 52 return false, err 53 } 54 h.parsedLen = protocol.ByteCount(startLen - b.Len()) 55 return reservedBitsValid, err 56} 57 58func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { 59 if err := h.readPacketNumber(b); err != nil { 60 return false, err 61 } 62 if h.typeByte&0xc != 0 { 63 return false, nil 64 } 65 return true, nil 66} 67 68func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { 69 h.KeyPhase = protocol.KeyPhaseZero 70 if h.typeByte&0x4 > 0 { 71 h.KeyPhase = protocol.KeyPhaseOne 72 } 73 74 if err := h.readPacketNumber(b); err != nil { 75 return false, err 76 } 77 if h.typeByte&0x18 != 0 { 78 return false, nil 79 } 80 return true, nil 81} 82 83func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { 84 h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 85 switch h.PacketNumberLen { 86 case protocol.PacketNumberLen1: 87 n, err := b.ReadByte() 88 if err != nil { 89 return err 90 } 91 h.PacketNumber = protocol.PacketNumber(n) 92 case protocol.PacketNumberLen2: 93 n, err := utils.BigEndian.ReadUint16(b) 94 if err != nil { 95 return err 96 } 97 h.PacketNumber = protocol.PacketNumber(n) 98 case protocol.PacketNumberLen3: 99 n, err := utils.BigEndian.ReadUint24(b) 100 if err != nil { 101 return err 102 } 103 h.PacketNumber = protocol.PacketNumber(n) 104 case protocol.PacketNumberLen4: 105 n, err := utils.BigEndian.ReadUint32(b) 106 if err != nil { 107 return err 108 } 109 h.PacketNumber = protocol.PacketNumber(n) 110 default: 111 return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) 112 } 113 return nil 114} 115 116// Write writes the Header. 117func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { 118 if h.DestConnectionID.Len() > protocol.MaxConnIDLen { 119 return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) 120 } 121 if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { 122 return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) 123 } 124 if h.IsLongHeader { 125 return h.writeLongHeader(b, ver) 126 } 127 return h.writeShortHeader(b, ver) 128} 129 130func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { 131 var packetType uint8 132 //nolint:exhaustive 133 switch h.Type { 134 case protocol.PacketTypeInitial: 135 packetType = 0x0 136 case protocol.PacketType0RTT: 137 packetType = 0x1 138 case protocol.PacketTypeHandshake: 139 packetType = 0x2 140 case protocol.PacketTypeRetry: 141 packetType = 0x3 142 } 143 firstByte := 0xc0 | packetType<<4 144 if h.Type != protocol.PacketTypeRetry { 145 // Retry packets don't have a packet number 146 firstByte |= uint8(h.PacketNumberLen - 1) 147 } 148 149 b.WriteByte(firstByte) 150 utils.BigEndian.WriteUint32(b, uint32(h.Version)) 151 b.WriteByte(uint8(h.DestConnectionID.Len())) 152 b.Write(h.DestConnectionID.Bytes()) 153 b.WriteByte(uint8(h.SrcConnectionID.Len())) 154 b.Write(h.SrcConnectionID.Bytes()) 155 156 //nolint:exhaustive 157 switch h.Type { 158 case protocol.PacketTypeRetry: 159 b.Write(h.Token) 160 return nil 161 case protocol.PacketTypeInitial: 162 quicvarint.Write(b, uint64(len(h.Token))) 163 b.Write(h.Token) 164 } 165 quicvarint.WriteWithLen(b, uint64(h.Length), 2) 166 return h.writePacketNumber(b) 167} 168 169func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { 170 typeByte := 0x40 | uint8(h.PacketNumberLen-1) 171 if h.KeyPhase == protocol.KeyPhaseOne { 172 typeByte |= byte(1 << 2) 173 } 174 175 b.WriteByte(typeByte) 176 b.Write(h.DestConnectionID.Bytes()) 177 return h.writePacketNumber(b) 178} 179 180func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { 181 switch h.PacketNumberLen { 182 case protocol.PacketNumberLen1: 183 b.WriteByte(uint8(h.PacketNumber)) 184 case protocol.PacketNumberLen2: 185 utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) 186 case protocol.PacketNumberLen3: 187 utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) 188 case protocol.PacketNumberLen4: 189 utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) 190 default: 191 return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) 192 } 193 return nil 194} 195 196// ParsedLen returns the number of bytes that were consumed when parsing the header 197func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { 198 return h.parsedLen 199} 200 201// GetLength determines the length of the Header. 202func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { 203 if h.IsLongHeader { 204 length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ 205 if h.Type == protocol.PacketTypeInitial { 206 length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) 207 } 208 return length 209 } 210 211 length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) 212 length += protocol.ByteCount(h.PacketNumberLen) 213 return length 214} 215 216// Log logs the Header 217func (h *ExtendedHeader) Log(logger utils.Logger) { 218 if h.IsLongHeader { 219 var token string 220 if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { 221 if len(h.Token) == 0 { 222 token = "Token: (empty), " 223 } else { 224 token = fmt.Sprintf("Token: %#x, ", h.Token) 225 } 226 if h.Type == protocol.PacketTypeRetry { 227 logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) 228 return 229 } 230 } 231 logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) 232 } else { 233 logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) 234 } 235} 236