1// Package nl has low level primitives for making Netlink calls. 2package nl 3 4import ( 5 "bytes" 6 "encoding/binary" 7 "fmt" 8 "net" 9 "sync/atomic" 10 "syscall" 11 "unsafe" 12) 13 14const ( 15 // Family type definitions 16 FAMILY_ALL = syscall.AF_UNSPEC 17 FAMILY_V4 = syscall.AF_INET 18 FAMILY_V6 = syscall.AF_INET6 19) 20 21var nextSeqNr uint32 22 23// GetIPFamily returns the family type of a net.IP. 24func GetIPFamily(ip net.IP) int { 25 if len(ip) <= net.IPv4len { 26 return FAMILY_V4 27 } 28 if ip.To4() != nil { 29 return FAMILY_V4 30 } 31 return FAMILY_V6 32} 33 34var nativeEndian binary.ByteOrder 35 36// Get native endianness for the system 37func NativeEndian() binary.ByteOrder { 38 if nativeEndian == nil { 39 var x uint32 = 0x01020304 40 if *(*byte)(unsafe.Pointer(&x)) == 0x01 { 41 nativeEndian = binary.BigEndian 42 } else { 43 nativeEndian = binary.LittleEndian 44 } 45 } 46 return nativeEndian 47} 48 49// Byte swap a 16 bit value if we aren't big endian 50func Swap16(i uint16) uint16 { 51 if NativeEndian() == binary.BigEndian { 52 return i 53 } 54 return (i&0xff00)>>8 | (i&0xff)<<8 55} 56 57// Byte swap a 32 bit value if aren't big endian 58func Swap32(i uint32) uint32 { 59 if NativeEndian() == binary.BigEndian { 60 return i 61 } 62 return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24 63} 64 65type NetlinkRequestData interface { 66 Len() int 67 Serialize() []byte 68} 69 70// IfInfomsg is related to links, but it is used for list requests as well 71type IfInfomsg struct { 72 syscall.IfInfomsg 73} 74 75// Create an IfInfomsg with family specified 76func NewIfInfomsg(family int) *IfInfomsg { 77 return &IfInfomsg{ 78 IfInfomsg: syscall.IfInfomsg{ 79 Family: uint8(family), 80 }, 81 } 82} 83 84func DeserializeIfInfomsg(b []byte) *IfInfomsg { 85 return (*IfInfomsg)(unsafe.Pointer(&b[0:syscall.SizeofIfInfomsg][0])) 86} 87 88func (msg *IfInfomsg) Serialize() []byte { 89 return (*(*[syscall.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:] 90} 91 92func (msg *IfInfomsg) Len() int { 93 return syscall.SizeofIfInfomsg 94} 95 96func rtaAlignOf(attrlen int) int { 97 return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1) 98} 99 100func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg { 101 msg := NewIfInfomsg(family) 102 parent.children = append(parent.children, msg) 103 return msg 104} 105 106// Extend RtAttr to handle data and children 107type RtAttr struct { 108 syscall.RtAttr 109 Data []byte 110 children []NetlinkRequestData 111} 112 113// Create a new Extended RtAttr object 114func NewRtAttr(attrType int, data []byte) *RtAttr { 115 return &RtAttr{ 116 RtAttr: syscall.RtAttr{ 117 Type: uint16(attrType), 118 }, 119 children: []NetlinkRequestData{}, 120 Data: data, 121 } 122} 123 124// Create a new RtAttr obj anc add it as a child of an existing object 125func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr { 126 attr := NewRtAttr(attrType, data) 127 parent.children = append(parent.children, attr) 128 return attr 129} 130 131func (a *RtAttr) Len() int { 132 if len(a.children) == 0 { 133 return (syscall.SizeofRtAttr + len(a.Data)) 134 } 135 136 l := 0 137 for _, child := range a.children { 138 l += rtaAlignOf(child.Len()) 139 } 140 l += syscall.SizeofRtAttr 141 return rtaAlignOf(l + len(a.Data)) 142} 143 144// Serialize the RtAttr into a byte array 145// This can't just unsafe.cast because it must iterate through children. 146func (a *RtAttr) Serialize() []byte { 147 native := NativeEndian() 148 149 length := a.Len() 150 buf := make([]byte, rtaAlignOf(length)) 151 152 if a.Data != nil { 153 copy(buf[4:], a.Data) 154 } else { 155 next := 4 156 for _, child := range a.children { 157 childBuf := child.Serialize() 158 copy(buf[next:], childBuf) 159 next += rtaAlignOf(len(childBuf)) 160 } 161 } 162 163 if l := uint16(length); l != 0 { 164 native.PutUint16(buf[0:2], l) 165 } 166 native.PutUint16(buf[2:4], a.Type) 167 return buf 168} 169 170type NetlinkRequest struct { 171 syscall.NlMsghdr 172 Data []NetlinkRequestData 173} 174 175// Serialize the Netlink Request into a byte array 176func (req *NetlinkRequest) Serialize() []byte { 177 length := syscall.SizeofNlMsghdr 178 dataBytes := make([][]byte, len(req.Data)) 179 for i, data := range req.Data { 180 dataBytes[i] = data.Serialize() 181 length = length + len(dataBytes[i]) 182 } 183 req.Len = uint32(length) 184 b := make([]byte, length) 185 hdr := (*(*[syscall.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:] 186 next := syscall.SizeofNlMsghdr 187 copy(b[0:next], hdr) 188 for _, data := range dataBytes { 189 for _, dataByte := range data { 190 b[next] = dataByte 191 next = next + 1 192 } 193 } 194 return b 195} 196 197func (req *NetlinkRequest) AddData(data NetlinkRequestData) { 198 if data != nil { 199 req.Data = append(req.Data, data) 200 } 201} 202 203// Execute the request against a the given sockType. 204// Returns a list of netlink messages in seriaized format, optionally filtered 205// by resType. 206func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) { 207 s, err := getNetlinkSocket(sockType) 208 if err != nil { 209 return nil, err 210 } 211 defer s.Close() 212 213 if err := s.Send(req); err != nil { 214 return nil, err 215 } 216 217 pid, err := s.GetPid() 218 if err != nil { 219 return nil, err 220 } 221 222 var res [][]byte 223 224done: 225 for { 226 msgs, err := s.Receive() 227 if err != nil { 228 return nil, err 229 } 230 for _, m := range msgs { 231 if m.Header.Seq != req.Seq { 232 return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq) 233 } 234 if m.Header.Pid != pid { 235 return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid) 236 } 237 if m.Header.Type == syscall.NLMSG_DONE { 238 break done 239 } 240 if m.Header.Type == syscall.NLMSG_ERROR { 241 native := NativeEndian() 242 error := int32(native.Uint32(m.Data[0:4])) 243 if error == 0 { 244 break done 245 } 246 return nil, syscall.Errno(-error) 247 } 248 if resType != 0 && m.Header.Type != resType { 249 continue 250 } 251 res = append(res, m.Data) 252 if m.Header.Flags&syscall.NLM_F_MULTI == 0 { 253 break done 254 } 255 } 256 } 257 return res, nil 258} 259 260// Create a new netlink request from proto and flags 261// Note the Len value will be inaccurate once data is added until 262// the message is serialized 263func NewNetlinkRequest(proto, flags int) *NetlinkRequest { 264 return &NetlinkRequest{ 265 NlMsghdr: syscall.NlMsghdr{ 266 Len: uint32(syscall.SizeofNlMsghdr), 267 Type: uint16(proto), 268 Flags: syscall.NLM_F_REQUEST | uint16(flags), 269 Seq: atomic.AddUint32(&nextSeqNr, 1), 270 }, 271 } 272} 273 274type NetlinkSocket struct { 275 fd int 276 lsa syscall.SockaddrNetlink 277} 278 279func getNetlinkSocket(protocol int) (*NetlinkSocket, error) { 280 fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol) 281 if err != nil { 282 return nil, err 283 } 284 s := &NetlinkSocket{ 285 fd: fd, 286 } 287 s.lsa.Family = syscall.AF_NETLINK 288 if err := syscall.Bind(fd, &s.lsa); err != nil { 289 syscall.Close(fd) 290 return nil, err 291 } 292 293 return s, nil 294} 295 296// Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE) 297// and subscribe it to multicast groups passed in variable argument list. 298// Returns the netlink socket on which Receive() method can be called 299// to retrieve the messages from the kernel. 300func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) { 301 fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol) 302 if err != nil { 303 return nil, err 304 } 305 s := &NetlinkSocket{ 306 fd: fd, 307 } 308 s.lsa.Family = syscall.AF_NETLINK 309 310 for _, g := range groups { 311 s.lsa.Groups |= (1 << (g - 1)) 312 } 313 314 if err := syscall.Bind(fd, &s.lsa); err != nil { 315 syscall.Close(fd) 316 return nil, err 317 } 318 319 return s, nil 320} 321 322func (s *NetlinkSocket) Close() { 323 syscall.Close(s.fd) 324} 325 326func (s *NetlinkSocket) Send(request *NetlinkRequest) error { 327 if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil { 328 return err 329 } 330 return nil 331} 332 333func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) { 334 rb := make([]byte, syscall.Getpagesize()) 335 nr, _, err := syscall.Recvfrom(s.fd, rb, 0) 336 if err != nil { 337 return nil, err 338 } 339 if nr < syscall.NLMSG_HDRLEN { 340 return nil, fmt.Errorf("Got short response from netlink") 341 } 342 rb = rb[:nr] 343 return syscall.ParseNetlinkMessage(rb) 344} 345 346func (s *NetlinkSocket) GetPid() (uint32, error) { 347 lsa, err := syscall.Getsockname(s.fd) 348 if err != nil { 349 return 0, err 350 } 351 switch v := lsa.(type) { 352 case *syscall.SockaddrNetlink: 353 return v.Pid, nil 354 } 355 return 0, fmt.Errorf("Wrong socket type") 356} 357 358func ZeroTerminated(s string) []byte { 359 bytes := make([]byte, len(s)+1) 360 for i := 0; i < len(s); i++ { 361 bytes[i] = s[i] 362 } 363 bytes[len(s)] = 0 364 return bytes 365} 366 367func NonZeroTerminated(s string) []byte { 368 bytes := make([]byte, len(s)) 369 for i := 0; i < len(s); i++ { 370 bytes[i] = s[i] 371 } 372 return bytes 373} 374 375func BytesToString(b []byte) string { 376 n := bytes.Index(b, []byte{0}) 377 return string(b[:n]) 378} 379 380func Uint8Attr(v uint8) []byte { 381 return []byte{byte(v)} 382} 383 384func Uint16Attr(v uint16) []byte { 385 native := NativeEndian() 386 bytes := make([]byte, 2) 387 native.PutUint16(bytes, v) 388 return bytes 389} 390 391func Uint32Attr(v uint32) []byte { 392 native := NativeEndian() 393 bytes := make([]byte, 4) 394 native.PutUint32(bytes, v) 395 return bytes 396} 397 398func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) { 399 var attrs []syscall.NetlinkRouteAttr 400 for len(b) >= syscall.SizeofRtAttr { 401 a, vbuf, alen, err := netlinkRouteAttrAndValue(b) 402 if err != nil { 403 return nil, err 404 } 405 ra := syscall.NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-syscall.SizeofRtAttr]} 406 attrs = append(attrs, ra) 407 b = b[alen:] 408 } 409 return attrs, nil 410} 411 412func netlinkRouteAttrAndValue(b []byte) (*syscall.RtAttr, []byte, int, error) { 413 a := (*syscall.RtAttr)(unsafe.Pointer(&b[0])) 414 if int(a.Len) < syscall.SizeofRtAttr || int(a.Len) > len(b) { 415 return nil, nil, 0, syscall.EINVAL 416 } 417 return a, b[syscall.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil 418} 419