1package netlink
2
3import (
4	"fmt"
5	"net"
6	"strings"
7	"syscall"
8
9	"github.com/vishvananda/netlink/nl"
10	"github.com/vishvananda/netns"
11	"golang.org/x/sys/unix"
12)
13
14// RtAttr is shared so it is in netlink_linux.go
15
16const (
17	SCOPE_UNIVERSE Scope = unix.RT_SCOPE_UNIVERSE
18	SCOPE_SITE     Scope = unix.RT_SCOPE_SITE
19	SCOPE_LINK     Scope = unix.RT_SCOPE_LINK
20	SCOPE_HOST     Scope = unix.RT_SCOPE_HOST
21	SCOPE_NOWHERE  Scope = unix.RT_SCOPE_NOWHERE
22)
23
24const (
25	RT_FILTER_PROTOCOL uint64 = 1 << (1 + iota)
26	RT_FILTER_SCOPE
27	RT_FILTER_TYPE
28	RT_FILTER_TOS
29	RT_FILTER_IIF
30	RT_FILTER_OIF
31	RT_FILTER_DST
32	RT_FILTER_SRC
33	RT_FILTER_GW
34	RT_FILTER_TABLE
35	RT_FILTER_HOPLIMIT
36)
37
38const (
39	FLAG_ONLINK    NextHopFlag = unix.RTNH_F_ONLINK
40	FLAG_PERVASIVE NextHopFlag = unix.RTNH_F_PERVASIVE
41)
42
43var testFlags = []flagString{
44	{f: FLAG_ONLINK, s: "onlink"},
45	{f: FLAG_PERVASIVE, s: "pervasive"},
46}
47
48func listFlags(flag int) []string {
49	var flags []string
50	for _, tf := range testFlags {
51		if flag&int(tf.f) != 0 {
52			flags = append(flags, tf.s)
53		}
54	}
55	return flags
56}
57
58func (r *Route) ListFlags() []string {
59	return listFlags(r.Flags)
60}
61
62func (n *NexthopInfo) ListFlags() []string {
63	return listFlags(n.Flags)
64}
65
66type MPLSDestination struct {
67	Labels []int
68}
69
70func (d *MPLSDestination) Family() int {
71	return nl.FAMILY_MPLS
72}
73
74func (d *MPLSDestination) Decode(buf []byte) error {
75	d.Labels = nl.DecodeMPLSStack(buf)
76	return nil
77}
78
79func (d *MPLSDestination) Encode() ([]byte, error) {
80	return nl.EncodeMPLSStack(d.Labels...), nil
81}
82
83func (d *MPLSDestination) String() string {
84	s := make([]string, 0, len(d.Labels))
85	for _, l := range d.Labels {
86		s = append(s, fmt.Sprintf("%d", l))
87	}
88	return strings.Join(s, "/")
89}
90
91func (d *MPLSDestination) Equal(x Destination) bool {
92	o, ok := x.(*MPLSDestination)
93	if !ok {
94		return false
95	}
96	if d == nil && o == nil {
97		return true
98	}
99	if d == nil || o == nil {
100		return false
101	}
102	if d.Labels == nil && o.Labels == nil {
103		return true
104	}
105	if d.Labels == nil || o.Labels == nil {
106		return false
107	}
108	if len(d.Labels) != len(o.Labels) {
109		return false
110	}
111	for i := range d.Labels {
112		if d.Labels[i] != o.Labels[i] {
113			return false
114		}
115	}
116	return true
117}
118
119type MPLSEncap struct {
120	Labels []int
121}
122
123func (e *MPLSEncap) Type() int {
124	return nl.LWTUNNEL_ENCAP_MPLS
125}
126
127func (e *MPLSEncap) Decode(buf []byte) error {
128	if len(buf) < 4 {
129		return fmt.Errorf("lack of bytes")
130	}
131	native := nl.NativeEndian()
132	l := native.Uint16(buf)
133	if len(buf) < int(l) {
134		return fmt.Errorf("lack of bytes")
135	}
136	buf = buf[:l]
137	typ := native.Uint16(buf[2:])
138	if typ != nl.MPLS_IPTUNNEL_DST {
139		return fmt.Errorf("unknown MPLS Encap Type: %d", typ)
140	}
141	e.Labels = nl.DecodeMPLSStack(buf[4:])
142	return nil
143}
144
145func (e *MPLSEncap) Encode() ([]byte, error) {
146	s := nl.EncodeMPLSStack(e.Labels...)
147	native := nl.NativeEndian()
148	hdr := make([]byte, 4)
149	native.PutUint16(hdr, uint16(len(s)+4))
150	native.PutUint16(hdr[2:], nl.MPLS_IPTUNNEL_DST)
151	return append(hdr, s...), nil
152}
153
154func (e *MPLSEncap) String() string {
155	s := make([]string, 0, len(e.Labels))
156	for _, l := range e.Labels {
157		s = append(s, fmt.Sprintf("%d", l))
158	}
159	return strings.Join(s, "/")
160}
161
162func (e *MPLSEncap) Equal(x Encap) bool {
163	o, ok := x.(*MPLSEncap)
164	if !ok {
165		return false
166	}
167	if e == nil && o == nil {
168		return true
169	}
170	if e == nil || o == nil {
171		return false
172	}
173	if e.Labels == nil && o.Labels == nil {
174		return true
175	}
176	if e.Labels == nil || o.Labels == nil {
177		return false
178	}
179	if len(e.Labels) != len(o.Labels) {
180		return false
181	}
182	for i := range e.Labels {
183		if e.Labels[i] != o.Labels[i] {
184			return false
185		}
186	}
187	return true
188}
189
190// SEG6 definitions
191type SEG6Encap struct {
192	Mode     int
193	Segments []net.IP
194}
195
196func (e *SEG6Encap) Type() int {
197	return nl.LWTUNNEL_ENCAP_SEG6
198}
199func (e *SEG6Encap) Decode(buf []byte) error {
200	if len(buf) < 4 {
201		return fmt.Errorf("lack of bytes")
202	}
203	native := nl.NativeEndian()
204	// Get Length(l) & Type(typ) : 2 + 2 bytes
205	l := native.Uint16(buf)
206	if len(buf) < int(l) {
207		return fmt.Errorf("lack of bytes")
208	}
209	buf = buf[:l] // make sure buf size upper limit is Length
210	typ := native.Uint16(buf[2:])
211	// LWTUNNEL_ENCAP_SEG6 has only one attr type SEG6_IPTUNNEL_SRH
212	if typ != nl.SEG6_IPTUNNEL_SRH {
213		return fmt.Errorf("unknown SEG6 Type: %d", typ)
214	}
215
216	var err error
217	e.Mode, e.Segments, err = nl.DecodeSEG6Encap(buf[4:])
218
219	return err
220}
221func (e *SEG6Encap) Encode() ([]byte, error) {
222	s, err := nl.EncodeSEG6Encap(e.Mode, e.Segments)
223	native := nl.NativeEndian()
224	hdr := make([]byte, 4)
225	native.PutUint16(hdr, uint16(len(s)+4))
226	native.PutUint16(hdr[2:], nl.SEG6_IPTUNNEL_SRH)
227	return append(hdr, s...), err
228}
229func (e *SEG6Encap) String() string {
230	segs := make([]string, 0, len(e.Segments))
231	// append segment backwards (from n to 0) since seg#0 is the last segment.
232	for i := len(e.Segments); i > 0; i-- {
233		segs = append(segs, fmt.Sprintf("%s", e.Segments[i-1]))
234	}
235	str := fmt.Sprintf("mode %s segs %d [ %s ]", nl.SEG6EncapModeString(e.Mode),
236		len(e.Segments), strings.Join(segs, " "))
237	return str
238}
239func (e *SEG6Encap) Equal(x Encap) bool {
240	o, ok := x.(*SEG6Encap)
241	if !ok {
242		return false
243	}
244	if e == o {
245		return true
246	}
247	if e == nil || o == nil {
248		return false
249	}
250	if e.Mode != o.Mode {
251		return false
252	}
253	if len(e.Segments) != len(o.Segments) {
254		return false
255	}
256	for i := range e.Segments {
257		if !e.Segments[i].Equal(o.Segments[i]) {
258			return false
259		}
260	}
261	return true
262}
263
264// SEG6LocalEncap definitions
265type SEG6LocalEncap struct {
266	Flags    [nl.SEG6_LOCAL_MAX]bool
267	Action   int
268	Segments []net.IP // from SRH in seg6_local_lwt
269	Table    int      // table id for End.T and End.DT6
270	InAddr   net.IP
271	In6Addr  net.IP
272	Iif      int
273	Oif      int
274}
275
276func (e *SEG6LocalEncap) Type() int {
277	return nl.LWTUNNEL_ENCAP_SEG6_LOCAL
278}
279func (e *SEG6LocalEncap) Decode(buf []byte) error {
280	attrs, err := nl.ParseRouteAttr(buf)
281	if err != nil {
282		return err
283	}
284	native := nl.NativeEndian()
285	for _, attr := range attrs {
286		switch attr.Attr.Type {
287		case nl.SEG6_LOCAL_ACTION:
288			e.Action = int(native.Uint32(attr.Value[0:4]))
289			e.Flags[nl.SEG6_LOCAL_ACTION] = true
290		case nl.SEG6_LOCAL_SRH:
291			e.Segments, err = nl.DecodeSEG6Srh(attr.Value[:])
292			e.Flags[nl.SEG6_LOCAL_SRH] = true
293		case nl.SEG6_LOCAL_TABLE:
294			e.Table = int(native.Uint32(attr.Value[0:4]))
295			e.Flags[nl.SEG6_LOCAL_TABLE] = true
296		case nl.SEG6_LOCAL_NH4:
297			e.InAddr = net.IP(attr.Value[0:4])
298			e.Flags[nl.SEG6_LOCAL_NH4] = true
299		case nl.SEG6_LOCAL_NH6:
300			e.In6Addr = net.IP(attr.Value[0:16])
301			e.Flags[nl.SEG6_LOCAL_NH6] = true
302		case nl.SEG6_LOCAL_IIF:
303			e.Iif = int(native.Uint32(attr.Value[0:4]))
304			e.Flags[nl.SEG6_LOCAL_IIF] = true
305		case nl.SEG6_LOCAL_OIF:
306			e.Oif = int(native.Uint32(attr.Value[0:4]))
307			e.Flags[nl.SEG6_LOCAL_OIF] = true
308		}
309	}
310	return err
311}
312func (e *SEG6LocalEncap) Encode() ([]byte, error) {
313	var err error
314	native := nl.NativeEndian()
315	res := make([]byte, 8)
316	native.PutUint16(res, 8) // length
317	native.PutUint16(res[2:], nl.SEG6_LOCAL_ACTION)
318	native.PutUint32(res[4:], uint32(e.Action))
319	if e.Flags[nl.SEG6_LOCAL_SRH] {
320		srh, err := nl.EncodeSEG6Srh(e.Segments)
321		if err != nil {
322			return nil, err
323		}
324		attr := make([]byte, 4)
325		native.PutUint16(attr, uint16(len(srh)+4))
326		native.PutUint16(attr[2:], nl.SEG6_LOCAL_SRH)
327		attr = append(attr, srh...)
328		res = append(res, attr...)
329	}
330	if e.Flags[nl.SEG6_LOCAL_TABLE] {
331		attr := make([]byte, 8)
332		native.PutUint16(attr, 8)
333		native.PutUint16(attr[2:], nl.SEG6_LOCAL_TABLE)
334		native.PutUint32(attr[4:], uint32(e.Table))
335		res = append(res, attr...)
336	}
337	if e.Flags[nl.SEG6_LOCAL_NH4] {
338		attr := make([]byte, 4)
339		native.PutUint16(attr, 8)
340		native.PutUint16(attr[2:], nl.SEG6_LOCAL_NH4)
341		ipv4 := e.InAddr.To4()
342		if ipv4 == nil {
343			err = fmt.Errorf("SEG6_LOCAL_NH4 has invalid IPv4 address")
344			return nil, err
345		}
346		attr = append(attr, ipv4...)
347		res = append(res, attr...)
348	}
349	if e.Flags[nl.SEG6_LOCAL_NH6] {
350		attr := make([]byte, 4)
351		native.PutUint16(attr, 20)
352		native.PutUint16(attr[2:], nl.SEG6_LOCAL_NH6)
353		attr = append(attr, e.In6Addr...)
354		res = append(res, attr...)
355	}
356	if e.Flags[nl.SEG6_LOCAL_IIF] {
357		attr := make([]byte, 8)
358		native.PutUint16(attr, 8)
359		native.PutUint16(attr[2:], nl.SEG6_LOCAL_IIF)
360		native.PutUint32(attr[4:], uint32(e.Iif))
361		res = append(res, attr...)
362	}
363	if e.Flags[nl.SEG6_LOCAL_OIF] {
364		attr := make([]byte, 8)
365		native.PutUint16(attr, 8)
366		native.PutUint16(attr[2:], nl.SEG6_LOCAL_OIF)
367		native.PutUint32(attr[4:], uint32(e.Oif))
368		res = append(res, attr...)
369	}
370	return res, err
371}
372func (e *SEG6LocalEncap) String() string {
373	strs := make([]string, 0, nl.SEG6_LOCAL_MAX)
374	strs = append(strs, fmt.Sprintf("action %s", nl.SEG6LocalActionString(e.Action)))
375
376	if e.Flags[nl.SEG6_LOCAL_TABLE] {
377		strs = append(strs, fmt.Sprintf("table %d", e.Table))
378	}
379	if e.Flags[nl.SEG6_LOCAL_NH4] {
380		strs = append(strs, fmt.Sprintf("nh4 %s", e.InAddr))
381	}
382	if e.Flags[nl.SEG6_LOCAL_NH6] {
383		strs = append(strs, fmt.Sprintf("nh6 %s", e.In6Addr))
384	}
385	if e.Flags[nl.SEG6_LOCAL_IIF] {
386		link, err := LinkByIndex(e.Iif)
387		if err != nil {
388			strs = append(strs, fmt.Sprintf("iif %d", e.Iif))
389		} else {
390			strs = append(strs, fmt.Sprintf("iif %s", link.Attrs().Name))
391		}
392	}
393	if e.Flags[nl.SEG6_LOCAL_OIF] {
394		link, err := LinkByIndex(e.Oif)
395		if err != nil {
396			strs = append(strs, fmt.Sprintf("oif %d", e.Oif))
397		} else {
398			strs = append(strs, fmt.Sprintf("oif %s", link.Attrs().Name))
399		}
400	}
401	if e.Flags[nl.SEG6_LOCAL_SRH] {
402		segs := make([]string, 0, len(e.Segments))
403		//append segment backwards (from n to 0) since seg#0 is the last segment.
404		for i := len(e.Segments); i > 0; i-- {
405			segs = append(segs, fmt.Sprintf("%s", e.Segments[i-1]))
406		}
407		strs = append(strs, fmt.Sprintf("segs %d [ %s ]", len(e.Segments), strings.Join(segs, " ")))
408	}
409	return strings.Join(strs, " ")
410}
411func (e *SEG6LocalEncap) Equal(x Encap) bool {
412	o, ok := x.(*SEG6LocalEncap)
413	if !ok {
414		return false
415	}
416	if e == o {
417		return true
418	}
419	if e == nil || o == nil {
420		return false
421	}
422	// compare all arrays first
423	for i := range e.Flags {
424		if e.Flags[i] != o.Flags[i] {
425			return false
426		}
427	}
428	if len(e.Segments) != len(o.Segments) {
429		return false
430	}
431	for i := range e.Segments {
432		if !e.Segments[i].Equal(o.Segments[i]) {
433			return false
434		}
435	}
436	// compare values
437	if !e.InAddr.Equal(o.InAddr) || !e.In6Addr.Equal(o.In6Addr) {
438		return false
439	}
440	if e.Action != o.Action || e.Table != o.Table || e.Iif != o.Iif || e.Oif != o.Oif {
441		return false
442	}
443	return true
444}
445
446// RouteAdd will add a route to the system.
447// Equivalent to: `ip route add $route`
448func RouteAdd(route *Route) error {
449	return pkgHandle.RouteAdd(route)
450}
451
452// RouteAdd will add a route to the system.
453// Equivalent to: `ip route add $route`
454func (h *Handle) RouteAdd(route *Route) error {
455	flags := unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK
456	req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags)
457	return h.routeHandle(route, req, nl.NewRtMsg())
458}
459
460// RouteReplace will add a route to the system.
461// Equivalent to: `ip route replace $route`
462func RouteReplace(route *Route) error {
463	return pkgHandle.RouteReplace(route)
464}
465
466// RouteReplace will add a route to the system.
467// Equivalent to: `ip route replace $route`
468func (h *Handle) RouteReplace(route *Route) error {
469	flags := unix.NLM_F_CREATE | unix.NLM_F_REPLACE | unix.NLM_F_ACK
470	req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags)
471	return h.routeHandle(route, req, nl.NewRtMsg())
472}
473
474// RouteDel will delete a route from the system.
475// Equivalent to: `ip route del $route`
476func RouteDel(route *Route) error {
477	return pkgHandle.RouteDel(route)
478}
479
480// RouteDel will delete a route from the system.
481// Equivalent to: `ip route del $route`
482func (h *Handle) RouteDel(route *Route) error {
483	req := h.newNetlinkRequest(unix.RTM_DELROUTE, unix.NLM_F_ACK)
484	return h.routeHandle(route, req, nl.NewRtDelMsg())
485}
486
487func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error {
488	if (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil && route.MPLSDst == nil {
489		return fmt.Errorf("one of Dst.IP, Src, or Gw must not be nil")
490	}
491
492	family := -1
493	var rtAttrs []*nl.RtAttr
494
495	if route.Dst != nil && route.Dst.IP != nil {
496		dstLen, _ := route.Dst.Mask.Size()
497		msg.Dst_len = uint8(dstLen)
498		dstFamily := nl.GetIPFamily(route.Dst.IP)
499		family = dstFamily
500		var dstData []byte
501		if dstFamily == FAMILY_V4 {
502			dstData = route.Dst.IP.To4()
503		} else {
504			dstData = route.Dst.IP.To16()
505		}
506		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
507	} else if route.MPLSDst != nil {
508		family = nl.FAMILY_MPLS
509		msg.Dst_len = uint8(20)
510		msg.Type = unix.RTN_UNICAST
511		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, nl.EncodeMPLSStack(*route.MPLSDst)))
512	}
513
514	if route.NewDst != nil {
515		if family != -1 && family != route.NewDst.Family() {
516			return fmt.Errorf("new destination and destination are not the same address family")
517		}
518		buf, err := route.NewDst.Encode()
519		if err != nil {
520			return err
521		}
522		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_NEWDST, buf))
523	}
524
525	if route.Encap != nil {
526		buf := make([]byte, 2)
527		native.PutUint16(buf, uint16(route.Encap.Type()))
528		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf))
529		buf, err := route.Encap.Encode()
530		if err != nil {
531			return err
532		}
533		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP, buf))
534	}
535
536	if route.Src != nil {
537		srcFamily := nl.GetIPFamily(route.Src)
538		if family != -1 && family != srcFamily {
539			return fmt.Errorf("source and destination ip are not the same IP family")
540		}
541		family = srcFamily
542		var srcData []byte
543		if srcFamily == FAMILY_V4 {
544			srcData = route.Src.To4()
545		} else {
546			srcData = route.Src.To16()
547		}
548		// The commonly used src ip for routes is actually PREFSRC
549		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_PREFSRC, srcData))
550	}
551
552	if route.Gw != nil {
553		gwFamily := nl.GetIPFamily(route.Gw)
554		if family != -1 && family != gwFamily {
555			return fmt.Errorf("gateway, source, and destination ip are not the same IP family")
556		}
557		family = gwFamily
558		var gwData []byte
559		if gwFamily == FAMILY_V4 {
560			gwData = route.Gw.To4()
561		} else {
562			gwData = route.Gw.To16()
563		}
564		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_GATEWAY, gwData))
565	}
566
567	if len(route.MultiPath) > 0 {
568		buf := []byte{}
569		for _, nh := range route.MultiPath {
570			rtnh := &nl.RtNexthop{
571				RtNexthop: unix.RtNexthop{
572					Hops:    uint8(nh.Hops),
573					Ifindex: int32(nh.LinkIndex),
574					Flags:   uint8(nh.Flags),
575				},
576			}
577			children := []nl.NetlinkRequestData{}
578			if nh.Gw != nil {
579				gwFamily := nl.GetIPFamily(nh.Gw)
580				if family != -1 && family != gwFamily {
581					return fmt.Errorf("gateway, source, and destination ip are not the same IP family")
582				}
583				if gwFamily == FAMILY_V4 {
584					children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To4())))
585				} else {
586					children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To16())))
587				}
588			}
589			if nh.NewDst != nil {
590				if family != -1 && family != nh.NewDst.Family() {
591					return fmt.Errorf("new destination and destination are not the same address family")
592				}
593				buf, err := nh.NewDst.Encode()
594				if err != nil {
595					return err
596				}
597				children = append(children, nl.NewRtAttr(unix.RTA_NEWDST, buf))
598			}
599			if nh.Encap != nil {
600				buf := make([]byte, 2)
601				native.PutUint16(buf, uint16(nh.Encap.Type()))
602				children = append(children, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf))
603				buf, err := nh.Encap.Encode()
604				if err != nil {
605					return err
606				}
607				children = append(children, nl.NewRtAttr(unix.RTA_ENCAP, buf))
608			}
609			rtnh.Children = children
610			buf = append(buf, rtnh.Serialize()...)
611		}
612		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_MULTIPATH, buf))
613	}
614
615	if route.Table > 0 {
616		if route.Table >= 256 {
617			msg.Table = unix.RT_TABLE_UNSPEC
618			b := make([]byte, 4)
619			native.PutUint32(b, uint32(route.Table))
620			rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_TABLE, b))
621		} else {
622			msg.Table = uint8(route.Table)
623		}
624	}
625
626	if route.Priority > 0 {
627		b := make([]byte, 4)
628		native.PutUint32(b, uint32(route.Priority))
629		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_PRIORITY, b))
630	}
631	if route.Tos > 0 {
632		msg.Tos = uint8(route.Tos)
633	}
634	if route.Protocol > 0 {
635		msg.Protocol = uint8(route.Protocol)
636	}
637	if route.Type > 0 {
638		msg.Type = uint8(route.Type)
639	}
640
641	var metrics []*nl.RtAttr
642	// TODO: support other rta_metric values
643	if route.MTU > 0 {
644		b := nl.Uint32Attr(uint32(route.MTU))
645		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_MTU, b))
646	}
647	if route.AdvMSS > 0 {
648		b := nl.Uint32Attr(uint32(route.AdvMSS))
649		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_ADVMSS, b))
650	}
651	if route.Hoplimit > 0 {
652		b := nl.Uint32Attr(uint32(route.Hoplimit))
653		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_HOPLIMIT, b))
654	}
655
656	if metrics != nil {
657		attr := nl.NewRtAttr(unix.RTA_METRICS, nil)
658		for _, metric := range metrics {
659			attr.AddChild(metric)
660		}
661		rtAttrs = append(rtAttrs, attr)
662	}
663
664	msg.Flags = uint32(route.Flags)
665	msg.Scope = uint8(route.Scope)
666	msg.Family = uint8(family)
667	req.AddData(msg)
668	for _, attr := range rtAttrs {
669		req.AddData(attr)
670	}
671
672	var (
673		b      = make([]byte, 4)
674		native = nl.NativeEndian()
675	)
676	native.PutUint32(b, uint32(route.LinkIndex))
677
678	req.AddData(nl.NewRtAttr(unix.RTA_OIF, b))
679
680	_, err := req.Execute(unix.NETLINK_ROUTE, 0)
681	return err
682}
683
684// RouteList gets a list of routes in the system.
685// Equivalent to: `ip route show`.
686// The list can be filtered by link and ip family.
687func RouteList(link Link, family int) ([]Route, error) {
688	return pkgHandle.RouteList(link, family)
689}
690
691// RouteList gets a list of routes in the system.
692// Equivalent to: `ip route show`.
693// The list can be filtered by link and ip family.
694func (h *Handle) RouteList(link Link, family int) ([]Route, error) {
695	var routeFilter *Route
696	if link != nil {
697		routeFilter = &Route{
698			LinkIndex: link.Attrs().Index,
699		}
700	}
701	return h.RouteListFiltered(family, routeFilter, RT_FILTER_OIF)
702}
703
704// RouteListFiltered gets a list of routes in the system filtered with specified rules.
705// All rules must be defined in RouteFilter struct
706func RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
707	return pkgHandle.RouteListFiltered(family, filter, filterMask)
708}
709
710// RouteListFiltered gets a list of routes in the system filtered with specified rules.
711// All rules must be defined in RouteFilter struct
712func (h *Handle) RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
713	req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_DUMP)
714	infmsg := nl.NewIfInfomsg(family)
715	req.AddData(infmsg)
716
717	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWROUTE)
718	if err != nil {
719		return nil, err
720	}
721
722	var res []Route
723	for _, m := range msgs {
724		msg := nl.DeserializeRtMsg(m)
725		if msg.Flags&unix.RTM_F_CLONED != 0 {
726			// Ignore cloned routes
727			continue
728		}
729		if msg.Table != unix.RT_TABLE_MAIN {
730			if filter == nil || filter != nil && filterMask&RT_FILTER_TABLE == 0 {
731				// Ignore non-main tables
732				continue
733			}
734		}
735		route, err := deserializeRoute(m)
736		if err != nil {
737			return nil, err
738		}
739		if filter != nil {
740			switch {
741			case filterMask&RT_FILTER_TABLE != 0 && filter.Table != unix.RT_TABLE_UNSPEC && route.Table != filter.Table:
742				continue
743			case filterMask&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol:
744				continue
745			case filterMask&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope:
746				continue
747			case filterMask&RT_FILTER_TYPE != 0 && route.Type != filter.Type:
748				continue
749			case filterMask&RT_FILTER_TOS != 0 && route.Tos != filter.Tos:
750				continue
751			case filterMask&RT_FILTER_OIF != 0 && route.LinkIndex != filter.LinkIndex:
752				continue
753			case filterMask&RT_FILTER_IIF != 0 && route.ILinkIndex != filter.ILinkIndex:
754				continue
755			case filterMask&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw):
756				continue
757			case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src):
758				continue
759			case filterMask&RT_FILTER_DST != 0:
760				if filter.MPLSDst == nil || route.MPLSDst == nil || (*filter.MPLSDst) != (*route.MPLSDst) {
761					if !ipNetEqual(route.Dst, filter.Dst) {
762						continue
763					}
764				}
765			case filterMask&RT_FILTER_HOPLIMIT != 0 && route.Hoplimit != filter.Hoplimit:
766				continue
767			}
768		}
769		res = append(res, route)
770	}
771	return res, nil
772}
773
774// deserializeRoute decodes a binary netlink message into a Route struct
775func deserializeRoute(m []byte) (Route, error) {
776	msg := nl.DeserializeRtMsg(m)
777	attrs, err := nl.ParseRouteAttr(m[msg.Len():])
778	if err != nil {
779		return Route{}, err
780	}
781	route := Route{
782		Scope:    Scope(msg.Scope),
783		Protocol: int(msg.Protocol),
784		Table:    int(msg.Table),
785		Type:     int(msg.Type),
786		Tos:      int(msg.Tos),
787		Flags:    int(msg.Flags),
788	}
789
790	native := nl.NativeEndian()
791	var encap, encapType syscall.NetlinkRouteAttr
792	for _, attr := range attrs {
793		switch attr.Attr.Type {
794		case unix.RTA_GATEWAY:
795			route.Gw = net.IP(attr.Value)
796		case unix.RTA_PREFSRC:
797			route.Src = net.IP(attr.Value)
798		case unix.RTA_DST:
799			if msg.Family == nl.FAMILY_MPLS {
800				stack := nl.DecodeMPLSStack(attr.Value)
801				if len(stack) == 0 || len(stack) > 1 {
802					return route, fmt.Errorf("invalid MPLS RTA_DST")
803				}
804				route.MPLSDst = &stack[0]
805			} else {
806				route.Dst = &net.IPNet{
807					IP:   attr.Value,
808					Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attr.Value)),
809				}
810			}
811		case unix.RTA_OIF:
812			route.LinkIndex = int(native.Uint32(attr.Value[0:4]))
813		case unix.RTA_IIF:
814			route.ILinkIndex = int(native.Uint32(attr.Value[0:4]))
815		case unix.RTA_PRIORITY:
816			route.Priority = int(native.Uint32(attr.Value[0:4]))
817		case unix.RTA_TABLE:
818			route.Table = int(native.Uint32(attr.Value[0:4]))
819		case unix.RTA_MULTIPATH:
820			parseRtNexthop := func(value []byte) (*NexthopInfo, []byte, error) {
821				if len(value) < unix.SizeofRtNexthop {
822					return nil, nil, fmt.Errorf("lack of bytes")
823				}
824				nh := nl.DeserializeRtNexthop(value)
825				if len(value) < int(nh.RtNexthop.Len) {
826					return nil, nil, fmt.Errorf("lack of bytes")
827				}
828				info := &NexthopInfo{
829					LinkIndex: int(nh.RtNexthop.Ifindex),
830					Hops:      int(nh.RtNexthop.Hops),
831					Flags:     int(nh.RtNexthop.Flags),
832				}
833				attrs, err := nl.ParseRouteAttr(value[unix.SizeofRtNexthop:int(nh.RtNexthop.Len)])
834				if err != nil {
835					return nil, nil, err
836				}
837				var encap, encapType syscall.NetlinkRouteAttr
838				for _, attr := range attrs {
839					switch attr.Attr.Type {
840					case unix.RTA_GATEWAY:
841						info.Gw = net.IP(attr.Value)
842					case unix.RTA_NEWDST:
843						var d Destination
844						switch msg.Family {
845						case nl.FAMILY_MPLS:
846							d = &MPLSDestination{}
847						}
848						if err := d.Decode(attr.Value); err != nil {
849							return nil, nil, err
850						}
851						info.NewDst = d
852					case unix.RTA_ENCAP_TYPE:
853						encapType = attr
854					case unix.RTA_ENCAP:
855						encap = attr
856					}
857				}
858
859				if len(encap.Value) != 0 && len(encapType.Value) != 0 {
860					typ := int(native.Uint16(encapType.Value[0:2]))
861					var e Encap
862					switch typ {
863					case nl.LWTUNNEL_ENCAP_MPLS:
864						e = &MPLSEncap{}
865						if err := e.Decode(encap.Value); err != nil {
866							return nil, nil, err
867						}
868					}
869					info.Encap = e
870				}
871
872				return info, value[int(nh.RtNexthop.Len):], nil
873			}
874			rest := attr.Value
875			for len(rest) > 0 {
876				info, buf, err := parseRtNexthop(rest)
877				if err != nil {
878					return route, err
879				}
880				route.MultiPath = append(route.MultiPath, info)
881				rest = buf
882			}
883		case unix.RTA_NEWDST:
884			var d Destination
885			switch msg.Family {
886			case nl.FAMILY_MPLS:
887				d = &MPLSDestination{}
888			}
889			if err := d.Decode(attr.Value); err != nil {
890				return route, err
891			}
892			route.NewDst = d
893		case unix.RTA_ENCAP_TYPE:
894			encapType = attr
895		case unix.RTA_ENCAP:
896			encap = attr
897		case unix.RTA_METRICS:
898			metrics, err := nl.ParseRouteAttr(attr.Value)
899			if err != nil {
900				return route, err
901			}
902			for _, metric := range metrics {
903				switch metric.Attr.Type {
904				case unix.RTAX_MTU:
905					route.MTU = int(native.Uint32(metric.Value[0:4]))
906				case unix.RTAX_ADVMSS:
907					route.AdvMSS = int(native.Uint32(metric.Value[0:4]))
908				case unix.RTAX_HOPLIMIT:
909					route.Hoplimit = int(native.Uint32(metric.Value[0:4]))
910				}
911			}
912		}
913	}
914
915	if len(encap.Value) != 0 && len(encapType.Value) != 0 {
916		typ := int(native.Uint16(encapType.Value[0:2]))
917		var e Encap
918		switch typ {
919		case nl.LWTUNNEL_ENCAP_MPLS:
920			e = &MPLSEncap{}
921			if err := e.Decode(encap.Value); err != nil {
922				return route, err
923			}
924		case nl.LWTUNNEL_ENCAP_SEG6:
925			e = &SEG6Encap{}
926			if err := e.Decode(encap.Value); err != nil {
927				return route, err
928			}
929		case nl.LWTUNNEL_ENCAP_SEG6_LOCAL:
930			e = &SEG6LocalEncap{}
931			if err := e.Decode(encap.Value); err != nil {
932				return route, err
933			}
934		}
935		route.Encap = e
936	}
937
938	return route, nil
939}
940
941// RouteGet gets a route to a specific destination from the host system.
942// Equivalent to: 'ip route get'.
943func RouteGet(destination net.IP) ([]Route, error) {
944	return pkgHandle.RouteGet(destination)
945}
946
947// RouteGet gets a route to a specific destination from the host system.
948// Equivalent to: 'ip route get'.
949func (h *Handle) RouteGet(destination net.IP) ([]Route, error) {
950	req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_REQUEST)
951	family := nl.GetIPFamily(destination)
952	var destinationData []byte
953	var bitlen uint8
954	if family == FAMILY_V4 {
955		destinationData = destination.To4()
956		bitlen = 32
957	} else {
958		destinationData = destination.To16()
959		bitlen = 128
960	}
961	msg := &nl.RtMsg{}
962	msg.Family = uint8(family)
963	msg.Dst_len = bitlen
964	req.AddData(msg)
965
966	rtaDst := nl.NewRtAttr(unix.RTA_DST, destinationData)
967	req.AddData(rtaDst)
968
969	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWROUTE)
970	if err != nil {
971		return nil, err
972	}
973
974	var res []Route
975	for _, m := range msgs {
976		route, err := deserializeRoute(m)
977		if err != nil {
978			return nil, err
979		}
980		res = append(res, route)
981	}
982	return res, nil
983
984}
985
986// RouteSubscribe takes a chan down which notifications will be sent
987// when routes are added or deleted. Close the 'done' chan to stop subscription.
988func RouteSubscribe(ch chan<- RouteUpdate, done <-chan struct{}) error {
989	return routeSubscribeAt(netns.None(), netns.None(), ch, done, nil, false)
990}
991
992// RouteSubscribeAt works like RouteSubscribe plus it allows the caller
993// to choose the network namespace in which to subscribe (ns).
994func RouteSubscribeAt(ns netns.NsHandle, ch chan<- RouteUpdate, done <-chan struct{}) error {
995	return routeSubscribeAt(ns, netns.None(), ch, done, nil, false)
996}
997
998// RouteSubscribeOptions contains a set of options to use with
999// RouteSubscribeWithOptions.
1000type RouteSubscribeOptions struct {
1001	Namespace     *netns.NsHandle
1002	ErrorCallback func(error)
1003	ListExisting  bool
1004}
1005
1006// RouteSubscribeWithOptions work like RouteSubscribe but enable to
1007// provide additional options to modify the behavior. Currently, the
1008// namespace can be provided as well as an error callback.
1009func RouteSubscribeWithOptions(ch chan<- RouteUpdate, done <-chan struct{}, options RouteSubscribeOptions) error {
1010	if options.Namespace == nil {
1011		none := netns.None()
1012		options.Namespace = &none
1013	}
1014	return routeSubscribeAt(*options.Namespace, netns.None(), ch, done, options.ErrorCallback, options.ListExisting)
1015}
1016
1017func routeSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RouteUpdate, done <-chan struct{}, cberr func(error), listExisting bool) error {
1018	s, err := nl.SubscribeAt(newNs, curNs, unix.NETLINK_ROUTE, unix.RTNLGRP_IPV4_ROUTE, unix.RTNLGRP_IPV6_ROUTE)
1019	if err != nil {
1020		return err
1021	}
1022	if done != nil {
1023		go func() {
1024			<-done
1025			s.Close()
1026		}()
1027	}
1028	if listExisting {
1029		req := pkgHandle.newNetlinkRequest(unix.RTM_GETROUTE,
1030			unix.NLM_F_DUMP)
1031		infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC)
1032		req.AddData(infmsg)
1033		if err := s.Send(req); err != nil {
1034			return err
1035		}
1036	}
1037	go func() {
1038		defer close(ch)
1039		for {
1040			msgs, from, err := s.Receive()
1041			if err != nil {
1042				if cberr != nil {
1043					cberr(err)
1044				}
1045				return
1046			}
1047			if from.Pid != nl.PidKernel {
1048				if cberr != nil {
1049					cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
1050				}
1051				continue
1052			}
1053			for _, m := range msgs {
1054				if m.Header.Type == unix.NLMSG_DONE {
1055					continue
1056				}
1057				if m.Header.Type == unix.NLMSG_ERROR {
1058					native := nl.NativeEndian()
1059					error := int32(native.Uint32(m.Data[0:4]))
1060					if error == 0 {
1061						continue
1062					}
1063					if cberr != nil {
1064						cberr(syscall.Errno(-error))
1065					}
1066					return
1067				}
1068				route, err := deserializeRoute(m.Data)
1069				if err != nil {
1070					if cberr != nil {
1071						cberr(err)
1072					}
1073					return
1074				}
1075				ch <- RouteUpdate{Type: m.Header.Type, Route: route}
1076			}
1077		}
1078	}()
1079
1080	return nil
1081}
1082