1// Copyright 2012 Google, Inc. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style license
4// that can be found in the LICENSE file in the root of the source
5// tree.
6
7// +build linux
8
9// Package routing provides a very basic but mostly functional implementation of
10// a routing table for IPv4/IPv6 addresses.  It uses a routing table pulled from
11// the kernel via netlink to find the correct interface, gateway, and preferred
12// source IP address for packets destined to a particular location.
13//
14// The routing package is meant to be used with applications that are sending
15// raw packet data, which don't have the benefit of having the kernel route
16// packets for them.
17package routing
18
19import (
20	"bytes"
21	"errors"
22	"fmt"
23	"net"
24	"sort"
25	"strings"
26	"syscall"
27	"unsafe"
28)
29
30// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
31// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
32type routeInfoInMemory struct {
33	Family byte
34	DstLen byte
35	SrcLen byte
36	TOS    byte
37
38	Table    byte
39	Protocol byte
40	Scope    byte
41	Type     byte
42
43	Flags uint32
44}
45
46// rtInfo contains information on a single route.
47type rtInfo struct {
48	Src, Dst         *net.IPNet
49	Gateway, PrefSrc net.IP
50	// We currently ignore the InputIface.
51	InputIface, OutputIface uint32
52	Priority                uint32
53}
54
55// routeSlice implements sort.Interface to sort routes by Priority.
56type routeSlice []*rtInfo
57
58func (r routeSlice) Len() int {
59	return len(r)
60}
61func (r routeSlice) Less(i, j int) bool {
62	return r[i].Priority < r[j].Priority
63}
64func (r routeSlice) Swap(i, j int) {
65	r[i], r[j] = r[j], r[i]
66}
67
68type router struct {
69	ifaces []net.Interface
70	addrs  []ipAddrs
71	v4, v6 routeSlice
72}
73
74func (r *router) String() string {
75	strs := []string{"ROUTER", "--- V4 ---"}
76	for _, route := range r.v4 {
77		strs = append(strs, fmt.Sprintf("%+v", *route))
78	}
79	strs = append(strs, "--- V6 ---")
80	for _, route := range r.v6 {
81		strs = append(strs, fmt.Sprintf("%+v", *route))
82	}
83	return strings.Join(strs, "\n")
84}
85
86type ipAddrs struct {
87	v4, v6 net.IP
88}
89
90func (r *router) Route(dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
91	return r.RouteWithSrc(nil, nil, dst)
92}
93
94func (r *router) RouteWithSrc(input net.HardwareAddr, src, dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
95	var ifaceIndex int
96	switch {
97	case dst.To4() != nil:
98		ifaceIndex, gateway, preferredSrc, err = r.route(r.v4, input, src, dst)
99	case dst.To16() != nil:
100		ifaceIndex, gateway, preferredSrc, err = r.route(r.v6, input, src, dst)
101	default:
102		err = errors.New("IP is not valid as IPv4 or IPv6")
103		return
104	}
105
106	// Interfaces are 1-indexed, but we store them in a 0-indexed array.
107	ifaceIndex--
108
109	iface = &r.ifaces[ifaceIndex]
110	if preferredSrc == nil {
111		switch {
112		case dst.To4() != nil:
113			preferredSrc = r.addrs[ifaceIndex].v4
114		case dst.To16() != nil:
115			preferredSrc = r.addrs[ifaceIndex].v6
116		}
117	}
118	return
119}
120
121func (r *router) route(routes routeSlice, input net.HardwareAddr, src, dst net.IP) (iface int, gateway, preferredSrc net.IP, err error) {
122	var inputIndex uint32
123	if input != nil {
124		for i, iface := range r.ifaces {
125			if bytes.Equal(input, iface.HardwareAddr) {
126				// Convert from zero- to one-indexed.
127				inputIndex = uint32(i + 1)
128				break
129			}
130		}
131	}
132	for _, rt := range routes {
133		if rt.InputIface != 0 && rt.InputIface != inputIndex {
134			continue
135		}
136		if rt.Src != nil && !rt.Src.Contains(src) {
137			continue
138		}
139		if rt.Dst != nil && !rt.Dst.Contains(dst) {
140			continue
141		}
142		return int(rt.OutputIface), rt.Gateway, rt.PrefSrc, nil
143	}
144	err = fmt.Errorf("no route found for %v", dst)
145	return
146}
147
148// New creates a new router object.  The router returned by New currently does
149// not update its routes after construction... care should be taken for
150// long-running programs to call New() regularly to take into account any
151// changes to the routing table which have occurred since the last New() call.
152func New() (Router, error) {
153	rtr := &router{}
154	tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
155	if err != nil {
156		return nil, err
157	}
158	msgs, err := syscall.ParseNetlinkMessage(tab)
159	if err != nil {
160		return nil, err
161	}
162loop:
163	for _, m := range msgs {
164		switch m.Header.Type {
165		case syscall.NLMSG_DONE:
166			break loop
167		case syscall.RTM_NEWROUTE:
168			rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
169			routeInfo := rtInfo{}
170			attrs, err := syscall.ParseNetlinkRouteAttr(&m)
171			if err != nil {
172				return nil, err
173			}
174			switch rt.Family {
175			case syscall.AF_INET:
176				rtr.v4 = append(rtr.v4, &routeInfo)
177			case syscall.AF_INET6:
178				rtr.v6 = append(rtr.v6, &routeInfo)
179			default:
180				continue loop
181			}
182			for _, attr := range attrs {
183				switch attr.Attr.Type {
184				case syscall.RTA_DST:
185					routeInfo.Dst = &net.IPNet{
186						IP:   net.IP(attr.Value),
187						Mask: net.CIDRMask(int(rt.DstLen), len(attr.Value)*8),
188					}
189				case syscall.RTA_SRC:
190					routeInfo.Src = &net.IPNet{
191						IP:   net.IP(attr.Value),
192						Mask: net.CIDRMask(int(rt.SrcLen), len(attr.Value)*8),
193					}
194				case syscall.RTA_GATEWAY:
195					routeInfo.Gateway = net.IP(attr.Value)
196				case syscall.RTA_PREFSRC:
197					routeInfo.PrefSrc = net.IP(attr.Value)
198				case syscall.RTA_IIF:
199					routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
200				case syscall.RTA_OIF:
201					routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
202				case syscall.RTA_PRIORITY:
203					routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
204				}
205			}
206		}
207	}
208	sort.Sort(rtr.v4)
209	sort.Sort(rtr.v6)
210	ifaces, err := net.Interfaces()
211	if err != nil {
212		return nil, err
213	}
214	for i, iface := range ifaces {
215		if i != iface.Index-1 {
216			return nil, fmt.Errorf("out of order iface %d = %v", i, iface)
217		}
218		rtr.ifaces = append(rtr.ifaces, iface)
219		var addrs ipAddrs
220		ifaceAddrs, err := iface.Addrs()
221		if err != nil {
222			return nil, err
223		}
224		for _, addr := range ifaceAddrs {
225			if inet, ok := addr.(*net.IPNet); ok {
226				// Go has a nasty habit of giving you IPv4s as ::ffff:1.2.3.4 instead of 1.2.3.4.
227				// We want to use mapped v4 addresses as v4 preferred addresses, never as v6
228				// preferred addresses.
229				if v4 := inet.IP.To4(); v4 != nil {
230					if addrs.v4 == nil {
231						addrs.v4 = v4
232					}
233				} else if addrs.v6 == nil {
234					addrs.v6 = inet.IP
235				}
236			}
237		}
238		rtr.addrs = append(rtr.addrs, addrs)
239	}
240	return rtr, nil
241}
242