1package try
2
3import (
4	"errors"
5	"fmt"
6	"io"
7	"net/http"
8	"reflect"
9	"strings"
10
11	"github.com/kvtools/valkeyrie/store"
12)
13
14// ResponseCondition is a retry condition function.
15// It receives a response, and returns an error if the response failed the condition.
16type ResponseCondition func(*http.Response) error
17
18// BodyContains returns a retry condition function.
19// The condition returns an error if the request body does not contain all the given strings.
20func BodyContains(values ...string) ResponseCondition {
21	return func(res *http.Response) error {
22		body, err := io.ReadAll(res.Body)
23		if err != nil {
24			return fmt.Errorf("failed to read response body: %w", err)
25		}
26
27		for _, value := range values {
28			if !strings.Contains(string(body), value) {
29				return fmt.Errorf("could not find '%s' in body '%s'", value, string(body))
30			}
31		}
32		return nil
33	}
34}
35
36// BodyNotContains returns a retry condition function.
37// The condition returns an error if the request body  contain one of the given strings.
38func BodyNotContains(values ...string) ResponseCondition {
39	return func(res *http.Response) error {
40		body, err := io.ReadAll(res.Body)
41		if err != nil {
42			return fmt.Errorf("failed to read response body: %w", err)
43		}
44
45		for _, value := range values {
46			if strings.Contains(string(body), value) {
47				return fmt.Errorf("find '%s' in body '%s'", value, string(body))
48			}
49		}
50		return nil
51	}
52}
53
54// BodyContainsOr returns a retry condition function.
55// The condition returns an error if the request body does not contain one of the given strings.
56func BodyContainsOr(values ...string) ResponseCondition {
57	return func(res *http.Response) error {
58		body, err := io.ReadAll(res.Body)
59		if err != nil {
60			return fmt.Errorf("failed to read response body: %w", err)
61		}
62
63		for _, value := range values {
64			if strings.Contains(string(body), value) {
65				return nil
66			}
67		}
68		return fmt.Errorf("could not find '%v' in body '%s'", values, string(body))
69	}
70}
71
72// HasBody returns a retry condition function.
73// The condition returns an error if the request body does not have body content.
74func HasBody() ResponseCondition {
75	return func(res *http.Response) error {
76		body, err := io.ReadAll(res.Body)
77		if err != nil {
78			return fmt.Errorf("failed to read response body: %w", err)
79		}
80
81		if len(body) == 0 {
82			return errors.New("response doesn't have body content")
83		}
84		return nil
85	}
86}
87
88// HasCn returns a retry condition function.
89// The condition returns an error if the cn is not correct.
90func HasCn(cn string) ResponseCondition {
91	return func(res *http.Response) error {
92		if res.TLS == nil {
93			return errors.New("response doesn't have TLS")
94		}
95
96		if len(res.TLS.PeerCertificates) == 0 {
97			return errors.New("response TLS doesn't have peer certificates")
98		}
99
100		if res.TLS.PeerCertificates[0] == nil {
101			return errors.New("first peer certificate is nil")
102		}
103
104		commonName := res.TLS.PeerCertificates[0].Subject.CommonName
105		if cn != commonName {
106			return fmt.Errorf("common name don't match: %s != %s", cn, commonName)
107		}
108
109		return nil
110	}
111}
112
113// StatusCodeIs returns a retry condition function.
114// The condition returns an error if the given response's status code is not the
115// given HTTP status code.
116func StatusCodeIs(status int) ResponseCondition {
117	return func(res *http.Response) error {
118		if res.StatusCode != status {
119			return fmt.Errorf("got status code %d, wanted %d", res.StatusCode, status)
120		}
121		return nil
122	}
123}
124
125// HasHeader returns a retry condition function.
126// The condition returns an error if the response does not have a header set.
127func HasHeader(header string) ResponseCondition {
128	return func(res *http.Response) error {
129		if _, ok := res.Header[header]; !ok {
130			return errors.New("response doesn't contain header: " + header)
131		}
132		return nil
133	}
134}
135
136// HasHeaderValue returns a retry condition function.
137// The condition returns an error if the response does not have a header set, and a value for that header.
138// Has an option to test for an exact header match only, not just contains.
139func HasHeaderValue(header, value string, exactMatch bool) ResponseCondition {
140	return func(res *http.Response) error {
141		if _, ok := res.Header[header]; !ok {
142			return errors.New("response doesn't contain header: " + header)
143		}
144
145		matchFound := false
146		for _, hdr := range res.Header[header] {
147			if value != hdr && exactMatch {
148				return fmt.Errorf("got header %s with value %s, wanted %s", header, hdr, value)
149			}
150			if value == hdr {
151				matchFound = true
152			}
153		}
154
155		if !matchFound {
156			return fmt.Errorf("response doesn't contain header %s with value %s", header, value)
157		}
158		return nil
159	}
160}
161
162// HasHeaderStruct returns a retry condition function.
163// The condition returns an error if the response does contain the headers set, and matching contents.
164func HasHeaderStruct(header http.Header) ResponseCondition {
165	return func(res *http.Response) error {
166		for key := range header {
167			if _, ok := res.Header[key]; !ok {
168				return fmt.Errorf("header %s not present in the response. Expected headers: %v Got response headers: %v", key, header, res.Header)
169			}
170
171			// Header exists in the response, test it.
172			if !reflect.DeepEqual(header[key], res.Header[key]) {
173				return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key])
174			}
175		}
176		return nil
177	}
178}
179
180// DoCondition is a retry condition function.
181// It returns an error.
182type DoCondition func() error
183
184// KVExists is a retry condition function.
185// Verify if a Key exists in the store.
186func KVExists(kv store.Store, key string) DoCondition {
187	return func() error {
188		_, err := kv.Exists(key, nil)
189		return err
190	}
191}
192