1// Copyright (C) 2020 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package sync2
5
6import (
7	"context"
8	"math"
9	"sync"
10	"sync/atomic"
11
12	"github.com/zeebo/errs"
13)
14
15// SuccessThreshold tracks task formed by a known amount of concurrent tasks.
16// It notifies the caller when reached a specific successful threshold without
17// interrupting the remaining tasks.
18type SuccessThreshold struct {
19	noCopy noCopy // nolint: structcheck
20
21	toSucceed int64
22	pending   int64
23
24	successes int64
25	failures  int64
26
27	done chan struct{}
28	once sync.Once
29}
30
31// NewSuccessThreshold creates a SuccessThreshold with the tasks number and
32// successThreshold.
33//
34// It returns an error if tasks is less or equal than 1 or successThreshold
35// is less or equal than 0 or greater or equal than 1.
36func NewSuccessThreshold(tasks int, successThreshold float64) (*SuccessThreshold, error) {
37	switch {
38	case tasks <= 1:
39		return nil, errs.New(
40			"invalid number of tasks. It must be greater than 1, got %d", tasks,
41		)
42	case successThreshold <= 0 || successThreshold > 1:
43		return nil, errs.New(
44			"invalid successThreshold. It must be greater than 0 and less or equal to 1, got %f", successThreshold,
45		)
46	}
47
48	tasksToSuccess := int64(math.Ceil(float64(tasks) * successThreshold))
49
50	// just in case of floating point issues
51	if tasksToSuccess > int64(tasks) {
52		tasksToSuccess = int64(tasks)
53	}
54
55	return &SuccessThreshold{
56		toSucceed: tasksToSuccess,
57		pending:   int64(tasks),
58		done:      make(chan struct{}),
59	}, nil
60}
61
62// Success tells the SuccessThreshold that one tasks was successful.
63func (threshold *SuccessThreshold) Success() {
64	atomic.AddInt64(&threshold.successes, 1)
65
66	if atomic.AddInt64(&threshold.toSucceed, -1) <= 0 {
67		threshold.markAsDone()
68	}
69
70	if atomic.AddInt64(&threshold.pending, -1) <= 0 {
71		threshold.markAsDone()
72	}
73}
74
75// Failure tells the SuccessThreshold that one task was a failure.
76func (threshold *SuccessThreshold) Failure() {
77	atomic.AddInt64(&threshold.failures, 1)
78
79	if atomic.AddInt64(&threshold.pending, -1) <= 0 {
80		threshold.markAsDone()
81	}
82}
83
84// Wait blocks the caller until the successThreshold is reached or all the tasks
85// have finished.
86func (threshold *SuccessThreshold) Wait(ctx context.Context) {
87	select {
88	case <-ctx.Done():
89	case <-threshold.done:
90	}
91}
92
93// markAsDone finalizes threshold closing the completed channel just once.
94// It's safe to be called multiple times.
95func (threshold *SuccessThreshold) markAsDone() {
96	threshold.once.Do(func() {
97		close(threshold.done)
98	})
99}
100
101// SuccessCount returns the number of successes so far.
102func (threshold *SuccessThreshold) SuccessCount() int {
103	return int(atomic.LoadInt64(&threshold.successes))
104}
105
106// FailureCount returns the number of failures so far.
107func (threshold *SuccessThreshold) FailureCount() int {
108	return int(atomic.LoadInt64(&threshold.failures))
109}
110