1package sd
2
3import (
4	"io"
5	"sort"
6	"sync"
7	"time"
8
9	"github.com/go-kit/kit/endpoint"
10	"github.com/go-kit/kit/log"
11)
12
13// endpointCache collects the most recent set of instances from a service discovery
14// system, creates endpoints for them using a factory function, and makes
15// them available to consumers.
16type endpointCache struct {
17	options            endpointerOptions
18	mtx                sync.RWMutex
19	factory            Factory
20	cache              map[string]endpointCloser
21	err                error
22	endpoints          []endpoint.Endpoint
23	logger             log.Logger
24	invalidateDeadline time.Time
25	timeNow            func() time.Time
26}
27
28type endpointCloser struct {
29	endpoint.Endpoint
30	io.Closer
31}
32
33// newEndpointCache returns a new, empty endpointCache.
34func newEndpointCache(factory Factory, logger log.Logger, options endpointerOptions) *endpointCache {
35	return &endpointCache{
36		options: options,
37		factory: factory,
38		cache:   map[string]endpointCloser{},
39		logger:  logger,
40		timeNow: time.Now,
41	}
42}
43
44// Update should be invoked by clients with a complete set of current instance
45// strings whenever that set changes. The cache manufactures new endpoints via
46// the factory, closes old endpoints when they disappear, and persists existing
47// endpoints if they survive through an update.
48func (c *endpointCache) Update(event Event) {
49	c.mtx.Lock()
50	defer c.mtx.Unlock()
51
52	// Happy path.
53	if event.Err == nil {
54		c.updateCache(event.Instances)
55		c.err = nil
56		return
57	}
58
59	// Sad path. Something's gone wrong in sd.
60	c.logger.Log("err", event.Err)
61	if !c.options.invalidateOnError {
62		return // keep returning the last known endpoints on error
63	}
64	if c.err != nil {
65		return // already in the error state, do nothing & keep original error
66	}
67	c.err = event.Err
68	// set new deadline to invalidate Endpoints unless non-error Event is received
69	c.invalidateDeadline = c.timeNow().Add(c.options.invalidateTimeout)
70	return
71}
72
73func (c *endpointCache) updateCache(instances []string) {
74	// Deterministic order (for later).
75	sort.Strings(instances)
76
77	// Produce the current set of services.
78	cache := make(map[string]endpointCloser, len(instances))
79	for _, instance := range instances {
80		// If it already exists, just copy it over.
81		if sc, ok := c.cache[instance]; ok {
82			cache[instance] = sc
83			delete(c.cache, instance)
84			continue
85		}
86
87		// If it doesn't exist, create it.
88		service, closer, err := c.factory(instance)
89		if err != nil {
90			c.logger.Log("instance", instance, "err", err)
91			continue
92		}
93		cache[instance] = endpointCloser{service, closer}
94	}
95
96	// Close any leftover endpoints.
97	for _, sc := range c.cache {
98		if sc.Closer != nil {
99			sc.Closer.Close()
100		}
101	}
102
103	// Populate the slice of endpoints.
104	endpoints := make([]endpoint.Endpoint, 0, len(cache))
105	for _, instance := range instances {
106		// A bad factory may mean an instance is not present.
107		if _, ok := cache[instance]; !ok {
108			continue
109		}
110		endpoints = append(endpoints, cache[instance].Endpoint)
111	}
112
113	// Swap and trigger GC for old copies.
114	c.endpoints = endpoints
115	c.cache = cache
116}
117
118// Endpoints yields the current set of (presumably identical) endpoints, ordered
119// lexicographically by the corresponding instance string.
120func (c *endpointCache) Endpoints() ([]endpoint.Endpoint, error) {
121	// in the steady state we're going to have many goroutines calling Endpoints()
122	// concurrently, so to minimize contention we use a shared R-lock.
123	c.mtx.RLock()
124
125	if c.err == nil || c.timeNow().Before(c.invalidateDeadline) {
126		defer c.mtx.RUnlock()
127		return c.endpoints, nil
128	}
129
130	c.mtx.RUnlock()
131
132	// in case of an error, switch to an exclusive lock.
133	c.mtx.Lock()
134	defer c.mtx.Unlock()
135
136	// re-check condition due to a race between RUnlock() and Lock().
137	if c.err == nil || c.timeNow().Before(c.invalidateDeadline) {
138		return c.endpoints, nil
139	}
140
141	c.updateCache(nil) // close any remaining active endpoints
142	return nil, c.err
143}
144