1package dns01
2
3import (
4	"crypto/rand"
5	"crypto/rsa"
6	"errors"
7	"net/http"
8	"testing"
9	"time"
10
11	"github.com/go-acme/lego/v3/acme"
12	"github.com/go-acme/lego/v3/acme/api"
13	"github.com/go-acme/lego/v3/challenge"
14	"github.com/go-acme/lego/v3/platform/tester"
15	"github.com/stretchr/testify/require"
16)
17
18type providerMock struct {
19	present, cleanUp error
20}
21
22func (p *providerMock) Present(domain, token, keyAuth string) error { return p.present }
23func (p *providerMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp }
24
25type providerTimeoutMock struct {
26	present, cleanUp  error
27	timeout, interval time.Duration
28}
29
30func (p *providerTimeoutMock) Present(domain, token, keyAuth string) error { return p.present }
31func (p *providerTimeoutMock) CleanUp(domain, token, keyAuth string) error { return p.cleanUp }
32func (p *providerTimeoutMock) Timeout() (time.Duration, time.Duration)     { return p.timeout, p.interval }
33
34func TestChallenge_PreSolve(t *testing.T) {
35	_, apiURL, tearDown := tester.SetupFakeAPI()
36	defer tearDown()
37
38	privateKey, err := rsa.GenerateKey(rand.Reader, 512)
39	require.NoError(t, err)
40
41	core, err := api.New(http.DefaultClient, "lego-test", apiURL+"/dir", "", privateKey)
42	require.NoError(t, err)
43
44	testCases := []struct {
45		desc        string
46		validate    ValidateFunc
47		preCheck    WrapPreCheckFunc
48		provider    challenge.Provider
49		expectError bool
50	}{
51		{
52			desc:     "success",
53			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
54			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
55			provider: &providerMock{},
56		},
57		{
58			desc:     "validate fail",
59			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
60			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
61			provider: &providerMock{
62				present: nil,
63				cleanUp: nil,
64			},
65		},
66		{
67			desc:     "preCheck fail",
68			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
69			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
70			provider: &providerTimeoutMock{
71				timeout:  2 * time.Second,
72				interval: 500 * time.Millisecond,
73			},
74		},
75		{
76			desc:     "present fail",
77			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
78			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
79			provider: &providerMock{
80				present: errors.New("OOPS"),
81			},
82			expectError: true,
83		},
84		{
85			desc:     "cleanUp fail",
86			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
87			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
88			provider: &providerMock{
89				cleanUp: errors.New("OOPS"),
90			},
91		},
92	}
93
94	for _, test := range testCases {
95		t.Run(test.desc, func(t *testing.T) {
96			chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
97
98			authz := acme.Authorization{
99				Identifier: acme.Identifier{
100					Value: "example.com",
101				},
102				Challenges: []acme.Challenge{
103					{Type: challenge.DNS01.String()},
104				},
105			}
106
107			err = chlg.PreSolve(authz)
108			if test.expectError {
109				require.Error(t, err)
110			} else {
111				require.NoError(t, err)
112			}
113		})
114	}
115}
116
117func TestChallenge_Solve(t *testing.T) {
118	_, apiURL, tearDown := tester.SetupFakeAPI()
119	defer tearDown()
120
121	privateKey, err := rsa.GenerateKey(rand.Reader, 512)
122	require.NoError(t, err)
123
124	core, err := api.New(http.DefaultClient, "lego-test", apiURL+"/dir", "", privateKey)
125	require.NoError(t, err)
126
127	testCases := []struct {
128		desc        string
129		validate    ValidateFunc
130		preCheck    WrapPreCheckFunc
131		provider    challenge.Provider
132		expectError bool
133	}{
134		{
135			desc:     "success",
136			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
137			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
138			provider: &providerMock{},
139		},
140		{
141			desc:     "validate fail",
142			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
143			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
144			provider: &providerMock{
145				present: nil,
146				cleanUp: nil,
147			},
148			expectError: true,
149		},
150		{
151			desc:     "preCheck fail",
152			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
153			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
154			provider: &providerTimeoutMock{
155				timeout:  2 * time.Second,
156				interval: 500 * time.Millisecond,
157			},
158			expectError: true,
159		},
160		{
161			desc:     "present fail",
162			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
163			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
164			provider: &providerMock{
165				present: errors.New("OOPS"),
166			},
167		},
168		{
169			desc:     "cleanUp fail",
170			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
171			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
172			provider: &providerMock{
173				cleanUp: errors.New("OOPS"),
174			},
175		},
176	}
177
178	for _, test := range testCases {
179		t.Run(test.desc, func(t *testing.T) {
180			var options []ChallengeOption
181			if test.preCheck != nil {
182				options = append(options, WrapPreCheck(test.preCheck))
183			}
184			chlg := NewChallenge(core, test.validate, test.provider, options...)
185
186			authz := acme.Authorization{
187				Identifier: acme.Identifier{
188					Value: "example.com",
189				},
190				Challenges: []acme.Challenge{
191					{Type: challenge.DNS01.String()},
192				},
193			}
194
195			err = chlg.Solve(authz)
196			if test.expectError {
197				require.Error(t, err)
198			} else {
199				require.NoError(t, err)
200			}
201		})
202	}
203}
204
205func TestChallenge_CleanUp(t *testing.T) {
206	_, apiURL, tearDown := tester.SetupFakeAPI()
207	defer tearDown()
208
209	privateKey, err := rsa.GenerateKey(rand.Reader, 512)
210	require.NoError(t, err)
211
212	core, err := api.New(http.DefaultClient, "lego-test", apiURL+"/dir", "", privateKey)
213	require.NoError(t, err)
214
215	testCases := []struct {
216		desc        string
217		validate    ValidateFunc
218		preCheck    WrapPreCheckFunc
219		provider    challenge.Provider
220		expectError bool
221	}{
222		{
223			desc:     "success",
224			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
225			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
226			provider: &providerMock{},
227		},
228		{
229			desc:     "validate fail",
230			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
231			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
232			provider: &providerMock{
233				present: nil,
234				cleanUp: nil,
235			},
236		},
237		{
238			desc:     "preCheck fail",
239			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
240			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
241			provider: &providerTimeoutMock{
242				timeout:  2 * time.Second,
243				interval: 500 * time.Millisecond,
244			},
245		},
246		{
247			desc:     "present fail",
248			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
249			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
250			provider: &providerMock{
251				present: errors.New("OOPS"),
252			},
253		},
254		{
255			desc:     "cleanUp fail",
256			validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
257			preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
258			provider: &providerMock{
259				cleanUp: errors.New("OOPS"),
260			},
261			expectError: true,
262		},
263	}
264
265	for _, test := range testCases {
266		t.Run(test.desc, func(t *testing.T) {
267			chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
268
269			authz := acme.Authorization{
270				Identifier: acme.Identifier{
271					Value: "example.com",
272				},
273				Challenges: []acme.Challenge{
274					{Type: challenge.DNS01.String()},
275				},
276			}
277
278			err = chlg.CleanUp(authz)
279			if test.expectError {
280				require.Error(t, err)
281			} else {
282				require.NoError(t, err)
283			}
284		})
285	}
286}
287