1// Package nl has low level primitives for making Netlink calls.
2package nl
3
4import (
5	"bytes"
6	"encoding/binary"
7	"fmt"
8	"net"
9	"sync/atomic"
10	"syscall"
11	"unsafe"
12)
13
14const (
15	// Family type definitions
16	FAMILY_ALL = syscall.AF_UNSPEC
17	FAMILY_V4  = syscall.AF_INET
18	FAMILY_V6  = syscall.AF_INET6
19)
20
21var nextSeqNr uint32
22
23// GetIPFamily returns the family type of a net.IP.
24func GetIPFamily(ip net.IP) int {
25	if len(ip) <= net.IPv4len {
26		return FAMILY_V4
27	}
28	if ip.To4() != nil {
29		return FAMILY_V4
30	}
31	return FAMILY_V6
32}
33
34var nativeEndian binary.ByteOrder
35
36// Get native endianness for the system
37func NativeEndian() binary.ByteOrder {
38	if nativeEndian == nil {
39		var x uint32 = 0x01020304
40		if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
41			nativeEndian = binary.BigEndian
42		} else {
43			nativeEndian = binary.LittleEndian
44		}
45	}
46	return nativeEndian
47}
48
49// Byte swap a 16 bit value if we aren't big endian
50func Swap16(i uint16) uint16 {
51	if NativeEndian() == binary.BigEndian {
52		return i
53	}
54	return (i&0xff00)>>8 | (i&0xff)<<8
55}
56
57// Byte swap a 32 bit value if aren't big endian
58func Swap32(i uint32) uint32 {
59	if NativeEndian() == binary.BigEndian {
60		return i
61	}
62	return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
63}
64
65type NetlinkRequestData interface {
66	Len() int
67	Serialize() []byte
68}
69
70// IfInfomsg is related to links, but it is used for list requests as well
71type IfInfomsg struct {
72	syscall.IfInfomsg
73}
74
75// Create an IfInfomsg with family specified
76func NewIfInfomsg(family int) *IfInfomsg {
77	return &IfInfomsg{
78		IfInfomsg: syscall.IfInfomsg{
79			Family: uint8(family),
80		},
81	}
82}
83
84func DeserializeIfInfomsg(b []byte) *IfInfomsg {
85	return (*IfInfomsg)(unsafe.Pointer(&b[0:syscall.SizeofIfInfomsg][0]))
86}
87
88func (msg *IfInfomsg) Serialize() []byte {
89	return (*(*[syscall.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
90}
91
92func (msg *IfInfomsg) Len() int {
93	return syscall.SizeofIfInfomsg
94}
95
96func rtaAlignOf(attrlen int) int {
97	return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1)
98}
99
100func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
101	msg := NewIfInfomsg(family)
102	parent.children = append(parent.children, msg)
103	return msg
104}
105
106// Extend RtAttr to handle data and children
107type RtAttr struct {
108	syscall.RtAttr
109	Data     []byte
110	children []NetlinkRequestData
111}
112
113// Create a new Extended RtAttr object
114func NewRtAttr(attrType int, data []byte) *RtAttr {
115	return &RtAttr{
116		RtAttr: syscall.RtAttr{
117			Type: uint16(attrType),
118		},
119		children: []NetlinkRequestData{},
120		Data:     data,
121	}
122}
123
124// Create a new RtAttr obj anc add it as a child of an existing object
125func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
126	attr := NewRtAttr(attrType, data)
127	parent.children = append(parent.children, attr)
128	return attr
129}
130
131func (a *RtAttr) Len() int {
132	if len(a.children) == 0 {
133		return (syscall.SizeofRtAttr + len(a.Data))
134	}
135
136	l := 0
137	for _, child := range a.children {
138		l += rtaAlignOf(child.Len())
139	}
140	l += syscall.SizeofRtAttr
141	return rtaAlignOf(l + len(a.Data))
142}
143
144// Serialize the RtAttr into a byte array
145// This can't just unsafe.cast because it must iterate through children.
146func (a *RtAttr) Serialize() []byte {
147	native := NativeEndian()
148
149	length := a.Len()
150	buf := make([]byte, rtaAlignOf(length))
151
152	if a.Data != nil {
153		copy(buf[4:], a.Data)
154	} else {
155		next := 4
156		for _, child := range a.children {
157			childBuf := child.Serialize()
158			copy(buf[next:], childBuf)
159			next += rtaAlignOf(len(childBuf))
160		}
161	}
162
163	if l := uint16(length); l != 0 {
164		native.PutUint16(buf[0:2], l)
165	}
166	native.PutUint16(buf[2:4], a.Type)
167	return buf
168}
169
170type NetlinkRequest struct {
171	syscall.NlMsghdr
172	Data []NetlinkRequestData
173}
174
175// Serialize the Netlink Request into a byte array
176func (req *NetlinkRequest) Serialize() []byte {
177	length := syscall.SizeofNlMsghdr
178	dataBytes := make([][]byte, len(req.Data))
179	for i, data := range req.Data {
180		dataBytes[i] = data.Serialize()
181		length = length + len(dataBytes[i])
182	}
183	req.Len = uint32(length)
184	b := make([]byte, length)
185	hdr := (*(*[syscall.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
186	next := syscall.SizeofNlMsghdr
187	copy(b[0:next], hdr)
188	for _, data := range dataBytes {
189		for _, dataByte := range data {
190			b[next] = dataByte
191			next = next + 1
192		}
193	}
194	return b
195}
196
197func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
198	if data != nil {
199		req.Data = append(req.Data, data)
200	}
201}
202
203// Execute the request against a the given sockType.
204// Returns a list of netlink messages in seriaized format, optionally filtered
205// by resType.
206func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
207	s, err := getNetlinkSocket(sockType)
208	if err != nil {
209		return nil, err
210	}
211	defer s.Close()
212
213	if err := s.Send(req); err != nil {
214		return nil, err
215	}
216
217	pid, err := s.GetPid()
218	if err != nil {
219		return nil, err
220	}
221
222	var res [][]byte
223
224done:
225	for {
226		msgs, err := s.Receive()
227		if err != nil {
228			return nil, err
229		}
230		for _, m := range msgs {
231			if m.Header.Seq != req.Seq {
232				return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq)
233			}
234			if m.Header.Pid != pid {
235				return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
236			}
237			if m.Header.Type == syscall.NLMSG_DONE {
238				break done
239			}
240			if m.Header.Type == syscall.NLMSG_ERROR {
241				native := NativeEndian()
242				error := int32(native.Uint32(m.Data[0:4]))
243				if error == 0 {
244					break done
245				}
246				return nil, syscall.Errno(-error)
247			}
248			if resType != 0 && m.Header.Type != resType {
249				continue
250			}
251			res = append(res, m.Data)
252			if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
253				break done
254			}
255		}
256	}
257	return res, nil
258}
259
260// Create a new netlink request from proto and flags
261// Note the Len value will be inaccurate once data is added until
262// the message is serialized
263func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
264	return &NetlinkRequest{
265		NlMsghdr: syscall.NlMsghdr{
266			Len:   uint32(syscall.SizeofNlMsghdr),
267			Type:  uint16(proto),
268			Flags: syscall.NLM_F_REQUEST | uint16(flags),
269			Seq:   atomic.AddUint32(&nextSeqNr, 1),
270		},
271	}
272}
273
274type NetlinkSocket struct {
275	fd  int
276	lsa syscall.SockaddrNetlink
277}
278
279func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
280	fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
281	if err != nil {
282		return nil, err
283	}
284	s := &NetlinkSocket{
285		fd: fd,
286	}
287	s.lsa.Family = syscall.AF_NETLINK
288	if err := syscall.Bind(fd, &s.lsa); err != nil {
289		syscall.Close(fd)
290		return nil, err
291	}
292
293	return s, nil
294}
295
296// Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
297// and subscribe it to multicast groups passed in variable argument list.
298// Returns the netlink socket on which Receive() method can be called
299// to retrieve the messages from the kernel.
300func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
301	fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
302	if err != nil {
303		return nil, err
304	}
305	s := &NetlinkSocket{
306		fd: fd,
307	}
308	s.lsa.Family = syscall.AF_NETLINK
309
310	for _, g := range groups {
311		s.lsa.Groups |= (1 << (g - 1))
312	}
313
314	if err := syscall.Bind(fd, &s.lsa); err != nil {
315		syscall.Close(fd)
316		return nil, err
317	}
318
319	return s, nil
320}
321
322func (s *NetlinkSocket) Close() {
323	syscall.Close(s.fd)
324}
325
326func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
327	if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil {
328		return err
329	}
330	return nil
331}
332
333func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
334	rb := make([]byte, syscall.Getpagesize())
335	nr, _, err := syscall.Recvfrom(s.fd, rb, 0)
336	if err != nil {
337		return nil, err
338	}
339	if nr < syscall.NLMSG_HDRLEN {
340		return nil, fmt.Errorf("Got short response from netlink")
341	}
342	rb = rb[:nr]
343	return syscall.ParseNetlinkMessage(rb)
344}
345
346func (s *NetlinkSocket) GetPid() (uint32, error) {
347	lsa, err := syscall.Getsockname(s.fd)
348	if err != nil {
349		return 0, err
350	}
351	switch v := lsa.(type) {
352	case *syscall.SockaddrNetlink:
353		return v.Pid, nil
354	}
355	return 0, fmt.Errorf("Wrong socket type")
356}
357
358func ZeroTerminated(s string) []byte {
359	bytes := make([]byte, len(s)+1)
360	for i := 0; i < len(s); i++ {
361		bytes[i] = s[i]
362	}
363	bytes[len(s)] = 0
364	return bytes
365}
366
367func NonZeroTerminated(s string) []byte {
368	bytes := make([]byte, len(s))
369	for i := 0; i < len(s); i++ {
370		bytes[i] = s[i]
371	}
372	return bytes
373}
374
375func BytesToString(b []byte) string {
376	n := bytes.Index(b, []byte{0})
377	return string(b[:n])
378}
379
380func Uint8Attr(v uint8) []byte {
381	return []byte{byte(v)}
382}
383
384func Uint16Attr(v uint16) []byte {
385	native := NativeEndian()
386	bytes := make([]byte, 2)
387	native.PutUint16(bytes, v)
388	return bytes
389}
390
391func Uint32Attr(v uint32) []byte {
392	native := NativeEndian()
393	bytes := make([]byte, 4)
394	native.PutUint32(bytes, v)
395	return bytes
396}
397
398func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
399	var attrs []syscall.NetlinkRouteAttr
400	for len(b) >= syscall.SizeofRtAttr {
401		a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
402		if err != nil {
403			return nil, err
404		}
405		ra := syscall.NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-syscall.SizeofRtAttr]}
406		attrs = append(attrs, ra)
407		b = b[alen:]
408	}
409	return attrs, nil
410}
411
412func netlinkRouteAttrAndValue(b []byte) (*syscall.RtAttr, []byte, int, error) {
413	a := (*syscall.RtAttr)(unsafe.Pointer(&b[0]))
414	if int(a.Len) < syscall.SizeofRtAttr || int(a.Len) > len(b) {
415		return nil, nil, 0, syscall.EINVAL
416	}
417	return a, b[syscall.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
418}
419