1package vnet
2
3import (
4	"encoding/binary"
5	"fmt"
6	"math/rand"
7	"net"
8	"strconv"
9	"strings"
10	"sync"
11)
12
13var macAddrCounter uint64 = 0xBEEFED910200
14
15func newMACAddress() net.HardwareAddr {
16	b := make([]byte, 8)
17	binary.BigEndian.PutUint64(b, macAddrCounter)
18	macAddrCounter++
19	return b[2:]
20}
21
22type vNet struct {
23	interfaces []*Interface // read-only
24	staticIPs  []net.IP     // read-only
25	router     *Router      // read-only
26	udpConns   *udpConnMap  // read-only
27	mutex      sync.RWMutex
28}
29
30func (v *vNet) _getInterfaces() ([]*Interface, error) {
31	if len(v.interfaces) == 0 {
32		return nil, fmt.Errorf("no interface is available")
33	}
34
35	return v.interfaces, nil
36}
37
38func (v *vNet) getInterfaces() ([]*Interface, error) {
39	v.mutex.RLock()
40	defer v.mutex.RUnlock()
41
42	return v._getInterfaces()
43}
44
45// caller must hold the mutex (read)
46func (v *vNet) _getInterface(ifName string) (*Interface, error) {
47	ifs, err := v._getInterfaces()
48	if err != nil {
49		return nil, err
50	}
51	for _, ifc := range ifs {
52		if ifc.Name == ifName {
53			return ifc, nil
54		}
55	}
56
57	return nil, fmt.Errorf("interface %s not found", ifName)
58}
59
60func (v *vNet) getInterface(ifName string) (*Interface, error) {
61	v.mutex.RLock()
62	defer v.mutex.RUnlock()
63
64	return v._getInterface(ifName)
65}
66
67// caller must hold the mutex
68func (v *vNet) getAllIPAddrs(ipv6 bool) []net.IP {
69	ips := []net.IP{}
70
71	for _, ifc := range v.interfaces {
72		addrs, err := ifc.Addrs()
73		if err != nil {
74			continue
75		}
76
77		for _, addr := range addrs {
78			var ip net.IP
79			if ipNet, ok := addr.(*net.IPNet); ok {
80				ip = ipNet.IP
81			} else if ipAddr, ok := addr.(*net.IPAddr); ok {
82				ip = ipAddr.IP
83			} else {
84				continue
85			}
86
87			if !ipv6 {
88				if ip.To4() != nil {
89					ips = append(ips, ip)
90				}
91			}
92		}
93	}
94
95	return ips
96}
97
98func (v *vNet) setRouter(r *Router) error {
99	v.mutex.Lock()
100	defer v.mutex.Unlock()
101
102	v.router = r
103	return nil
104}
105
106func (v *vNet) onInboundChunk(c Chunk) {
107	v.mutex.Lock()
108	defer v.mutex.Unlock()
109
110	if c.Network() == "udp" {
111		if conn, ok := v.udpConns.find(c.DestinationAddr()); ok {
112			conn.onInboundChunk(c)
113		}
114	}
115}
116
117// caller must hold the mutex
118func (v *vNet) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) {
119	// validate network
120	if network != "udp" && network != "udp4" {
121		return nil, fmt.Errorf("unexpected network: %s", network)
122	}
123
124	if locAddr == nil {
125		locAddr = &net.UDPAddr{
126			IP: net.IPv4zero,
127		}
128	} else if locAddr.IP == nil {
129		locAddr.IP = net.IPv4zero
130	}
131
132	// validate address. do we have that address?
133	if !v.hasIPAddr(locAddr.IP) {
134		return nil, &net.OpError{
135			Op:   "listen",
136			Net:  network,
137			Addr: locAddr,
138			Err:  fmt.Errorf("bind: can't assign requested address"),
139		}
140	}
141
142	if locAddr.Port == 0 {
143		// choose randomly from the range between 5000 and 5999
144		port, err := v.assignPort(locAddr.IP, 5000, 5999)
145		if err != nil {
146			return nil, &net.OpError{
147				Op:   "listen",
148				Net:  network,
149				Addr: locAddr,
150				Err:  err,
151			}
152		}
153		locAddr.Port = port
154	} else if _, ok := v.udpConns.find(locAddr); ok {
155		return nil, &net.OpError{
156			Op:   "listen",
157			Net:  network,
158			Addr: locAddr,
159			Err:  fmt.Errorf("bind: address already in use"),
160		}
161	}
162
163	conn, err := newUDPConn(locAddr, remAddr, v)
164	if err != nil {
165		return nil, err
166	}
167
168	err = v.udpConns.insert(conn)
169	if err != nil {
170		return nil, err
171	}
172
173	return conn, nil
174}
175
176func (v *vNet) listenPacket(network string, address string) (UDPPacketConn, error) {
177	v.mutex.Lock()
178	defer v.mutex.Unlock()
179
180	locAddr, err := v.resolveUDPAddr(network, address)
181	if err != nil {
182		return nil, err
183	}
184
185	return v._dialUDP(network, locAddr, nil)
186}
187
188func (v *vNet) listenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) {
189	v.mutex.Lock()
190	defer v.mutex.Unlock()
191
192	return v._dialUDP(network, locAddr, nil)
193}
194
195func (v *vNet) dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) {
196	v.mutex.Lock()
197	defer v.mutex.Unlock()
198
199	return v._dialUDP(network, locAddr, remAddr)
200}
201
202func (v *vNet) dial(network string, address string) (UDPPacketConn, error) {
203	v.mutex.Lock()
204	defer v.mutex.Unlock()
205
206	remAddr, err := v.resolveUDPAddr(network, address)
207	if err != nil {
208		return nil, err
209	}
210
211	// Determine source address
212	srcIP := v.determineSourceIP(nil, remAddr.IP)
213
214	locAddr := &net.UDPAddr{IP: srcIP, Port: 0}
215
216	return v._dialUDP(network, locAddr, remAddr)
217}
218
219func (v *vNet) resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
220	if network != "udp" && network != "udp4" {
221		return nil, fmt.Errorf("unknown network %s", network)
222	}
223
224	host, sPort, err := net.SplitHostPort(address)
225	if err != nil {
226		return nil, err
227	}
228
229	// Check if host is a domain name
230	ip := net.ParseIP(host)
231	if ip == nil {
232		host = strings.ToLower(host)
233		if host == "localhost" {
234			ip = net.IPv4(127, 0, 0, 1)
235		} else {
236			// host is a domain name. resolve IP address by the name
237			if v.router == nil {
238				return nil, fmt.Errorf("no router linked")
239			}
240
241			ip, err = v.router.resolver.lookUp(host)
242			if err != nil {
243				return nil, err
244			}
245		}
246	}
247
248	port, err := strconv.Atoi(sPort)
249	if err != nil {
250		return nil, fmt.Errorf("invalid port number")
251	}
252
253	udpAddr := &net.UDPAddr{
254		IP:   ip,
255		Port: port,
256	}
257
258	return udpAddr, nil
259}
260
261func (v *vNet) write(c Chunk) error {
262	if c.Network() == "udp" {
263		if udp, ok := c.(*chunkUDP); ok {
264			if c.getDestinationIP().IsLoopback() {
265				if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok {
266					conn.onInboundChunk(udp)
267				}
268				return nil
269			}
270		} else {
271			return fmt.Errorf("unexpected type-switch failure")
272		}
273	}
274
275	if v.router == nil {
276		return fmt.Errorf("no router linked")
277	}
278
279	v.router.push(c)
280	return nil
281}
282
283func (v *vNet) onClosed(addr net.Addr) {
284	if addr.Network() == "udp" {
285		//nolint:errcheck
286		v.udpConns.delete(addr) // #nosec
287	}
288}
289
290// This method determines the srcIP based on the dstIP when locIP
291// is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr,
292// this method simply returns locIP.
293// caller must hold the mutex
294func (v *vNet) determineSourceIP(locIP, dstIP net.IP) net.IP {
295	if locIP != nil && !locIP.IsUnspecified() {
296		return locIP
297	}
298
299	var srcIP net.IP
300
301	if dstIP.IsLoopback() {
302		srcIP = net.ParseIP("127.0.0.1")
303	} else {
304		ifc, err2 := v._getInterface("eth0")
305		if err2 != nil {
306			return nil
307		}
308
309		addrs, err2 := ifc.Addrs()
310		if err2 != nil {
311			return nil
312		}
313
314		if len(addrs) == 0 {
315			return nil
316		}
317
318		var findIPv4 bool
319		if locIP != nil {
320			findIPv4 = (locIP.To4() != nil)
321		} else {
322			findIPv4 = (dstIP.To4() != nil)
323		}
324
325		for _, addr := range addrs {
326			ip := addr.(*net.IPNet).IP
327			if findIPv4 {
328				if ip.To4() != nil {
329					srcIP = ip
330					break
331				}
332			} else {
333				if ip.To4() == nil {
334					srcIP = ip
335					break
336				}
337			}
338		}
339	}
340
341	return srcIP
342}
343
344// caller must hold the mutex
345func (v *vNet) hasIPAddr(ip net.IP) bool {
346	for _, ifc := range v.interfaces {
347		if addrs, err := ifc.Addrs(); err == nil {
348			for _, addr := range addrs {
349				var locIP net.IP
350				if ipNet, ok := addr.(*net.IPNet); ok {
351					locIP = ipNet.IP
352				} else if ipAddr, ok := addr.(*net.IPAddr); ok {
353					locIP = ipAddr.IP
354				} else {
355					continue
356				}
357
358				switch ip.String() {
359				case "0.0.0.0":
360					if locIP.To4() != nil {
361						return true
362					}
363				case "::":
364					if locIP.To4() == nil {
365						return true
366					}
367				default:
368					if locIP.Equal(ip) {
369						return true
370					}
371				}
372			}
373		}
374	}
375
376	return false
377}
378
379// caller must hold the mutex
380func (v *vNet) allocateLocalAddr(ip net.IP, port int) error {
381	// gather local IP addresses to bind
382	var ips []net.IP
383	if ip.IsUnspecified() {
384		ips = v.getAllIPAddrs(ip.To4() == nil)
385	} else if v.hasIPAddr(ip) {
386		ips = []net.IP{ip}
387	}
388
389	if len(ips) == 0 {
390		return fmt.Errorf("bind failed for %s", ip.String())
391	}
392
393	// check if all these transport addresses are not in use
394	for _, ip2 := range ips {
395		addr := &net.UDPAddr{
396			IP:   ip2,
397			Port: port,
398		}
399		if _, ok := v.udpConns.find(addr); ok {
400			return &net.OpError{
401				Op:   "bind",
402				Net:  "udp",
403				Addr: addr,
404				Err:  fmt.Errorf("bind: address already in use"),
405			}
406		}
407	}
408
409	return nil
410}
411
412// caller must hold the mutex
413func (v *vNet) assignPort(ip net.IP, start, end int) (int, error) {
414	// choose randomly from the range between start and end (inclusive)
415	if end < start {
416		return -1, fmt.Errorf("end port is less than the start")
417	}
418
419	space := end + 1 - start
420	offset := rand.Intn(space)
421	for i := 0; i < space; i++ {
422		port := ((offset + i) % space) + start
423
424		err := v.allocateLocalAddr(ip, port)
425		if err == nil {
426			return port, nil
427		}
428	}
429
430	return -1, fmt.Errorf("port space exhausted")
431}
432
433// NetConfig is a bag of configuration parameters passed to NewNet().
434type NetConfig struct {
435	// StaticIPs is an array of static IP addresses to be assigned for this Net.
436	// If no static IP address is given, the router will automatically assign
437	// an IP address.
438	StaticIPs []string
439
440	// StaticIP is deprecated. Use StaticIPs.
441	StaticIP string
442}
443
444// Net represents a local network stack euivalent to a set of layers from NIC
445// up to the transport (UDP / TCP) layer.
446type Net struct {
447	v   *vNet
448	ifs []*Interface
449}
450
451// NewNet creates an instance of Net.
452// If config is nil, the virtual network is disabled. (uses corresponding
453// net.Xxxx() operations.
454// By design, it always have lo0 and eth0 interfaces.
455// The lo0 has the address 127.0.0.1 assigned by default.
456// IP address for eth0 will be assigned when this Net is added to a router.
457func NewNet(config *NetConfig) *Net {
458	if config == nil {
459		ifs := []*Interface{}
460		if orgIfs, err := net.Interfaces(); err == nil {
461			for _, orgIfc := range orgIfs {
462				ifc := NewInterface(orgIfc)
463				if addrs, err := orgIfc.Addrs(); err == nil {
464					for _, addr := range addrs {
465						ifc.AddAddr(addr)
466					}
467				}
468
469				ifs = append(ifs, ifc)
470			}
471		}
472
473		return &Net{ifs: ifs}
474	}
475
476	lo0 := NewInterface(net.Interface{
477		Index:        1,
478		MTU:          16384,
479		Name:         "lo0",
480		HardwareAddr: nil,
481		Flags:        net.FlagUp | net.FlagLoopback | net.FlagMulticast,
482	})
483	lo0.AddAddr(&net.IPNet{
484		IP:   net.ParseIP("127.0.0.1"),
485		Mask: net.CIDRMask(8, 32),
486	})
487
488	eth0 := NewInterface(net.Interface{
489		Index:        2,
490		MTU:          1500,
491		Name:         "eth0",
492		HardwareAddr: newMACAddress(),
493		Flags:        net.FlagUp | net.FlagMulticast,
494	})
495
496	var staticIPs []net.IP
497	for _, ipStr := range config.StaticIPs {
498		if ip := net.ParseIP(ipStr); ip != nil {
499			staticIPs = append(staticIPs, ip)
500		}
501	}
502	if len(config.StaticIP) > 0 {
503		if ip := net.ParseIP(config.StaticIP); ip != nil {
504			staticIPs = append(staticIPs, ip)
505		}
506	}
507
508	v := &vNet{
509		interfaces: []*Interface{lo0, eth0},
510		staticIPs:  staticIPs,
511		udpConns:   newUDPConnMap(),
512	}
513
514	return &Net{
515		v: v,
516	}
517}
518
519// Interfaces returns a list of the system's network interfaces.
520func (n *Net) Interfaces() ([]*Interface, error) {
521	if n.v == nil {
522		return n.ifs, nil
523	}
524
525	return n.v.getInterfaces()
526}
527
528// InterfaceByName returns the interface specified by name.
529func (n *Net) InterfaceByName(name string) (*Interface, error) {
530	if n.v == nil {
531		for _, ifc := range n.ifs {
532			if ifc.Name == name {
533				return ifc, nil
534			}
535		}
536
537		return nil, fmt.Errorf("interface %s not found", name)
538	}
539
540	return n.v.getInterface(name)
541}
542
543// ListenPacket announces on the local network address.
544func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) {
545	if n.v == nil {
546		return net.ListenPacket(network, address)
547	}
548
549	return n.v.listenPacket(network, address)
550}
551
552// ListenUDP acts like ListenPacket for UDP networks.
553func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) {
554	if n.v == nil {
555		return net.ListenUDP(network, locAddr)
556	}
557
558	return n.v.listenUDP(network, locAddr)
559}
560
561// Dial connects to the address on the named network.
562func (n *Net) Dial(network, address string) (net.Conn, error) {
563	if n.v == nil {
564		return net.Dial(network, address)
565	}
566
567	return n.v.dial(network, address)
568}
569
570// CreateDialer creates an instance of vnet.Dialer
571func (n *Net) CreateDialer(dialer *net.Dialer) Dialer {
572	if n.v == nil {
573		return &vDialer{
574			dialer: dialer,
575		}
576	}
577
578	return &vDialer{
579		dialer: dialer,
580		v:      n.v,
581	}
582}
583
584// DialUDP acts like Dial for UDP networks.
585func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (UDPPacketConn, error) {
586	if n.v == nil {
587		return net.DialUDP(network, laddr, raddr)
588	}
589
590	return n.v.dialUDP(network, laddr, raddr)
591}
592
593// ResolveUDPAddr returns an address of UDP end point.
594func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
595	if n.v == nil {
596		return net.ResolveUDPAddr(network, address)
597	}
598
599	return n.v.resolveUDPAddr(network, address)
600}
601
602func (n *Net) getInterface(ifName string) (*Interface, error) {
603	if n.v == nil {
604		return nil, fmt.Errorf("vnet is not enabled")
605	}
606
607	return n.v.getInterface(ifName)
608}
609
610func (n *Net) setRouter(r *Router) error {
611	if n.v == nil {
612		return fmt.Errorf("vnet is not enabled")
613	}
614
615	return n.v.setRouter(r)
616}
617
618func (n *Net) onInboundChunk(c Chunk) {
619	if n.v == nil {
620		return
621	}
622
623	n.v.onInboundChunk(c)
624}
625
626func (n *Net) getStaticIPs() []net.IP {
627	if n.v == nil {
628		return nil
629	}
630
631	return n.v.staticIPs
632}
633
634// IsVirtual tests if the virtual network is enabled.
635func (n *Net) IsVirtual() bool {
636	return n.v != nil
637}
638
639// Dialer is identical to net.Dialer excepts that its methods
640// (Dial, DialContext) are overridden to use virtual network.
641// Use vnet.CreateDialer() to create an instance of this Dialer.
642type Dialer interface {
643	Dial(network, address string) (net.Conn, error)
644}
645
646type vDialer struct {
647	dialer *net.Dialer
648	v      *vNet
649}
650
651func (d *vDialer) Dial(network, address string) (net.Conn, error) {
652	if d.v == nil {
653		return d.dialer.Dial(network, address)
654	}
655
656	return d.v.dial(network, address)
657}
658