1package rtnetlink
2
3import (
4	"errors"
5	"fmt"
6	"net"
7
8	"github.com/jsimonetti/rtnetlink/internal/unix"
9
10	"github.com/mdlayher/netlink"
11)
12
13var (
14	// errInvalidNeighMessage is returned when a LinkMessage is malformed.
15	errInvalidNeighMessage = errors.New("rtnetlink NeighMessage is invalid or too short")
16
17	// errInvalidNeighMessageAttr is returned when neigh attributes are malformed.
18	errInvalidNeighMessageAttr = errors.New("rtnetlink NeighMessage has a wrong attribute data length")
19)
20
21var _ Message = &NeighMessage{}
22
23// A NeighMessage is a route netlink neighbor message.
24type NeighMessage struct {
25	// Always set to AF_UNSPEC (0)
26	Family uint16
27
28	// Unique interface index
29	Index uint32
30
31	// Neighbor State is a bitmask of neighbor states (see rtnetlink(7))
32	State uint16
33
34	// Neighbor flags
35	Flags uint8
36
37	// Neighbor type
38	Type uint8
39
40	// Attributes List
41	Attributes *NeighAttributes
42}
43
44// MarshalBinary marshals a NeighMessage into a byte slice.
45func (m *NeighMessage) MarshalBinary() ([]byte, error) {
46	b := make([]byte, unix.SizeofNdMsg)
47
48	nativeEndian.PutUint16(b[0:2], m.Family)
49	// bytes 3 and 4 are padding
50	nativeEndian.PutUint32(b[4:8], m.Index)
51	nativeEndian.PutUint16(b[8:10], m.State)
52	b[10] = m.Flags
53	b[11] = m.Type
54
55	if m.Attributes != nil {
56		ae := netlink.NewAttributeEncoder()
57		ae.ByteOrder = nativeEndian
58		err := m.Attributes.encode(ae)
59		if err != nil {
60			return nil, err
61		}
62
63		a, err := ae.Encode()
64		if err != nil {
65			return nil, err
66		}
67
68		return append(b, a...), nil
69	}
70	return b, nil
71}
72
73// UnmarshalBinary unmarshals the contents of a byte slice into a NeighMessage.
74func (m *NeighMessage) UnmarshalBinary(b []byte) error {
75	l := len(b)
76	if l < unix.SizeofNdMsg {
77		return errInvalidNeighMessage
78	}
79
80	m.Family = nativeEndian.Uint16(b[0:2])
81	m.Index = nativeEndian.Uint32(b[4:8])
82	m.State = nativeEndian.Uint16(b[8:10])
83	m.Flags = b[10]
84	m.Type = b[11]
85
86	if l > unix.SizeofNdMsg {
87		m.Attributes = &NeighAttributes{}
88		ad, err := netlink.NewAttributeDecoder(b[unix.SizeofNdMsg:])
89		if err != nil {
90			return err
91		}
92		ad.ByteOrder = nativeEndian
93		err = m.Attributes.decode(ad)
94		if err != nil {
95			return err
96		}
97	}
98
99	return nil
100}
101
102// rtMessage is an empty method to sattisfy the Message interface.
103func (*NeighMessage) rtMessage() {}
104
105// NeighService is used to retrieve rtnetlink family information.
106type NeighService struct {
107	c *Conn
108}
109
110// New creates a new interface using the LinkMessage information.
111func (l *NeighService) New(req *NeighMessage) error {
112	flags := netlink.Request | netlink.Create | netlink.Acknowledge | netlink.Excl
113	_, err := l.c.Execute(req, unix.RTM_NEWNEIGH, flags)
114	if err != nil {
115		return err
116	}
117
118	return nil
119}
120
121// Delete removes an neighbor entry by index.
122func (l *NeighService) Delete(index uint32) error {
123	req := &NeighMessage{}
124
125	flags := netlink.Request | netlink.Acknowledge
126	_, err := l.c.Execute(req, unix.RTM_DELNEIGH, flags)
127	if err != nil {
128		return err
129	}
130
131	return nil
132}
133
134// List retrieves all neighbors.
135func (l *NeighService) List() ([]NeighMessage, error) {
136	req := NeighMessage{}
137
138	flags := netlink.Request | netlink.Dump
139	msgs, err := l.c.Execute(&req, unix.RTM_GETNEIGH, flags)
140	if err != nil {
141		return nil, err
142	}
143
144	neighs := make([]NeighMessage, len(msgs))
145	for i := range msgs {
146		neighs[i] = *msgs[i].(*NeighMessage)
147	}
148
149	return neighs, nil
150}
151
152// NeighCacheInfo contains neigh information
153type NeighCacheInfo struct {
154	Confirmed uint32
155	Used      uint32
156	Updated   uint32
157	RefCount  uint32
158}
159
160// UnmarshalBinary unmarshals the contents of a byte slice into a NeighMessage.
161func (n *NeighCacheInfo) unmarshalBinary(b []byte) error {
162	if len(b) != 16 {
163		return fmt.Errorf("incorrect size, want: 16, got: %d", len(b))
164	}
165
166	n.Confirmed = nativeEndian.Uint32(b[0:4])
167	n.Used = nativeEndian.Uint32(b[4:8])
168	n.Updated = nativeEndian.Uint32(b[8:12])
169	n.RefCount = nativeEndian.Uint32(b[12:16])
170
171	return nil
172}
173
174// NeighAttributes contains all attributes for a neighbor.
175type NeighAttributes struct {
176	Address   net.IP           // a neighbor cache n/w layer destination address
177	LLAddress net.HardwareAddr // a neighbor cache link layer address
178	CacheInfo *NeighCacheInfo  // cache statistics
179	IfIndex   uint32
180}
181
182func (a *NeighAttributes) decode(ad *netlink.AttributeDecoder) error {
183	for ad.Next() {
184		switch ad.Type() {
185		case unix.NDA_UNSPEC:
186			// unused attribute
187		case unix.NDA_DST:
188			l := len(ad.Bytes())
189			if l != 4 && l != 16 {
190				return errInvalidNeighMessageAttr
191			}
192			a.Address = ad.Bytes()
193		case unix.NDA_LLADDR:
194			if len(ad.Bytes()) != 6 {
195				return errInvalidNeighMessageAttr
196			}
197			a.LLAddress = ad.Bytes()
198		case unix.NDA_CACHEINFO:
199			a.CacheInfo = &NeighCacheInfo{}
200			err := a.CacheInfo.unmarshalBinary(ad.Bytes())
201			if err != nil {
202				return err
203			}
204		case unix.NDA_IFINDEX:
205			a.IfIndex = ad.Uint32()
206		}
207	}
208
209	return nil
210}
211
212func (a *NeighAttributes) encode(ae *netlink.AttributeEncoder) error {
213	ae.Uint16(unix.NDA_UNSPEC, 0)
214	ae.Bytes(unix.NDA_DST, a.Address)
215	ae.Bytes(unix.NDA_LLADDR, a.LLAddress)
216	ae.Uint32(unix.NDA_IFINDEX, a.IfIndex)
217
218	return nil
219}
220