1// +build go1.12 2 3/* 4 * 5 * Copyright 2020 gRPC authors. 6 * 7 * Licensed under the Apache License, Version 2.0 (the "License"); 8 * you may not use this file except in compliance with the License. 9 * You may obtain a copy of the License at 10 * 11 * http://www.apache.org/licenses/LICENSE-2.0 12 * 13 * Unless required by applicable law or agreed to in writing, software 14 * distributed under the License is distributed on an "AS IS" BASIS, 15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 * See the License for the specific language governing permissions and 17 * limitations under the License. 18 * 19 */ 20 21package advancedtls 22 23import ( 24 "context" 25 "crypto/tls" 26 "crypto/x509" 27 "fmt" 28 "io/ioutil" 29 "net" 30 "os" 31 "sync" 32 "testing" 33 "time" 34 35 "google.golang.org/grpc" 36 "google.golang.org/grpc/credentials" 37 "google.golang.org/grpc/credentials/tls/certprovider" 38 "google.golang.org/grpc/credentials/tls/certprovider/pemfile" 39 pb "google.golang.org/grpc/examples/helloworld/helloworld" 40 "google.golang.org/grpc/security/advancedtls/internal/testutils" 41 "google.golang.org/grpc/security/advancedtls/testdata" 42) 43 44const ( 45 // Default timeout for normal connections. 46 defaultTestTimeout = 5 * time.Second 47 // Default timeout for failed connections. 48 defaultTestShortTimeout = 10 * time.Millisecond 49 // Intervals that set to monitor the credential updates. 50 credRefreshingInterval = 200 * time.Millisecond 51 // Time we wait for the credential updates to be picked up. 52 sleepInterval = 400 * time.Millisecond 53) 54 55// stageInfo contains a stage number indicating the current phase of each 56// integration test, and a mutex. 57// Based on the stage number of current test, we will use different 58// certificates and custom verification functions to check if our tests behave 59// as expected. 60type stageInfo struct { 61 mutex sync.Mutex 62 stage int 63} 64 65func (s *stageInfo) increase() { 66 s.mutex.Lock() 67 defer s.mutex.Unlock() 68 s.stage = s.stage + 1 69} 70 71func (s *stageInfo) read() int { 72 s.mutex.Lock() 73 defer s.mutex.Unlock() 74 return s.stage 75} 76 77func (s *stageInfo) reset() { 78 s.mutex.Lock() 79 defer s.mutex.Unlock() 80 s.stage = 0 81} 82 83type greeterServer struct { 84 pb.UnimplementedGreeterServer 85} 86 87// sayHello is a simple implementation of the pb.GreeterServer SayHello method. 88func (greeterServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { 89 return &pb.HelloReply{Message: "Hello " + in.Name}, nil 90} 91 92// TODO(ZhenLian): remove shouldFail to the function signature to provider 93// tests. 94func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error { 95 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 96 defer cancel() 97 _, err := client.SayHello(ctx, &pb.HelloRequest{Name: msg}) 98 if want, got := shouldFail == true, err != nil; got != want { 99 return fmt.Errorf("want and got mismatch, want shouldFail=%v, got fail=%v, rpc error: %v", want, got, err) 100 } 101 return nil 102} 103 104// TODO(ZhenLian): remove shouldFail and add ...DialOption to the function 105// signature to provider cleaner tests. 106func callAndVerifyWithClientConn(connCtx context.Context, address string, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) { 107 var conn *grpc.ClientConn 108 var err error 109 // If we want the test to fail, we establish a non-blocking connection to 110 // avoid it hangs and killed by the context. 111 if shouldFail { 112 conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds)) 113 if err != nil { 114 return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err) 115 } 116 } else { 117 conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds), grpc.WithBlock()) 118 if err != nil { 119 return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err) 120 } 121 } 122 greetClient := pb.NewGreeterClient(conn) 123 err = callAndVerify(msg, greetClient, shouldFail) 124 if err != nil { 125 return nil, nil, err 126 } 127 return conn, greetClient, nil 128} 129 130// The advanced TLS features are tested in different stages. 131// At stage 0, we establish a good connection between client and server. 132// At stage 1, we change one factor(it could be we change the server's 133// certificate, or custom verification function, etc), and test if the 134// following connections would be dropped. 135// At stage 2, we re-establish the connection by changing the counterpart of 136// the factor we modified in stage 1. 137// (could be change the client's trust certificate, or change custom 138// verification function, etc) 139func (s) TestEnd2End(t *testing.T) { 140 cs := &testutils.CertStore{} 141 if err := cs.LoadCerts(); err != nil { 142 t.Fatalf("cs.LoadCerts() failed, err: %v", err) 143 } 144 stage := &stageInfo{} 145 for _, test := range []struct { 146 desc string 147 clientCert []tls.Certificate 148 clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error) 149 clientRoot *x509.CertPool 150 clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) 151 clientVerifyFunc CustomVerificationFunc 152 clientVType VerificationType 153 serverCert []tls.Certificate 154 serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) 155 serverRoot *x509.CertPool 156 serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) 157 serverVerifyFunc CustomVerificationFunc 158 serverVType VerificationType 159 }{ 160 // Test Scenarios: 161 // At initialization(stage = 0), client will be initialized with cert 162 // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. 163 // The mutual authentication works at the beginning, since ClientCert1 is 164 // trusted by ServerTrust1, and ServerCert1 by ClientTrust1. 165 // At stage 1, client changes ClientCert1 to ClientCert2. Since ClientCert2 166 // is not trusted by ServerTrust1, following rpc calls are expected to 167 // fail, while the previous rpc calls are still good because those are 168 // already authenticated. 169 // At stage 2, the server changes ServerTrust1 to ServerTrust2, and we 170 // should see it again accepts the connection, since ClientCert2 is trusted 171 // by ServerTrust2. 172 { 173 desc: "test the reloading feature for client identity callback and server trust callback", 174 clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { 175 switch stage.read() { 176 case 0: 177 return &cs.ClientCert1, nil 178 default: 179 return &cs.ClientCert2, nil 180 } 181 }, 182 clientRoot: cs.ClientTrust1, 183 clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 184 return &VerificationResults{}, nil 185 }, 186 clientVType: CertVerification, 187 serverCert: []tls.Certificate{cs.ServerCert1}, 188 serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { 189 switch stage.read() { 190 case 0, 1: 191 return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil 192 default: 193 return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil 194 } 195 }, 196 serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 197 return &VerificationResults{}, nil 198 }, 199 serverVType: CertVerification, 200 }, 201 // Test Scenarios: 202 // At initialization(stage = 0), client will be initialized with cert 203 // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. 204 // The mutual authentication works at the beginning, since ClientCert1 is 205 // trusted by ServerTrust1, and ServerCert1 by ClientTrust1. 206 // At stage 1, server changes ServerCert1 to ServerCert2. Since ServerCert2 207 // is not trusted by ClientTrust1, following rpc calls are expected to 208 // fail, while the previous rpc calls are still good because those are 209 // already authenticated. 210 // At stage 2, the client changes ClientTrust1 to ClientTrust2, and we 211 // should see it again accepts the connection, since ServerCert2 is trusted 212 // by ClientTrust2. 213 { 214 desc: "test the reloading feature for server identity callback and client trust callback", 215 clientCert: []tls.Certificate{cs.ClientCert1}, 216 clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { 217 switch stage.read() { 218 case 0, 1: 219 return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil 220 default: 221 return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil 222 } 223 }, 224 clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 225 return &VerificationResults{}, nil 226 }, 227 clientVType: CertVerification, 228 serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { 229 switch stage.read() { 230 case 0: 231 return []*tls.Certificate{&cs.ServerCert1}, nil 232 default: 233 return []*tls.Certificate{&cs.ServerCert2}, nil 234 } 235 }, 236 serverRoot: cs.ServerTrust1, 237 serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 238 return &VerificationResults{}, nil 239 }, 240 serverVType: CertVerification, 241 }, 242 // Test Scenarios: 243 // At initialization(stage = 0), client will be initialized with cert 244 // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. 245 // The mutual authentication works at the beginning, since ClientCert1 246 // trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the 247 // custom verification check allows the CommonName on ServerCert1. 248 // At stage 1, server changes ServerCert1 to ServerCert2, and client 249 // changes ClientTrust1 to ClientTrust2. Although ServerCert2 is trusted by 250 // ClientTrust2, our authorization check only accepts ServerCert1, and 251 // hence the following calls should fail. Previous connections should 252 // not be affected. 253 // At stage 2, the client changes authorization check to only accept 254 // ServerCert2. Now we should see the connection becomes normal again. 255 { 256 desc: "test client custom verification", 257 clientCert: []tls.Certificate{cs.ClientCert1}, 258 clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) { 259 switch stage.read() { 260 case 0: 261 return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil 262 default: 263 return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil 264 } 265 }, 266 clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 267 if len(params.RawCerts) == 0 { 268 return nil, fmt.Errorf("no peer certs") 269 } 270 cert, err := x509.ParseCertificate(params.RawCerts[0]) 271 if err != nil || cert == nil { 272 return nil, fmt.Errorf("failed to parse certificate: " + err.Error()) 273 } 274 authzCheck := false 275 switch stage.read() { 276 case 0, 1: 277 // foo.bar.com is the common name on ServerCert1 278 if cert.Subject.CommonName == "foo.bar.com" { 279 authzCheck = true 280 } 281 default: 282 // foo.bar.server2.com is the common name on ServerCert2 283 if cert.Subject.CommonName == "foo.bar.server2.com" { 284 authzCheck = true 285 } 286 } 287 if authzCheck { 288 return &VerificationResults{}, nil 289 } 290 return nil, fmt.Errorf("custom authz check fails") 291 }, 292 clientVType: CertVerification, 293 serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { 294 switch stage.read() { 295 case 0: 296 return []*tls.Certificate{&cs.ServerCert1}, nil 297 default: 298 return []*tls.Certificate{&cs.ServerCert2}, nil 299 } 300 }, 301 serverRoot: cs.ServerTrust1, 302 serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 303 return &VerificationResults{}, nil 304 }, 305 serverVType: CertVerification, 306 }, 307 // Test Scenarios: 308 // At initialization(stage = 0), client will be initialized with cert 309 // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1. 310 // The mutual authentication works at the beginning, since ClientCert1 311 // trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the 312 // custom verification check on server side allows all connections. 313 // At stage 1, server disallows the the connections by setting custom 314 // verification check. The following calls should fail. Previous 315 // connections should not be affected. 316 // At stage 2, server allows all the connections again and the 317 // authentications should go back to normal. 318 { 319 desc: "TestServerCustomVerification", 320 clientCert: []tls.Certificate{cs.ClientCert1}, 321 clientRoot: cs.ClientTrust1, 322 clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 323 return &VerificationResults{}, nil 324 }, 325 clientVType: CertVerification, 326 serverCert: []tls.Certificate{cs.ServerCert1}, 327 serverRoot: cs.ServerTrust1, 328 serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { 329 switch stage.read() { 330 case 0, 2: 331 return &VerificationResults{}, nil 332 case 1: 333 return nil, fmt.Errorf("custom authz check fails") 334 default: 335 return nil, fmt.Errorf("custom authz check fails") 336 } 337 }, 338 serverVType: CertVerification, 339 }, 340 } { 341 test := test 342 t.Run(test.desc, func(t *testing.T) { 343 // Start a server using ServerOptions in another goroutine. 344 serverOptions := &ServerOptions{ 345 IdentityOptions: IdentityCertificateOptions{ 346 Certificates: test.serverCert, 347 GetIdentityCertificatesForServer: test.serverGetCert, 348 }, 349 RootOptions: RootCertificateOptions{ 350 RootCACerts: test.serverRoot, 351 GetRootCertificates: test.serverGetRoot, 352 }, 353 RequireClientCert: true, 354 VerifyPeer: test.serverVerifyFunc, 355 VType: test.serverVType, 356 } 357 serverTLSCreds, err := NewServerCreds(serverOptions) 358 if err != nil { 359 t.Fatalf("failed to create server creds: %v", err) 360 } 361 s := grpc.NewServer(grpc.Creds(serverTLSCreds)) 362 defer s.Stop() 363 lis, err := net.Listen("tcp", "localhost:0") 364 if err != nil { 365 t.Fatalf("failed to listen: %v", err) 366 } 367 defer lis.Close() 368 addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port) 369 pb.RegisterGreeterServer(s, greeterServer{}) 370 go s.Serve(lis) 371 clientOptions := &ClientOptions{ 372 IdentityOptions: IdentityCertificateOptions{ 373 Certificates: test.clientCert, 374 GetIdentityCertificatesForClient: test.clientGetCert, 375 }, 376 VerifyPeer: test.clientVerifyFunc, 377 RootOptions: RootCertificateOptions{ 378 RootCACerts: test.clientRoot, 379 GetRootCertificates: test.clientGetRoot, 380 }, 381 VType: test.clientVType, 382 } 383 clientTLSCreds, err := NewClientCreds(clientOptions) 384 if err != nil { 385 t.Fatalf("clientTLSCreds failed to create") 386 } 387 // ------------------------Scenario 1------------------------------------ 388 // stage = 0, initial connection should succeed 389 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 390 defer cancel() 391 conn, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, false) 392 if err != nil { 393 t.Fatal(err) 394 } 395 defer conn.Close() 396 // ---------------------------------------------------------------------- 397 stage.increase() 398 // ------------------------Scenario 2------------------------------------ 399 // stage = 1, previous connection should still succeed 400 err = callAndVerify("rpc call 2", greetClient, false) 401 if err != nil { 402 t.Fatal(err) 403 } 404 // ------------------------Scenario 3------------------------------------ 405 // stage = 1, new connection should fail 406 shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) 407 defer shortCancel() 408 conn2, greetClient, err := callAndVerifyWithClientConn(shortCtx, addr, "rpc call 3", clientTLSCreds, true) 409 if err != nil { 410 t.Fatal(err) 411 } 412 defer conn2.Close() 413 // ---------------------------------------------------------------------- 414 stage.increase() 415 // ------------------------Scenario 4------------------------------------ 416 // stage = 2, new connection should succeed 417 conn3, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 4", clientTLSCreds, false) 418 if err != nil { 419 t.Fatal(err) 420 } 421 defer conn3.Close() 422 // ---------------------------------------------------------------------- 423 stage.reset() 424 }) 425 } 426} 427 428type tmpCredsFiles struct { 429 clientCertTmp *os.File 430 clientKeyTmp *os.File 431 clientTrustTmp *os.File 432 serverCertTmp *os.File 433 serverKeyTmp *os.File 434 serverTrustTmp *os.File 435} 436 437// Create temp files that are used to hold credentials. 438func createTmpFiles() (*tmpCredsFiles, error) { 439 tmpFiles := &tmpCredsFiles{} 440 var err error 441 tmpFiles.clientCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-") 442 if err != nil { 443 return nil, err 444 } 445 tmpFiles.clientKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-") 446 if err != nil { 447 return nil, err 448 } 449 tmpFiles.clientTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-") 450 if err != nil { 451 return nil, err 452 } 453 tmpFiles.serverCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-") 454 if err != nil { 455 return nil, err 456 } 457 tmpFiles.serverKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-") 458 if err != nil { 459 return nil, err 460 } 461 tmpFiles.serverTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-") 462 if err != nil { 463 return nil, err 464 } 465 return tmpFiles, nil 466} 467 468// Copy the credential contents to the temporary files. 469func (tmpFiles *tmpCredsFiles) copyCredsToTmpFiles() error { 470 if err := copyFileContents(testdata.Path("client_cert_1.pem"), tmpFiles.clientCertTmp.Name()); err != nil { 471 return err 472 } 473 if err := copyFileContents(testdata.Path("client_key_1.pem"), tmpFiles.clientKeyTmp.Name()); err != nil { 474 return err 475 } 476 if err := copyFileContents(testdata.Path("client_trust_cert_1.pem"), tmpFiles.clientTrustTmp.Name()); err != nil { 477 return err 478 } 479 if err := copyFileContents(testdata.Path("server_cert_1.pem"), tmpFiles.serverCertTmp.Name()); err != nil { 480 return err 481 } 482 if err := copyFileContents(testdata.Path("server_key_1.pem"), tmpFiles.serverKeyTmp.Name()); err != nil { 483 return err 484 } 485 if err := copyFileContents(testdata.Path("server_trust_cert_1.pem"), tmpFiles.serverTrustTmp.Name()); err != nil { 486 return err 487 } 488 return nil 489} 490 491func (tmpFiles *tmpCredsFiles) removeFiles() { 492 os.Remove(tmpFiles.clientCertTmp.Name()) 493 os.Remove(tmpFiles.clientKeyTmp.Name()) 494 os.Remove(tmpFiles.clientTrustTmp.Name()) 495 os.Remove(tmpFiles.serverCertTmp.Name()) 496 os.Remove(tmpFiles.serverKeyTmp.Name()) 497 os.Remove(tmpFiles.serverTrustTmp.Name()) 498} 499 500func copyFileContents(sourceFile, destinationFile string) error { 501 input, err := ioutil.ReadFile(sourceFile) 502 if err != nil { 503 return err 504 } 505 err = ioutil.WriteFile(destinationFile, input, 0644) 506 if err != nil { 507 return err 508 } 509 return nil 510} 511 512// Create PEMFileProvider(s) watching the content changes of temporary 513// files. 514func createProviders(tmpFiles *tmpCredsFiles) (certprovider.Provider, certprovider.Provider, certprovider.Provider, certprovider.Provider, error) { 515 clientIdentityOptions := pemfile.Options{ 516 CertFile: tmpFiles.clientCertTmp.Name(), 517 KeyFile: tmpFiles.clientKeyTmp.Name(), 518 RefreshDuration: credRefreshingInterval, 519 } 520 clientIdentityProvider, err := pemfile.NewProvider(clientIdentityOptions) 521 if err != nil { 522 return nil, nil, nil, nil, err 523 } 524 clientRootOptions := pemfile.Options{ 525 RootFile: tmpFiles.clientTrustTmp.Name(), 526 RefreshDuration: credRefreshingInterval, 527 } 528 clientRootProvider, err := pemfile.NewProvider(clientRootOptions) 529 if err != nil { 530 return nil, nil, nil, nil, err 531 } 532 serverIdentityOptions := pemfile.Options{ 533 CertFile: tmpFiles.serverCertTmp.Name(), 534 KeyFile: tmpFiles.serverKeyTmp.Name(), 535 RefreshDuration: credRefreshingInterval, 536 } 537 serverIdentityProvider, err := pemfile.NewProvider(serverIdentityOptions) 538 if err != nil { 539 return nil, nil, nil, nil, err 540 } 541 serverRootOptions := pemfile.Options{ 542 RootFile: tmpFiles.serverTrustTmp.Name(), 543 RefreshDuration: credRefreshingInterval, 544 } 545 serverRootProvider, err := pemfile.NewProvider(serverRootOptions) 546 if err != nil { 547 return nil, nil, nil, nil, err 548 } 549 return clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, nil 550} 551 552// In order to test advanced TLS provider features, we used temporary files to 553// hold credential data, and copy the contents under testdata/ to these tmp 554// files. 555// Initially, we establish a good connection with providers watching contents 556// from tmp files. 557// Next, we change the identity certs that IdentityProvider is watching. Since 558// the identity key is not changed, the IdentityProvider should ignore the 559// update, and the connection should still be good. 560// Then the the identity key is changed. This time IdentityProvider should pick 561// up the update, and the connection should fail, due to the trust certs on the 562// other side is not changed. 563// Finally, the trust certs that other-side's RootProvider is watching get 564// changed. The connection should go back to normal again. 565func (s) TestPEMFileProviderEnd2End(t *testing.T) { 566 tmpFiles, err := createTmpFiles() 567 if err != nil { 568 t.Fatalf("createTmpFiles() failed, error: %v", err) 569 } 570 defer tmpFiles.removeFiles() 571 for _, test := range []struct { 572 desc string 573 certUpdateFunc func() 574 keyUpdateFunc func() 575 trustCertUpdateFunc func() 576 }{ 577 { 578 desc: "test the reloading feature for clientIdentityProvider and serverTrustProvider", 579 certUpdateFunc: func() { 580 err = copyFileContents(testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name()) 581 if err != nil { 582 t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name(), err) 583 } 584 }, 585 keyUpdateFunc: func() { 586 err = copyFileContents(testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name()) 587 if err != nil { 588 t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name(), err) 589 } 590 }, 591 trustCertUpdateFunc: func() { 592 err = copyFileContents(testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name()) 593 if err != nil { 594 t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name(), err) 595 } 596 }, 597 }, 598 { 599 desc: "test the reloading feature for serverIdentityProvider and clientTrustProvider", 600 certUpdateFunc: func() { 601 err = copyFileContents(testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name()) 602 if err != nil { 603 t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name(), err) 604 } 605 }, 606 keyUpdateFunc: func() { 607 err = copyFileContents(testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name()) 608 if err != nil { 609 t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name(), err) 610 } 611 }, 612 trustCertUpdateFunc: func() { 613 err = copyFileContents(testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name()) 614 if err != nil { 615 t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name(), err) 616 } 617 }, 618 }, 619 } { 620 test := test 621 t.Run(test.desc, func(t *testing.T) { 622 if err := tmpFiles.copyCredsToTmpFiles(); err != nil { 623 t.Fatalf("tmpFiles.copyCredsToTmpFiles() failed, error: %v", err) 624 } 625 clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, err := createProviders(tmpFiles) 626 if err != nil { 627 t.Fatalf("createProviders(%v) failed, error: %v", tmpFiles, err) 628 } 629 defer clientIdentityProvider.Close() 630 defer clientRootProvider.Close() 631 defer serverIdentityProvider.Close() 632 defer serverRootProvider.Close() 633 // Start a server and create a client using advancedtls API with Provider. 634 serverOptions := &ServerOptions{ 635 IdentityOptions: IdentityCertificateOptions{ 636 IdentityProvider: serverIdentityProvider, 637 }, 638 RootOptions: RootCertificateOptions{ 639 RootProvider: serverRootProvider, 640 }, 641 RequireClientCert: true, 642 VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) { 643 return &VerificationResults{}, nil 644 }, 645 VType: CertVerification, 646 } 647 serverTLSCreds, err := NewServerCreds(serverOptions) 648 if err != nil { 649 t.Fatalf("failed to create server creds: %v", err) 650 } 651 s := grpc.NewServer(grpc.Creds(serverTLSCreds)) 652 defer s.Stop() 653 lis, err := net.Listen("tcp", "localhost:0") 654 if err != nil { 655 t.Fatalf("failed to listen: %v", err) 656 } 657 defer lis.Close() 658 addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port) 659 pb.RegisterGreeterServer(s, greeterServer{}) 660 go s.Serve(lis) 661 clientOptions := &ClientOptions{ 662 IdentityOptions: IdentityCertificateOptions{ 663 IdentityProvider: clientIdentityProvider, 664 }, 665 VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) { 666 return &VerificationResults{}, nil 667 }, 668 RootOptions: RootCertificateOptions{ 669 RootProvider: clientRootProvider, 670 }, 671 VType: CertVerification, 672 } 673 clientTLSCreds, err := NewClientCreds(clientOptions) 674 if err != nil { 675 t.Fatalf("clientTLSCreds failed to create, error: %v", err) 676 } 677 678 // At initialization, the connection should be good. 679 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 680 defer cancel() 681 conn, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, false) 682 if err != nil { 683 t.Fatal(err) 684 } 685 defer conn.Close() 686 // Make the identity cert change, and wait 1 second for the provider to 687 // pick up the change. 688 test.certUpdateFunc() 689 time.Sleep(sleepInterval) 690 // The already-established connection should not be affected. 691 err = callAndVerify("rpc call 2", greetClient, false) 692 if err != nil { 693 t.Fatal(err) 694 } 695 // New connections should still be good, because the Provider didn't pick 696 // up the changes due to key-cert mismatch. 697 conn2, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 3", clientTLSCreds, false) 698 if err != nil { 699 t.Fatal(err) 700 } 701 defer conn2.Close() 702 // Make the identity key change, and wait 1 second for the provider to 703 // pick up the change. 704 test.keyUpdateFunc() 705 time.Sleep(sleepInterval) 706 // New connections should fail now, because the Provider picked the 707 // change, and *_cert_2.pem is not trusted by *_trust_cert_1.pem on the 708 // other side. 709 shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) 710 defer shortCancel() 711 conn3, greetClient, err := callAndVerifyWithClientConn(shortCtx, addr, "rpc call 4", clientTLSCreds, true) 712 if err != nil { 713 t.Fatal(err) 714 } 715 defer conn3.Close() 716 // Make the trust cert change on the other side, and wait 1 second for 717 // the provider to pick up the change. 718 test.trustCertUpdateFunc() 719 time.Sleep(sleepInterval) 720 // New connections should be good, because the other side is using 721 // *_trust_cert_2.pem now. 722 conn4, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 5", clientTLSCreds, false) 723 if err != nil { 724 t.Fatal(err) 725 } 726 defer conn4.Close() 727 }) 728 } 729} 730 731func (s) TestDefaultHostNameCheck(t *testing.T) { 732 cs := &testutils.CertStore{} 733 if err := cs.LoadCerts(); err != nil { 734 t.Fatalf("cs.LoadCerts() failed, err: %v", err) 735 } 736 for _, test := range []struct { 737 desc string 738 clientRoot *x509.CertPool 739 clientVerifyFunc CustomVerificationFunc 740 clientVType VerificationType 741 serverCert []tls.Certificate 742 serverVType VerificationType 743 expectError bool 744 }{ 745 // Client side sets vType to CertAndHostVerification, and will do 746 // default hostname check. Server uses a cert without "localhost" or 747 // "127.0.0.1" as common name or SAN names, and will hence fail. 748 { 749 desc: "Bad default hostname check", 750 clientRoot: cs.ClientTrust1, 751 clientVType: CertAndHostVerification, 752 serverCert: []tls.Certificate{cs.ServerCert1}, 753 serverVType: CertAndHostVerification, 754 expectError: true, 755 }, 756 // Client side sets vType to CertAndHostVerification, and will do 757 // default hostname check. Server uses a certificate with "localhost" as 758 // common name, and will hence pass the default hostname check. 759 { 760 desc: "Good default hostname check", 761 clientRoot: cs.ClientTrust1, 762 clientVType: CertAndHostVerification, 763 serverCert: []tls.Certificate{cs.ServerPeerLocalhost1}, 764 serverVType: CertAndHostVerification, 765 expectError: false, 766 }, 767 } { 768 test := test 769 t.Run(test.desc, func(t *testing.T) { 770 // Start a server using ServerOptions in another goroutine. 771 serverOptions := &ServerOptions{ 772 IdentityOptions: IdentityCertificateOptions{ 773 Certificates: test.serverCert, 774 }, 775 RequireClientCert: false, 776 VType: test.serverVType, 777 } 778 serverTLSCreds, err := NewServerCreds(serverOptions) 779 if err != nil { 780 t.Fatalf("failed to create server creds: %v", err) 781 } 782 s := grpc.NewServer(grpc.Creds(serverTLSCreds)) 783 defer s.Stop() 784 lis, err := net.Listen("tcp", "localhost:0") 785 if err != nil { 786 t.Fatalf("failed to listen: %v", err) 787 } 788 defer lis.Close() 789 addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port) 790 pb.RegisterGreeterServer(s, greeterServer{}) 791 go s.Serve(lis) 792 clientOptions := &ClientOptions{ 793 VerifyPeer: test.clientVerifyFunc, 794 RootOptions: RootCertificateOptions{ 795 RootCACerts: test.clientRoot, 796 }, 797 VType: test.clientVType, 798 } 799 clientTLSCreds, err := NewClientCreds(clientOptions) 800 if err != nil { 801 t.Fatalf("clientTLSCreds failed to create") 802 } 803 shouldFail := false 804 if test.expectError { 805 shouldFail = true 806 } 807 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 808 defer cancel() 809 conn, _, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, shouldFail) 810 if err != nil { 811 t.Fatal(err) 812 } 813 defer conn.Close() 814 }) 815 } 816} 817