1package vnet
2
3import (
4	"fmt"
5	"io"
6	"math"
7	"net"
8	"sync"
9	"time"
10)
11
12const (
13	maxReadQueueSize = 1024
14)
15
16var noDeadline time.Time
17
18// UDPPacketConn is packet-oriented connection for UDP.
19type UDPPacketConn interface {
20	net.PacketConn
21	Read(b []byte) (int, error)
22	RemoteAddr() net.Addr
23	Write(b []byte) (int, error)
24}
25
26// vNet implements this
27type connObserver interface {
28	write(c Chunk) error
29	onClosed(addr net.Addr)
30	determineSourceIP(locIP, dstIP net.IP) net.IP
31}
32
33// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
34// comatible with net.PacketConn and net.Conn
35type UDPConn struct {
36	locAddr   *net.UDPAddr // read-only
37	remAddr   *net.UDPAddr // read-only
38	obs       connObserver // read-only
39	readCh    chan Chunk   // requires mutex for writers
40	muReadCh  sync.Mutex   // to mutex readCh writers
41	readTimer *time.Timer  // thread-safe
42}
43
44func newUDPConn(locAddr, remAddr *net.UDPAddr, obs connObserver) (*UDPConn, error) {
45	if obs == nil {
46		return nil, fmt.Errorf("obs cannot be nil")
47	}
48
49	return &UDPConn{
50		locAddr:   locAddr,
51		remAddr:   remAddr,
52		obs:       obs,
53		readCh:    make(chan Chunk, maxReadQueueSize),
54		readTimer: time.NewTimer(time.Duration(math.MaxInt64)),
55	}, nil
56}
57
58// ReadFrom reads a packet from the connection,
59// copying the payload into p. It returns the number of
60// bytes copied into p and the return address that
61// was on the packet.
62// It returns the number of bytes read (0 <= n <= len(p))
63// and any error encountered. Callers should always process
64// the n > 0 bytes returned before considering the error err.
65// ReadFrom can be made to time out and return
66// an Error with Timeout() == true after a fixed time limit;
67// see SetDeadline and SetReadDeadline.
68func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
69loop:
70	for {
71		select {
72		case chunk, ok := <-c.readCh:
73			if !ok {
74				break loop
75			}
76			var err error
77			n := copy(p, chunk.UserData())
78			addr := chunk.SourceAddr()
79			if n < len(chunk.UserData()) {
80				err = io.ErrShortBuffer
81			}
82
83			if c.remAddr != nil {
84				if addr.String() != c.remAddr.String() {
85					break // discard (shouldn't happen)
86				}
87			}
88			return n, addr, err
89
90		case <-c.readTimer.C:
91			return 0, nil, &net.OpError{
92				Op:   "read",
93				Net:  c.locAddr.Network(),
94				Addr: c.locAddr,
95				Err:  newTimeoutError("i/o timeout"),
96			}
97		}
98	}
99
100	return 0, nil, &net.OpError{
101		Op:   "read",
102		Net:  c.locAddr.Network(),
103		Addr: c.locAddr,
104		Err:  fmt.Errorf("use of closed network connection"),
105	}
106}
107
108// WriteTo writes a packet with payload p to addr.
109// WriteTo can be made to time out and return
110// an Error with Timeout() == true after a fixed time limit;
111// see SetDeadline and SetWriteDeadline.
112// On packet-oriented connections, write timeouts are rare.
113func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
114	dstAddr, ok := addr.(*net.UDPAddr)
115	if !ok {
116		return 0, fmt.Errorf("addr is not a net.UDPAddr")
117	}
118
119	srcIP := c.obs.determineSourceIP(c.locAddr.IP, dstAddr.IP)
120	if srcIP == nil {
121		return 0, fmt.Errorf("something went wrong with locAddr")
122	}
123	srcAddr := &net.UDPAddr{
124		IP:   srcIP,
125		Port: c.locAddr.Port,
126	}
127
128	chunk := newChunkUDP(srcAddr, dstAddr)
129	chunk.userData = make([]byte, len(p))
130	copy(chunk.userData, p)
131	if err := c.obs.write(chunk); err != nil {
132		return 0, err
133	}
134	return len(p), nil
135}
136
137// Close closes the connection.
138// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
139// See: https://play.golang.org/p/GrCRAII0VSN
140func (c *UDPConn) Close() error {
141loop:
142	for {
143		select {
144		case _, ok := <-c.readCh:
145			if !ok {
146				return fmt.Errorf("already closed")
147			}
148		default:
149			c.muReadCh.Lock()
150			close(c.readCh)
151			c.muReadCh.Unlock()
152			c.obs.onClosed(c.locAddr)
153			break loop
154		}
155	}
156	return nil
157}
158
159// LocalAddr returns the local network address.
160func (c *UDPConn) LocalAddr() net.Addr {
161	return c.locAddr
162}
163
164// SetDeadline sets the read and write deadlines associated
165// with the connection. It is equivalent to calling both
166// SetReadDeadline and SetWriteDeadline.
167//
168// A deadline is an absolute time after which I/O operations
169// fail with a timeout (see type Error) instead of
170// blocking. The deadline applies to all future and pending
171// I/O, not just the immediately following call to ReadFrom or
172// WriteTo. After a deadline has been exceeded, the connection
173// can be refreshed by setting a deadline in the future.
174//
175// An idle timeout can be implemented by repeatedly extending
176// the deadline after successful ReadFrom or WriteTo calls.
177//
178// A zero value for t means I/O operations will not time out.
179func (c *UDPConn) SetDeadline(t time.Time) error {
180	return c.SetReadDeadline(t)
181}
182
183// SetReadDeadline sets the deadline for future ReadFrom calls
184// and any currently-blocked ReadFrom call.
185// A zero value for t means ReadFrom will not time out.
186func (c *UDPConn) SetReadDeadline(t time.Time) error {
187	var d time.Duration
188	if t == noDeadline {
189		d = time.Duration(math.MaxInt64)
190	} else {
191		d = time.Until(t)
192	}
193	c.readTimer.Reset(d)
194	return nil
195}
196
197// SetWriteDeadline sets the deadline for future WriteTo calls
198// and any currently-blocked WriteTo call.
199// Even if write times out, it may return n > 0, indicating that
200// some of the data was successfully written.
201// A zero value for t means WriteTo will not time out.
202func (c *UDPConn) SetWriteDeadline(t time.Time) error {
203	// Write never blocks.
204	return nil
205}
206
207// Read reads data from the connection.
208// Read can be made to time out and return an Error with Timeout() == true
209// after a fixed time limit; see SetDeadline and SetReadDeadline.
210func (c *UDPConn) Read(b []byte) (int, error) {
211	n, _, err := c.ReadFrom(b)
212	return n, err
213}
214
215// RemoteAddr returns the remote network address.
216func (c *UDPConn) RemoteAddr() net.Addr {
217	return c.remAddr
218}
219
220// Write writes data to the connection.
221// Write can be made to time out and return an Error with Timeout() == true
222// after a fixed time limit; see SetDeadline and SetWriteDeadline.
223func (c *UDPConn) Write(b []byte) (int, error) {
224	if c.remAddr == nil {
225		return 0, fmt.Errorf("no remAddr defined")
226	}
227
228	return c.WriteTo(b, c.remAddr)
229}
230
231func (c *UDPConn) onInboundChunk(chunk Chunk) {
232	c.muReadCh.Lock()
233	defer c.muReadCh.Unlock()
234
235	select {
236	case c.readCh <- chunk:
237	default:
238	}
239}
240