1package netlink
2
3import (
4	"fmt"
5	"net"
6
7	"github.com/vishvananda/netlink/nl"
8	"golang.org/x/sys/unix"
9)
10
11const FibRuleInvert = 0x2
12
13// RuleAdd adds a rule to the system.
14// Equivalent to: ip rule add
15func RuleAdd(rule *Rule) error {
16	return pkgHandle.RuleAdd(rule)
17}
18
19// RuleAdd adds a rule to the system.
20// Equivalent to: ip rule add
21func (h *Handle) RuleAdd(rule *Rule) error {
22	req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
23	return ruleHandle(rule, req)
24}
25
26// RuleDel deletes a rule from the system.
27// Equivalent to: ip rule del
28func RuleDel(rule *Rule) error {
29	return pkgHandle.RuleDel(rule)
30}
31
32// RuleDel deletes a rule from the system.
33// Equivalent to: ip rule del
34func (h *Handle) RuleDel(rule *Rule) error {
35	req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
36	return ruleHandle(rule, req)
37}
38
39func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
40	msg := nl.NewRtMsg()
41	msg.Family = unix.AF_INET
42	msg.Protocol = unix.RTPROT_BOOT
43	msg.Scope = unix.RT_SCOPE_UNIVERSE
44	msg.Table = unix.RT_TABLE_UNSPEC
45	msg.Type = unix.RTN_UNSPEC
46	if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
47		msg.Type = unix.RTN_UNICAST
48	}
49	if rule.Invert {
50		msg.Flags |= FibRuleInvert
51	}
52	if rule.Family != 0 {
53		msg.Family = uint8(rule.Family)
54	}
55	if rule.Table >= 0 && rule.Table < 256 {
56		msg.Table = uint8(rule.Table)
57	}
58
59	var dstFamily uint8
60	var rtAttrs []*nl.RtAttr
61	if rule.Dst != nil && rule.Dst.IP != nil {
62		dstLen, _ := rule.Dst.Mask.Size()
63		msg.Dst_len = uint8(dstLen)
64		msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
65		dstFamily = msg.Family
66		var dstData []byte
67		if msg.Family == unix.AF_INET {
68			dstData = rule.Dst.IP.To4()
69		} else {
70			dstData = rule.Dst.IP.To16()
71		}
72		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
73	}
74
75	if rule.Src != nil && rule.Src.IP != nil {
76		msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
77		if dstFamily != 0 && dstFamily != msg.Family {
78			return fmt.Errorf("source and destination ip are not the same IP family")
79		}
80		srcLen, _ := rule.Src.Mask.Size()
81		msg.Src_len = uint8(srcLen)
82		var srcData []byte
83		if msg.Family == unix.AF_INET {
84			srcData = rule.Src.IP.To4()
85		} else {
86			srcData = rule.Src.IP.To16()
87		}
88		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
89	}
90
91	req.AddData(msg)
92	for i := range rtAttrs {
93		req.AddData(rtAttrs[i])
94	}
95
96	native := nl.NativeEndian()
97
98	if rule.Priority >= 0 {
99		b := make([]byte, 4)
100		native.PutUint32(b, uint32(rule.Priority))
101		req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
102	}
103	if rule.Mark >= 0 {
104		b := make([]byte, 4)
105		native.PutUint32(b, uint32(rule.Mark))
106		req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
107	}
108	if rule.Mask >= 0 {
109		b := make([]byte, 4)
110		native.PutUint32(b, uint32(rule.Mask))
111		req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
112	}
113	if rule.Flow >= 0 {
114		b := make([]byte, 4)
115		native.PutUint32(b, uint32(rule.Flow))
116		req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
117	}
118	if rule.TunID > 0 {
119		b := make([]byte, 4)
120		native.PutUint32(b, uint32(rule.TunID))
121		req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
122	}
123	if rule.Table >= 256 {
124		b := make([]byte, 4)
125		native.PutUint32(b, uint32(rule.Table))
126		req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
127	}
128	if msg.Table > 0 {
129		if rule.SuppressPrefixlen >= 0 {
130			b := make([]byte, 4)
131			native.PutUint32(b, uint32(rule.SuppressPrefixlen))
132			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
133		}
134		if rule.SuppressIfgroup >= 0 {
135			b := make([]byte, 4)
136			native.PutUint32(b, uint32(rule.SuppressIfgroup))
137			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
138		}
139	}
140	if rule.IifName != "" {
141		req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName)))
142	}
143	if rule.OifName != "" {
144		req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName)))
145	}
146	if rule.Goto >= 0 {
147		msg.Type = nl.FR_ACT_GOTO
148		b := make([]byte, 4)
149		native.PutUint32(b, uint32(rule.Goto))
150		req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
151	}
152
153	_, err := req.Execute(unix.NETLINK_ROUTE, 0)
154	return err
155}
156
157// RuleList lists rules in the system.
158// Equivalent to: ip rule list
159func RuleList(family int) ([]Rule, error) {
160	return pkgHandle.RuleList(family)
161}
162
163// RuleList lists rules in the system.
164// Equivalent to: ip rule list
165func (h *Handle) RuleList(family int) ([]Rule, error) {
166	req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
167	msg := nl.NewIfInfomsg(family)
168	req.AddData(msg)
169
170	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
171	if err != nil {
172		return nil, err
173	}
174
175	native := nl.NativeEndian()
176	var res = make([]Rule, 0)
177	for i := range msgs {
178		msg := nl.DeserializeRtMsg(msgs[i])
179		attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
180		if err != nil {
181			return nil, err
182		}
183
184		rule := NewRule()
185
186		rule.Invert = msg.Flags&FibRuleInvert > 0
187
188		for j := range attrs {
189			switch attrs[j].Attr.Type {
190			case unix.RTA_TABLE:
191				rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
192			case nl.FRA_SRC:
193				rule.Src = &net.IPNet{
194					IP:   attrs[j].Value,
195					Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
196				}
197			case nl.FRA_DST:
198				rule.Dst = &net.IPNet{
199					IP:   attrs[j].Value,
200					Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
201				}
202			case nl.FRA_FWMARK:
203				rule.Mark = int(native.Uint32(attrs[j].Value[0:4]))
204			case nl.FRA_FWMASK:
205				rule.Mask = int(native.Uint32(attrs[j].Value[0:4]))
206			case nl.FRA_TUN_ID:
207				rule.TunID = uint(native.Uint64(attrs[j].Value[0:4]))
208			case nl.FRA_IIFNAME:
209				rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
210			case nl.FRA_OIFNAME:
211				rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
212			case nl.FRA_SUPPRESS_PREFIXLEN:
213				i := native.Uint32(attrs[j].Value[0:4])
214				if i != 0xffffffff {
215					rule.SuppressPrefixlen = int(i)
216				}
217			case nl.FRA_SUPPRESS_IFGROUP:
218				i := native.Uint32(attrs[j].Value[0:4])
219				if i != 0xffffffff {
220					rule.SuppressIfgroup = int(i)
221				}
222			case nl.FRA_FLOW:
223				rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
224			case nl.FRA_GOTO:
225				rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
226			case nl.FRA_PRIORITY:
227				rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
228			}
229		}
230		res = append(res, *rule)
231	}
232
233	return res, nil
234}
235