1// +build go1.12
2
3/*
4 *
5 * Copyright 2020 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package pemfile
22
23import (
24	"context"
25	"fmt"
26	"io/ioutil"
27	"math/big"
28	"os"
29	"path"
30	"testing"
31	"time"
32
33	"github.com/google/go-cmp/cmp"
34	"github.com/google/go-cmp/cmp/cmpopts"
35
36	"google.golang.org/grpc/credentials/tls/certprovider"
37	"google.golang.org/grpc/internal/grpctest"
38	"google.golang.org/grpc/internal/testutils"
39	"google.golang.org/grpc/testdata"
40)
41
42const (
43	// These are the names of files inside temporary directories, which the
44	// plugin is asked to watch.
45	certFile = "cert.pem"
46	keyFile  = "key.pem"
47	rootFile = "ca.pem"
48
49	defaultTestRefreshDuration = 100 * time.Millisecond
50	defaultTestTimeout         = 5 * time.Second
51)
52
53type s struct {
54	grpctest.Tester
55}
56
57func Test(t *testing.T) {
58	grpctest.RunSubTests(t, s{})
59}
60
61func compareKeyMaterial(got, want *certprovider.KeyMaterial) error {
62	// x509.Certificate type defines an Equal() method, but does not check for
63	// nil. This has been fixed in
64	// https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351,
65	// but this is only available starting go1.14.
66	// TODO(easwars): Remove this check once we remove support for go1.13.
67	if (got.Certs == nil && want.Certs != nil) || (want.Certs == nil && got.Certs != nil) {
68		return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
69	}
70	if !cmp.Equal(got.Certs, want.Certs, cmp.AllowUnexported(big.Int{})) {
71		return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
72	}
73	// x509.CertPool contains only unexported fields some of which contain other
74	// unexported fields. So usage of cmp.AllowUnexported() or
75	// cmpopts.IgnoreUnexported() does not help us much here. Also, the standard
76	// library does not provide a way to compare CertPool values. Comparing the
77	// subjects field of the certs in the CertPool seems like a reasonable
78	// approach.
79	if gotR, wantR := got.Roots.Subjects(), want.Roots.Subjects(); !cmp.Equal(gotR, wantR, cmpopts.EquateEmpty()) {
80		return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR)
81	}
82	return nil
83}
84
85// TestNewProvider tests the NewProvider() function with different inputs.
86func (s) TestNewProvider(t *testing.T) {
87	tests := []struct {
88		desc      string
89		options   Options
90		wantError bool
91	}{
92		{
93			desc:      "No credential files specified",
94			options:   Options{},
95			wantError: true,
96		},
97		{
98			desc: "Only identity cert is specified",
99			options: Options{
100				CertFile: testdata.Path("x509/client1_cert.pem"),
101			},
102			wantError: true,
103		},
104		{
105			desc: "Only identity key is specified",
106			options: Options{
107				KeyFile: testdata.Path("x509/client1_key.pem"),
108			},
109			wantError: true,
110		},
111		{
112			desc: "Identity cert/key pair is specified",
113			options: Options{
114				KeyFile:  testdata.Path("x509/client1_key.pem"),
115				CertFile: testdata.Path("x509/client1_cert.pem"),
116			},
117		},
118		{
119			desc: "Only root certs are specified",
120			options: Options{
121				RootFile: testdata.Path("x509/client_ca_cert.pem"),
122			},
123		},
124		{
125			desc: "Everything is specified",
126			options: Options{
127				KeyFile:  testdata.Path("x509/client1_key.pem"),
128				CertFile: testdata.Path("x509/client1_cert.pem"),
129				RootFile: testdata.Path("x509/client_ca_cert.pem"),
130			},
131			wantError: false,
132		},
133	}
134	for _, test := range tests {
135		t.Run(test.desc, func(t *testing.T) {
136			provider, err := NewProvider(test.options)
137			if (err != nil) != test.wantError {
138				t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError)
139			}
140			if err != nil {
141				return
142			}
143			provider.Close()
144		})
145	}
146}
147
148// wrappedDistributor wraps a distributor and pushes on a channel whenever new
149// key material is pushed to the distributor.
150type wrappedDistributor struct {
151	*certprovider.Distributor
152	distCh *testutils.Channel
153}
154
155func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor {
156	return &wrappedDistributor{
157		distCh:      distCh,
158		Distributor: certprovider.NewDistributor(),
159	}
160}
161
162func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
163	wd.Distributor.Set(km, err)
164	wd.distCh.Send(nil)
165}
166
167func createTmpFile(t *testing.T, src, dst string) {
168	t.Helper()
169
170	data, err := ioutil.ReadFile(src)
171	if err != nil {
172		t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err)
173	}
174	if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil {
175		t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err)
176	}
177	t.Logf("Wrote file at: %s", dst)
178	t.Logf("%s", string(data))
179}
180
181// createTempDirWithFiles creates a temporary directory under the system default
182// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and
183// rootSrc files are creates appropriate files under the newly create tempDir.
184// Returns the name of the created tempDir.
185func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string {
186	t.Helper()
187
188	// Create a temp directory. Passing an empty string for the first argument
189	// uses the system temp directory.
190	dir, err := ioutil.TempDir("", dirSuffix)
191	if err != nil {
192		t.Fatalf("ioutil.TempDir() failed: %v", err)
193	}
194	t.Logf("Using tmpdir: %s", dir)
195
196	createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile))
197	createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile))
198	createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile))
199	return dir
200}
201
202// initializeProvider performs setup steps common to all tests (except the one
203// which uses symlinks).
204func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) {
205	t.Helper()
206
207	// Override the newDistributor to one which pushes on a channel that we
208	// can block on.
209	origDistributorFunc := newDistributor
210	distCh := testutils.NewChannel()
211	d := newWrappedDistributor(distCh)
212	newDistributor = func() distributor { return d }
213
214	// Create a new provider to watch the files in tmpdir.
215	dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
216	opts := Options{
217		CertFile:        path.Join(dir, certFile),
218		KeyFile:         path.Join(dir, keyFile),
219		RootFile:        path.Join(dir, rootFile),
220		RefreshDuration: defaultTestRefreshDuration,
221	}
222	prov, err := NewProvider(opts)
223	if err != nil {
224		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
225	}
226
227	// Make sure the provider picks up the files and pushes the key material on
228	// to the distributors.
229	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
230	defer cancel()
231	for i := 0; i < 2; i++ {
232		// Since we have root and identity certs, we need to make sure the
233		// update is pushed on both of them.
234		if _, err := distCh.Receive(ctx); err != nil {
235			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
236		}
237	}
238
239	return dir, prov, distCh, func() {
240		newDistributor = origDistributorFunc
241		prov.Close()
242	}
243}
244
245// TestProvider_NoUpdate tests the case where a file watcher plugin is created
246// successfully, and the underlying files do not change. Verifies that the
247// plugin does not push new updates to the distributor in this case.
248func (s) TestProvider_NoUpdate(t *testing.T) {
249	_, prov, distCh, cancel := initializeProvider(t, "no_update")
250	defer cancel()
251
252	// Make sure the provider is healthy and returns key material.
253	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
254	defer cc()
255	if _, err := prov.KeyMaterial(ctx); err != nil {
256		t.Fatalf("provider.KeyMaterial() failed: %v", err)
257	}
258
259	// Files haven't change. Make sure no updates are pushed by the provider.
260	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
261	defer sc()
262	if _, err := distCh.Receive(sCtx); err == nil {
263		t.Fatal("new key material pushed to distributor when underlying files did not change")
264	}
265}
266
267// TestProvider_UpdateSuccess tests the case where a file watcher plugin is
268// created successfully and the underlying files change. Verifies that the
269// changes are picked up by the provider.
270func (s) TestProvider_UpdateSuccess(t *testing.T) {
271	dir, prov, distCh, cancel := initializeProvider(t, "update_success")
272	defer cancel()
273
274	// Make sure the provider is healthy and returns key material.
275	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
276	defer cc()
277	km1, err := prov.KeyMaterial(ctx)
278	if err != nil {
279		t.Fatalf("provider.KeyMaterial() failed: %v", err)
280	}
281
282	// Change only the root file.
283	createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile))
284	if _, err := distCh.Receive(ctx); err != nil {
285		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
286	}
287
288	// Make sure update is picked up.
289	km2, err := prov.KeyMaterial(ctx)
290	if err != nil {
291		t.Fatalf("provider.KeyMaterial() failed: %v", err)
292	}
293	if err := compareKeyMaterial(km1, km2); err == nil {
294		t.Fatal("expected provider to return new key material after update to underlying file")
295	}
296
297	// Change only cert/key files.
298	createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile))
299	createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile))
300	if _, err := distCh.Receive(ctx); err != nil {
301		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
302	}
303
304	// Make sure update is picked up.
305	km3, err := prov.KeyMaterial(ctx)
306	if err != nil {
307		t.Fatalf("provider.KeyMaterial() failed: %v", err)
308	}
309	if err := compareKeyMaterial(km2, km3); err == nil {
310		t.Fatal("expected provider to return new key material after update to underlying file")
311	}
312}
313
314// TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher
315// plugin is created successfully to watch files through a symlink and the
316// symlink is updates to point to new files. Verifies that the changes are
317// picked up by the provider.
318func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
319	// Override the newDistributor to one which pushes on a channel that we
320	// can block on.
321	origDistributorFunc := newDistributor
322	distCh := testutils.NewChannel()
323	d := newWrappedDistributor(distCh)
324	newDistributor = func() distributor { return d }
325	defer func() { newDistributor = origDistributorFunc }()
326
327	// Create two tempDirs with different files.
328	dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
329	dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem")
330
331	// Create a symlink under a new tempdir, and make it point to dir1.
332	tmpdir, err := ioutil.TempDir("", "test_symlink_*")
333	if err != nil {
334		t.Fatalf("ioutil.TempDir() failed: %v", err)
335	}
336	symLinkName := path.Join(tmpdir, "test_symlink")
337	if err := os.Symlink(dir1, symLinkName); err != nil {
338		t.Fatalf("failed to create symlink to %q: %v", dir1, err)
339	}
340
341	// Create a provider which watches the files pointed to by the symlink.
342	opts := Options{
343		CertFile:        path.Join(symLinkName, certFile),
344		KeyFile:         path.Join(symLinkName, keyFile),
345		RootFile:        path.Join(symLinkName, rootFile),
346		RefreshDuration: defaultTestRefreshDuration,
347	}
348	prov, err := NewProvider(opts)
349	if err != nil {
350		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
351	}
352	defer prov.Close()
353
354	// Make sure the provider picks up the files and pushes the key material on
355	// to the distributors.
356	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
357	defer cancel()
358	for i := 0; i < 2; i++ {
359		// Since we have root and identity certs, we need to make sure the
360		// update is pushed on both of them.
361		if _, err := distCh.Receive(ctx); err != nil {
362			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
363		}
364	}
365	km1, err := prov.KeyMaterial(ctx)
366	if err != nil {
367		t.Fatalf("provider.KeyMaterial() failed: %v", err)
368	}
369
370	// Update the symlink to point to dir2.
371	symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp")
372	if err := os.Symlink(dir2, symLinkTmpName); err != nil {
373		t.Fatalf("failed to create symlink to %q: %v", dir2, err)
374	}
375	if err := os.Rename(symLinkTmpName, symLinkName); err != nil {
376		t.Fatalf("failed to update symlink: %v", err)
377	}
378
379	// Make sure the provider picks up the new files and pushes the key material
380	// on to the distributors.
381	for i := 0; i < 2; i++ {
382		// Since we have root and identity certs, we need to make sure the
383		// update is pushed on both of them.
384		if _, err := distCh.Receive(ctx); err != nil {
385			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
386		}
387	}
388	km2, err := prov.KeyMaterial(ctx)
389	if err != nil {
390		t.Fatalf("provider.KeyMaterial() failed: %v", err)
391	}
392
393	if err := compareKeyMaterial(km1, km2); err == nil {
394		t.Fatal("expected provider to return new key material after symlink update")
395	}
396}
397
398// TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key
399// files fail. Verifies that the failed update does not push anything on the
400// distributor. Then the update succeeds, and the test verifies that the key
401// material is updated.
402func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) {
403	dir, prov, distCh, cancel := initializeProvider(t, "update_failure")
404	defer cancel()
405
406	// Make sure the provider is healthy and returns key material.
407	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
408	defer cc()
409	km1, err := prov.KeyMaterial(ctx)
410	if err != nil {
411		t.Fatalf("provider.KeyMaterial() failed: %v", err)
412	}
413
414	// Update only the cert file. The key file is left unchanged. This should
415	// lead to these two files being not compatible with each other. This
416	// simulates the case where the watching goroutine might catch the files in
417	// the midst of an update.
418	createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile))
419
420	// Since the last update left the files in an incompatible state, the update
421	// should not be picked up by our provider.
422	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
423	defer sc()
424	if _, err := distCh.Receive(sCtx); err == nil {
425		t.Fatal("new key material pushed to distributor when underlying files did not change")
426	}
427
428	// The provider should return key material corresponding to the old state.
429	km2, err := prov.KeyMaterial(ctx)
430	if err != nil {
431		t.Fatalf("provider.KeyMaterial() failed: %v", err)
432	}
433	if err := compareKeyMaterial(km1, km2); err != nil {
434		t.Fatalf("expected provider to not update key material: %v", err)
435	}
436
437	// Update the key file to match the cert file.
438	createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile))
439
440	// Make sure update is picked up.
441	if _, err := distCh.Receive(ctx); err != nil {
442		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
443	}
444	km3, err := prov.KeyMaterial(ctx)
445	if err != nil {
446		t.Fatalf("provider.KeyMaterial() failed: %v", err)
447	}
448	if err := compareKeyMaterial(km2, km3); err == nil {
449		t.Fatal("expected provider to return new key material after update to underlying file")
450	}
451}
452