1package netlink
2
3import (
4	"github.com/vishvananda/netlink/nl"
5	"golang.org/x/sys/unix"
6)
7
8func selFromPolicy(sel *nl.XfrmSelector, policy *XfrmPolicy) {
9	sel.Family = uint16(nl.FAMILY_V4)
10	if policy.Dst != nil {
11		sel.Family = uint16(nl.GetIPFamily(policy.Dst.IP))
12		sel.Daddr.FromIP(policy.Dst.IP)
13		prefixlenD, _ := policy.Dst.Mask.Size()
14		sel.PrefixlenD = uint8(prefixlenD)
15	}
16	if policy.Src != nil {
17		sel.Saddr.FromIP(policy.Src.IP)
18		prefixlenS, _ := policy.Src.Mask.Size()
19		sel.PrefixlenS = uint8(prefixlenS)
20	}
21	sel.Proto = uint8(policy.Proto)
22	sel.Dport = nl.Swap16(uint16(policy.DstPort))
23	sel.Sport = nl.Swap16(uint16(policy.SrcPort))
24	if sel.Dport != 0 {
25		sel.DportMask = ^uint16(0)
26	}
27	if sel.Sport != 0 {
28		sel.SportMask = ^uint16(0)
29	}
30	sel.Ifindex = int32(policy.Ifindex)
31}
32
33// XfrmPolicyAdd will add an xfrm policy to the system.
34// Equivalent to: `ip xfrm policy add $policy`
35func XfrmPolicyAdd(policy *XfrmPolicy) error {
36	return pkgHandle.XfrmPolicyAdd(policy)
37}
38
39// XfrmPolicyAdd will add an xfrm policy to the system.
40// Equivalent to: `ip xfrm policy add $policy`
41func (h *Handle) XfrmPolicyAdd(policy *XfrmPolicy) error {
42	return h.xfrmPolicyAddOrUpdate(policy, nl.XFRM_MSG_NEWPOLICY)
43}
44
45// XfrmPolicyUpdate will update an xfrm policy to the system.
46// Equivalent to: `ip xfrm policy update $policy`
47func XfrmPolicyUpdate(policy *XfrmPolicy) error {
48	return pkgHandle.XfrmPolicyUpdate(policy)
49}
50
51// XfrmPolicyUpdate will update an xfrm policy to the system.
52// Equivalent to: `ip xfrm policy update $policy`
53func (h *Handle) XfrmPolicyUpdate(policy *XfrmPolicy) error {
54	return h.xfrmPolicyAddOrUpdate(policy, nl.XFRM_MSG_UPDPOLICY)
55}
56
57func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error {
58	req := h.newNetlinkRequest(nlProto, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
59
60	msg := &nl.XfrmUserpolicyInfo{}
61	selFromPolicy(&msg.Sel, policy)
62	msg.Priority = uint32(policy.Priority)
63	msg.Index = uint32(policy.Index)
64	msg.Dir = uint8(policy.Dir)
65	msg.Action = uint8(policy.Action)
66	msg.Lft.SoftByteLimit = nl.XFRM_INF
67	msg.Lft.HardByteLimit = nl.XFRM_INF
68	msg.Lft.SoftPacketLimit = nl.XFRM_INF
69	msg.Lft.HardPacketLimit = nl.XFRM_INF
70	req.AddData(msg)
71
72	tmplData := make([]byte, nl.SizeofXfrmUserTmpl*len(policy.Tmpls))
73	for i, tmpl := range policy.Tmpls {
74		start := i * nl.SizeofXfrmUserTmpl
75		userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
76		userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
77		userTmpl.Saddr.FromIP(tmpl.Src)
78		userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
79		userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
80		userTmpl.Mode = uint8(tmpl.Mode)
81		userTmpl.Reqid = uint32(tmpl.Reqid)
82		userTmpl.Aalgos = ^uint32(0)
83		userTmpl.Ealgos = ^uint32(0)
84		userTmpl.Calgos = ^uint32(0)
85	}
86	if len(tmplData) > 0 {
87		tmpls := nl.NewRtAttr(nl.XFRMA_TMPL, tmplData)
88		req.AddData(tmpls)
89	}
90	if policy.Mark != nil {
91		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(policy.Mark))
92		req.AddData(out)
93	}
94
95	ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(policy.Ifid)))
96	req.AddData(ifId)
97
98	_, err := req.Execute(unix.NETLINK_XFRM, 0)
99	return err
100}
101
102// XfrmPolicyDel will delete an xfrm policy from the system. Note that
103// the Tmpls are ignored when matching the policy to delete.
104// Equivalent to: `ip xfrm policy del $policy`
105func XfrmPolicyDel(policy *XfrmPolicy) error {
106	return pkgHandle.XfrmPolicyDel(policy)
107}
108
109// XfrmPolicyDel will delete an xfrm policy from the system. Note that
110// the Tmpls are ignored when matching the policy to delete.
111// Equivalent to: `ip xfrm policy del $policy`
112func (h *Handle) XfrmPolicyDel(policy *XfrmPolicy) error {
113	_, err := h.xfrmPolicyGetOrDelete(policy, nl.XFRM_MSG_DELPOLICY)
114	return err
115}
116
117// XfrmPolicyList gets a list of xfrm policies in the system.
118// Equivalent to: `ip xfrm policy show`.
119// The list can be filtered by ip family.
120func XfrmPolicyList(family int) ([]XfrmPolicy, error) {
121	return pkgHandle.XfrmPolicyList(family)
122}
123
124// XfrmPolicyList gets a list of xfrm policies in the system.
125// Equivalent to: `ip xfrm policy show`.
126// The list can be filtered by ip family.
127func (h *Handle) XfrmPolicyList(family int) ([]XfrmPolicy, error) {
128	req := h.newNetlinkRequest(nl.XFRM_MSG_GETPOLICY, unix.NLM_F_DUMP)
129
130	msg := nl.NewIfInfomsg(family)
131	req.AddData(msg)
132
133	msgs, err := req.Execute(unix.NETLINK_XFRM, nl.XFRM_MSG_NEWPOLICY)
134	if err != nil {
135		return nil, err
136	}
137
138	var res []XfrmPolicy
139	for _, m := range msgs {
140		if policy, err := parseXfrmPolicy(m, family); err == nil {
141			res = append(res, *policy)
142		} else if err == familyError {
143			continue
144		} else {
145			return nil, err
146		}
147	}
148	return res, nil
149}
150
151// XfrmPolicyGet gets a the policy described by the index or selector, if found.
152// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
153func XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
154	return pkgHandle.XfrmPolicyGet(policy)
155}
156
157// XfrmPolicyGet gets a the policy described by the index or selector, if found.
158// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
159func (h *Handle) XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
160	return h.xfrmPolicyGetOrDelete(policy, nl.XFRM_MSG_GETPOLICY)
161}
162
163// XfrmPolicyFlush will flush the policies on the system.
164// Equivalent to: `ip xfrm policy flush`
165func XfrmPolicyFlush() error {
166	return pkgHandle.XfrmPolicyFlush()
167}
168
169// XfrmPolicyFlush will flush the policies on the system.
170// Equivalent to: `ip xfrm policy flush`
171func (h *Handle) XfrmPolicyFlush() error {
172	req := h.newNetlinkRequest(nl.XFRM_MSG_FLUSHPOLICY, unix.NLM_F_ACK)
173	_, err := req.Execute(unix.NETLINK_XFRM, 0)
174	return err
175}
176
177func (h *Handle) xfrmPolicyGetOrDelete(policy *XfrmPolicy, nlProto int) (*XfrmPolicy, error) {
178	req := h.newNetlinkRequest(nlProto, unix.NLM_F_ACK)
179
180	msg := &nl.XfrmUserpolicyId{}
181	selFromPolicy(&msg.Sel, policy)
182	msg.Index = uint32(policy.Index)
183	msg.Dir = uint8(policy.Dir)
184	req.AddData(msg)
185
186	if policy.Mark != nil {
187		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(policy.Mark))
188		req.AddData(out)
189	}
190
191	ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(policy.Ifid)))
192	req.AddData(ifId)
193
194	resType := nl.XFRM_MSG_NEWPOLICY
195	if nlProto == nl.XFRM_MSG_DELPOLICY {
196		resType = 0
197	}
198
199	msgs, err := req.Execute(unix.NETLINK_XFRM, uint16(resType))
200	if err != nil {
201		return nil, err
202	}
203
204	if nlProto == nl.XFRM_MSG_DELPOLICY {
205		return nil, err
206	}
207
208	return parseXfrmPolicy(msgs[0], FAMILY_ALL)
209}
210
211func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
212	msg := nl.DeserializeXfrmUserpolicyInfo(m)
213
214	// This is mainly for the policy dump
215	if family != FAMILY_ALL && family != int(msg.Sel.Family) {
216		return nil, familyError
217	}
218
219	var policy XfrmPolicy
220
221	policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
222	policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
223	policy.Proto = Proto(msg.Sel.Proto)
224	policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
225	policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
226	policy.Ifindex = int(msg.Sel.Ifindex)
227	policy.Priority = int(msg.Priority)
228	policy.Index = int(msg.Index)
229	policy.Dir = Dir(msg.Dir)
230	policy.Action = PolicyAction(msg.Action)
231
232	attrs, err := nl.ParseRouteAttr(m[msg.Len():])
233	if err != nil {
234		return nil, err
235	}
236
237	for _, attr := range attrs {
238		switch attr.Attr.Type {
239		case nl.XFRMA_TMPL:
240			max := len(attr.Value)
241			for i := 0; i < max; i += nl.SizeofXfrmUserTmpl {
242				var resTmpl XfrmPolicyTmpl
243				tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl])
244				resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP()
245				resTmpl.Src = tmpl.Saddr.ToIP()
246				resTmpl.Proto = Proto(tmpl.XfrmId.Proto)
247				resTmpl.Mode = Mode(tmpl.Mode)
248				resTmpl.Spi = int(nl.Swap32(tmpl.XfrmId.Spi))
249				resTmpl.Reqid = int(tmpl.Reqid)
250				policy.Tmpls = append(policy.Tmpls, resTmpl)
251			}
252		case nl.XFRMA_MARK:
253			mark := nl.DeserializeXfrmMark(attr.Value[:])
254			policy.Mark = new(XfrmMark)
255			policy.Mark.Value = mark.Value
256			policy.Mark.Mask = mark.Mask
257		case nl.XFRMA_IF_ID:
258			policy.Ifid = int(native.Uint32(attr.Value))
259		}
260	}
261
262	return &policy, nil
263}
264