1package wire
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"io"
8
9	"github.com/lucas-clemente/quic-go/internal/protocol"
10	"github.com/lucas-clemente/quic-go/internal/utils"
11	"github.com/lucas-clemente/quic-go/quicvarint"
12)
13
14// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
15// When this error is returned, parsing continues, and an ExtendedHeader is returned.
16// This is necessary because we need to decrypt the packet in that case,
17// in order to avoid a timing side-channel.
18var ErrInvalidReservedBits = errors.New("invalid reserved bits")
19
20// ExtendedHeader is the header of a QUIC packet.
21type ExtendedHeader struct {
22	Header
23
24	typeByte byte
25
26	KeyPhase protocol.KeyPhaseBit
27
28	PacketNumberLen protocol.PacketNumberLen
29	PacketNumber    protocol.PacketNumber
30
31	parsedLen protocol.ByteCount
32}
33
34func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
35	startLen := b.Len()
36	// read the (now unencrypted) first byte
37	var err error
38	h.typeByte, err = b.ReadByte()
39	if err != nil {
40		return false, err
41	}
42	if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
43		return false, err
44	}
45	var reservedBitsValid bool
46	if h.IsLongHeader {
47		reservedBitsValid, err = h.parseLongHeader(b, v)
48	} else {
49		reservedBitsValid, err = h.parseShortHeader(b, v)
50	}
51	if err != nil {
52		return false, err
53	}
54	h.parsedLen = protocol.ByteCount(startLen - b.Len())
55	return reservedBitsValid, err
56}
57
58func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
59	if err := h.readPacketNumber(b); err != nil {
60		return false, err
61	}
62	if h.typeByte&0xc != 0 {
63		return false, nil
64	}
65	return true, nil
66}
67
68func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
69	h.KeyPhase = protocol.KeyPhaseZero
70	if h.typeByte&0x4 > 0 {
71		h.KeyPhase = protocol.KeyPhaseOne
72	}
73
74	if err := h.readPacketNumber(b); err != nil {
75		return false, err
76	}
77	if h.typeByte&0x18 != 0 {
78		return false, nil
79	}
80	return true, nil
81}
82
83func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
84	h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
85	switch h.PacketNumberLen {
86	case protocol.PacketNumberLen1:
87		n, err := b.ReadByte()
88		if err != nil {
89			return err
90		}
91		h.PacketNumber = protocol.PacketNumber(n)
92	case protocol.PacketNumberLen2:
93		n, err := utils.BigEndian.ReadUint16(b)
94		if err != nil {
95			return err
96		}
97		h.PacketNumber = protocol.PacketNumber(n)
98	case protocol.PacketNumberLen3:
99		n, err := utils.BigEndian.ReadUint24(b)
100		if err != nil {
101			return err
102		}
103		h.PacketNumber = protocol.PacketNumber(n)
104	case protocol.PacketNumberLen4:
105		n, err := utils.BigEndian.ReadUint32(b)
106		if err != nil {
107			return err
108		}
109		h.PacketNumber = protocol.PacketNumber(n)
110	default:
111		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
112	}
113	return nil
114}
115
116// Write writes the Header.
117func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
118	if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
119		return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
120	}
121	if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
122		return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
123	}
124	if h.IsLongHeader {
125		return h.writeLongHeader(b, ver)
126	}
127	return h.writeShortHeader(b, ver)
128}
129
130func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
131	var packetType uint8
132	//nolint:exhaustive
133	switch h.Type {
134	case protocol.PacketTypeInitial:
135		packetType = 0x0
136	case protocol.PacketType0RTT:
137		packetType = 0x1
138	case protocol.PacketTypeHandshake:
139		packetType = 0x2
140	case protocol.PacketTypeRetry:
141		packetType = 0x3
142	}
143	firstByte := 0xc0 | packetType<<4
144	if h.Type != protocol.PacketTypeRetry {
145		// Retry packets don't have a packet number
146		firstByte |= uint8(h.PacketNumberLen - 1)
147	}
148
149	b.WriteByte(firstByte)
150	utils.BigEndian.WriteUint32(b, uint32(h.Version))
151	b.WriteByte(uint8(h.DestConnectionID.Len()))
152	b.Write(h.DestConnectionID.Bytes())
153	b.WriteByte(uint8(h.SrcConnectionID.Len()))
154	b.Write(h.SrcConnectionID.Bytes())
155
156	//nolint:exhaustive
157	switch h.Type {
158	case protocol.PacketTypeRetry:
159		b.Write(h.Token)
160		return nil
161	case protocol.PacketTypeInitial:
162		quicvarint.Write(b, uint64(len(h.Token)))
163		b.Write(h.Token)
164	}
165	quicvarint.WriteWithLen(b, uint64(h.Length), 2)
166	return h.writePacketNumber(b)
167}
168
169func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
170	typeByte := 0x40 | uint8(h.PacketNumberLen-1)
171	if h.KeyPhase == protocol.KeyPhaseOne {
172		typeByte |= byte(1 << 2)
173	}
174
175	b.WriteByte(typeByte)
176	b.Write(h.DestConnectionID.Bytes())
177	return h.writePacketNumber(b)
178}
179
180func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
181	switch h.PacketNumberLen {
182	case protocol.PacketNumberLen1:
183		b.WriteByte(uint8(h.PacketNumber))
184	case protocol.PacketNumberLen2:
185		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
186	case protocol.PacketNumberLen3:
187		utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
188	case protocol.PacketNumberLen4:
189		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
190	default:
191		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
192	}
193	return nil
194}
195
196// ParsedLen returns the number of bytes that were consumed when parsing the header
197func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
198	return h.parsedLen
199}
200
201// GetLength determines the length of the Header.
202func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
203	if h.IsLongHeader {
204		length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
205		if h.Type == protocol.PacketTypeInitial {
206			length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
207		}
208		return length
209	}
210
211	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
212	length += protocol.ByteCount(h.PacketNumberLen)
213	return length
214}
215
216// Log logs the Header
217func (h *ExtendedHeader) Log(logger utils.Logger) {
218	if h.IsLongHeader {
219		var token string
220		if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
221			if len(h.Token) == 0 {
222				token = "Token: (empty), "
223			} else {
224				token = fmt.Sprintf("Token: %#x, ", h.Token)
225			}
226			if h.Type == protocol.PacketTypeRetry {
227				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
228				return
229			}
230		}
231		logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
232	} else {
233		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
234	}
235}
236