1package wire 2 3import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 9 "github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/protocol" 10 "github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/utils" 11) 12 13// ErrInvalidReservedBits is returned when the reserved bits are incorrect. 14// When this error is returned, parsing continues, and an ExtendedHeader is returned. 15// This is necessary because we need to decrypt the packet in that case, 16// in order to avoid a timing side-channel. 17var ErrInvalidReservedBits = errors.New("invalid reserved bits") 18 19// ExtendedHeader is the header of a QUIC packet. 20type ExtendedHeader struct { 21 Header 22 23 typeByte byte 24 25 KeyPhase protocol.KeyPhaseBit 26 27 PacketNumberLen protocol.PacketNumberLen 28 PacketNumber protocol.PacketNumber 29} 30 31func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { 32 // read the (now unencrypted) first byte 33 var err error 34 h.typeByte, err = b.ReadByte() 35 if err != nil { 36 return nil, err 37 } 38 if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil { 39 return nil, err 40 } 41 if h.IsLongHeader { 42 return h.parseLongHeader(b, v) 43 } 44 return h.parseShortHeader(b, v) 45} 46 47func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) { 48 if err := h.readPacketNumber(b); err != nil { 49 return nil, err 50 } 51 var err error 52 if h.typeByte&0xc != 0 { 53 err = ErrInvalidReservedBits 54 } 55 return h, err 56} 57 58func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) { 59 h.KeyPhase = protocol.KeyPhaseZero 60 if h.typeByte&0x4 > 0 { 61 h.KeyPhase = protocol.KeyPhaseOne 62 } 63 64 if err := h.readPacketNumber(b); err != nil { 65 return nil, err 66 } 67 var err error 68 if h.typeByte&0x18 != 0 { 69 err = ErrInvalidReservedBits 70 } 71 return h, err 72} 73 74func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { 75 h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 76 switch h.PacketNumberLen { 77 case protocol.PacketNumberLen1: 78 n, err := b.ReadByte() 79 if err != nil { 80 return err 81 } 82 h.PacketNumber = protocol.PacketNumber(n) 83 case protocol.PacketNumberLen2: 84 n, err := utils.BigEndian.ReadUint16(b) 85 if err != nil { 86 return err 87 } 88 h.PacketNumber = protocol.PacketNumber(n) 89 case protocol.PacketNumberLen3: 90 n, err := utils.BigEndian.ReadUint24(b) 91 if err != nil { 92 return err 93 } 94 h.PacketNumber = protocol.PacketNumber(n) 95 case protocol.PacketNumberLen4: 96 n, err := utils.BigEndian.ReadUint32(b) 97 if err != nil { 98 return err 99 } 100 h.PacketNumber = protocol.PacketNumber(n) 101 default: 102 return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) 103 } 104 return nil 105} 106 107// Write writes the Header. 108func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { 109 if h.DestConnectionID.Len() > protocol.MaxConnIDLen { 110 return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) 111 } 112 if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { 113 return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) 114 } 115 if h.OrigDestConnectionID.Len() > protocol.MaxConnIDLen { 116 return fmt.Errorf("invalid connection ID length: %d bytes", h.OrigDestConnectionID.Len()) 117 } 118 if h.IsLongHeader { 119 return h.writeLongHeader(b, ver) 120 } 121 return h.writeShortHeader(b, ver) 122} 123 124func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { 125 var packetType uint8 126 switch h.Type { 127 case protocol.PacketTypeInitial: 128 packetType = 0x0 129 case protocol.PacketType0RTT: 130 packetType = 0x1 131 case protocol.PacketTypeHandshake: 132 packetType = 0x2 133 case protocol.PacketTypeRetry: 134 packetType = 0x3 135 } 136 firstByte := 0xc0 | packetType<<4 137 if h.Type != protocol.PacketTypeRetry { 138 // Retry packets don't have a packet number 139 firstByte |= uint8(h.PacketNumberLen - 1) 140 } 141 142 b.WriteByte(firstByte) 143 utils.BigEndian.WriteUint32(b, uint32(h.Version)) 144 b.WriteByte(uint8(h.DestConnectionID.Len())) 145 b.Write(h.DestConnectionID.Bytes()) 146 b.WriteByte(uint8(h.SrcConnectionID.Len())) 147 b.Write(h.SrcConnectionID.Bytes()) 148 149 switch h.Type { 150 case protocol.PacketTypeRetry: 151 b.WriteByte(uint8(h.OrigDestConnectionID.Len())) 152 b.Write(h.OrigDestConnectionID.Bytes()) 153 b.Write(h.Token) 154 return nil 155 case protocol.PacketTypeInitial: 156 utils.WriteVarInt(b, uint64(len(h.Token))) 157 b.Write(h.Token) 158 } 159 160 utils.WriteVarInt(b, uint64(h.Length)) 161 return h.writePacketNumber(b) 162} 163 164func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { 165 typeByte := 0x40 | uint8(h.PacketNumberLen-1) 166 if h.KeyPhase == protocol.KeyPhaseOne { 167 typeByte |= byte(1 << 2) 168 } 169 170 b.WriteByte(typeByte) 171 b.Write(h.DestConnectionID.Bytes()) 172 return h.writePacketNumber(b) 173} 174 175func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { 176 switch h.PacketNumberLen { 177 case protocol.PacketNumberLen1: 178 b.WriteByte(uint8(h.PacketNumber)) 179 case protocol.PacketNumberLen2: 180 utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) 181 case protocol.PacketNumberLen3: 182 utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) 183 case protocol.PacketNumberLen4: 184 utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) 185 default: 186 return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) 187 } 188 return nil 189} 190 191// GetLength determines the length of the Header. 192func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { 193 if h.IsLongHeader { 194 length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length)) 195 if h.Type == protocol.PacketTypeInitial { 196 length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) 197 } 198 return length 199 } 200 201 length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) 202 length += protocol.ByteCount(h.PacketNumberLen) 203 return length 204} 205 206// Log logs the Header 207func (h *ExtendedHeader) Log(logger utils.Logger) { 208 if h.IsLongHeader { 209 var token string 210 if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { 211 if len(h.Token) == 0 { 212 token = "Token: (empty), " 213 } else { 214 token = fmt.Sprintf("Token: %#x, ", h.Token) 215 } 216 if h.Type == protocol.PacketTypeRetry { 217 logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version) 218 return 219 } 220 } 221 logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) 222 } else { 223 logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) 224 } 225} 226