1// Copyright 2012-present Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7import (
8	"context"
9	"errors"
10	"net/http"
11	"sync/atomic"
12	"testing"
13	"time"
14)
15
16type testRetrier struct {
17	Retrier
18	N   int64
19	Err error
20}
21
22func (r *testRetrier) Retry(ctx context.Context, retry int, req *http.Request, resp *http.Response, err error) (time.Duration, bool, error) {
23	atomic.AddInt64(&r.N, 1)
24	if r.Err != nil {
25		return 0, false, r.Err
26	}
27	return r.Retrier.Retry(ctx, retry, req, resp, err)
28}
29
30func TestStopRetrier(t *testing.T) {
31	r := NewStopRetrier()
32	wait, ok, err := r.Retry(context.TODO(), 1, nil, nil, nil)
33	if want, got := 0*time.Second, wait; want != got {
34		t.Fatalf("expected %v, got %v", want, got)
35	}
36	if want, got := false, ok; want != got {
37		t.Fatalf("expected %v, got %v", want, got)
38	}
39	if err != nil {
40		t.Fatalf("expected nil, got %v", err)
41	}
42}
43
44func TestRetrier(t *testing.T) {
45	var numFailedReqs int
46	fail := func(r *http.Request) (*http.Response, error) {
47		numFailedReqs += 1
48		//return &http.Response{Request: r, StatusCode: 400}, nil
49		return nil, errors.New("request failed")
50	}
51
52	tr := &failingTransport{path: "/fail", fail: fail}
53	httpClient := &http.Client{Transport: tr}
54
55	retrier := &testRetrier{
56		Retrier: NewBackoffRetrier(NewSimpleBackoff(100, 100, 100, 100, 100)),
57	}
58
59	client, err := NewClient(
60		SetHttpClient(httpClient),
61		SetMaxRetries(5),
62		SetHealthcheck(false),
63		SetRetrier(retrier))
64	if err != nil {
65		t.Fatal(err)
66	}
67
68	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
69		Method: "GET",
70		Path:   "/fail",
71	})
72	if err == nil {
73		t.Fatal("expected error")
74	}
75	if res != nil {
76		t.Fatal("expected no response")
77	}
78	// Connection should be marked as dead after it failed
79	if numFailedReqs != 5 {
80		t.Errorf("expected %d failed requests; got: %d", 5, numFailedReqs)
81	}
82	if retrier.N != 5 {
83		t.Errorf("expected %d Retrier calls; got: %d", 5, retrier.N)
84	}
85}
86
87func TestRetrierWithError(t *testing.T) {
88	var numFailedReqs int
89	fail := func(r *http.Request) (*http.Response, error) {
90		numFailedReqs += 1
91		//return &http.Response{Request: r, StatusCode: 400}, nil
92		return nil, errors.New("request failed")
93	}
94
95	tr := &failingTransport{path: "/fail", fail: fail}
96	httpClient := &http.Client{Transport: tr}
97
98	kaboom := errors.New("kaboom")
99	retrier := &testRetrier{
100		Err:     kaboom,
101		Retrier: NewBackoffRetrier(NewSimpleBackoff(100, 100, 100, 100, 100)),
102	}
103
104	client, err := NewClient(
105		SetHttpClient(httpClient),
106		SetMaxRetries(5),
107		SetHealthcheck(false),
108		SetRetrier(retrier))
109	if err != nil {
110		t.Fatal(err)
111	}
112
113	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
114		Method: "GET",
115		Path:   "/fail",
116	})
117	if err != kaboom {
118		t.Fatalf("expected %v, got %v", kaboom, err)
119	}
120	if res != nil {
121		t.Fatal("expected no response")
122	}
123	if numFailedReqs != 1 {
124		t.Errorf("expected %d failed requests; got: %d", 1, numFailedReqs)
125	}
126	if retrier.N != 1 {
127		t.Errorf("expected %d Retrier calls; got: %d", 1, retrier.N)
128	}
129}
130
131func TestRetrierOnPerformRequest(t *testing.T) {
132	var numFailedReqs int
133	fail := func(r *http.Request) (*http.Response, error) {
134		numFailedReqs += 1
135		//return &http.Response{Request: r, StatusCode: 400}, nil
136		return nil, errors.New("request failed")
137	}
138
139	tr := &failingTransport{path: "/fail", fail: fail}
140	httpClient := &http.Client{Transport: tr}
141
142	defaultRetrier := &testRetrier{
143		Retrier: NewStopRetrier(),
144	}
145	requestRetrier := &testRetrier{
146		Retrier: NewStopRetrier(),
147	}
148
149	client, err := NewClient(
150		SetHttpClient(httpClient),
151		SetHealthcheck(false),
152		SetRetrier(defaultRetrier))
153	if err != nil {
154		t.Fatal(err)
155	}
156
157	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
158		Method:  "GET",
159		Path:    "/fail",
160		Retrier: requestRetrier,
161	})
162	if err == nil {
163		t.Fatal("expected error")
164	}
165	if res != nil {
166		t.Fatal("expected no response")
167	}
168	if want, have := int64(0), defaultRetrier.N; want != have {
169		t.Errorf("defaultRetrier: expected %d calls; got: %d", want, have)
170	}
171	if want, have := int64(1), requestRetrier.N; want != have {
172		t.Errorf("requestRetrier: expected %d calls; got: %d", want, have)
173	}
174}
175