1package netlink
2
3import (
4	"syscall"
5
6	"github.com/vishvananda/netlink/nl"
7)
8
9func selFromPolicy(sel *nl.XfrmSelector, policy *XfrmPolicy) {
10	sel.Family = uint16(nl.GetIPFamily(policy.Dst.IP))
11	sel.Daddr.FromIP(policy.Dst.IP)
12	sel.Saddr.FromIP(policy.Src.IP)
13	prefixlenD, _ := policy.Dst.Mask.Size()
14	sel.PrefixlenD = uint8(prefixlenD)
15	prefixlenS, _ := policy.Src.Mask.Size()
16	sel.PrefixlenS = uint8(prefixlenS)
17}
18
19// XfrmPolicyAdd will add an xfrm policy to the system.
20// Equivalent to: `ip xfrm policy add $policy`
21func XfrmPolicyAdd(policy *XfrmPolicy) error {
22	req := nl.NewNetlinkRequest(nl.XFRM_MSG_NEWPOLICY, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
23
24	msg := &nl.XfrmUserpolicyInfo{}
25	selFromPolicy(&msg.Sel, policy)
26	msg.Priority = uint32(policy.Priority)
27	msg.Index = uint32(policy.Index)
28	msg.Dir = uint8(policy.Dir)
29	msg.Lft.SoftByteLimit = nl.XFRM_INF
30	msg.Lft.HardByteLimit = nl.XFRM_INF
31	msg.Lft.SoftPacketLimit = nl.XFRM_INF
32	msg.Lft.HardPacketLimit = nl.XFRM_INF
33	req.AddData(msg)
34
35	tmplData := make([]byte, nl.SizeofXfrmUserTmpl*len(policy.Tmpls))
36	for i, tmpl := range policy.Tmpls {
37		start := i * nl.SizeofXfrmUserTmpl
38		userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
39		userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
40		userTmpl.Saddr.FromIP(tmpl.Src)
41		userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
42		userTmpl.Mode = uint8(tmpl.Mode)
43		userTmpl.Reqid = uint32(tmpl.Reqid)
44		userTmpl.Aalgos = ^uint32(0)
45		userTmpl.Ealgos = ^uint32(0)
46		userTmpl.Calgos = ^uint32(0)
47	}
48	if len(tmplData) > 0 {
49		tmpls := nl.NewRtAttr(nl.XFRMA_TMPL, tmplData)
50		req.AddData(tmpls)
51	}
52
53	_, err := req.Execute(syscall.NETLINK_XFRM, 0)
54	return err
55}
56
57// XfrmPolicyDel will delete an xfrm policy from the system. Note that
58// the Tmpls are ignored when matching the policy to delete.
59// Equivalent to: `ip xfrm policy del $policy`
60func XfrmPolicyDel(policy *XfrmPolicy) error {
61	req := nl.NewNetlinkRequest(nl.XFRM_MSG_DELPOLICY, syscall.NLM_F_ACK)
62
63	msg := &nl.XfrmUserpolicyId{}
64	selFromPolicy(&msg.Sel, policy)
65	msg.Index = uint32(policy.Index)
66	msg.Dir = uint8(policy.Dir)
67	req.AddData(msg)
68
69	_, err := req.Execute(syscall.NETLINK_XFRM, 0)
70	return err
71}
72
73// XfrmPolicyList gets a list of xfrm policies in the system.
74// Equivalent to: `ip xfrm policy show`.
75// The list can be filtered by ip family.
76func XfrmPolicyList(family int) ([]XfrmPolicy, error) {
77	req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETPOLICY, syscall.NLM_F_DUMP)
78
79	msg := nl.NewIfInfomsg(family)
80	req.AddData(msg)
81
82	msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWPOLICY)
83	if err != nil {
84		return nil, err
85	}
86
87	var res []XfrmPolicy
88	for _, m := range msgs {
89		msg := nl.DeserializeXfrmUserpolicyInfo(m)
90
91		if family != FAMILY_ALL && family != int(msg.Sel.Family) {
92			continue
93		}
94
95		var policy XfrmPolicy
96
97		policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
98		policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
99		policy.Priority = int(msg.Priority)
100		policy.Index = int(msg.Index)
101		policy.Dir = Dir(msg.Dir)
102
103		attrs, err := nl.ParseRouteAttr(m[msg.Len():])
104		if err != nil {
105			return nil, err
106		}
107
108		for _, attr := range attrs {
109			switch attr.Attr.Type {
110			case nl.XFRMA_TMPL:
111				max := len(attr.Value)
112				for i := 0; i < max; i += nl.SizeofXfrmUserTmpl {
113					var resTmpl XfrmPolicyTmpl
114					tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl])
115					resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP()
116					resTmpl.Src = tmpl.Saddr.ToIP()
117					resTmpl.Proto = Proto(tmpl.XfrmId.Proto)
118					resTmpl.Mode = Mode(tmpl.Mode)
119					resTmpl.Reqid = int(tmpl.Reqid)
120					policy.Tmpls = append(policy.Tmpls, resTmpl)
121				}
122			}
123		}
124		res = append(res, policy)
125	}
126	return res, nil
127}
128