1package crr
2
3import (
4	"sync/atomic"
5)
6
7// EndpointCache is an LRU cache that holds a series of endpoints
8// based on some key. The datastructure makes use of a read write
9// mutex to enable asynchronous use.
10type EndpointCache struct {
11	endpoints     syncMap
12	endpointLimit int64
13	// size is used to count the number elements in the cache.
14	// The atomic package is used to ensure this size is accurate when
15	// using multiple goroutines.
16	size int64
17}
18
19// NewEndpointCache will return a newly initialized cache with a limit
20// of endpointLimit entries.
21func NewEndpointCache(endpointLimit int64) *EndpointCache {
22	return &EndpointCache{
23		endpointLimit: endpointLimit,
24		endpoints:     newSyncMap(),
25	}
26}
27
28// get is a concurrent safe get operation that will retrieve an endpoint
29// based on endpointKey. A boolean will also be returned to illustrate whether
30// or not the endpoint had been found.
31func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
32	endpoint, ok := c.endpoints.Load(endpointKey)
33	if !ok {
34		return Endpoint{}, false
35	}
36
37	c.endpoints.Store(endpointKey, endpoint)
38	return endpoint.(Endpoint), true
39}
40
41// Has returns if the enpoint cache contains a valid entry for the endpoint key
42// provided.
43func (c *EndpointCache) Has(endpointKey string) bool {
44	endpoint, ok := c.get(endpointKey)
45	_, found := endpoint.GetValidAddress()
46
47	return ok && found
48}
49
50// Get will retrieve a weighted address  based off of the endpoint key. If an endpoint
51// should be retrieved, due to not existing or the current endpoint has expired
52// the Discoverer object that was passed in will attempt to discover a new endpoint
53// and add that to the cache.
54func (c *EndpointCache) Get(d Discoverer, endpointKey string, required bool) (WeightedAddress, error) {
55	var err error
56	endpoint, ok := c.get(endpointKey)
57	weighted, found := endpoint.GetValidAddress()
58	shouldGet := !ok || !found
59
60	if required && shouldGet {
61		if endpoint, err = c.discover(d, endpointKey); err != nil {
62			return WeightedAddress{}, err
63		}
64
65		weighted, _ = endpoint.GetValidAddress()
66	} else if shouldGet {
67		go c.discover(d, endpointKey)
68	}
69
70	return weighted, nil
71}
72
73// Add is a concurrent safe operation that will allow new endpoints to be added
74// to the cache. If the cache is full, the number of endpoints equal endpointLimit,
75// then this will remove the oldest entry before adding the new endpoint.
76func (c *EndpointCache) Add(endpoint Endpoint) {
77	// de-dups multiple adds of an endpoint with a pre-existing key
78	if iface, ok := c.endpoints.Load(endpoint.Key); ok {
79		e := iface.(Endpoint)
80		if e.Len() > 0 {
81			return
82		}
83	}
84	c.endpoints.Store(endpoint.Key, endpoint)
85
86	size := atomic.AddInt64(&c.size, 1)
87	if size > 0 && size > c.endpointLimit {
88		c.deleteRandomKey()
89	}
90}
91
92// deleteRandomKey will delete a random key from the cache. If
93// no key was deleted false will be returned.
94func (c *EndpointCache) deleteRandomKey() bool {
95	atomic.AddInt64(&c.size, -1)
96	found := false
97
98	c.endpoints.Range(func(key, value interface{}) bool {
99		found = true
100		c.endpoints.Delete(key)
101
102		return false
103	})
104
105	return found
106}
107
108// discover will get and store and endpoint using the Discoverer.
109func (c *EndpointCache) discover(d Discoverer, endpointKey string) (Endpoint, error) {
110	endpoint, err := d.Discover()
111	if err != nil {
112		return Endpoint{}, err
113	}
114
115	endpoint.Key = endpointKey
116	c.Add(endpoint)
117
118	return endpoint, nil
119}
120