1package lifecycle
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"os/signal"
8	"syscall"
9	"time"
10
11	"golang.org/x/sync/errgroup"
12)
13
14var defaultSignals = []os.Signal{syscall.SIGINT, syscall.SIGTERM}
15
16type manager struct {
17	group *errgroup.Group
18
19	timeout time.Duration
20	sigs    []os.Signal
21
22	ctx      context.Context
23	cancel   func()
24	gctx     context.Context
25	deferred []func() error
26	panic    chan interface{}
27}
28
29type contextKey struct{}
30
31func fromContext(ctx context.Context) *manager {
32	m, ok := ctx.Value(contextKey{}).(*manager)
33	if !ok {
34		panic(fmt.Errorf("lifecycle: manager not in context"))
35	}
36	return m
37}
38
39// New returns a lifecycle manager with context derived from that
40// provided.
41func New(ctx context.Context, opts ...Option) context.Context {
42	m := &manager{
43		deferred: []func() error{},
44		panic:    make(chan interface{}, 1),
45	}
46
47	ctx = context.WithValue(ctx, contextKey{}, m)
48
49	m.sigs = make([]os.Signal, len(defaultSignals))
50	copy(m.sigs, defaultSignals)
51
52	m.ctx, m.cancel = context.WithCancel(ctx)
53	m.group, m.gctx = errgroup.WithContext(context.Background())
54
55	for _, o := range opts {
56		o(m)
57	}
58
59	return m.ctx
60}
61
62// Exists returns true if the context has a lifecycle manager attached
63func Exists(ctx context.Context) bool {
64	_, ok := ctx.Value(contextKey{}).(*manager)
65	return ok
66}
67
68func wrapFunc(ctx context.Context, fn func() error) func() error {
69	m := fromContext(ctx)
70
71	return func() error {
72		defer func() {
73			if r := recover(); r != nil {
74				m.panic <- r
75			}
76		}()
77		return fn()
78	}
79}
80
81// Go run a function in a new goroutine
82func Go(ctx context.Context, f ...func()) {
83	m := fromContext(ctx)
84
85	for _, t := range f {
86		fn := t
87		m.group.Go(wrapFunc(ctx, func() error {
88			fn()
89			return nil
90		}))
91	}
92}
93
94// GoErr runs a function that returns an error in a new goroutine. If any GoErr
95// or DeferErr func returns an error, only the first one will be returned by
96// Wait()
97func GoErr(ctx context.Context, f ...func() error) {
98	m := fromContext(ctx)
99
100	for _, fn := range f {
101		m.group.Go(wrapFunc(ctx, fn))
102	}
103}
104
105// Defer adds funcs that should be called after the Go funcs complete (either
106// clean or with errors) or a signal is received
107func Defer(ctx context.Context, deferred ...func()) {
108	m := fromContext(ctx)
109
110	for _, t := range deferred {
111		fn := t
112		m.deferred = append(m.deferred, wrapFunc(ctx, func() error {
113			fn()
114			return nil
115		}))
116	}
117}
118
119// DeferErr adds funcs, that return errors, that should be called after the Go
120// funcs complete (either clean or with errors) or a signal is received. If any
121// GoErr or DeferErr func returns an error, only the first one will be returned
122// by Wait()
123func DeferErr(ctx context.Context, deferred ...func() error) {
124	m := fromContext(ctx)
125
126	for _, fn := range deferred {
127		m.deferred = append(m.deferred, wrapFunc(ctx, fn))
128	}
129}
130
131// Wait blocks until all go routines have been completed.
132//
133// All funcs registered with Go and Defer _will_ complete under every
134// circumstance except a panic
135//
136// Funcs passed to Defer begin (and the context returned by New() is canceled)
137// when any of:
138//
139//   - All funcs registered with Go complete successfully
140//   - Any func registered with Go returns an error
141//   - A signal is received (by default SIGINT or SIGTERM, but can be changed by
142//     WithSignals
143//
144// Funcs registered with Go should stop and clean up when the context
145// returned by New() is canceled. If the func accepts a context argument, it
146// will be passed the context returned by New().
147//
148// WithTimeout() can be used to set a maximum amount of time, starting with the
149// context returned by New() is canceled, that Wait will wait before returning.
150//
151// The returned err is the first non-nil error returned by any func registered
152// with Go or Defer, otherwise nil.
153func Wait(ctx context.Context) error {
154	m := fromContext(ctx)
155
156	err := m.runPrimaryGroup()
157	m.cancel()
158	if err != nil {
159		_ = m.runDeferredGroup() // #nosec
160		return err
161	}
162
163	return m.runDeferredGroup()
164}
165
166// ErrSignal is returned by Wait if the reason it returned was because a signal
167// was caught
168type ErrSignal struct {
169	os.Signal
170}
171
172func (e ErrSignal) Error() string {
173	return fmt.Sprintf("lifecycle: caught signal: %v", e.Signal)
174}
175
176// runPrimaryGroup waits for all registered routines to
177// complete, returning on an error from any of them, or from
178// the receipt of a registered signal, or from a context cancelation.
179func (m *manager) runPrimaryGroup() error {
180	select {
181	case sig := <-m.signalReceived():
182		return ErrSignal{sig}
183	case err := <-m.runPrimaryGroupRoutines():
184		return err
185	case <-m.ctx.Done():
186		return m.ctx.Err()
187	case <-m.gctx.Done():
188		// the error from the gctx errgroup will be returned
189		// from errgroup.Wait() later in runDeferredGroupRoutines
190	case r := <-m.panic:
191		panic(r)
192	}
193	return nil
194}
195
196func (m *manager) runDeferredGroup() error {
197	ctx := context.Background()
198
199	if m.timeout > 0 {
200		var cancel context.CancelFunc
201		ctx, cancel = context.WithTimeout(ctx, m.timeout)
202		defer cancel() // releases resources if deferred functions return early
203	}
204
205	select {
206	case <-ctx.Done():
207		return ctx.Err()
208	case err := <-m.runDeferredGroupRoutines():
209		return err
210	case r := <-m.panic:
211		panic(r)
212	}
213}
214
215// A channel that receives any os signals registered to be received.
216// If not configured to receive signals, it will receive nothing.
217func (m *manager) signalReceived() <-chan os.Signal {
218	sigCh := make(chan os.Signal, 1)
219	if len(m.sigs) > 0 {
220		signal.Notify(sigCh, m.sigs...)
221	}
222	return sigCh
223}
224
225// A channel that notifies of errors caused while waiting for subroutines to finish.
226func (m *manager) runPrimaryGroupRoutines() <-chan error {
227	errs := make(chan error, 1)
228	go func() { errs <- m.group.Wait() }()
229	return errs
230}
231
232func (m *manager) runDeferredGroupRoutines() <-chan error {
233	errs := make(chan error, 1)
234	dg := errgroup.Group{}
235	dg.Go(m.group.Wait) // Wait for the primary group as well
236	for _, f := range m.deferred {
237		dg.Go(f)
238	}
239	go func() {
240		errs <- dg.Wait()
241	}()
242	return errs
243}
244