1// Copyright (C) 2020 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package rpcpool
5
6import (
7	"context"
8	"crypto/tls"
9	"runtime"
10	"time"
11
12	"storj.io/common/peertls/tlsopts"
13	"storj.io/common/rpc/rpccache"
14	"storj.io/drpc"
15)
16
17// Options controls the options for a connection pool.
18type Options struct {
19	// Capacity is how many connections to keep open.
20	Capacity int
21
22	// KeyCapacity is the number of connections to keep open per cache key.
23	KeyCapacity int
24
25	// IdleExpiration is how long a connection in the pool is allowed to be
26	// kept idle. If zero, connections do not expire.
27	IdleExpiration time.Duration
28}
29
30// Pool is a wrapper around a cache of connections that allows one to get or
31// create new cached connections.
32type Pool struct {
33	cache *rpccache.Cache
34}
35
36// New constructs a new Pool with the Options.
37func New(opts Options) *Pool {
38	p := &Pool{cache: rpccache.New(rpccache.Options{
39		Expiration:  opts.IdleExpiration,
40		Capacity:    opts.Capacity,
41		KeyCapacity: opts.KeyCapacity,
42		Close: func(pv interface{}) error {
43			return pv.(*poolValue).conn.Close()
44		},
45		Stale: func(pv interface{}) bool {
46			select {
47			case <-pv.(*poolValue).conn.Closed():
48				return true
49			default:
50				return false
51			}
52		},
53	})}
54
55	// As much as I dislike finalizers, especially for cases where it handles
56	// file descriptors, I think it's important to add one here at least until
57	// a full audit of all of the uses of the rpc.Dialer type and ensuring they
58	// all get closed.
59	runtime.SetFinalizer(p, func(p *Pool) {
60		mon.Event("pool_leaked")
61		_ = p.Close()
62	})
63
64	return p
65}
66
67// poolKey is the type of keys in the cache.
68type poolKey struct {
69	key        string
70	tlsOptions *tlsopts.Options
71}
72
73// poolValue is the type of values in the cache.
74type poolValue struct {
75	conn  drpc.Conn
76	state *tls.ConnectionState
77}
78
79// Dialer is the type of function to create a new connection.
80type Dialer = func(context.Context) (drpc.Conn, *tls.ConnectionState, error)
81
82// Close closes all of the cached connections. It is safe to call on a nil receiver.
83func (p *Pool) Close() error {
84	if p == nil {
85		return nil
86	}
87
88	runtime.SetFinalizer(p, nil)
89	return p.cache.Close()
90}
91
92// get returns a drpc connection from the cache if possible, dialing if necessary.
93func (p *Pool) get(ctx context.Context, pk poolKey, dial Dialer) (pv *poolValue, err error) {
94	defer mon.Task()(&ctx)(&err)
95
96	if p != nil {
97		pv, ok := p.cache.Take(pk).(*poolValue)
98		if ok {
99			mon.Event("connection_from_cache")
100			return pv, nil
101		}
102	}
103
104	mon.Event("connection_dialed")
105	conn, state, err := dial(ctx)
106	if err != nil {
107		return nil, err
108	}
109
110	return &poolValue{
111		conn:  conn,
112		state: state,
113	}, nil
114}
115
116// Get looks up a connection with the same key and TLS options and returns it if it
117// exists. If it does not exist, it calls the dial function to create one. It is safe
118// to call on a nil receiver, and if so, always returns a dialed connection.
119func (p *Pool) Get(ctx context.Context, key string, tlsOptions *tlsopts.Options, dial Dialer) (
120	conn drpc.Conn, state *tls.ConnectionState, err error) {
121	defer mon.Task()(&ctx)(&err)
122
123	pk := poolKey{
124		key:        key,
125		tlsOptions: tlsOptions,
126	}
127
128	pv, err := p.get(ctx, pk, dial)
129	if err != nil {
130		return nil, nil, err
131	}
132
133	// if we have a nil pool, we always dial once and do not return a wrapped connection.
134	if p == nil {
135		return pv.conn, pv.state, nil
136	}
137
138	// we immediately place the connection back into the pool so that it may be used
139	// by the returned poolConn.
140	p.cache.Put(pk, pv)
141
142	return &poolConn{
143		ch:   make(chan struct{}),
144		pk:   pk,
145		dial: dial,
146		pool: p,
147	}, pv.state, nil
148}
149