1// +build !windows
2
3package dns
4
5import (
6	"net"
7
8	"golang.org/x/net/ipv4"
9	"golang.org/x/net/ipv6"
10)
11
12// This is the required size of the OOB buffer to pass to ReadMsgUDP.
13var udpOOBSize = func() int {
14	// We can't know whether we'll get an IPv4 control message or an
15	// IPv6 control message ahead of time. To get around this, we size
16	// the buffer equal to the largest of the two.
17
18	oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)
19	oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)
20
21	if len(oob4) > len(oob6) {
22		return len(oob4)
23	}
24
25	return len(oob6)
26}()
27
28// SessionUDP holds the remote address and the associated
29// out-of-band data.
30type SessionUDP struct {
31	raddr   *net.UDPAddr
32	context []byte
33}
34
35// RemoteAddr returns the remote network address.
36func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr }
37
38// ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a
39// net.UDPAddr.
40func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) {
41	oob := make([]byte, udpOOBSize)
42	n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob)
43	if err != nil {
44		return n, nil, err
45	}
46	return n, &SessionUDP{raddr, oob[:oobn]}, err
47}
48
49// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr.
50func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) {
51	oob := correctSource(session.context)
52	n, _, err := conn.WriteMsgUDP(b, oob, session.raddr)
53	return n, err
54}
55
56func setUDPSocketOptions(conn *net.UDPConn) error {
57	// Try setting the flags for both families and ignore the errors unless they
58	// both error.
59	err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
60	err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
61	if err6 != nil && err4 != nil {
62		return err4
63	}
64	return nil
65}
66
67// parseDstFromOOB takes oob data and returns the destination IP.
68func parseDstFromOOB(oob []byte) net.IP {
69	// Start with IPv6 and then fallback to IPv4
70	// TODO(fastest963): Figure out a way to prefer one or the other. Looking at
71	// the lvl of the header for a 0 or 41 isn't cross-platform.
72	cm6 := new(ipv6.ControlMessage)
73	if cm6.Parse(oob) == nil && cm6.Dst != nil {
74		return cm6.Dst
75	}
76	cm4 := new(ipv4.ControlMessage)
77	if cm4.Parse(oob) == nil && cm4.Dst != nil {
78		return cm4.Dst
79	}
80	return nil
81}
82
83// correctSource takes oob data and returns new oob data with the Src equal to the Dst
84func correctSource(oob []byte) []byte {
85	dst := parseDstFromOOB(oob)
86	if dst == nil {
87		return nil
88	}
89	// If the dst is definitely an IPv6, then use ipv6's ControlMessage to
90	// respond otherwise use ipv4's because ipv6's marshal ignores ipv4
91	// addresses.
92	if dst.To4() == nil {
93		cm := new(ipv6.ControlMessage)
94		cm.Src = dst
95		oob = cm.Marshal()
96	} else {
97		cm := new(ipv4.ControlMessage)
98		cm.Src = dst
99		oob = cm.Marshal()
100	}
101	return oob
102}
103