1package wait
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"github.com/docker/go-connections/nat"
8	"time"
9)
10
11//ForSQL constructs a new waitForSql strategy for the given driver
12func ForSQL(port nat.Port, driver string, url func(nat.Port) string) *waitForSql {
13	return &waitForSql{
14		Port:   port,
15		URL:    url,
16		Driver: driver,
17	}
18}
19
20type waitForSql struct {
21	URL            func(port nat.Port) string
22	Driver         string
23	Port           nat.Port
24	startupTimeout time.Duration
25}
26
27//Timeout sets the maximum waiting time for the strategy after which it'll give up and return an error
28func (w *waitForSql) Timeout(duration time.Duration) *waitForSql {
29	w.startupTimeout = duration
30	return w
31}
32
33//WaitUntilReady repeatedly tries to run "SELECT 1" query on the given port using sql and driver.
34// If the it doesn't succeed until the timeout value which defaults to 10 seconds, it will return an error
35func (w *waitForSql) WaitUntilReady(ctx context.Context, target StrategyTarget) (err error) {
36	if w.startupTimeout == 0 {
37		w.startupTimeout = time.Second * 10
38	}
39	ctx, cancel := context.WithTimeout(ctx, w.startupTimeout)
40	defer cancel()
41
42	ticker := time.NewTicker(time.Millisecond * 100)
43	defer ticker.Stop()
44
45	port, err := target.MappedPort(ctx, w.Port)
46	if err != nil {
47		return fmt.Errorf("target.MappedPort: %v", err)
48	}
49
50	db, err := sql.Open(w.Driver, w.URL(port))
51	if err != nil {
52		return fmt.Errorf("sql.Open: %v", err)
53	}
54	for {
55		select {
56		case <-ctx.Done():
57			return ctx.Err()
58		case <-ticker.C:
59
60			if _, err := db.ExecContext(ctx, "SELECT 1"); err != nil {
61				continue
62			}
63			return nil
64		}
65	}
66}
67