1// +build linux
2
3package netlink
4
5import (
6	"encoding/binary"
7	"errors"
8
9	"github.com/vishvananda/netlink/nl"
10	"golang.org/x/sys/unix"
11)
12
13const (
14	FOU_GENL_NAME = "fou"
15)
16
17const (
18	FOU_CMD_UNSPEC uint8 = iota
19	FOU_CMD_ADD
20	FOU_CMD_DEL
21	FOU_CMD_GET
22	FOU_CMD_MAX = FOU_CMD_GET
23)
24
25const (
26	FOU_ATTR_UNSPEC = iota
27	FOU_ATTR_PORT
28	FOU_ATTR_AF
29	FOU_ATTR_IPPROTO
30	FOU_ATTR_TYPE
31	FOU_ATTR_REMCSUM_NOPARTIAL
32	FOU_ATTR_MAX = FOU_ATTR_REMCSUM_NOPARTIAL
33)
34
35const (
36	FOU_ENCAP_UNSPEC = iota
37	FOU_ENCAP_DIRECT
38	FOU_ENCAP_GUE
39	FOU_ENCAP_MAX = FOU_ENCAP_GUE
40)
41
42var fouFamilyId int
43
44func FouFamilyId() (int, error) {
45	if fouFamilyId != 0 {
46		return fouFamilyId, nil
47	}
48
49	fam, err := GenlFamilyGet(FOU_GENL_NAME)
50	if err != nil {
51		return -1, err
52	}
53
54	fouFamilyId = int(fam.ID)
55	return fouFamilyId, nil
56}
57
58func FouAdd(f Fou) error {
59	return pkgHandle.FouAdd(f)
60}
61
62func (h *Handle) FouAdd(f Fou) error {
63	fam_id, err := FouFamilyId()
64	if err != nil {
65		return err
66	}
67
68	// setting ip protocol conflicts with encapsulation type GUE
69	if f.EncapType == FOU_ENCAP_GUE && f.Protocol != 0 {
70		return errors.New("GUE encapsulation doesn't specify an IP protocol")
71	}
72
73	req := h.newNetlinkRequest(fam_id, unix.NLM_F_ACK)
74
75	// int to byte for port
76	bp := make([]byte, 2)
77	binary.BigEndian.PutUint16(bp[0:2], uint16(f.Port))
78
79	attrs := []*nl.RtAttr{
80		nl.NewRtAttr(FOU_ATTR_PORT, bp),
81		nl.NewRtAttr(FOU_ATTR_TYPE, []byte{uint8(f.EncapType)}),
82		nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}),
83		nl.NewRtAttr(FOU_ATTR_IPPROTO, []byte{uint8(f.Protocol)}),
84	}
85	raw := []byte{FOU_CMD_ADD, 1, 0, 0}
86	for _, a := range attrs {
87		raw = append(raw, a.Serialize()...)
88	}
89
90	req.AddRawData(raw)
91
92	_, err = req.Execute(unix.NETLINK_GENERIC, 0)
93	return err
94}
95
96func FouDel(f Fou) error {
97	return pkgHandle.FouDel(f)
98}
99
100func (h *Handle) FouDel(f Fou) error {
101	fam_id, err := FouFamilyId()
102	if err != nil {
103		return err
104	}
105
106	req := h.newNetlinkRequest(fam_id, unix.NLM_F_ACK)
107
108	// int to byte for port
109	bp := make([]byte, 2)
110	binary.BigEndian.PutUint16(bp[0:2], uint16(f.Port))
111
112	attrs := []*nl.RtAttr{
113		nl.NewRtAttr(FOU_ATTR_PORT, bp),
114		nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(f.Family)}),
115	}
116	raw := []byte{FOU_CMD_DEL, 1, 0, 0}
117	for _, a := range attrs {
118		raw = append(raw, a.Serialize()...)
119	}
120
121	req.AddRawData(raw)
122
123	_, err = req.Execute(unix.NETLINK_GENERIC, 0)
124	if err != nil {
125		return err
126	}
127
128	return nil
129}
130
131func FouList(fam int) ([]Fou, error) {
132	return pkgHandle.FouList(fam)
133}
134
135func (h *Handle) FouList(fam int) ([]Fou, error) {
136	fam_id, err := FouFamilyId()
137	if err != nil {
138		return nil, err
139	}
140
141	req := h.newNetlinkRequest(fam_id, unix.NLM_F_DUMP)
142
143	attrs := []*nl.RtAttr{
144		nl.NewRtAttr(FOU_ATTR_AF, []byte{uint8(fam)}),
145	}
146	raw := []byte{FOU_CMD_GET, 1, 0, 0}
147	for _, a := range attrs {
148		raw = append(raw, a.Serialize()...)
149	}
150
151	req.AddRawData(raw)
152
153	msgs, err := req.Execute(unix.NETLINK_GENERIC, 0)
154	if err != nil {
155		return nil, err
156	}
157
158	fous := make([]Fou, 0, len(msgs))
159	for _, m := range msgs {
160		f, err := deserializeFouMsg(m)
161		if err != nil {
162			return fous, err
163		}
164
165		fous = append(fous, f)
166	}
167
168	return fous, nil
169}
170
171func deserializeFouMsg(msg []byte) (Fou, error) {
172	// we'll skip to byte 4 to first attribute
173	msg = msg[3:]
174	var shift int
175	fou := Fou{}
176
177	for {
178		// attribute header is at least 16 bits
179		if len(msg) < 4 {
180			return fou, ErrAttrHeaderTruncated
181		}
182
183		lgt := int(binary.BigEndian.Uint16(msg[0:2]))
184		if len(msg) < lgt+4 {
185			return fou, ErrAttrBodyTruncated
186		}
187		attr := binary.BigEndian.Uint16(msg[2:4])
188
189		shift = lgt + 3
190		switch attr {
191		case FOU_ATTR_AF:
192			fou.Family = int(msg[5])
193		case FOU_ATTR_PORT:
194			fou.Port = int(binary.BigEndian.Uint16(msg[5:7]))
195			// port is 2 bytes
196			shift = lgt + 2
197		case FOU_ATTR_IPPROTO:
198			fou.Protocol = int(msg[5])
199		case FOU_ATTR_TYPE:
200			fou.EncapType = int(msg[5])
201		}
202
203		msg = msg[shift:]
204
205		if len(msg) < 4 {
206			break
207		}
208	}
209
210	return fou, nil
211}
212