1package lifecycle 2 3import ( 4 "context" 5 "fmt" 6 "os" 7 "os/signal" 8 "syscall" 9 "time" 10 11 "golang.org/x/sync/errgroup" 12) 13 14var defaultSignals = []os.Signal{syscall.SIGINT, syscall.SIGTERM} 15 16type manager struct { 17 group *errgroup.Group 18 19 timeout time.Duration 20 sigs []os.Signal 21 22 ctx context.Context 23 cancel func() 24 gctx context.Context 25 deferred []func() error 26 panic chan interface{} 27} 28 29type contextKey struct{} 30 31func fromContext(ctx context.Context) *manager { 32 m, ok := ctx.Value(contextKey{}).(*manager) 33 if !ok { 34 panic(fmt.Errorf("lifecycle: manager not in context")) 35 } 36 return m 37} 38 39// New returns a lifecycle manager with context derived from that 40// provided. 41func New(ctx context.Context, opts ...Option) context.Context { 42 m := &manager{ 43 deferred: []func() error{}, 44 panic: make(chan interface{}, 1), 45 } 46 47 ctx = context.WithValue(ctx, contextKey{}, m) 48 49 m.sigs = make([]os.Signal, len(defaultSignals)) 50 copy(m.sigs, defaultSignals) 51 52 m.ctx, m.cancel = context.WithCancel(ctx) 53 m.group, m.gctx = errgroup.WithContext(context.Background()) 54 55 for _, o := range opts { 56 o(m) 57 } 58 59 return m.ctx 60} 61 62// Exists returns true if the context has a lifecycle manager attached 63func Exists(ctx context.Context) bool { 64 _, ok := ctx.Value(contextKey{}).(*manager) 65 return ok 66} 67 68func wrapFunc(ctx context.Context, fn func() error) func() error { 69 m := fromContext(ctx) 70 71 return func() error { 72 defer func() { 73 if r := recover(); r != nil { 74 m.panic <- r 75 } 76 }() 77 return fn() 78 } 79} 80 81// Go run a function in a new goroutine 82func Go(ctx context.Context, f ...func()) { 83 m := fromContext(ctx) 84 85 for _, t := range f { 86 fn := t 87 m.group.Go(wrapFunc(ctx, func() error { 88 fn() 89 return nil 90 })) 91 } 92} 93 94// GoErr runs a function that returns an error in a new goroutine. If any GoErr 95// or DeferErr func returns an error, only the first one will be returned by 96// Wait() 97func GoErr(ctx context.Context, f ...func() error) { 98 m := fromContext(ctx) 99 100 for _, fn := range f { 101 m.group.Go(wrapFunc(ctx, fn)) 102 } 103} 104 105// Defer adds funcs that should be called after the Go funcs complete (either 106// clean or with errors) or a signal is received 107func Defer(ctx context.Context, deferred ...func()) { 108 m := fromContext(ctx) 109 110 for _, t := range deferred { 111 fn := t 112 m.deferred = append(m.deferred, wrapFunc(ctx, func() error { 113 fn() 114 return nil 115 })) 116 } 117} 118 119// DeferErr adds funcs, that return errors, that should be called after the Go 120// funcs complete (either clean or with errors) or a signal is received. If any 121// GoErr or DeferErr func returns an error, only the first one will be returned 122// by Wait() 123func DeferErr(ctx context.Context, deferred ...func() error) { 124 m := fromContext(ctx) 125 126 for _, fn := range deferred { 127 m.deferred = append(m.deferred, wrapFunc(ctx, fn)) 128 } 129} 130 131// Wait blocks until all go routines have been completed. 132// 133// All funcs registered with Go and Defer _will_ complete under every 134// circumstance except a panic 135// 136// Funcs passed to Defer begin (and the context returned by New() is canceled) 137// when any of: 138// 139// - All funcs registered with Go complete successfully 140// - Any func registered with Go returns an error 141// - A signal is received (by default SIGINT or SIGTERM, but can be changed by 142// WithSignals 143// 144// Funcs registered with Go should stop and clean up when the context 145// returned by New() is canceled. If the func accepts a context argument, it 146// will be passed the context returned by New(). 147// 148// WithTimeout() can be used to set a maximum amount of time, starting with the 149// context returned by New() is canceled, that Wait will wait before returning. 150// 151// The returned err is the first non-nil error returned by any func registered 152// with Go or Defer, otherwise nil. 153func Wait(ctx context.Context) error { 154 m := fromContext(ctx) 155 156 err := m.runPrimaryGroup() 157 m.cancel() 158 if err != nil { 159 _ = m.runDeferredGroup() // #nosec 160 return err 161 } 162 163 return m.runDeferredGroup() 164} 165 166// ErrSignal is returned by Wait if the reason it returned was because a signal 167// was caught 168type ErrSignal struct { 169 os.Signal 170} 171 172func (e ErrSignal) Error() string { 173 return fmt.Sprintf("lifecycle: caught signal: %v", e.Signal) 174} 175 176// runPrimaryGroup waits for all registered routines to 177// complete, returning on an error from any of them, or from 178// the receipt of a registered signal, or from a context cancelation. 179func (m *manager) runPrimaryGroup() error { 180 select { 181 case sig := <-m.signalReceived(): 182 return ErrSignal{sig} 183 case err := <-m.runPrimaryGroupRoutines(): 184 return err 185 case <-m.ctx.Done(): 186 return m.ctx.Err() 187 case <-m.gctx.Done(): 188 // the error from the gctx errgroup will be returned 189 // from errgroup.Wait() later in runDeferredGroupRoutines 190 case r := <-m.panic: 191 panic(r) 192 } 193 return nil 194} 195 196func (m *manager) runDeferredGroup() error { 197 ctx := context.Background() 198 199 if m.timeout > 0 { 200 var cancel context.CancelFunc 201 ctx, cancel = context.WithTimeout(ctx, m.timeout) 202 defer cancel() // releases resources if deferred functions return early 203 } 204 205 select { 206 case <-ctx.Done(): 207 return ctx.Err() 208 case err := <-m.runDeferredGroupRoutines(): 209 return err 210 case r := <-m.panic: 211 panic(r) 212 } 213} 214 215// A channel that receives any os signals registered to be received. 216// If not configured to receive signals, it will receive nothing. 217func (m *manager) signalReceived() <-chan os.Signal { 218 sigCh := make(chan os.Signal, 1) 219 if len(m.sigs) > 0 { 220 signal.Notify(sigCh, m.sigs...) 221 } 222 return sigCh 223} 224 225// A channel that notifies of errors caused while waiting for subroutines to finish. 226func (m *manager) runPrimaryGroupRoutines() <-chan error { 227 errs := make(chan error, 1) 228 go func() { errs <- m.group.Wait() }() 229 return errs 230} 231 232func (m *manager) runDeferredGroupRoutines() <-chan error { 233 errs := make(chan error, 1) 234 dg := errgroup.Group{} 235 dg.Go(m.group.Wait) // Wait for the primary group as well 236 for _, f := range m.deferred { 237 dg.Go(f) 238 } 239 go func() { 240 errs <- dg.Wait() 241 }() 242 return errs 243} 244