1package wire
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"io"
8
9	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/protocol"
10	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/quic-go/internal/utils"
11)
12
13// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
14// When this error is returned, parsing continues, and an ExtendedHeader is returned.
15// This is necessary because we need to decrypt the packet in that case,
16// in order to avoid a timing side-channel.
17var ErrInvalidReservedBits = errors.New("invalid reserved bits")
18
19// ExtendedHeader is the header of a QUIC packet.
20type ExtendedHeader struct {
21	Header
22
23	typeByte byte
24
25	KeyPhase protocol.KeyPhaseBit
26
27	PacketNumberLen protocol.PacketNumberLen
28	PacketNumber    protocol.PacketNumber
29}
30
31func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
32	// read the (now unencrypted) first byte
33	var err error
34	h.typeByte, err = b.ReadByte()
35	if err != nil {
36		return nil, err
37	}
38	if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil {
39		return nil, err
40	}
41	if h.IsLongHeader {
42		return h.parseLongHeader(b, v)
43	}
44	return h.parseShortHeader(b, v)
45}
46
47func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) {
48	if err := h.readPacketNumber(b); err != nil {
49		return nil, err
50	}
51	var err error
52	if h.typeByte&0xc != 0 {
53		err = ErrInvalidReservedBits
54	}
55	return h, err
56}
57
58func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) {
59	h.KeyPhase = protocol.KeyPhaseZero
60	if h.typeByte&0x4 > 0 {
61		h.KeyPhase = protocol.KeyPhaseOne
62	}
63
64	if err := h.readPacketNumber(b); err != nil {
65		return nil, err
66	}
67	var err error
68	if h.typeByte&0x18 != 0 {
69		err = ErrInvalidReservedBits
70	}
71	return h, err
72}
73
74func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
75	h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
76	switch h.PacketNumberLen {
77	case protocol.PacketNumberLen1:
78		n, err := b.ReadByte()
79		if err != nil {
80			return err
81		}
82		h.PacketNumber = protocol.PacketNumber(n)
83	case protocol.PacketNumberLen2:
84		n, err := utils.BigEndian.ReadUint16(b)
85		if err != nil {
86			return err
87		}
88		h.PacketNumber = protocol.PacketNumber(n)
89	case protocol.PacketNumberLen3:
90		n, err := utils.BigEndian.ReadUint24(b)
91		if err != nil {
92			return err
93		}
94		h.PacketNumber = protocol.PacketNumber(n)
95	case protocol.PacketNumberLen4:
96		n, err := utils.BigEndian.ReadUint32(b)
97		if err != nil {
98			return err
99		}
100		h.PacketNumber = protocol.PacketNumber(n)
101	default:
102		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
103	}
104	return nil
105}
106
107// Write writes the Header.
108func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
109	if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
110		return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
111	}
112	if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
113		return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
114	}
115	if h.OrigDestConnectionID.Len() > protocol.MaxConnIDLen {
116		return fmt.Errorf("invalid connection ID length: %d bytes", h.OrigDestConnectionID.Len())
117	}
118	if h.IsLongHeader {
119		return h.writeLongHeader(b, ver)
120	}
121	return h.writeShortHeader(b, ver)
122}
123
124func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
125	var packetType uint8
126	switch h.Type {
127	case protocol.PacketTypeInitial:
128		packetType = 0x0
129	case protocol.PacketType0RTT:
130		packetType = 0x1
131	case protocol.PacketTypeHandshake:
132		packetType = 0x2
133	case protocol.PacketTypeRetry:
134		packetType = 0x3
135	}
136	firstByte := 0xc0 | packetType<<4
137	if h.Type != protocol.PacketTypeRetry {
138		// Retry packets don't have a packet number
139		firstByte |= uint8(h.PacketNumberLen - 1)
140	}
141
142	b.WriteByte(firstByte)
143	utils.BigEndian.WriteUint32(b, uint32(h.Version))
144	b.WriteByte(uint8(h.DestConnectionID.Len()))
145	b.Write(h.DestConnectionID.Bytes())
146	b.WriteByte(uint8(h.SrcConnectionID.Len()))
147	b.Write(h.SrcConnectionID.Bytes())
148
149	switch h.Type {
150	case protocol.PacketTypeRetry:
151		b.WriteByte(uint8(h.OrigDestConnectionID.Len()))
152		b.Write(h.OrigDestConnectionID.Bytes())
153		b.Write(h.Token)
154		return nil
155	case protocol.PacketTypeInitial:
156		utils.WriteVarInt(b, uint64(len(h.Token)))
157		b.Write(h.Token)
158	}
159
160	utils.WriteVarInt(b, uint64(h.Length))
161	return h.writePacketNumber(b)
162}
163
164func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
165	typeByte := 0x40 | uint8(h.PacketNumberLen-1)
166	if h.KeyPhase == protocol.KeyPhaseOne {
167		typeByte |= byte(1 << 2)
168	}
169
170	b.WriteByte(typeByte)
171	b.Write(h.DestConnectionID.Bytes())
172	return h.writePacketNumber(b)
173}
174
175func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
176	switch h.PacketNumberLen {
177	case protocol.PacketNumberLen1:
178		b.WriteByte(uint8(h.PacketNumber))
179	case protocol.PacketNumberLen2:
180		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
181	case protocol.PacketNumberLen3:
182		utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
183	case protocol.PacketNumberLen4:
184		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
185	default:
186		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
187	}
188	return nil
189}
190
191// GetLength determines the length of the Header.
192func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
193	if h.IsLongHeader {
194		length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
195		if h.Type == protocol.PacketTypeInitial {
196			length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
197		}
198		return length
199	}
200
201	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
202	length += protocol.ByteCount(h.PacketNumberLen)
203	return length
204}
205
206// Log logs the Header
207func (h *ExtendedHeader) Log(logger utils.Logger) {
208	if h.IsLongHeader {
209		var token string
210		if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
211			if len(h.Token) == 0 {
212				token = "Token: (empty), "
213			} else {
214				token = fmt.Sprintf("Token: %#x, ", h.Token)
215			}
216			if h.Type == protocol.PacketTypeRetry {
217				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
218				return
219			}
220		}
221		logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
222	} else {
223		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
224	}
225}
226