1// Package client implements the API for a TURN client
2package client
3
4import (
5	"errors"
6	"fmt"
7	"io"
8	"math"
9	"net"
10	"sync"
11	"time"
12
13	"github.com/pion/logging"
14	"github.com/pion/stun"
15	"github.com/pion/turn/v2/internal/proto"
16)
17
18const (
19	maxReadQueueSize    = 1024
20	permRefreshInterval = 120 * time.Second
21	maxRetryAttempts    = 3
22)
23
24const (
25	timerIDRefreshAlloc int = iota
26	timerIDRefreshPerms
27)
28
29func noDeadline() time.Time {
30	return time.Time{}
31}
32
33type inboundData struct {
34	data []byte
35	from net.Addr
36}
37
38// UDPConnObserver is an interface to UDPConn observer
39type UDPConnObserver interface {
40	TURNServerAddr() net.Addr
41	Username() stun.Username
42	Realm() stun.Realm
43	WriteTo(data []byte, to net.Addr) (int, error)
44	PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error)
45	OnDeallocated(relayedAddr net.Addr)
46}
47
48// UDPConnConfig is a set of configuration params use by NewUDPConn
49type UDPConnConfig struct {
50	Observer    UDPConnObserver
51	RelayedAddr net.Addr
52	Integrity   stun.MessageIntegrity
53	Nonce       stun.Nonce
54	Lifetime    time.Duration
55	Log         logging.LeveledLogger
56}
57
58// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
59// comatible with net.PacketConn and net.Conn
60type UDPConn struct {
61	obs               UDPConnObserver       // read-only
62	relayedAddr       net.Addr              // read-only
63	permMap           *permissionMap        // thread-safe
64	bindingMgr        *bindingManager       // thread-safe
65	integrity         stun.MessageIntegrity // read-only
66	_nonce            stun.Nonce            // needs mutex x
67	_lifetime         time.Duration         // needs mutex x
68	readCh            chan *inboundData     // thread-safe
69	closeCh           chan struct{}         // thread-safe
70	readTimer         *time.Timer           // thread-safe
71	refreshAllocTimer *PeriodicTimer        // thread-safe
72	refreshPermsTimer *PeriodicTimer        // thread-safe
73	mutex             sync.RWMutex          // thread-safe
74	log               logging.LeveledLogger // read-only
75}
76
77// NewUDPConn creates a new instance of UDPConn
78func NewUDPConn(config *UDPConnConfig) *UDPConn {
79	c := &UDPConn{
80		obs:         config.Observer,
81		relayedAddr: config.RelayedAddr,
82		permMap:     newPermissionMap(),
83		bindingMgr:  newBindingManager(),
84		integrity:   config.Integrity,
85		_nonce:      config.Nonce,
86		_lifetime:   config.Lifetime,
87		readCh:      make(chan *inboundData, maxReadQueueSize),
88		closeCh:     make(chan struct{}),
89		readTimer:   time.NewTimer(time.Duration(math.MaxInt64)),
90		log:         config.Log,
91	}
92
93	c.log.Debugf("initial lifetime: %d seconds", int(c.lifetime().Seconds()))
94
95	c.refreshAllocTimer = NewPeriodicTimer(
96		timerIDRefreshAlloc,
97		c.onRefreshTimers,
98		c.lifetime()/2,
99	)
100
101	c.refreshPermsTimer = NewPeriodicTimer(
102		timerIDRefreshPerms,
103		c.onRefreshTimers,
104		permRefreshInterval,
105	)
106
107	if c.refreshAllocTimer.Start() {
108		c.log.Debugf("refreshAllocTimer started")
109	}
110	if c.refreshPermsTimer.Start() {
111		c.log.Debugf("refreshPermsTimer started")
112	}
113
114	return c
115}
116
117// ReadFrom reads a packet from the connection,
118// copying the payload into p. It returns the number of
119// bytes copied into p and the return address that
120// was on the packet.
121// It returns the number of bytes read (0 <= n <= len(p))
122// and any error encountered. Callers should always process
123// the n > 0 bytes returned before considering the error err.
124// ReadFrom can be made to time out and return
125// an Error with Timeout() == true after a fixed time limit;
126// see SetDeadline and SetReadDeadline.
127func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
128	for {
129		select {
130		case ibData := <-c.readCh:
131			n := copy(p, ibData.data)
132			if n < len(ibData.data) {
133				return 0, nil, io.ErrShortBuffer
134			}
135			return n, ibData.from, nil
136
137		case <-c.readTimer.C:
138			return 0, nil, &net.OpError{
139				Op:   "read",
140				Net:  c.LocalAddr().Network(),
141				Addr: c.LocalAddr(),
142				Err:  newTimeoutError("i/o timeout"),
143			}
144
145		case <-c.closeCh:
146			return 0, nil, &net.OpError{
147				Op:   "read",
148				Net:  c.LocalAddr().Network(),
149				Addr: c.LocalAddr(),
150				Err:  errClosed,
151			}
152		}
153	}
154}
155
156// WriteTo writes a packet with payload p to addr.
157// WriteTo can be made to time out and return
158// an Error with Timeout() == true after a fixed time limit;
159// see SetDeadline and SetWriteDeadline.
160// On packet-oriented connections, write timeouts are rare.
161func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { //nolint: gocognit
162	var err error
163	_, ok := addr.(*net.UDPAddr)
164	if !ok {
165		return 0, errUDPAddrCast
166	}
167
168	// check if we have a permission for the destination IP addr
169	perm, ok := c.permMap.find(addr)
170	if !ok {
171		perm = &permission{}
172		c.permMap.insert(addr, perm)
173	}
174
175	// This func-block would block, per destination IP (, or perm), until
176	// the perm state becomes "requested". Purpose of this is to guarantee
177	// the order of packets (within the same perm).
178	// Note that CreatePermission transaction may not be complete before
179	// all the data transmission. This is done assuming that the request
180	// will be mostly likely successful and we can tolerate some loss of
181	// UDP packet (or reorder), inorder to minimize the latency in most cases.
182	createPermission := func() error {
183		perm.mutex.Lock()
184		defer perm.mutex.Unlock()
185
186		if perm.state() == permStateIdle {
187			// punch a hole! (this would block a bit..)
188			if err = c.createPermissions(addr); err != nil {
189				c.permMap.delete(addr)
190				return err
191			}
192			perm.setState(permStatePermitted)
193		}
194		return nil
195	}
196
197	for i := 0; i < maxRetryAttempts; i++ {
198		if err = createPermission(); !errors.Is(err, errTryAgain) {
199			break
200		}
201	}
202	if err != nil {
203		return 0, err
204	}
205
206	// bind channel
207	b, ok := c.bindingMgr.findByAddr(addr)
208	if !ok {
209		b = c.bindingMgr.create(addr)
210	}
211
212	bindSt := b.state()
213
214	if bindSt == bindingStateIdle || bindSt == bindingStateRequest || bindSt == bindingStateFailed {
215		func() {
216			// block only callers with the same binding until
217			// the binding transaction has been complete
218			b.muBind.Lock()
219			defer b.muBind.Unlock()
220
221			// binding state may have been changed while waiting. check again.
222			if b.state() == bindingStateIdle {
223				b.setState(bindingStateRequest)
224				go func() {
225					err2 := c.bind(b)
226					if err2 != nil {
227						c.log.Warnf("bind() failed: %s", err2.Error())
228						b.setState(bindingStateFailed)
229						// keep going...
230					} else {
231						b.setState(bindingStateReady)
232					}
233				}()
234			}
235		}()
236
237		// send data using SendIndication
238		peerAddr := addr2PeerAddress(addr)
239		var msg *stun.Message
240		msg, err = stun.Build(
241			stun.TransactionID,
242			stun.NewType(stun.MethodSend, stun.ClassIndication),
243			proto.Data(p),
244			peerAddr,
245			stun.Fingerprint,
246		)
247		if err != nil {
248			return 0, err
249		}
250
251		// indication has no transaction (fire-and-forget)
252
253		return c.obs.WriteTo(msg.Raw, c.obs.TURNServerAddr())
254	}
255
256	// binding is either ready
257
258	// check if the binding needs a refresh
259	func() {
260		b.muBind.Lock()
261		defer b.muBind.Unlock()
262
263		if b.state() == bindingStateReady && time.Since(b.refreshedAt()) > 5*time.Minute {
264			b.setState(bindingStateRefresh)
265			go func() {
266				err = c.bind(b)
267				if err != nil {
268					c.log.Warnf("bind() for refresh failed: %s", err.Error())
269					b.setState(bindingStateFailed)
270					// keep going...
271				} else {
272					b.setRefreshedAt(time.Now())
273					b.setState(bindingStateReady)
274				}
275			}()
276		}
277	}()
278
279	// send via ChannelData
280	return c.sendChannelData(p, b.number)
281}
282
283// Close closes the connection.
284// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
285func (c *UDPConn) Close() error {
286	c.refreshAllocTimer.Stop()
287	c.refreshPermsTimer.Stop()
288
289	select {
290	case <-c.closeCh:
291		return errAlreadyClosed
292	default:
293		close(c.closeCh)
294	}
295
296	c.obs.OnDeallocated(c.relayedAddr)
297	return c.refreshAllocation(0, true /* dontWait=true */)
298}
299
300// LocalAddr returns the local network address.
301func (c *UDPConn) LocalAddr() net.Addr {
302	return c.relayedAddr
303}
304
305// SetDeadline sets the read and write deadlines associated
306// with the connection. It is equivalent to calling both
307// SetReadDeadline and SetWriteDeadline.
308//
309// A deadline is an absolute time after which I/O operations
310// fail with a timeout (see type Error) instead of
311// blocking. The deadline applies to all future and pending
312// I/O, not just the immediately following call to ReadFrom or
313// WriteTo. After a deadline has been exceeded, the connection
314// can be refreshed by setting a deadline in the future.
315//
316// An idle timeout can be implemented by repeatedly extending
317// the deadline after successful ReadFrom or WriteTo calls.
318//
319// A zero value for t means I/O operations will not time out.
320func (c *UDPConn) SetDeadline(t time.Time) error {
321	return c.SetReadDeadline(t)
322}
323
324// SetReadDeadline sets the deadline for future ReadFrom calls
325// and any currently-blocked ReadFrom call.
326// A zero value for t means ReadFrom will not time out.
327func (c *UDPConn) SetReadDeadline(t time.Time) error {
328	var d time.Duration
329	if t == noDeadline() {
330		d = time.Duration(math.MaxInt64)
331	} else {
332		d = time.Until(t)
333	}
334	c.readTimer.Reset(d)
335	return nil
336}
337
338// SetWriteDeadline sets the deadline for future WriteTo calls
339// and any currently-blocked WriteTo call.
340// Even if write times out, it may return n > 0, indicating that
341// some of the data was successfully written.
342// A zero value for t means WriteTo will not time out.
343func (c *UDPConn) SetWriteDeadline(t time.Time) error {
344	// Write never blocks.
345	return nil
346}
347
348func addr2PeerAddress(addr net.Addr) proto.PeerAddress {
349	var peerAddr proto.PeerAddress
350	switch a := addr.(type) {
351	case *net.UDPAddr:
352		peerAddr.IP = a.IP
353		peerAddr.Port = a.Port
354	case *net.TCPAddr:
355		peerAddr.IP = a.IP
356		peerAddr.Port = a.Port
357	}
358
359	return peerAddr
360}
361
362func (c *UDPConn) createPermissions(addrs ...net.Addr) error {
363	setters := []stun.Setter{
364		stun.TransactionID,
365		stun.NewType(stun.MethodCreatePermission, stun.ClassRequest),
366	}
367
368	for _, addr := range addrs {
369		setters = append(setters, addr2PeerAddress(addr))
370	}
371
372	setters = append(setters,
373		c.obs.Username(),
374		c.obs.Realm(),
375		c.nonce(),
376		c.integrity,
377		stun.Fingerprint)
378
379	msg, err := stun.Build(setters...)
380	if err != nil {
381		return err
382	}
383
384	trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), false)
385	if err != nil {
386		return err
387	}
388
389	res := trRes.Msg
390
391	if res.Type.Class == stun.ClassErrorResponse {
392		var code stun.ErrorCodeAttribute
393		if err = code.GetFrom(res); err == nil {
394			if code.Code == stun.CodeStaleNonce {
395				c.setNonceFromMsg(res)
396				return errTryAgain
397			}
398			return fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
399		}
400
401		return fmt.Errorf("%s", res.Type) //nolint:goerr113
402	}
403
404	return nil
405}
406
407// HandleInbound passes inbound data in UDPConn
408func (c *UDPConn) HandleInbound(data []byte, from net.Addr) {
409	// copy data
410	copied := make([]byte, len(data))
411	copy(copied, data)
412
413	select {
414	case c.readCh <- &inboundData{data: copied, from: from}:
415	default:
416		c.log.Warnf("receive buffer full")
417	}
418}
419
420// FindAddrByChannelNumber returns a peer address associated with the
421// channel number on this UDPConn
422func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) {
423	b, ok := c.bindingMgr.findByNumber(chNum)
424	if !ok {
425		return nil, false
426	}
427	return b.addr, true
428}
429
430func (c *UDPConn) setNonceFromMsg(msg *stun.Message) {
431	// Update nonce
432	var nonce stun.Nonce
433	if err := nonce.GetFrom(msg); err == nil {
434		c.setNonce(nonce)
435		c.log.Debug("refresh allocation: 438, got new nonce.")
436	} else {
437		c.log.Warn("refresh allocation: 438 but no nonce.")
438	}
439}
440
441func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) error {
442	msg, err := stun.Build(
443		stun.TransactionID,
444		stun.NewType(stun.MethodRefresh, stun.ClassRequest),
445		proto.Lifetime{Duration: lifetime},
446		c.obs.Username(),
447		c.obs.Realm(),
448		c.nonce(),
449		c.integrity,
450		stun.Fingerprint,
451	)
452	if err != nil {
453		return fmt.Errorf("%w: %s", errFailedToBuildRefreshRequest, err.Error())
454	}
455
456	c.log.Debugf("send refresh request (dontWait=%v)", dontWait)
457	trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), dontWait)
458	if err != nil {
459		return fmt.Errorf("%w: %s", errFailedToRefreshAllocation, err.Error())
460	}
461
462	if dontWait {
463		c.log.Debug("refresh request sent")
464		return nil
465	}
466
467	c.log.Debug("refresh request sent, and waiting response")
468
469	res := trRes.Msg
470	if res.Type.Class == stun.ClassErrorResponse {
471		var code stun.ErrorCodeAttribute
472		if err = code.GetFrom(res); err == nil {
473			if code.Code == stun.CodeStaleNonce {
474				c.setNonceFromMsg(res)
475				return errTryAgain
476			}
477			return err
478		}
479		return fmt.Errorf("%s", res.Type) //nolint:goerr113
480	}
481
482	// Getting lifetime from response
483	var updatedLifetime proto.Lifetime
484	if err := updatedLifetime.GetFrom(res); err != nil {
485		return fmt.Errorf("%w: %s", errFailedToGetLifetime, err.Error())
486	}
487
488	c.setLifetime(updatedLifetime.Duration)
489	c.log.Debugf("updated lifetime: %d seconds", int(c.lifetime().Seconds()))
490	return nil
491}
492
493func (c *UDPConn) refreshPermissions() error {
494	addrs := c.permMap.addrs()
495	if len(addrs) == 0 {
496		c.log.Debug("no permission to refresh")
497		return nil
498	}
499	if err := c.createPermissions(addrs...); err != nil {
500		if errors.Is(err, errTryAgain) {
501			return errTryAgain
502		}
503		c.log.Errorf("fail to refresh permissions: %s", err.Error())
504		return err
505	}
506	c.log.Debug("refresh permissions successful")
507	return nil
508}
509
510func (c *UDPConn) bind(b *binding) error {
511	setters := []stun.Setter{
512		stun.TransactionID,
513		stun.NewType(stun.MethodChannelBind, stun.ClassRequest),
514		addr2PeerAddress(b.addr),
515		proto.ChannelNumber(b.number),
516		c.obs.Username(),
517		c.obs.Realm(),
518		c.nonce(),
519		c.integrity,
520		stun.Fingerprint,
521	}
522
523	msg, err := stun.Build(setters...)
524	if err != nil {
525		return err
526	}
527
528	trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), false)
529	if err != nil {
530		c.bindingMgr.deleteByAddr(b.addr)
531		return err
532	}
533
534	res := trRes.Msg
535
536	if res.Type != stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse) {
537		return fmt.Errorf("unexpected response type %s", res.Type) //nolint:goerr113
538	}
539
540	c.log.Debugf("channel binding successful: %s %d", b.addr.String(), b.number)
541
542	// Success.
543	return nil
544}
545
546func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) {
547	chData := &proto.ChannelData{
548		Data:   data,
549		Number: proto.ChannelNumber(chNum),
550	}
551	chData.Encode()
552	return c.obs.WriteTo(chData.Raw, c.obs.TURNServerAddr())
553}
554
555func (c *UDPConn) onRefreshTimers(id int) {
556	c.log.Debugf("refresh timer %d expired", id)
557	switch id {
558	case timerIDRefreshAlloc:
559		var err error
560		lifetime := c.lifetime()
561		// limit the max retries on errTryAgain to 3
562		// when stale nonce returns, sencond retry should succeed
563		for i := 0; i < maxRetryAttempts; i++ {
564			err = c.refreshAllocation(lifetime, false)
565			if !errors.Is(err, errTryAgain) {
566				break
567			}
568		}
569		if err != nil {
570			c.log.Warnf("refresh allocation failed")
571		}
572	case timerIDRefreshPerms:
573		var err error
574		for i := 0; i < maxRetryAttempts; i++ {
575			err = c.refreshPermissions()
576			if !errors.Is(err, errTryAgain) {
577				break
578			}
579		}
580		if err != nil {
581			c.log.Warnf("refresh permissions failed")
582		}
583	}
584}
585
586func (c *UDPConn) nonce() stun.Nonce {
587	c.mutex.RLock()
588	defer c.mutex.RUnlock()
589
590	return c._nonce
591}
592
593func (c *UDPConn) setNonce(nonce stun.Nonce) {
594	c.mutex.Lock()
595	defer c.mutex.Unlock()
596
597	c.log.Debugf("set new nonce with %d bytes", len(nonce))
598	c._nonce = nonce
599}
600
601func (c *UDPConn) lifetime() time.Duration {
602	c.mutex.RLock()
603	defer c.mutex.RUnlock()
604
605	return c._lifetime
606}
607
608func (c *UDPConn) setLifetime(lifetime time.Duration) {
609	c.mutex.Lock()
610	defer c.mutex.Unlock()
611
612	c._lifetime = lifetime
613}
614