1// Copyright (c) 2020 The Jaeger Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package utils
16
17import (
18	"fmt"
19	"net"
20	"sync"
21	"sync/atomic"
22	"time"
23
24	"github.com/uber/jaeger-client-go/log"
25)
26
27// reconnectingUDPConn is an implementation of udpConn that resolves hostPort every resolveTimeout, if the resolved address is
28// different than the current conn then the new address is dialed and the conn is swapped.
29type reconnectingUDPConn struct {
30	hostPort    string
31	resolveFunc resolveFunc
32	dialFunc    dialFunc
33	logger      log.Logger
34	bufferBytes int64
35
36	connMtx   sync.RWMutex
37	conn      *net.UDPConn
38	destAddr  *net.UDPAddr
39	closeChan chan struct{}
40}
41
42type resolveFunc func(network string, hostPort string) (*net.UDPAddr, error)
43type dialFunc func(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error)
44
45// newReconnectingUDPConn returns a new udpConn that resolves hostPort every resolveTimeout, if the resolved address is
46// different than the current conn then the new address is dialed and the conn is swapped.
47func newReconnectingUDPConn(hostPort string, resolveTimeout time.Duration, resolveFunc resolveFunc, dialFunc dialFunc, logger log.Logger) (*reconnectingUDPConn, error) {
48	conn := &reconnectingUDPConn{
49		hostPort:    hostPort,
50		resolveFunc: resolveFunc,
51		dialFunc:    dialFunc,
52		logger:      logger,
53		closeChan:   make(chan struct{}),
54	}
55
56	if err := conn.attemptResolveAndDial(); err != nil {
57		logger.Error(fmt.Sprintf("failed resolving destination address on connection startup, with err: %q. retrying in %s", err.Error(), resolveTimeout))
58	}
59
60	go conn.reconnectLoop(resolveTimeout)
61
62	return conn, nil
63}
64
65func (c *reconnectingUDPConn) reconnectLoop(resolveTimeout time.Duration) {
66	ticker := time.NewTicker(resolveTimeout)
67	defer ticker.Stop()
68
69	for {
70		select {
71		case <-c.closeChan:
72			return
73		case <-ticker.C:
74			if err := c.attemptResolveAndDial(); err != nil {
75				c.logger.Error(err.Error())
76			}
77		}
78	}
79}
80
81func (c *reconnectingUDPConn) attemptResolveAndDial() error {
82	newAddr, err := c.resolveFunc("udp", c.hostPort)
83	if err != nil {
84		return fmt.Errorf("failed to resolve new addr for host %q, with err: %w", c.hostPort, err)
85	}
86
87	c.connMtx.RLock()
88	curAddr := c.destAddr
89	c.connMtx.RUnlock()
90
91	// dont attempt dial if an addr was successfully dialed previously and, resolved addr is the same as current conn
92	if curAddr != nil && newAddr.String() == curAddr.String() {
93		return nil
94	}
95
96	if err := c.attemptDialNewAddr(newAddr); err != nil {
97		return fmt.Errorf("failed to dial newly resolved addr '%s', with err: %w", newAddr, err)
98	}
99
100	return nil
101}
102
103func (c *reconnectingUDPConn) attemptDialNewAddr(newAddr *net.UDPAddr) error {
104	connUDP, err := c.dialFunc(newAddr.Network(), nil, newAddr)
105	if err != nil {
106		return err
107	}
108
109	if bufferBytes := int(atomic.LoadInt64(&c.bufferBytes)); bufferBytes != 0 {
110		if err = connUDP.SetWriteBuffer(bufferBytes); err != nil {
111			return err
112		}
113	}
114
115	c.connMtx.Lock()
116	c.destAddr = newAddr
117	// store prev to close later
118	prevConn := c.conn
119	c.conn = connUDP
120	c.connMtx.Unlock()
121
122	if prevConn != nil {
123		return prevConn.Close()
124	}
125
126	return nil
127}
128
129// Write calls net.udpConn.Write, if it fails an attempt is made to connect to a new addr, if that succeeds the write is retried before returning
130func (c *reconnectingUDPConn) Write(b []byte) (int, error) {
131	var bytesWritten int
132	var err error
133
134	c.connMtx.RLock()
135	if c.conn == nil {
136		// if connection is not initialized indicate this with err in order to hook into retry logic
137		err = fmt.Errorf("UDP connection not yet initialized, an address has not been resolved")
138	} else {
139		bytesWritten, err = c.conn.Write(b)
140	}
141	c.connMtx.RUnlock()
142
143	if err == nil {
144		return bytesWritten, nil
145	}
146
147	// attempt to resolve and dial new address in case that's the problem, if resolve and dial succeeds, try write again
148	if reconnErr := c.attemptResolveAndDial(); reconnErr == nil {
149		c.connMtx.RLock()
150		defer c.connMtx.RUnlock()
151		return c.conn.Write(b)
152	}
153
154	// return original error if reconn fails
155	return bytesWritten, err
156}
157
158// Close stops the reconnectLoop, then closes the connection via net.udpConn 's implementation
159func (c *reconnectingUDPConn) Close() error {
160	close(c.closeChan)
161
162	// acquire rw lock before closing conn to ensure calls to Write drain
163	c.connMtx.Lock()
164	defer c.connMtx.Unlock()
165
166	if c.conn != nil {
167		return c.conn.Close()
168	}
169
170	return nil
171}
172
173// SetWriteBuffer defers to the net.udpConn SetWriteBuffer implementation wrapped with a RLock. if no conn is currently held
174// and SetWriteBuffer is called store bufferBytes to be set for new conns
175func (c *reconnectingUDPConn) SetWriteBuffer(bytes int) error {
176	var err error
177
178	c.connMtx.RLock()
179	if c.conn != nil {
180		err = c.conn.SetWriteBuffer(bytes)
181	}
182	c.connMtx.RUnlock()
183
184	if err == nil {
185		atomic.StoreInt64(&c.bufferBytes, int64(bytes))
186	}
187
188	return err
189}
190