1package backoff
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io"
8	"log"
9	"testing"
10	"time"
11)
12
13type testTimer struct {
14	timer *time.Timer
15}
16
17func (t *testTimer) Start(duration time.Duration) {
18	t.timer = time.NewTimer(0)
19}
20
21func (t *testTimer) Stop() {
22	if t.timer != nil {
23		t.timer.Stop()
24	}
25}
26
27func (t *testTimer) C() <-chan time.Time {
28	return t.timer.C
29}
30
31func TestRetry(t *testing.T) {
32	const successOn = 3
33	var i = 0
34
35	// This function is successful on "successOn" calls.
36	f := func() error {
37		i++
38		log.Printf("function is called %d. time\n", i)
39
40		if i == successOn {
41			log.Println("OK")
42			return nil
43		}
44
45		log.Println("error")
46		return errors.New("error")
47	}
48
49	err := RetryNotifyWithTimer(f, NewExponentialBackOff(), nil, &testTimer{})
50	if err != nil {
51		t.Errorf("unexpected error: %s", err.Error())
52	}
53	if i != successOn {
54		t.Errorf("invalid number of retries: %d", i)
55	}
56}
57
58func TestRetryContext(t *testing.T) {
59	var cancelOn = 3
60	var i = 0
61
62	ctx, cancel := context.WithCancel(context.Background())
63	defer cancel()
64
65	// This function cancels context on "cancelOn" calls.
66	f := func() error {
67		i++
68		log.Printf("function is called %d. time\n", i)
69
70		// cancelling the context in the operation function is not a typical
71		// use-case, however it allows to get predictable test results.
72		if i == cancelOn {
73			cancel()
74		}
75
76		log.Println("error")
77		return fmt.Errorf("error (%d)", i)
78	}
79
80	err := RetryNotifyWithTimer(f, WithContext(NewConstantBackOff(time.Millisecond), ctx), nil, &testTimer{})
81	if err == nil {
82		t.Errorf("error is unexpectedly nil")
83	}
84	if !errors.Is(err, context.Canceled) {
85		t.Errorf("unexpected error: %s", err.Error())
86	}
87	if i != cancelOn {
88		t.Errorf("invalid number of retries: %d", i)
89	}
90}
91
92func TestRetryPermanent(t *testing.T) {
93	ensureRetries := func(test string, shouldRetry bool, f func() error) {
94		numRetries := -1
95		maxRetries := 1
96
97		_ = RetryNotifyWithTimer(
98			func() error {
99				numRetries++
100				if numRetries >= maxRetries {
101					return Permanent(errors.New("forced"))
102				}
103				return f()
104			},
105			NewExponentialBackOff(),
106			nil,
107			&testTimer{},
108		)
109
110		if shouldRetry && numRetries == 0 {
111			t.Errorf("Test: '%s', backoff should have retried", test)
112		}
113
114		if !shouldRetry && numRetries > 0 {
115			t.Errorf("Test: '%s', backoff should not have retried", test)
116		}
117	}
118
119	for _, testCase := range []struct {
120		name        string
121		f           func() error
122		shouldRetry bool
123	}{
124		{
125			"nil test",
126			func() error {
127				return nil
128			},
129			false,
130		},
131		{
132			"io.EOF",
133			func() error {
134				return io.EOF
135			},
136			true,
137		},
138		{
139			"Permanent(io.EOF)",
140			func() error {
141				return Permanent(io.EOF)
142			},
143			false,
144		},
145		{
146			"Wrapped: Permanent(io.EOF)",
147			func() error {
148				return fmt.Errorf("Wrapped error: %w", Permanent(io.EOF))
149			},
150			false,
151		},
152	} {
153		ensureRetries(testCase.name, testCase.shouldRetry, testCase.f)
154	}
155}
156
157func TestPermanent(t *testing.T) {
158	want := errors.New("foo")
159	other := errors.New("bar")
160	var err error = Permanent(want)
161
162	got := errors.Unwrap(err)
163	if got != want {
164		t.Errorf("got %v, want %v", got, want)
165	}
166
167	if is := errors.Is(err, want); !is {
168		t.Errorf("err: %v is not %v", err, want)
169	}
170
171	if is := errors.Is(err, other); is {
172		t.Errorf("err: %v is %v", err, other)
173	}
174
175	wrapped := fmt.Errorf("wrapped: %w", err)
176	var permanent *PermanentError
177	if !errors.As(wrapped, &permanent) {
178		t.Errorf("errors.As(%v, %v)", wrapped, permanent)
179	}
180
181	err = Permanent(nil)
182	if err != nil {
183		t.Errorf("got %v, want nil", err)
184	}
185}
186