1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package memoize supports memoizing the return values of functions with
6// idempotent results that are expensive to compute.
7//
8// To use this package, build a store and use it to acquire handles with the
9// Bind method.
10//
11package memoize
12
13import (
14	"context"
15	"flag"
16	"fmt"
17	"reflect"
18	"sync"
19	"sync/atomic"
20
21	"golang.org/x/tools/internal/xcontext"
22)
23
24var (
25	panicOnDestroyed = flag.Bool("memoize_panic_on_destroyed", false,
26		"Panic when a destroyed generation is read rather than returning an error. "+
27			"Panicking may make it easier to debug lifetime errors, especially when "+
28			"used with GOTRACEBACK=crash to see all running goroutines.")
29)
30
31// Store binds keys to functions, returning handles that can be used to access
32// the functions results.
33type Store struct {
34	mu sync.Mutex
35	// handles is the set of values stored.
36	handles map[interface{}]*Handle
37
38	// generations is the set of generations live in this store.
39	generations map[*Generation]struct{}
40}
41
42// Generation creates a new Generation associated with s. Destroy must be
43// called on the returned Generation once it is no longer in use. name is
44// for debugging purposes only.
45func (s *Store) Generation(name string) *Generation {
46	s.mu.Lock()
47	defer s.mu.Unlock()
48	if s.handles == nil {
49		s.handles = map[interface{}]*Handle{}
50		s.generations = map[*Generation]struct{}{}
51	}
52	g := &Generation{store: s, name: name}
53	s.generations[g] = struct{}{}
54	return g
55}
56
57// A Generation is a logical point in time of the cache life-cycle. Cache
58// entries associated with a Generation will not be removed until the
59// Generation is destroyed.
60type Generation struct {
61	// destroyed is 1 after the generation is destroyed. Atomic.
62	destroyed uint32
63	store     *Store
64	name      string
65	// wg tracks the reference count of this generation.
66	wg sync.WaitGroup
67}
68
69// Destroy waits for all operations referencing g to complete, then removes
70// all references to g from cache entries. Cache entries that no longer
71// reference any non-destroyed generation are removed. Destroy must be called
72// exactly once for each generation.
73func (g *Generation) Destroy() {
74	g.wg.Wait()
75	atomic.StoreUint32(&g.destroyed, 1)
76	g.store.mu.Lock()
77	defer g.store.mu.Unlock()
78	for k, e := range g.store.handles {
79		e.mu.Lock()
80		if _, ok := e.generations[g]; ok {
81			delete(e.generations, g) // delete even if it's dead, in case of dangling references to the entry.
82			if len(e.generations) == 0 {
83				delete(g.store.handles, k)
84				e.state = stateDestroyed
85				if e.cleanup != nil && e.value != nil {
86					e.cleanup(e.value)
87				}
88			}
89		}
90		e.mu.Unlock()
91	}
92	delete(g.store.generations, g)
93}
94
95// Acquire creates a new reference to g, and returns a func to release that
96// reference.
97func (g *Generation) Acquire(ctx context.Context) func() {
98	destroyed := atomic.LoadUint32(&g.destroyed)
99	if ctx.Err() != nil {
100		return func() {}
101	}
102	if destroyed != 0 {
103		panic("acquire on destroyed generation " + g.name)
104	}
105	g.wg.Add(1)
106	return g.wg.Done
107}
108
109// Arg is a marker interface that can be embedded to indicate a type is
110// intended for use as a Function argument.
111type Arg interface{ memoizeArg() }
112
113// Function is the type for functions that can be memoized.
114// The result must be a pointer.
115type Function func(ctx context.Context, arg Arg) interface{}
116
117type state int
118
119const (
120	stateIdle = iota
121	stateRunning
122	stateCompleted
123	stateDestroyed
124)
125
126// Handle is returned from a store when a key is bound to a function.
127// It is then used to access the results of that function.
128//
129// A Handle starts out in idle state, waiting for something to demand its
130// evaluation. It then transitions into running state. While it's running,
131// waiters tracks the number of Get calls waiting for a result, and the done
132// channel is used to notify waiters of the next state transition. Once the
133// evaluation finishes, value is set, state changes to completed, and done
134// is closed, unblocking waiters. Alternatively, as Get calls are cancelled,
135// they decrement waiters. If it drops to zero, the inner context is cancelled,
136// computation is abandoned, and state resets to idle to start the process over
137// again.
138type Handle struct {
139	key interface{}
140	mu  sync.Mutex
141
142	// generations is the set of generations in which this handle is valid.
143	generations map[*Generation]struct{}
144
145	state state
146	// done is set in running state, and closed when exiting it.
147	done chan struct{}
148	// cancel is set in running state. It cancels computation.
149	cancel context.CancelFunc
150	// waiters is the number of Gets outstanding.
151	waiters uint
152	// the function that will be used to populate the value
153	function Function
154	// value is set in completed state.
155	value interface{}
156	// cleanup, if non-nil, is used to perform any necessary clean-up on values
157	// produced by function.
158	cleanup func(interface{})
159}
160
161// Bind returns a handle for the given key and function.
162//
163// Each call to bind will return the same handle if it is already bound. Bind
164// will always return a valid handle, creating one if needed. Each key can
165// only have one handle at any given time. The value will be held at least
166// until the associated generation is destroyed. Bind does not cause the value
167// to be generated.
168//
169// If cleanup is non-nil, it will be called on any non-nil values produced by
170// function when they are no longer referenced.
171func (g *Generation) Bind(key interface{}, function Function, cleanup func(interface{})) *Handle {
172	// panic early if the function is nil
173	// it would panic later anyway, but in a way that was much harder to debug
174	if function == nil {
175		panic("the function passed to bind must not be nil")
176	}
177	if atomic.LoadUint32(&g.destroyed) != 0 {
178		panic("operation on destroyed generation " + g.name)
179	}
180	g.store.mu.Lock()
181	defer g.store.mu.Unlock()
182	h, ok := g.store.handles[key]
183	if !ok {
184		h := &Handle{
185			key:         key,
186			function:    function,
187			generations: map[*Generation]struct{}{g: {}},
188			cleanup:     cleanup,
189		}
190		g.store.handles[key] = h
191		return h
192	}
193	h.mu.Lock()
194	defer h.mu.Unlock()
195	if _, ok := h.generations[g]; !ok {
196		h.generations[g] = struct{}{}
197	}
198	return h
199}
200
201// Stats returns the number of each type of value in the store.
202func (s *Store) Stats() map[reflect.Type]int {
203	s.mu.Lock()
204	defer s.mu.Unlock()
205
206	result := map[reflect.Type]int{}
207	for k := range s.handles {
208		result[reflect.TypeOf(k)]++
209	}
210	return result
211}
212
213// DebugOnlyIterate iterates through all live cache entries and calls f on them.
214// It should only be used for debugging purposes.
215func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
216	s.mu.Lock()
217	defer s.mu.Unlock()
218
219	for k, e := range s.handles {
220		var v interface{}
221		e.mu.Lock()
222		if e.state == stateCompleted {
223			v = e.value
224		}
225		e.mu.Unlock()
226		if v == nil {
227			continue
228		}
229		f(k, v)
230	}
231}
232
233func (g *Generation) Inherit(hs ...*Handle) {
234	for _, h := range hs {
235		if atomic.LoadUint32(&g.destroyed) != 0 {
236			panic("inherit on destroyed generation " + g.name)
237		}
238
239		h.mu.Lock()
240		defer h.mu.Unlock()
241		if h.state == stateDestroyed {
242			panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
243		}
244		h.generations[g] = struct{}{}
245	}
246}
247
248// Cached returns the value associated with a handle.
249//
250// It will never cause the value to be generated.
251// It will return the cached value, if present.
252func (h *Handle) Cached(g *Generation) interface{} {
253	h.mu.Lock()
254	defer h.mu.Unlock()
255	if _, ok := h.generations[g]; !ok {
256		return nil
257	}
258	if h.state == stateCompleted {
259		return h.value
260	}
261	return nil
262}
263
264// Get returns the value associated with a handle.
265//
266// If the value is not yet ready, the underlying function will be invoked.
267// If ctx is cancelled, Get returns nil.
268func (h *Handle) Get(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
269	release := g.Acquire(ctx)
270	defer release()
271
272	if ctx.Err() != nil {
273		return nil, ctx.Err()
274	}
275	h.mu.Lock()
276	if _, ok := h.generations[g]; !ok {
277		h.mu.Unlock()
278
279		err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name)
280		if *panicOnDestroyed && ctx.Err() != nil {
281			panic(err)
282		}
283		return nil, err
284	}
285	switch h.state {
286	case stateIdle:
287		return h.run(ctx, g, arg)
288	case stateRunning:
289		return h.wait(ctx)
290	case stateCompleted:
291		defer h.mu.Unlock()
292		return h.value, nil
293	case stateDestroyed:
294		h.mu.Unlock()
295		err := fmt.Errorf("Get on destroyed entry %#v (type %T) in generation %v", h.key, h.key, g.name)
296		if *panicOnDestroyed {
297			panic(err)
298		}
299		return nil, err
300	default:
301		panic("unknown state")
302	}
303}
304
305// run starts h.function and returns the result. h.mu must be locked.
306func (h *Handle) run(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
307	childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
308	h.cancel = cancel
309	h.state = stateRunning
310	h.done = make(chan struct{})
311	function := h.function // Read under the lock
312
313	// Make sure that the generation isn't destroyed while we're running in it.
314	release := g.Acquire(ctx)
315	go func() {
316		defer release()
317		// Just in case the function does something expensive without checking
318		// the context, double-check we're still alive.
319		if childCtx.Err() != nil {
320			return
321		}
322		v := function(childCtx, arg)
323		if childCtx.Err() != nil {
324			// It's possible that v was computed despite the context cancellation. In
325			// this case we should ensure that it is cleaned up.
326			if h.cleanup != nil && v != nil {
327				h.cleanup(v)
328			}
329			return
330		}
331
332		h.mu.Lock()
333		defer h.mu.Unlock()
334		// It's theoretically possible that the handle has been cancelled out
335		// of the run that started us, and then started running again since we
336		// checked childCtx above. Even so, that should be harmless, since each
337		// run should produce the same results.
338		if h.state != stateRunning {
339			// v will never be used, so ensure that it is cleaned up.
340			if h.cleanup != nil && v != nil {
341				h.cleanup(v)
342			}
343			return
344		}
345		// At this point v will be cleaned up whenever h is destroyed.
346		h.value = v
347		h.function = nil
348		h.state = stateCompleted
349		close(h.done)
350	}()
351
352	return h.wait(ctx)
353}
354
355// wait waits for the value to be computed, or ctx to be cancelled. h.mu must be locked.
356func (h *Handle) wait(ctx context.Context) (interface{}, error) {
357	h.waiters++
358	done := h.done
359	h.mu.Unlock()
360
361	select {
362	case <-done:
363		h.mu.Lock()
364		defer h.mu.Unlock()
365		if h.state == stateCompleted {
366			return h.value, nil
367		}
368		return nil, nil
369	case <-ctx.Done():
370		h.mu.Lock()
371		defer h.mu.Unlock()
372		h.waiters--
373		if h.waiters == 0 && h.state == stateRunning {
374			h.cancel()
375			close(h.done)
376			h.state = stateIdle
377			h.done = nil
378			h.cancel = nil
379		}
380		return nil, ctx.Err()
381	}
382}
383