1package wire
2
3import (
4	"bytes"
5	"crypto/rand"
6	"errors"
7	"fmt"
8
9	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
10	"github.com/ooni/psiphon/oopsi/github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
11)
12
13// Header is the header of a QUIC packet.
14// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
15type Header struct {
16	IsPublicHeader bool
17
18	Raw []byte
19
20	Version protocol.VersionNumber
21
22	DestConnectionID     protocol.ConnectionID
23	SrcConnectionID      protocol.ConnectionID
24	OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
25
26	PacketNumberLen protocol.PacketNumberLen
27	PacketNumber    protocol.PacketNumber
28
29	IsVersionNegotiation bool
30	SupportedVersions    []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
31
32	// only needed for the gQUIC Public Header
33	VersionFlag          bool
34	ResetFlag            bool
35	DiversificationNonce []byte
36
37	// only needed for the IETF Header
38	Type         protocol.PacketType
39	IsLongHeader bool
40	KeyPhase     int
41	PayloadLen   protocol.ByteCount
42	Token        []byte
43}
44
45var errInvalidPacketNumberLen = errors.New("invalid packet number length")
46
47// Write writes the Header.
48func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
49	if !ver.UsesIETFHeaderFormat() {
50		h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
51		return h.writePublicHeader(b, pers, ver)
52	}
53	// write an IETF QUIC header
54	if h.IsLongHeader {
55		return h.writeLongHeader(b, ver)
56	}
57	return h.writeShortHeader(b, ver)
58}
59
60// TODO: add support for the key phase
61func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
62	b.WriteByte(byte(0x80 | h.Type))
63	utils.BigEndian.WriteUint32(b, uint32(h.Version))
64	connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
65	if err != nil {
66		return err
67	}
68	b.WriteByte(connIDLen)
69	b.Write(h.DestConnectionID.Bytes())
70	b.Write(h.SrcConnectionID.Bytes())
71
72	if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
73		utils.WriteVarInt(b, uint64(len(h.Token)))
74		b.Write(h.Token)
75	}
76
77	if h.Type == protocol.PacketTypeRetry {
78		odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
79		if err != nil {
80			return err
81		}
82		// randomize the first 4 bits
83		odcilByte := make([]byte, 1)
84		_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
85		odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
86		b.Write(odcilByte)
87		b.Write(h.OrigDestConnectionID.Bytes())
88		b.Write(h.Token)
89		return nil
90	}
91
92	if v.UsesLengthInHeader() {
93		utils.WriteVarInt(b, uint64(h.PayloadLen))
94	}
95	if v.UsesVarintPacketNumbers() {
96		return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
97	}
98	utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
99	if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
100		if len(h.DiversificationNonce) != 32 {
101			return errors.New("invalid diversification nonce length")
102		}
103		b.Write(h.DiversificationNonce)
104	}
105	return nil
106}
107
108func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
109	typeByte := byte(0x30)
110	typeByte |= byte(h.KeyPhase << 6)
111	if !v.UsesVarintPacketNumbers() {
112		switch h.PacketNumberLen {
113		case protocol.PacketNumberLen1:
114		case protocol.PacketNumberLen2:
115			typeByte |= 0x1
116		case protocol.PacketNumberLen4:
117			typeByte |= 0x2
118		default:
119			return errInvalidPacketNumberLen
120		}
121	}
122
123	b.WriteByte(typeByte)
124	b.Write(h.DestConnectionID.Bytes())
125
126	if !v.UsesVarintPacketNumbers() {
127		switch h.PacketNumberLen {
128		case protocol.PacketNumberLen1:
129			b.WriteByte(uint8(h.PacketNumber))
130		case protocol.PacketNumberLen2:
131			utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
132		case protocol.PacketNumberLen4:
133			utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
134		}
135		return nil
136	}
137	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
138}
139
140// writePublicHeader writes a Public Header.
141func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
142	if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
143		return errors.New("PublicHeader: Can only write regular packets")
144	}
145	if h.SrcConnectionID.Len() != 0 {
146		return errors.New("PublicHeader: SrcConnectionID must not be set")
147	}
148	if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
149		return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
150	}
151
152	publicFlagByte := uint8(0x00)
153	if h.VersionFlag {
154		publicFlagByte |= 0x01
155	}
156	if h.DestConnectionID.Len() > 0 {
157		publicFlagByte |= 0x08
158	}
159	if len(h.DiversificationNonce) > 0 {
160		if len(h.DiversificationNonce) != 32 {
161			return errors.New("invalid diversification nonce length")
162		}
163		publicFlagByte |= 0x04
164	}
165	switch h.PacketNumberLen {
166	case protocol.PacketNumberLen1:
167		publicFlagByte |= 0x00
168	case protocol.PacketNumberLen2:
169		publicFlagByte |= 0x10
170	case protocol.PacketNumberLen4:
171		publicFlagByte |= 0x20
172	}
173	b.WriteByte(publicFlagByte)
174
175	if h.DestConnectionID.Len() > 0 {
176		b.Write(h.DestConnectionID)
177	}
178	if h.VersionFlag && pers == protocol.PerspectiveClient {
179		utils.BigEndian.WriteUint32(b, uint32(h.Version))
180	}
181	if len(h.DiversificationNonce) > 0 {
182		b.Write(h.DiversificationNonce)
183	}
184
185	switch h.PacketNumberLen {
186	case protocol.PacketNumberLen1:
187		b.WriteByte(uint8(h.PacketNumber))
188	case protocol.PacketNumberLen2:
189		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
190	case protocol.PacketNumberLen4:
191		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
192	case protocol.PacketNumberLen6:
193		return errInvalidPacketNumberLen
194	default:
195		return errors.New("PublicHeader: PacketNumberLen not set")
196	}
197
198	return nil
199}
200
201// GetLength determines the length of the Header.
202func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
203	if !v.UsesIETFHeaderFormat() {
204		return h.getPublicHeaderLength()
205	}
206	return h.getHeaderLength(v)
207}
208
209func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
210	if h.IsLongHeader {
211		length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
212		if v.UsesLengthInHeader() {
213			length += utils.VarIntLen(uint64(h.PayloadLen))
214		}
215		if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
216			length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
217		}
218		if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
219			length += protocol.ByteCount(len(h.DiversificationNonce))
220		}
221		return length, nil
222	}
223
224	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
225	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
226		return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
227	}
228	length += protocol.ByteCount(h.PacketNumberLen)
229	return length, nil
230}
231
232// getPublicHeaderLength gets the length of the publicHeader in bytes.
233// It can only be called for regular packets.
234func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
235	length := protocol.ByteCount(1) // 1 byte for public flags
236	if h.PacketNumberLen == protocol.PacketNumberLen6 {
237		return 0, errInvalidPacketNumberLen
238	}
239	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
240		return 0, errPacketNumberLenNotSet
241	}
242	length += protocol.ByteCount(h.PacketNumberLen)
243	length += protocol.ByteCount(h.DestConnectionID.Len())
244	// Version Number in packets sent by the client
245	if h.VersionFlag {
246		length += 4
247	}
248	length += protocol.ByteCount(len(h.DiversificationNonce))
249	return length, nil
250}
251
252// Log logs the Header
253func (h *Header) Log(logger utils.Logger) {
254	if h.IsPublicHeader {
255		h.logPublicHeader(logger)
256	} else {
257		h.logHeader(logger)
258	}
259}
260
261func (h *Header) logHeader(logger utils.Logger) {
262	if h.IsLongHeader {
263		if h.Version == 0 {
264			logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
265		} else {
266			var token string
267			if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
268				if len(h.Token) == 0 {
269					token = "Token: (empty), "
270				} else {
271					token = fmt.Sprintf("Token: %#x, ", h.Token)
272				}
273			}
274			if h.Type == protocol.PacketTypeRetry {
275				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
276				return
277			}
278			if h.Version == protocol.Version44 {
279				var divNonce string
280				if h.Type == protocol.PacketType0RTT {
281					divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
282				}
283				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
284				return
285			}
286			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
287		}
288	} else {
289		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
290	}
291}
292
293func (h *Header) logPublicHeader(logger utils.Logger) {
294	ver := "(unset)"
295	if h.Version != 0 {
296		ver = h.Version.String()
297	}
298	logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
299}
300
301func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
302	dcil, err := encodeSingleConnIDLen(dest)
303	if err != nil {
304		return 0, err
305	}
306	scil, err := encodeSingleConnIDLen(src)
307	if err != nil {
308		return 0, err
309	}
310	return scil | dcil<<4, nil
311}
312
313func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
314	len := id.Len()
315	if len == 0 {
316		return 0, nil
317	}
318	if len < 4 || len > 18 {
319		return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
320	}
321	return byte(len - 3), nil
322}
323
324func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
325	return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
326}
327
328func decodeSingleConnIDLen(enc uint8) int {
329	if enc == 0 {
330		return 0
331	}
332	return int(enc) + 3
333}
334