1package sessionresolver
2
3import (
4	"context"
5	"errors"
6	"io"
7	"testing"
8	"time"
9
10	"github.com/google/go-cmp/cmp"
11)
12
13type FakeResolver struct {
14	Closed bool
15	Data   []string
16	Err    error
17	Sleep  time.Duration
18}
19
20func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
21	select {
22	case <-time.After(r.Sleep):
23		return r.Data, r.Err
24	case <-ctx.Done():
25		return nil, ctx.Err()
26	}
27}
28
29func (r *FakeResolver) CloseIdleConnections() {
30	r.Closed = true
31}
32
33func TestTimeLimitedLookupSuccess(t *testing.T) {
34	reso := &Resolver{}
35	re := &FakeResolver{
36		Data: []string{"8.8.8.8", "8.8.4.4"},
37	}
38	ctx := context.Background()
39	out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
40	if err != nil {
41		t.Fatal(err)
42	}
43	if diff := cmp.Diff(re.Data, out); diff != "" {
44		t.Fatal(diff)
45	}
46}
47
48func TestTimeLimitedLookupFailure(t *testing.T) {
49	reso := &Resolver{}
50	re := &FakeResolver{
51		Err: io.EOF,
52	}
53	ctx := context.Background()
54	out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
55	if !errors.Is(err, re.Err) {
56		t.Fatal("not the error we expected", err)
57	}
58	if out != nil {
59		t.Fatal("expected nil here")
60	}
61}
62
63func TestTimeLimitedLookupWillTimeout(t *testing.T) {
64	if testing.Short() {
65		t.Skip("skip test in short mode")
66	}
67	reso := &Resolver{}
68	re := &FakeResolver{
69		Err:   io.EOF,
70		Sleep: 20 * time.Second,
71	}
72	ctx := context.Background()
73	out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
74	if !errors.Is(err, context.DeadlineExceeded) {
75		t.Fatal("not the error we expected", err)
76	}
77	if out != nil {
78		t.Fatal("expected nil here")
79	}
80}
81