1// +build darwin linux freebsd
2
3package quic
4
5import (
6	"encoding/binary"
7	"errors"
8	"fmt"
9	"net"
10	"runtime"
11	"syscall"
12	"time"
13	"unsafe"
14
15	"golang.org/x/net/ipv4"
16	"golang.org/x/net/ipv6"
17	"golang.org/x/sys/unix"
18
19	"github.com/lucas-clemente/quic-go/internal/protocol"
20	"github.com/lucas-clemente/quic-go/internal/utils"
21)
22
23const (
24	ecnMask       = 0x3
25	oobBufferSize = 128
26)
27
28// Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version.
29// They're both just aliases for x/net/internal/socket.Message.
30// This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages.
31var _ ipv4.Message = ipv6.Message{}
32
33type batchConn interface {
34	ReadBatch(ms []ipv4.Message, flags int) (int, error)
35}
36
37func inspectReadBuffer(c interface{}) (int, error) {
38	conn, ok := c.(interface {
39		SyscallConn() (syscall.RawConn, error)
40	})
41	if !ok {
42		return 0, errors.New("doesn't have a SyscallConn")
43	}
44	rawConn, err := conn.SyscallConn()
45	if err != nil {
46		return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
47	}
48	var size int
49	var serr error
50	if err := rawConn.Control(func(fd uintptr) {
51		size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
52	}); err != nil {
53		return 0, err
54	}
55	return size, serr
56}
57
58type oobConn struct {
59	OOBCapablePacketConn
60	batchConn batchConn
61
62	readPos uint8
63	// Packets received from the kernel, but not yet returned by ReadPacket().
64	messages []ipv4.Message
65	buffers  [batchSize]*packetBuffer
66}
67
68var _ connection = &oobConn{}
69
70func newConn(c OOBCapablePacketConn) (*oobConn, error) {
71	rawConn, err := c.SyscallConn()
72	if err != nil {
73		return nil, err
74	}
75	needsPacketInfo := false
76	if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
77		needsPacketInfo = true
78	}
79	// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
80	// Try enabling receiving of ECN and packet info for both IP versions.
81	// We expect at least one of those syscalls to succeed.
82	var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
83	if err := rawConn.Control(func(fd uintptr) {
84		errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
85		errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
86
87		if needsPacketInfo {
88			errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1)
89			errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1)
90		}
91	}); err != nil {
92		return nil, err
93	}
94	switch {
95	case errECNIPv4 == nil && errECNIPv6 == nil:
96		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
97	case errECNIPv4 == nil && errECNIPv6 != nil:
98		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
99	case errECNIPv4 != nil && errECNIPv6 == nil:
100		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
101	case errECNIPv4 != nil && errECNIPv6 != nil:
102		return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
103	}
104	if needsPacketInfo {
105		switch {
106		case errPIIPv4 == nil && errPIIPv6 == nil:
107			utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
108		case errPIIPv4 == nil && errPIIPv6 != nil:
109			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
110		case errPIIPv4 != nil && errPIIPv6 == nil:
111			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
112		case errPIIPv4 != nil && errPIIPv6 != nil:
113			return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
114		}
115	}
116
117	// Allows callers to pass in a connection that already satisfies batchConn interface
118	// to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor
119	// via SyscallConn(), and read it that way, which might not be what the caller wants.
120	var bc batchConn
121	if ibc, ok := c.(batchConn); ok {
122		bc = ibc
123	} else {
124		bc = ipv4.NewPacketConn(c)
125	}
126
127	oobConn := &oobConn{
128		OOBCapablePacketConn: c,
129		batchConn:            bc,
130		messages:             make([]ipv4.Message, batchSize),
131		readPos:              batchSize,
132	}
133	for i := 0; i < batchSize; i++ {
134		oobConn.messages[i].OOB = make([]byte, oobBufferSize)
135	}
136	return oobConn, nil
137}
138
139func (c *oobConn) ReadPacket() (*receivedPacket, error) {
140	if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
141		c.messages = c.messages[:batchSize]
142		// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
143		for i := uint8(0); i < c.readPos; i++ {
144			buffer := getPacketBuffer()
145			buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
146			c.buffers[i] = buffer
147			c.messages[i].Buffers = [][]byte{c.buffers[i].Data}
148		}
149		c.readPos = 0
150
151		n, err := c.batchConn.ReadBatch(c.messages, 0)
152		if n == 0 || err != nil {
153			return nil, err
154		}
155		c.messages = c.messages[:n]
156	}
157
158	msg := c.messages[c.readPos]
159	buffer := c.buffers[c.readPos]
160	c.readPos++
161	ctrlMsgs, err := unix.ParseSocketControlMessage(msg.OOB[:msg.NN])
162	if err != nil {
163		return nil, err
164	}
165	var ecn protocol.ECN
166	var destIP net.IP
167	var ifIndex uint32
168	for _, ctrlMsg := range ctrlMsgs {
169		if ctrlMsg.Header.Level == unix.IPPROTO_IP {
170			switch ctrlMsg.Header.Type {
171			case msgTypeIPTOS:
172				ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
173			case msgTypeIPv4PKTINFO:
174				// struct in_pktinfo {
175				// 	unsigned int   ipi_ifindex;  /* Interface index */
176				// 	struct in_addr ipi_spec_dst; /* Local address */
177				// 	struct in_addr ipi_addr;     /* Header Destination
178				// 									address */
179				// };
180				ip := make([]byte, 4)
181				if len(ctrlMsg.Data) == 12 {
182					ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data)
183					copy(ip, ctrlMsg.Data[8:12])
184				} else if len(ctrlMsg.Data) == 4 {
185					// FreeBSD
186					copy(ip, ctrlMsg.Data)
187				}
188				destIP = net.IP(ip)
189			}
190		}
191		if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 {
192			switch ctrlMsg.Header.Type {
193			case unix.IPV6_TCLASS:
194				ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
195			case msgTypeIPv6PKTINFO:
196				// struct in6_pktinfo {
197				// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
198				// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
199				// };
200				if len(ctrlMsg.Data) == 20 {
201					ip := make([]byte, 16)
202					copy(ip, ctrlMsg.Data[:16])
203					destIP = net.IP(ip)
204					ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:])
205				}
206			}
207		}
208	}
209	var info *packetInfo
210	if destIP != nil {
211		info = &packetInfo{
212			addr:    destIP,
213			ifIndex: ifIndex,
214		}
215	}
216	return &receivedPacket{
217		remoteAddr: msg.Addr,
218		rcvTime:    time.Now(),
219		data:       msg.Buffers[0][:msg.N],
220		ecn:        ecn,
221		info:       info,
222		buffer:     buffer,
223	}, nil
224}
225
226func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) {
227	n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
228	return n, err
229}
230
231func (info *packetInfo) OOB() []byte {
232	if info == nil {
233		return nil
234	}
235	if ip4 := info.addr.To4(); ip4 != nil {
236		// struct in_pktinfo {
237		// 	unsigned int   ipi_ifindex;  /* Interface index */
238		// 	struct in_addr ipi_spec_dst; /* Local address */
239		// 	struct in_addr ipi_addr;     /* Header Destination address */
240		// };
241		msgLen := 12
242		if runtime.GOOS == "freebsd" {
243			msgLen = 4
244		}
245		cmsglen := cmsgLen(msgLen)
246		oob := make([]byte, cmsglen)
247		cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
248		cmsg.Level = syscall.IPPROTO_TCP
249		cmsg.Type = msgTypeIPv4PKTINFO
250		cmsg.SetLen(cmsglen)
251		off := cmsgLen(0)
252		if runtime.GOOS != "freebsd" {
253			// FreeBSD does not support in_pktinfo, just an in_addr is sent
254			binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
255			off += 4
256		}
257		copy(oob[off:], ip4)
258		return oob
259	} else if len(info.addr) == 16 {
260		// struct in6_pktinfo {
261		// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
262		// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
263		// };
264		const msgLen = 20
265		cmsglen := cmsgLen(msgLen)
266		oob := make([]byte, cmsglen)
267		cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
268		cmsg.Level = syscall.IPPROTO_IPV6
269		cmsg.Type = msgTypeIPv6PKTINFO
270		cmsg.SetLen(cmsglen)
271		off := cmsgLen(0)
272		off += copy(oob[off:], info.addr)
273		binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
274		return oob
275	}
276	return nil
277}
278
279func cmsgLen(datalen int) int {
280	return cmsgAlign(syscall.SizeofCmsghdr) + datalen
281}
282
283func cmsgAlign(salen int) int {
284	const sizeOfPtr = 0x8
285	salign := sizeOfPtr
286	return (salen + salign - 1) & ^(salign - 1)
287}
288