1//+build linux
2
3package genetlink
4
5import (
6	"errors"
7	"fmt"
8	"math"
9
10	"github.com/mdlayher/netlink"
11	"github.com/mdlayher/netlink/nlenc"
12	"golang.org/x/sys/unix"
13)
14
15// errInvalidFamilyVersion is returned when a family's version is greater
16// than an 8-bit integer.
17var errInvalidFamilyVersion = errors.New("invalid family version attribute")
18
19// getFamily retrieves a generic netlink family with the specified name.
20func (c *Conn) getFamily(name string) (Family, error) {
21	b, err := netlink.MarshalAttributes([]netlink.Attribute{{
22		Type: unix.CTRL_ATTR_FAMILY_NAME,
23		Data: nlenc.Bytes(name),
24	}})
25	if err != nil {
26		return Family{}, err
27	}
28
29	req := Message{
30		Header: Header{
31			Command: unix.CTRL_CMD_GETFAMILY,
32			// TODO(mdlayher): grab nlctrl version?
33			Version: 1,
34		},
35		Data: b,
36	}
37
38	msgs, err := c.Execute(req, unix.GENL_ID_CTRL, netlink.Request)
39	if err != nil {
40		return Family{}, err
41	}
42
43	// TODO(mdlayher): consider interpreting generic netlink header values
44
45	families, err := buildFamilies(msgs)
46	if err != nil {
47		return Family{}, err
48	}
49	if len(families) != 1 {
50		// If this were to ever happen, netlink must be in a state where
51		// its answers cannot be trusted
52		panic(fmt.Sprintf("netlink returned multiple families for name: %q", name))
53	}
54
55	return families[0], nil
56}
57
58// listFamilies retrieves all registered generic netlink families.
59func (c *Conn) listFamilies() ([]Family, error) {
60	req := Message{
61		Header: Header{
62			Command: unix.CTRL_CMD_GETFAMILY,
63			// TODO(mdlayher): grab nlctrl version?
64			Version: 1,
65		},
66	}
67
68	flags := netlink.Request | netlink.Dump
69	msgs, err := c.Execute(req, unix.GENL_ID_CTRL, flags)
70	if err != nil {
71		return nil, err
72	}
73
74	return buildFamilies(msgs)
75}
76
77// buildFamilies builds a slice of Families by parsing attributes from the
78// input Messages.
79func buildFamilies(msgs []Message) ([]Family, error) {
80	families := make([]Family, 0, len(msgs))
81	for _, m := range msgs {
82		var f Family
83		if err := (&f).parseAttributes(m.Data); err != nil {
84			return nil, err
85		}
86
87		families = append(families, f)
88	}
89
90	return families, nil
91}
92
93// parseAttributes decodes netlink attributes into a Family's fields.
94func (f *Family) parseAttributes(b []byte) error {
95	ad, err := netlink.NewAttributeDecoder(b)
96	if err != nil {
97		return err
98	}
99
100	for ad.Next() {
101		switch ad.Type() {
102		case unix.CTRL_ATTR_FAMILY_ID:
103			f.ID = ad.Uint16()
104		case unix.CTRL_ATTR_FAMILY_NAME:
105			f.Name = ad.String()
106		case unix.CTRL_ATTR_VERSION:
107			v := ad.Uint32()
108			if v > math.MaxUint8 {
109				return errInvalidFamilyVersion
110			}
111
112			f.Version = uint8(v)
113		case unix.CTRL_ATTR_MCAST_GROUPS:
114			ad.Nested(func(nad *netlink.AttributeDecoder) error {
115				f.Groups = parseMulticastGroups(nad)
116				return nil
117			})
118		}
119	}
120
121	return ad.Err()
122}
123
124// parseMulticastGroups parses an array of multicast group nested attributes
125// into a slice of MulticastGroups.
126func parseMulticastGroups(ad *netlink.AttributeDecoder) []MulticastGroup {
127	groups := make([]MulticastGroup, 0, ad.Len())
128	for ad.Next() {
129		ad.Nested(func(nad *netlink.AttributeDecoder) error {
130			var g MulticastGroup
131			for nad.Next() {
132				switch nad.Type() {
133				case unix.CTRL_ATTR_MCAST_GRP_NAME:
134					g.Name = nad.String()
135				case unix.CTRL_ATTR_MCAST_GRP_ID:
136					g.ID = nad.Uint32()
137				}
138			}
139
140			groups = append(groups, g)
141			return nil
142		})
143	}
144
145	return groups
146}
147