1// +build go1.12 2 3/* 4 * 5 * Copyright 2020 gRPC authors. 6 * 7 * Licensed under the Apache License, Version 2.0 (the "License"); 8 * you may not use this file except in compliance with the License. 9 * You may obtain a copy of the License at 10 * 11 * http://www.apache.org/licenses/LICENSE-2.0 12 * 13 * Unless required by applicable law or agreed to in writing, software 14 * distributed under the License is distributed on an "AS IS" BASIS, 15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 * See the License for the specific language governing permissions and 17 * limitations under the License. 18 * 19 */ 20 21package pemfile 22 23import ( 24 "context" 25 "fmt" 26 "io/ioutil" 27 "math/big" 28 "os" 29 "path" 30 "testing" 31 "time" 32 33 "github.com/google/go-cmp/cmp" 34 "github.com/google/go-cmp/cmp/cmpopts" 35 36 "google.golang.org/grpc/credentials/tls/certprovider" 37 "google.golang.org/grpc/internal/grpctest" 38 "google.golang.org/grpc/internal/testutils" 39 "google.golang.org/grpc/testdata" 40) 41 42const ( 43 // These are the names of files inside temporary directories, which the 44 // plugin is asked to watch. 45 certFile = "cert.pem" 46 keyFile = "key.pem" 47 rootFile = "ca.pem" 48 49 defaultTestRefreshDuration = 100 * time.Millisecond 50 defaultTestTimeout = 5 * time.Second 51) 52 53type s struct { 54 grpctest.Tester 55} 56 57func Test(t *testing.T) { 58 grpctest.RunSubTests(t, s{}) 59} 60 61func compareKeyMaterial(got, want *certprovider.KeyMaterial) error { 62 // x509.Certificate type defines an Equal() method, but does not check for 63 // nil. This has been fixed in 64 // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, 65 // but this is only available starting go1.14. 66 // TODO(easwars): Remove this check once we remove support for go1.13. 67 if (got.Certs == nil && want.Certs != nil) || (want.Certs == nil && got.Certs != nil) { 68 return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) 69 } 70 if !cmp.Equal(got.Certs, want.Certs, cmp.AllowUnexported(big.Int{})) { 71 return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) 72 } 73 // x509.CertPool contains only unexported fields some of which contain other 74 // unexported fields. So usage of cmp.AllowUnexported() or 75 // cmpopts.IgnoreUnexported() does not help us much here. Also, the standard 76 // library does not provide a way to compare CertPool values. Comparing the 77 // subjects field of the certs in the CertPool seems like a reasonable 78 // approach. 79 if gotR, wantR := got.Roots.Subjects(), want.Roots.Subjects(); !cmp.Equal(gotR, wantR, cmpopts.EquateEmpty()) { 80 return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR) 81 } 82 return nil 83} 84 85// TestNewProvider tests the NewProvider() function with different inputs. 86func (s) TestNewProvider(t *testing.T) { 87 tests := []struct { 88 desc string 89 options Options 90 wantError bool 91 }{ 92 { 93 desc: "No credential files specified", 94 options: Options{}, 95 wantError: true, 96 }, 97 { 98 desc: "Only identity cert is specified", 99 options: Options{ 100 CertFile: testdata.Path("x509/client1_cert.pem"), 101 }, 102 wantError: true, 103 }, 104 { 105 desc: "Only identity key is specified", 106 options: Options{ 107 KeyFile: testdata.Path("x509/client1_key.pem"), 108 }, 109 wantError: true, 110 }, 111 { 112 desc: "Identity cert/key pair is specified", 113 options: Options{ 114 KeyFile: testdata.Path("x509/client1_key.pem"), 115 CertFile: testdata.Path("x509/client1_cert.pem"), 116 }, 117 }, 118 { 119 desc: "Only root certs are specified", 120 options: Options{ 121 RootFile: testdata.Path("x509/client_ca_cert.pem"), 122 }, 123 }, 124 { 125 desc: "Everything is specified", 126 options: Options{ 127 KeyFile: testdata.Path("x509/client1_key.pem"), 128 CertFile: testdata.Path("x509/client1_cert.pem"), 129 RootFile: testdata.Path("x509/client_ca_cert.pem"), 130 }, 131 wantError: false, 132 }, 133 } 134 for _, test := range tests { 135 t.Run(test.desc, func(t *testing.T) { 136 provider, err := NewProvider(test.options) 137 if (err != nil) != test.wantError { 138 t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError) 139 } 140 if err != nil { 141 return 142 } 143 provider.Close() 144 }) 145 } 146} 147 148// wrappedDistributor wraps a distributor and pushes on a channel whenever new 149// key material is pushed to the distributor. 150type wrappedDistributor struct { 151 *certprovider.Distributor 152 distCh *testutils.Channel 153} 154 155func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor { 156 return &wrappedDistributor{ 157 distCh: distCh, 158 Distributor: certprovider.NewDistributor(), 159 } 160} 161 162func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) { 163 wd.Distributor.Set(km, err) 164 wd.distCh.Send(nil) 165} 166 167func createTmpFile(t *testing.T, src, dst string) { 168 t.Helper() 169 170 data, err := ioutil.ReadFile(src) 171 if err != nil { 172 t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err) 173 } 174 if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil { 175 t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err) 176 } 177 t.Logf("Wrote file at: %s", dst) 178 t.Logf("%s", string(data)) 179} 180 181// createTempDirWithFiles creates a temporary directory under the system default 182// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and 183// rootSrc files are creates appropriate files under the newly create tempDir. 184// Returns the name of the created tempDir. 185func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string { 186 t.Helper() 187 188 // Create a temp directory. Passing an empty string for the first argument 189 // uses the system temp directory. 190 dir, err := ioutil.TempDir("", dirSuffix) 191 if err != nil { 192 t.Fatalf("ioutil.TempDir() failed: %v", err) 193 } 194 t.Logf("Using tmpdir: %s", dir) 195 196 createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile)) 197 createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile)) 198 createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile)) 199 return dir 200} 201 202// initializeProvider performs setup steps common to all tests (except the one 203// which uses symlinks). 204func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) { 205 t.Helper() 206 207 // Override the newDistributor to one which pushes on a channel that we 208 // can block on. 209 origDistributorFunc := newDistributor 210 distCh := testutils.NewChannel() 211 d := newWrappedDistributor(distCh) 212 newDistributor = func() distributor { return d } 213 214 // Create a new provider to watch the files in tmpdir. 215 dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem") 216 opts := Options{ 217 CertFile: path.Join(dir, certFile), 218 KeyFile: path.Join(dir, keyFile), 219 RootFile: path.Join(dir, rootFile), 220 RefreshDuration: defaultTestRefreshDuration, 221 } 222 prov, err := NewProvider(opts) 223 if err != nil { 224 t.Fatalf("NewProvider(%+v) failed: %v", opts, err) 225 } 226 227 // Make sure the provider picks up the files and pushes the key material on 228 // to the distributors. 229 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 230 defer cancel() 231 for i := 0; i < 2; i++ { 232 // Since we have root and identity certs, we need to make sure the 233 // update is pushed on both of them. 234 if _, err := distCh.Receive(ctx); err != nil { 235 t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err) 236 } 237 } 238 239 return dir, prov, distCh, func() { 240 newDistributor = origDistributorFunc 241 prov.Close() 242 } 243} 244 245// TestProvider_NoUpdate tests the case where a file watcher plugin is created 246// successfully, and the underlying files do not change. Verifies that the 247// plugin does not push new updates to the distributor in this case. 248func (s) TestProvider_NoUpdate(t *testing.T) { 249 _, prov, distCh, cancel := initializeProvider(t, "no_update") 250 defer cancel() 251 252 // Make sure the provider is healthy and returns key material. 253 ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout) 254 defer cc() 255 if _, err := prov.KeyMaterial(ctx); err != nil { 256 t.Fatalf("provider.KeyMaterial() failed: %v", err) 257 } 258 259 // Files haven't change. Make sure no updates are pushed by the provider. 260 sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration) 261 defer sc() 262 if _, err := distCh.Receive(sCtx); err == nil { 263 t.Fatal("new key material pushed to distributor when underlying files did not change") 264 } 265} 266 267// TestProvider_UpdateSuccess tests the case where a file watcher plugin is 268// created successfully and the underlying files change. Verifies that the 269// changes are picked up by the provider. 270func (s) TestProvider_UpdateSuccess(t *testing.T) { 271 dir, prov, distCh, cancel := initializeProvider(t, "update_success") 272 defer cancel() 273 274 // Make sure the provider is healthy and returns key material. 275 ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout) 276 defer cc() 277 km1, err := prov.KeyMaterial(ctx) 278 if err != nil { 279 t.Fatalf("provider.KeyMaterial() failed: %v", err) 280 } 281 282 // Change only the root file. 283 createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile)) 284 if _, err := distCh.Receive(ctx); err != nil { 285 t.Fatal("timeout waiting for new key material to be pushed to the distributor") 286 } 287 288 // Make sure update is picked up. 289 km2, err := prov.KeyMaterial(ctx) 290 if err != nil { 291 t.Fatalf("provider.KeyMaterial() failed: %v", err) 292 } 293 if err := compareKeyMaterial(km1, km2); err == nil { 294 t.Fatal("expected provider to return new key material after update to underlying file") 295 } 296 297 // Change only cert/key files. 298 createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile)) 299 createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile)) 300 if _, err := distCh.Receive(ctx); err != nil { 301 t.Fatal("timeout waiting for new key material to be pushed to the distributor") 302 } 303 304 // Make sure update is picked up. 305 km3, err := prov.KeyMaterial(ctx) 306 if err != nil { 307 t.Fatalf("provider.KeyMaterial() failed: %v", err) 308 } 309 if err := compareKeyMaterial(km2, km3); err == nil { 310 t.Fatal("expected provider to return new key material after update to underlying file") 311 } 312} 313 314// TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher 315// plugin is created successfully to watch files through a symlink and the 316// symlink is updates to point to new files. Verifies that the changes are 317// picked up by the provider. 318func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) { 319 // Override the newDistributor to one which pushes on a channel that we 320 // can block on. 321 origDistributorFunc := newDistributor 322 distCh := testutils.NewChannel() 323 d := newWrappedDistributor(distCh) 324 newDistributor = func() distributor { return d } 325 defer func() { newDistributor = origDistributorFunc }() 326 327 // Create two tempDirs with different files. 328 dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem") 329 dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem") 330 331 // Create a symlink under a new tempdir, and make it point to dir1. 332 tmpdir, err := ioutil.TempDir("", "test_symlink_*") 333 if err != nil { 334 t.Fatalf("ioutil.TempDir() failed: %v", err) 335 } 336 symLinkName := path.Join(tmpdir, "test_symlink") 337 if err := os.Symlink(dir1, symLinkName); err != nil { 338 t.Fatalf("failed to create symlink to %q: %v", dir1, err) 339 } 340 341 // Create a provider which watches the files pointed to by the symlink. 342 opts := Options{ 343 CertFile: path.Join(symLinkName, certFile), 344 KeyFile: path.Join(symLinkName, keyFile), 345 RootFile: path.Join(symLinkName, rootFile), 346 RefreshDuration: defaultTestRefreshDuration, 347 } 348 prov, err := NewProvider(opts) 349 if err != nil { 350 t.Fatalf("NewProvider(%+v) failed: %v", opts, err) 351 } 352 defer prov.Close() 353 354 // Make sure the provider picks up the files and pushes the key material on 355 // to the distributors. 356 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 357 defer cancel() 358 for i := 0; i < 2; i++ { 359 // Since we have root and identity certs, we need to make sure the 360 // update is pushed on both of them. 361 if _, err := distCh.Receive(ctx); err != nil { 362 t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err) 363 } 364 } 365 km1, err := prov.KeyMaterial(ctx) 366 if err != nil { 367 t.Fatalf("provider.KeyMaterial() failed: %v", err) 368 } 369 370 // Update the symlink to point to dir2. 371 symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp") 372 if err := os.Symlink(dir2, symLinkTmpName); err != nil { 373 t.Fatalf("failed to create symlink to %q: %v", dir2, err) 374 } 375 if err := os.Rename(symLinkTmpName, symLinkName); err != nil { 376 t.Fatalf("failed to update symlink: %v", err) 377 } 378 379 // Make sure the provider picks up the new files and pushes the key material 380 // on to the distributors. 381 for i := 0; i < 2; i++ { 382 // Since we have root and identity certs, we need to make sure the 383 // update is pushed on both of them. 384 if _, err := distCh.Receive(ctx); err != nil { 385 t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err) 386 } 387 } 388 km2, err := prov.KeyMaterial(ctx) 389 if err != nil { 390 t.Fatalf("provider.KeyMaterial() failed: %v", err) 391 } 392 393 if err := compareKeyMaterial(km1, km2); err == nil { 394 t.Fatal("expected provider to return new key material after symlink update") 395 } 396} 397 398// TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key 399// files fail. Verifies that the failed update does not push anything on the 400// distributor. Then the update succeeds, and the test verifies that the key 401// material is updated. 402func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) { 403 dir, prov, distCh, cancel := initializeProvider(t, "update_failure") 404 defer cancel() 405 406 // Make sure the provider is healthy and returns key material. 407 ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout) 408 defer cc() 409 km1, err := prov.KeyMaterial(ctx) 410 if err != nil { 411 t.Fatalf("provider.KeyMaterial() failed: %v", err) 412 } 413 414 // Update only the cert file. The key file is left unchanged. This should 415 // lead to these two files being not compatible with each other. This 416 // simulates the case where the watching goroutine might catch the files in 417 // the midst of an update. 418 createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile)) 419 420 // Since the last update left the files in an incompatible state, the update 421 // should not be picked up by our provider. 422 sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration) 423 defer sc() 424 if _, err := distCh.Receive(sCtx); err == nil { 425 t.Fatal("new key material pushed to distributor when underlying files did not change") 426 } 427 428 // The provider should return key material corresponding to the old state. 429 km2, err := prov.KeyMaterial(ctx) 430 if err != nil { 431 t.Fatalf("provider.KeyMaterial() failed: %v", err) 432 } 433 if err := compareKeyMaterial(km1, km2); err != nil { 434 t.Fatalf("expected provider to not update key material: %v", err) 435 } 436 437 // Update the key file to match the cert file. 438 createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile)) 439 440 // Make sure update is picked up. 441 if _, err := distCh.Receive(ctx); err != nil { 442 t.Fatal("timeout waiting for new key material to be pushed to the distributor") 443 } 444 km3, err := prov.KeyMaterial(ctx) 445 if err != nil { 446 t.Fatalf("provider.KeyMaterial() failed: %v", err) 447 } 448 if err := compareKeyMaterial(km2, km3); err == nil { 449 t.Fatal("expected provider to return new key material after update to underlying file") 450 } 451} 452