1package wanfed
2
3import (
4	"fmt"
5	"net"
6	"sync"
7	"time"
8)
9
10// connPool pools idle negotiated ALPN_WANGossipPacket flavored connections to
11// remote servers. Idle connections only remain pooled for up to maxTime after
12// they were last acquired.
13type connPool struct {
14	// maxTime is the maximum time to keep a connection open.
15	maxTime time.Duration
16
17	// mu protects pool and shutdown
18	mu       sync.Mutex
19	pool     map[string][]*conn
20	shutdown bool
21
22	shutdownCh chan struct{}
23	reapWg     sync.WaitGroup
24}
25
26func newConnPool(maxTime time.Duration) (*connPool, error) {
27	if maxTime == 0 {
28		return nil, fmt.Errorf("wanfed: conn pool needs a max time configured")
29	}
30
31	p := &connPool{
32		maxTime:    maxTime,
33		pool:       make(map[string][]*conn),
34		shutdownCh: make(chan struct{}),
35	}
36
37	p.reapWg.Add(1)
38	go p.reap()
39
40	return p, nil
41}
42
43func (p *connPool) Close() error {
44	p.mu.Lock()
45	defer p.mu.Unlock()
46
47	if p.shutdown {
48		return nil
49	}
50
51	for _, conns := range p.pool {
52		for _, conn := range conns {
53			conn.Close()
54		}
55	}
56	p.pool = nil
57	p.shutdown = true
58
59	close(p.shutdownCh)
60	p.reapWg.Wait()
61
62	return nil
63}
64
65// AcquireOrDial either removes an idle connection from the pool or
66// estabilishes a new one using the provided dialer function.
67func (p *connPool) AcquireOrDial(key string, dialer func() (net.Conn, error)) (*conn, error) {
68	c, err := p.maybeAcquire(key)
69	if err != nil {
70		return nil, err
71	}
72	if c != nil {
73		c.markForUse()
74		return c, nil
75	}
76
77	nc, err := dialer()
78	if err != nil {
79		return nil, err
80	}
81
82	c = &conn{
83		key:  key,
84		pool: p,
85		Conn: nc,
86	}
87	c.markForUse()
88
89	return c, nil
90}
91
92var errPoolClosed = fmt.Errorf("wanfed: connection pool is closed")
93
94// maybeAcquire removes an idle connection from the pool if possible otherwise
95// returns nil indicating there were no idle connections ready. It is the
96// caller's responsibility to open a new connection if that is desired.
97func (p *connPool) maybeAcquire(key string) (*conn, error) {
98	p.mu.Lock()
99	defer p.mu.Unlock()
100	if p.shutdown {
101		return nil, errPoolClosed
102	}
103	conns, ok := p.pool[key]
104	if !ok {
105		return nil, nil
106	}
107
108	switch len(conns) {
109	case 0:
110		delete(p.pool, key) // stray cleanup
111		return nil, nil
112
113	case 1:
114		c := conns[0]
115		delete(p.pool, key)
116		return c, nil
117
118	default:
119		sz := len(conns)
120		remaining, last := conns[0:sz-1], conns[sz-1]
121		p.pool[key] = remaining
122		return last, nil
123	}
124}
125
126// returnConn puts the connection back into the idle pool for reuse.
127func (p *connPool) returnConn(c *conn) error {
128	p.mu.Lock()
129	defer p.mu.Unlock()
130
131	if p.shutdown {
132		return c.Conn.Close() // actual shutdown
133	}
134
135	p.pool[c.key] = append(p.pool[c.key], c)
136
137	return nil
138}
139
140// reap periodically scans the idle pool for connections that have not been
141// used recently and closes them.
142func (p *connPool) reap() {
143	defer p.reapWg.Done()
144	for {
145		select {
146		case <-p.shutdownCh:
147			return
148		case <-time.After(time.Second):
149		}
150
151		p.reapOnce()
152	}
153}
154
155func (p *connPool) reapOnce() {
156	p.mu.Lock()
157	defer p.mu.Unlock()
158
159	if p.shutdown {
160		return
161	}
162
163	now := time.Now()
164
165	var removedKeys []string
166	for key, conns := range p.pool {
167		if len(conns) == 0 {
168			removedKeys = append(removedKeys, key) // cleanup
169			continue
170		}
171
172		var retain []*conn
173		for _, c := range conns {
174			// Skip recently used connections
175			if now.Sub(c.lastUsed) < p.maxTime {
176				retain = append(retain, c)
177			} else {
178				c.Conn.Close()
179			}
180		}
181
182		if len(retain) == len(conns) {
183			continue // no change
184
185		} else if len(retain) == 0 {
186			removedKeys = append(removedKeys, key)
187			continue
188		}
189
190		p.pool[key] = retain
191	}
192
193	for _, key := range removedKeys {
194		delete(p.pool, key)
195	}
196}
197
198type conn struct {
199	key string
200
201	mu       sync.Mutex
202	lastUsed time.Time
203	failed   bool
204	closed   bool
205
206	pool *connPool
207
208	net.Conn
209}
210
211func (c *conn) ReturnOrClose() error {
212	c.mu.Lock()
213	closed := c.closed
214	failed := c.failed
215	if failed {
216		c.closed = true
217	}
218	c.mu.Unlock()
219
220	if closed {
221		return nil
222	}
223
224	if failed {
225		return c.Conn.Close()
226	}
227
228	return c.pool.returnConn(c)
229}
230
231func (c *conn) Close() error {
232	c.mu.Lock()
233	closed := c.closed
234	c.closed = true
235	c.mu.Unlock()
236
237	if closed {
238		return nil
239	}
240
241	return c.Conn.Close()
242}
243
244func (c *conn) markForUse() {
245	c.mu.Lock()
246	c.lastUsed = time.Now()
247	c.failed = false
248	c.mu.Unlock()
249}
250
251func (c *conn) MarkFailed() {
252	c.mu.Lock()
253	c.failed = true
254	c.mu.Unlock()
255}
256