1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package autocert
6
7import (
8	"context"
9	"crypto/ecdsa"
10	"crypto/elliptic"
11	"crypto/rand"
12	"crypto/tls"
13	"crypto/x509"
14	"encoding/base64"
15	"fmt"
16	"net/http"
17	"net/http/httptest"
18	"testing"
19	"time"
20
21	"golang.org/x/crypto/acme"
22)
23
24func TestRenewalNext(t *testing.T) {
25	now := time.Now()
26	timeNow = func() time.Time { return now }
27	defer func() { timeNow = time.Now }()
28
29	man := &Manager{RenewBefore: 7 * 24 * time.Hour}
30	defer man.stopRenew()
31	tt := []struct {
32		expiry   time.Time
33		min, max time.Duration
34	}{
35		{now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour},
36		{now.Add(time.Hour), 0, 1},
37		{now, 0, 1},
38		{now.Add(-time.Hour), 0, 1},
39	}
40
41	dr := &domainRenewal{m: man}
42	for i, test := range tt {
43		next := dr.next(test.expiry)
44		if next < test.min || test.max < next {
45			t.Errorf("%d: next = %v; want between %v and %v", i, next, test.min, test.max)
46		}
47	}
48}
49
50func TestRenewFromCache(t *testing.T) {
51	const domain = "example.org"
52
53	// ACME CA server stub
54	var ca *httptest.Server
55	ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
56		w.Header().Set("Replay-Nonce", "nonce")
57		if r.Method == "HEAD" {
58			// a nonce request
59			return
60		}
61
62		switch r.URL.Path {
63		// discovery
64		case "/":
65			if err := discoTmpl.Execute(w, ca.URL); err != nil {
66				t.Fatalf("discoTmpl: %v", err)
67			}
68		// client key registration
69		case "/new-reg":
70			w.Write([]byte("{}"))
71		// domain authorization
72		case "/new-authz":
73			w.Header().Set("Location", ca.URL+"/authz/1")
74			w.WriteHeader(http.StatusCreated)
75			w.Write([]byte(`{"status": "valid"}`))
76		// cert request
77		case "/new-cert":
78			var req struct {
79				CSR string `json:"csr"`
80			}
81			decodePayload(&req, r.Body)
82			b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
83			csr, err := x509.ParseCertificateRequest(b)
84			if err != nil {
85				t.Fatalf("new-cert: CSR: %v", err)
86			}
87			der, err := dummyCert(csr.PublicKey, domain)
88			if err != nil {
89				t.Fatalf("new-cert: dummyCert: %v", err)
90			}
91			chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
92			w.Header().Set("Link", chainUp)
93			w.WriteHeader(http.StatusCreated)
94			w.Write(der)
95		// CA chain cert
96		case "/ca-cert":
97			der, err := dummyCert(nil, "ca")
98			if err != nil {
99				t.Fatalf("ca-cert: dummyCert: %v", err)
100			}
101			w.Write(der)
102		default:
103			t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
104		}
105	}))
106	defer ca.Close()
107
108	// use EC key to run faster on 386
109	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
110	if err != nil {
111		t.Fatal(err)
112	}
113	man := &Manager{
114		Prompt:      AcceptTOS,
115		Cache:       newMemCache(),
116		RenewBefore: 24 * time.Hour,
117		Client: &acme.Client{
118			Key:          key,
119			DirectoryURL: ca.URL,
120		},
121	}
122	defer man.stopRenew()
123
124	// cache an almost expired cert
125	now := time.Now()
126	cert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), domain)
127	if err != nil {
128		t.Fatal(err)
129	}
130	tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}}
131	if err := man.cachePut(context.Background(), domain, tlscert); err != nil {
132		t.Fatal(err)
133	}
134
135	// veriy the renewal happened
136	defer func() {
137		testDidRenewLoop = func(next time.Duration, err error) {}
138	}()
139	done := make(chan struct{})
140	testDidRenewLoop = func(next time.Duration, err error) {
141		defer close(done)
142		if err != nil {
143			t.Errorf("testDidRenewLoop: %v", err)
144		}
145		// Next should be about 90 days:
146		// dummyCert creates 90days expiry + account for man.RenewBefore.
147		// Previous expiration was within 1 min.
148		future := 88 * 24 * time.Hour
149		if next < future {
150			t.Errorf("testDidRenewLoop: next = %v; want >= %v", next, future)
151		}
152
153		// ensure the new cert is cached
154		after := time.Now().Add(future)
155		tlscert, err := man.cacheGet(context.Background(), domain)
156		if err != nil {
157			t.Fatalf("man.cacheGet: %v", err)
158		}
159		if !tlscert.Leaf.NotAfter.After(after) {
160			t.Errorf("cache leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after)
161		}
162
163		// verify the old cert is also replaced in memory
164		man.stateMu.Lock()
165		defer man.stateMu.Unlock()
166		s := man.state[domain]
167		if s == nil {
168			t.Fatalf("m.state[%q] is nil", domain)
169		}
170		tlscert, err = s.tlscert()
171		if err != nil {
172			t.Fatalf("s.tlscert: %v", err)
173		}
174		if !tlscert.Leaf.NotAfter.After(after) {
175			t.Errorf("state leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after)
176		}
177	}
178
179	// trigger renew
180	hello := &tls.ClientHelloInfo{ServerName: domain}
181	if _, err := man.GetCertificate(hello); err != nil {
182		t.Fatal(err)
183	}
184
185	// wait for renew loop
186	select {
187	case <-time.After(10 * time.Second):
188		t.Fatal("renew took too long to occur")
189	case <-done:
190	}
191}
192
193func TestRenewFromCacheAlreadyRenewed(t *testing.T) {
194	const domain = "example.org"
195
196	// use EC key to run faster on 386
197	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
198	if err != nil {
199		t.Fatal(err)
200	}
201	man := &Manager{
202		Prompt:      AcceptTOS,
203		Cache:       newMemCache(),
204		RenewBefore: 24 * time.Hour,
205		Client: &acme.Client{
206			Key:          key,
207			DirectoryURL: "invalid",
208		},
209	}
210	defer man.stopRenew()
211
212	// cache a recently renewed cert with a different private key
213	newKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
214	if err != nil {
215		t.Fatal(err)
216	}
217	now := time.Now()
218	newCert, err := dateDummyCert(newKey.Public(), now.Add(-2*time.Hour), now.Add(time.Hour*24*90), domain)
219	if err != nil {
220		t.Fatal(err)
221	}
222	newLeaf, err := validCert(domain, [][]byte{newCert}, newKey)
223	if err != nil {
224		t.Fatal(err)
225	}
226	newTLSCert := &tls.Certificate{PrivateKey: newKey, Certificate: [][]byte{newCert}, Leaf: newLeaf}
227	if err := man.cachePut(context.Background(), domain, newTLSCert); err != nil {
228		t.Fatal(err)
229	}
230
231	// set internal state to an almost expired cert
232	oldCert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), domain)
233	if err != nil {
234		t.Fatal(err)
235	}
236	oldLeaf, err := validCert(domain, [][]byte{oldCert}, key)
237	if err != nil {
238		t.Fatal(err)
239	}
240	man.stateMu.Lock()
241	if man.state == nil {
242		man.state = make(map[string]*certState)
243	}
244	s := &certState{
245		key:  key,
246		cert: [][]byte{oldCert},
247		leaf: oldLeaf,
248	}
249	man.state[domain] = s
250	man.stateMu.Unlock()
251
252	// veriy the renewal accepted the newer cached cert
253	defer func() {
254		testDidRenewLoop = func(next time.Duration, err error) {}
255	}()
256	done := make(chan struct{})
257	testDidRenewLoop = func(next time.Duration, err error) {
258		defer close(done)
259		if err != nil {
260			t.Errorf("testDidRenewLoop: %v", err)
261		}
262		// Next should be about 90 days
263		// Previous expiration was within 1 min.
264		future := 88 * 24 * time.Hour
265		if next < future {
266			t.Errorf("testDidRenewLoop: next = %v; want >= %v", next, future)
267		}
268
269		// ensure the cached cert was not modified
270		tlscert, err := man.cacheGet(context.Background(), domain)
271		if err != nil {
272			t.Fatalf("man.cacheGet: %v", err)
273		}
274		if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
275			t.Errorf("cache leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
276		}
277
278		// verify the old cert is also replaced in memory
279		man.stateMu.Lock()
280		defer man.stateMu.Unlock()
281		s := man.state[domain]
282		if s == nil {
283			t.Fatalf("m.state[%q] is nil", domain)
284		}
285		stateKey := s.key.Public().(*ecdsa.PublicKey)
286		if stateKey.X.Cmp(newKey.X) != 0 || stateKey.Y.Cmp(newKey.Y) != 0 {
287			t.Fatalf("state key was not updated from cache x: %v y: %v; want x: %v y: %v", stateKey.X, stateKey.Y, newKey.X, newKey.Y)
288		}
289		tlscert, err = s.tlscert()
290		if err != nil {
291			t.Fatalf("s.tlscert: %v", err)
292		}
293		if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
294			t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
295		}
296
297		// verify the private key is replaced in the renewal state
298		r := man.renewal[domain]
299		if r == nil {
300			t.Fatalf("m.renewal[%q] is nil", domain)
301		}
302		renewalKey := r.key.Public().(*ecdsa.PublicKey)
303		if renewalKey.X.Cmp(newKey.X) != 0 || renewalKey.Y.Cmp(newKey.Y) != 0 {
304			t.Fatalf("renewal private key was not updated from cache x: %v y: %v; want x: %v y: %v", renewalKey.X, renewalKey.Y, newKey.X, newKey.Y)
305		}
306
307	}
308
309	// assert the expiring cert is returned from state
310	hello := &tls.ClientHelloInfo{ServerName: domain}
311	tlscert, err := man.GetCertificate(hello)
312	if err != nil {
313		t.Fatal(err)
314	}
315	if !oldLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
316		t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, oldLeaf.NotAfter)
317	}
318
319	// trigger renew
320	go man.renew(domain, s.key, s.leaf.NotAfter)
321
322	// wait for renew loop
323	select {
324	case <-time.After(10 * time.Second):
325		t.Fatal("renew took too long to occur")
326	case <-done:
327		// assert the new cert is returned from state after renew
328		hello := &tls.ClientHelloInfo{ServerName: domain}
329		tlscert, err := man.GetCertificate(hello)
330		if err != nil {
331			t.Fatal(err)
332		}
333		if !newTLSCert.Leaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
334			t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newTLSCert.Leaf.NotAfter)
335		}
336	}
337}
338