1package wait
2
3import (
4	"context"
5	"crypto/tls"
6	"errors"
7	"fmt"
8	"net"
9	"net/http"
10	"strconv"
11	"time"
12
13	"github.com/docker/go-connections/nat"
14)
15
16// Implement interface
17var _ Strategy = (*HTTPStrategy)(nil)
18
19type HTTPStrategy struct {
20	// all Strategies should have a startupTimeout to avoid waiting infinitely
21	startupTimeout time.Duration
22
23	// additional properties
24	Port              nat.Port
25	Path              string
26	StatusCodeMatcher func(status int) bool
27	UseTLS            bool
28	AllowInsecure     bool
29}
30
31// NewHTTPStrategy constructs a HTTP strategy waiting on port 80 and status code 200
32func NewHTTPStrategy(path string) *HTTPStrategy {
33	return &HTTPStrategy{
34		startupTimeout:    defaultStartupTimeout(),
35		Port:              "80/tcp",
36		Path:              path,
37		StatusCodeMatcher: defaultStatusCodeMatcher,
38		UseTLS:            false,
39	}
40
41}
42
43func defaultStatusCodeMatcher(status int) bool {
44	return status == http.StatusOK
45}
46
47// fluent builders for each property
48// since go has neither covariance nor generics, the return type must be the type of the concrete implementation
49// this is true for all properties, even the "shared" ones like startupTimeout
50
51func (ws *HTTPStrategy) WithStartupTimeout(startupTimeout time.Duration) *HTTPStrategy {
52	ws.startupTimeout = startupTimeout
53	return ws
54}
55
56func (ws *HTTPStrategy) WithPort(port nat.Port) *HTTPStrategy {
57	ws.Port = port
58	return ws
59}
60
61func (ws *HTTPStrategy) WithStatusCodeMatcher(statusCodeMatcher func(status int) bool) *HTTPStrategy {
62	ws.StatusCodeMatcher = statusCodeMatcher
63	return ws
64}
65
66func (ws *HTTPStrategy) WithTLS(useTLS bool) *HTTPStrategy {
67	ws.UseTLS = useTLS
68	return ws
69}
70
71func (ws *HTTPStrategy) WithAllowInsecure(allowInsecure bool) *HTTPStrategy {
72	ws.AllowInsecure = allowInsecure
73	return ws
74}
75
76// ForHTTP is a convenience method similar to Wait.java
77// https://github.com/testcontainers/testcontainers-java/blob/1d85a3834bd937f80aad3a4cec249c027f31aeb4/core/src/main/java/org/testcontainers/containers/wait/strategy/Wait.java
78func ForHTTP(path string) *HTTPStrategy {
79	return NewHTTPStrategy(path)
80}
81
82// WaitUntilReady implements Strategy.WaitUntilReady
83func (ws *HTTPStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) (err error) {
84	// limit context to startupTimeout
85	ctx, cancelContext := context.WithTimeout(ctx, ws.startupTimeout)
86	defer cancelContext()
87
88	ipAddress, err := target.Host(ctx)
89	if err != nil {
90		return
91	}
92
93	port, err := target.MappedPort(ctx, ws.Port)
94	if err != nil {
95		return
96	}
97
98	if port.Proto() != "tcp" {
99		return errors.New("Cannot use HTTP client on non-TCP ports")
100	}
101
102	portNumber := port.Int()
103	portString := strconv.Itoa(portNumber)
104
105	address := net.JoinHostPort(ipAddress, portString)
106
107	var proto string
108	if ws.UseTLS {
109		proto = "https"
110	} else {
111		proto = "http"
112	}
113
114	url := fmt.Sprintf("%s://%s%s", proto, address, ws.Path)
115
116	tripper := http.DefaultTransport
117
118	if ws.AllowInsecure {
119		tripper.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
120	}
121
122	client := http.Client{Timeout: ws.startupTimeout, Transport: tripper}
123	req, err := http.NewRequest("GET", url, nil)
124	if err != nil {
125		return err
126	}
127
128	req = req.WithContext(ctx)
129
130Retry:
131	for {
132		select {
133		case <-ctx.Done():
134			break Retry
135		default:
136			resp, err := client.Do(req)
137			if err != nil || !ws.StatusCodeMatcher(resp.StatusCode) {
138				time.Sleep(100 * time.Millisecond)
139				continue
140			}
141
142			break Retry
143		}
144	}
145
146	return nil
147}
148