1package netlink
2
3import (
4	"fmt"
5	"net"
6	"syscall"
7
8	"github.com/vishvananda/netlink/nl"
9)
10
11// RuleAdd adds a rule to the system.
12// Equivalent to: ip rule add
13func RuleAdd(rule *Rule) error {
14	return pkgHandle.RuleAdd(rule)
15}
16
17// RuleAdd adds a rule to the system.
18// Equivalent to: ip rule add
19func (h *Handle) RuleAdd(rule *Rule) error {
20	req := h.newNetlinkRequest(syscall.RTM_NEWRULE, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
21	return ruleHandle(rule, req)
22}
23
24// RuleDel deletes a rule from the system.
25// Equivalent to: ip rule del
26func RuleDel(rule *Rule) error {
27	return pkgHandle.RuleDel(rule)
28}
29
30// RuleDel deletes a rule from the system.
31// Equivalent to: ip rule del
32func (h *Handle) RuleDel(rule *Rule) error {
33	req := h.newNetlinkRequest(syscall.RTM_DELRULE, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
34	return ruleHandle(rule, req)
35}
36
37func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
38	msg := nl.NewRtMsg()
39	msg.Family = syscall.AF_INET
40	if rule.Family != 0 {
41		msg.Family = uint8(rule.Family)
42	}
43	var dstFamily uint8
44
45	var rtAttrs []*nl.RtAttr
46	if rule.Dst != nil && rule.Dst.IP != nil {
47		dstLen, _ := rule.Dst.Mask.Size()
48		msg.Dst_len = uint8(dstLen)
49		msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
50		dstFamily = msg.Family
51		var dstData []byte
52		if msg.Family == syscall.AF_INET {
53			dstData = rule.Dst.IP.To4()
54		} else {
55			dstData = rule.Dst.IP.To16()
56		}
57		rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_DST, dstData))
58	}
59
60	if rule.Src != nil && rule.Src.IP != nil {
61		msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
62		if dstFamily != 0 && dstFamily != msg.Family {
63			return fmt.Errorf("source and destination ip are not the same IP family")
64		}
65		srcLen, _ := rule.Src.Mask.Size()
66		msg.Src_len = uint8(srcLen)
67		var srcData []byte
68		if msg.Family == syscall.AF_INET {
69			srcData = rule.Src.IP.To4()
70		} else {
71			srcData = rule.Src.IP.To16()
72		}
73		rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_SRC, srcData))
74	}
75
76	if rule.Table >= 0 {
77		msg.Table = uint8(rule.Table)
78		if rule.Table >= 256 {
79			msg.Table = syscall.RT_TABLE_UNSPEC
80		}
81	}
82
83	req.AddData(msg)
84	for i := range rtAttrs {
85		req.AddData(rtAttrs[i])
86	}
87
88	native := nl.NativeEndian()
89
90	if rule.Priority >= 0 {
91		b := make([]byte, 4)
92		native.PutUint32(b, uint32(rule.Priority))
93		req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
94	}
95	if rule.Mark >= 0 {
96		b := make([]byte, 4)
97		native.PutUint32(b, uint32(rule.Mark))
98		req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
99	}
100	if rule.Mask >= 0 {
101		b := make([]byte, 4)
102		native.PutUint32(b, uint32(rule.Mask))
103		req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
104	}
105	if rule.Flow >= 0 {
106		b := make([]byte, 4)
107		native.PutUint32(b, uint32(rule.Flow))
108		req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
109	}
110	if rule.TunID > 0 {
111		b := make([]byte, 4)
112		native.PutUint32(b, uint32(rule.TunID))
113		req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
114	}
115	if rule.Table >= 256 {
116		b := make([]byte, 4)
117		native.PutUint32(b, uint32(rule.Table))
118		req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
119	}
120	if msg.Table > 0 {
121		if rule.SuppressPrefixlen >= 0 {
122			b := make([]byte, 4)
123			native.PutUint32(b, uint32(rule.SuppressPrefixlen))
124			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
125		}
126		if rule.SuppressIfgroup >= 0 {
127			b := make([]byte, 4)
128			native.PutUint32(b, uint32(rule.SuppressIfgroup))
129			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
130		}
131	}
132	if rule.IifName != "" {
133		req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName)))
134	}
135	if rule.OifName != "" {
136		req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName)))
137	}
138	if rule.Goto >= 0 {
139		msg.Type = nl.FR_ACT_NOP
140		b := make([]byte, 4)
141		native.PutUint32(b, uint32(rule.Goto))
142		req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
143	}
144
145	_, err := req.Execute(syscall.NETLINK_ROUTE, 0)
146	return err
147}
148
149// RuleList lists rules in the system.
150// Equivalent to: ip rule list
151func RuleList(family int) ([]Rule, error) {
152	return pkgHandle.RuleList(family)
153}
154
155// RuleList lists rules in the system.
156// Equivalent to: ip rule list
157func (h *Handle) RuleList(family int) ([]Rule, error) {
158	req := h.newNetlinkRequest(syscall.RTM_GETRULE, syscall.NLM_F_DUMP|syscall.NLM_F_REQUEST)
159	msg := nl.NewIfInfomsg(family)
160	req.AddData(msg)
161
162	msgs, err := req.Execute(syscall.NETLINK_ROUTE, syscall.RTM_NEWRULE)
163	if err != nil {
164		return nil, err
165	}
166
167	native := nl.NativeEndian()
168	var res = make([]Rule, 0)
169	for i := range msgs {
170		msg := nl.DeserializeRtMsg(msgs[i])
171		attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
172		if err != nil {
173			return nil, err
174		}
175
176		rule := NewRule()
177
178		for j := range attrs {
179			switch attrs[j].Attr.Type {
180			case syscall.RTA_TABLE:
181				rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
182			case nl.FRA_SRC:
183				rule.Src = &net.IPNet{
184					IP:   attrs[j].Value,
185					Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
186				}
187			case nl.FRA_DST:
188				rule.Dst = &net.IPNet{
189					IP:   attrs[j].Value,
190					Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
191				}
192			case nl.FRA_FWMARK:
193				rule.Mark = int(native.Uint32(attrs[j].Value[0:4]))
194			case nl.FRA_FWMASK:
195				rule.Mask = int(native.Uint32(attrs[j].Value[0:4]))
196			case nl.FRA_TUN_ID:
197				rule.TunID = uint(native.Uint64(attrs[j].Value[0:4]))
198			case nl.FRA_IIFNAME:
199				rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
200			case nl.FRA_OIFNAME:
201				rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
202			case nl.FRA_SUPPRESS_PREFIXLEN:
203				i := native.Uint32(attrs[j].Value[0:4])
204				if i != 0xffffffff {
205					rule.SuppressPrefixlen = int(i)
206				}
207			case nl.FRA_SUPPRESS_IFGROUP:
208				i := native.Uint32(attrs[j].Value[0:4])
209				if i != 0xffffffff {
210					rule.SuppressIfgroup = int(i)
211				}
212			case nl.FRA_FLOW:
213				rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
214			case nl.FRA_GOTO:
215				rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
216			case nl.FRA_PRIORITY:
217				rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
218			}
219		}
220		res = append(res, *rule)
221	}
222
223	return res, nil
224}
225