1// +build go1.12 2 3/* 4 * 5 * Copyright 2019 gRPC authors. 6 * 7 * Licensed under the Apache License, Version 2.0 (the "License"); 8 * you may not use this file except in compliance with the License. 9 * You may obtain a copy of the License at 10 * 11 * http://www.apache.org/licenses/LICENSE-2.0 12 * 13 * Unless required by applicable law or agreed to in writing, software 14 * distributed under the License is distributed on an "AS IS" BASIS, 15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 * See the License for the specific language governing permissions and 17 * limitations under the License. 18 * 19 */ 20 21package advancedtls 22 23import ( 24 "context" 25 "crypto/tls" 26 "crypto/x509" 27 "errors" 28 "fmt" 29 "net" 30 "testing" 31 32 "google.golang.org/grpc/credentials" 33 "google.golang.org/grpc/credentials/tls/certprovider" 34 "google.golang.org/grpc/internal/grpctest" 35 "google.golang.org/grpc/security/advancedtls/internal/testutils" 36) 37 38type s struct { 39 grpctest.Tester 40} 41 42func Test(t *testing.T) { 43 grpctest.RunSubTests(t, s{}) 44} 45 46type provType int 47 48const ( 49 provTypeRoot provType = iota 50 provTypeIdentity 51) 52 53type fakeProvider struct { 54 pt provType 55 isClient bool 56 wantMultiCert bool 57 wantError bool 58} 59 60func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { 61 if f.wantError { 62 return nil, fmt.Errorf("bad fakeProvider") 63 } 64 cs := &testutils.CertStore{} 65 if err := cs.LoadCerts(); err != nil { 66 return nil, fmt.Errorf("cs.LoadCerts() failed, err: %v", err) 67 } 68 if f.pt == provTypeRoot && f.isClient { 69 return &certprovider.KeyMaterial{Roots: cs.ClientTrust1}, nil 70 } 71 if f.pt == provTypeRoot && !f.isClient { 72 return &certprovider.KeyMaterial{Roots: cs.ServerTrust1}, nil 73 } 74 if f.pt == provTypeIdentity && f.isClient { 75 if f.wantMultiCert { 76 return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1, cs.ClientCert2}}, nil 77 } 78 return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, nil 79 } 80 if f.wantMultiCert { 81 return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1, cs.ServerCert2}}, nil 82 } 83 return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1}}, nil 84} 85 86func (f fakeProvider) Close() {} 87 88func (s) TestClientOptionsConfigErrorCases(t *testing.T) { 89 tests := []struct { 90 desc string 91 clientVType VerificationType 92 IdentityOptions IdentityCertificateOptions 93 RootOptions RootCertificateOptions 94 }{ 95 { 96 desc: "Skip default verification and provide no root credentials", 97 clientVType: SkipVerification, 98 }, 99 { 100 desc: "More than one fields in RootCertificateOptions is specified", 101 clientVType: CertVerification, 102 RootOptions: RootCertificateOptions{ 103 RootCACerts: x509.NewCertPool(), 104 RootProvider: fakeProvider{}, 105 }, 106 }, 107 { 108 desc: "More than one fields in IdentityCertificateOptions is specified", 109 clientVType: CertVerification, 110 IdentityOptions: IdentityCertificateOptions{ 111 GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { 112 return nil, nil 113 }, 114 IdentityProvider: fakeProvider{pt: provTypeIdentity}, 115 }, 116 }, 117 { 118 desc: "Specify GetIdentityCertificatesForServer", 119 IdentityOptions: IdentityCertificateOptions{ 120 GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { 121 return nil, nil 122 }, 123 }, 124 }, 125 } 126 for _, test := range tests { 127 test := test 128 t.Run(test.desc, func(t *testing.T) { 129 clientOptions := &ClientOptions{ 130 VType: test.clientVType, 131 IdentityOptions: test.IdentityOptions, 132 RootOptions: test.RootOptions, 133 } 134 _, err := clientOptions.config() 135 if err == nil { 136 t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions) 137 } 138 }) 139 } 140} 141 142func (s) TestClientOptionsConfigSuccessCases(t *testing.T) { 143 tests := []struct { 144 desc string 145 clientVType VerificationType 146 IdentityOptions IdentityCertificateOptions 147 RootOptions RootCertificateOptions 148 }{ 149 { 150 desc: "Use system default if no fields in RootCertificateOptions is specified", 151 clientVType: CertVerification, 152 }, 153 { 154 desc: "Good case with mutual TLS", 155 clientVType: CertVerification, 156 RootOptions: RootCertificateOptions{ 157 RootProvider: fakeProvider{}, 158 }, 159 IdentityOptions: IdentityCertificateOptions{ 160 IdentityProvider: fakeProvider{pt: provTypeIdentity}, 161 }, 162 }, 163 } 164 for _, test := range tests { 165 test := test 166 t.Run(test.desc, func(t *testing.T) { 167 clientOptions := &ClientOptions{ 168 VType: test.clientVType, 169 IdentityOptions: test.IdentityOptions, 170 RootOptions: test.RootOptions, 171 } 172 clientConfig, err := clientOptions.config() 173 if err != nil { 174 t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err) 175 } 176 // Verify that the system-provided certificates would be used 177 // when no verification method was set in clientOptions. 178 if clientOptions.RootOptions.RootCACerts == nil && 179 clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil { 180 if clientConfig.RootCAs == nil { 181 t.Fatalf("Failed to assign system-provided certificates on the client side.") 182 } 183 } 184 }) 185 } 186} 187 188func (s) TestServerOptionsConfigErrorCases(t *testing.T) { 189 tests := []struct { 190 desc string 191 requireClientCert bool 192 serverVType VerificationType 193 IdentityOptions IdentityCertificateOptions 194 RootOptions RootCertificateOptions 195 }{ 196 { 197 desc: "Skip default verification and provide no root credentials", 198 requireClientCert: true, 199 serverVType: SkipVerification, 200 }, 201 { 202 desc: "More than one fields in RootCertificateOptions is specified", 203 requireClientCert: true, 204 serverVType: CertVerification, 205 RootOptions: RootCertificateOptions{ 206 RootCACerts: x509.NewCertPool(), 207 GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) { 208 return nil, nil 209 }, 210 }, 211 }, 212 { 213 desc: "More than one fields in IdentityCertificateOptions is specified", 214 serverVType: CertVerification, 215 IdentityOptions: IdentityCertificateOptions{ 216 Certificates: []tls.Certificate{}, 217 IdentityProvider: fakeProvider{pt: provTypeIdentity}, 218 }, 219 }, 220 { 221 desc: "no field in IdentityCertificateOptions is specified", 222 serverVType: CertVerification, 223 }, 224 { 225 desc: "Specify GetIdentityCertificatesForClient", 226 IdentityOptions: IdentityCertificateOptions{ 227 GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { 228 return nil, nil 229 }, 230 }, 231 }, 232 } 233 for _, test := range tests { 234 test := test 235 t.Run(test.desc, func(t *testing.T) { 236 serverOptions := &ServerOptions{ 237 VType: test.serverVType, 238 RequireClientCert: test.requireClientCert, 239 IdentityOptions: test.IdentityOptions, 240 RootOptions: test.RootOptions, 241 } 242 _, err := serverOptions.config() 243 if err == nil { 244 t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions) 245 } 246 }) 247 } 248} 249 250func (s) TestServerOptionsConfigSuccessCases(t *testing.T) { 251 tests := []struct { 252 desc string 253 requireClientCert bool 254 serverVType VerificationType 255 IdentityOptions IdentityCertificateOptions 256 RootOptions RootCertificateOptions 257 }{ 258 { 259 desc: "Use system default if no fields in RootCertificateOptions is specified", 260 requireClientCert: true, 261 serverVType: CertVerification, 262 IdentityOptions: IdentityCertificateOptions{ 263 Certificates: []tls.Certificate{}, 264 }, 265 }, 266 { 267 desc: "Good case with mutual TLS", 268 requireClientCert: true, 269 serverVType: CertVerification, 270 RootOptions: RootCertificateOptions{ 271 RootProvider: fakeProvider{}, 272 }, 273 IdentityOptions: IdentityCertificateOptions{ 274 GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { 275 return nil, nil 276 }, 277 }, 278 }, 279 } 280 for _, test := range tests { 281 test := test 282 t.Run(test.desc, func(t *testing.T) { 283 serverOptions := &ServerOptions{ 284 VType: test.serverVType, 285 RequireClientCert: test.requireClientCert, 286 IdentityOptions: test.IdentityOptions, 287 RootOptions: test.RootOptions, 288 } 289 serverConfig, err := serverOptions.config() 290 if err != nil { 291 t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err) 292 } 293 // Verify that the system-provided certificates would be used 294 // when no verification method was set in serverOptions. 295 if serverOptions.RootOptions.RootCACerts == nil && 296 serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil { 297 if serverConfig.ClientCAs == nil { 298 t.Fatalf("Failed to assign system-provided certificates on the server side.") 299 } 300 } 301 }) 302 } 303} 304 305func (s) TestClientServerHandshake(t *testing.T) { 306 cs := &testutils.CertStore{} 307 if err := cs.LoadCerts(); err != nil { 308 t.Fatalf("cs.LoadCerts() failed, err: %v", err) 309 } 310 getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { 311 return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil 312 } 313 clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { 314 if params.ServerName == "" { 315 return nil, errors.New("client side server name should have a value") 316 } 317 // "foo.bar.com" is the common name on server certificate server_cert_1.pem. 318 if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.com") { 319 return nil, errors.New("client side params parsing error") 320 } 321 322 return &VerificationResults{}, nil 323 } 324 verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) { 325 return nil, fmt.Errorf("custom verification function failed") 326 } 327 getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { 328 return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil 329 } 330 serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) { 331 if params.ServerName != "" { 332 return nil, errors.New("server side server name should not have a value") 333 } 334 // "foo.bar.hoo.com" is the common name on client certificate client_cert_1.pem. 335 if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.hoo.com") { 336 return nil, errors.New("server side params parsing error") 337 } 338 339 return &VerificationResults{}, nil 340 } 341 getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { 342 return nil, fmt.Errorf("bad root certificate reloading") 343 } 344 for _, test := range []struct { 345 desc string 346 clientCert []tls.Certificate 347 clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error) 348 clientRoot *x509.CertPool 349 clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) 350 clientVerifyFunc CustomVerificationFunc 351 clientVType VerificationType 352 clientRootProvider certprovider.Provider 353 clientIdentityProvider certprovider.Provider 354 clientExpectHandshakeError bool 355 serverMutualTLS bool 356 serverCert []tls.Certificate 357 serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) 358 serverRoot *x509.CertPool 359 serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) 360 serverVerifyFunc CustomVerificationFunc 361 serverVType VerificationType 362 serverRootProvider certprovider.Provider 363 serverIdentityProvider certprovider.Provider 364 serverExpectError bool 365 }{ 366 // Client: nil setting except verifyFuncGood 367 // Server: only set serverCert with mutual TLS off 368 // Expected Behavior: success 369 // Reason: we will use verifyFuncGood to verify the server, 370 // if either clientCert or clientGetCert is not set 371 { 372 desc: "Client has no trust cert with verifyFuncGood; server sends peer cert", 373 clientVerifyFunc: clientVerifyFuncGood, 374 clientVType: SkipVerification, 375 serverCert: []tls.Certificate{cs.ServerCert1}, 376 serverVType: CertAndHostVerification, 377 }, 378 // Client: set clientGetRoot and clientVerifyFunc 379 // Server: only set serverCert with mutual TLS off 380 // Expected Behavior: success 381 { 382 desc: "Client sets reload root function with verifyFuncGood; server sends peer cert", 383 clientGetRoot: getRootCAsForClient, 384 clientVerifyFunc: clientVerifyFuncGood, 385 clientVType: CertVerification, 386 serverCert: []tls.Certificate{cs.ServerCert1}, 387 serverVType: CertAndHostVerification, 388 }, 389 // Client: set clientGetRoot and bad clientVerifyFunc function 390 // Server: only set serverCert with mutual TLS off 391 // Expected Behavior: server side failure and client handshake failure 392 // Reason: custom verification function is bad 393 { 394 desc: "Client sets reload root function with verifyFuncBad; server sends peer cert", 395 clientGetRoot: getRootCAsForClient, 396 clientVerifyFunc: verifyFuncBad, 397 clientVType: CertVerification, 398 clientExpectHandshakeError: true, 399 serverCert: []tls.Certificate{cs.ServerCert1}, 400 serverVType: CertVerification, 401 serverExpectError: true, 402 }, 403 // Client: set clientGetRoot, clientVerifyFunc and clientCert 404 // Server: set serverRoot and serverCert with mutual TLS on 405 // Expected Behavior: success 406 { 407 desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS", 408 clientCert: []tls.Certificate{cs.ClientCert1}, 409 clientGetRoot: getRootCAsForClient, 410 clientVerifyFunc: clientVerifyFuncGood, 411 clientVType: CertVerification, 412 serverMutualTLS: true, 413 serverCert: []tls.Certificate{cs.ServerCert1}, 414 serverRoot: cs.ServerTrust1, 415 serverVType: CertVerification, 416 }, 417 // Client: set clientGetRoot, clientVerifyFunc and clientCert 418 // Server: set serverGetRoot and serverCert with mutual TLS on 419 // Expected Behavior: success 420 { 421 desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", 422 clientCert: []tls.Certificate{cs.ClientCert1}, 423 clientGetRoot: getRootCAsForClient, 424 clientVerifyFunc: clientVerifyFuncGood, 425 clientVType: CertVerification, 426 serverMutualTLS: true, 427 serverCert: []tls.Certificate{cs.ServerCert1}, 428 serverGetRoot: getRootCAsForServer, 429 serverVType: CertVerification, 430 }, 431 // Client: set clientGetRoot, clientVerifyFunc and clientCert 432 // Server: set serverGetRoot returning error and serverCert with mutual 433 // TLS on 434 // Expected Behavior: server side failure 435 // Reason: server side reloading returns failure 436 { 437 desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS", 438 clientCert: []tls.Certificate{cs.ClientCert1}, 439 clientGetRoot: getRootCAsForClient, 440 clientVerifyFunc: clientVerifyFuncGood, 441 clientVType: CertVerification, 442 serverMutualTLS: true, 443 serverCert: []tls.Certificate{cs.ServerCert1}, 444 serverGetRoot: getRootCAsForServerBad, 445 serverVType: CertVerification, 446 serverExpectError: true, 447 }, 448 // Client: set clientGetRoot, clientVerifyFunc and clientGetCert 449 // Server: set serverGetRoot and serverGetCert with mutual TLS on 450 // Expected Behavior: success 451 { 452 desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS", 453 clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { 454 return &cs.ClientCert1, nil 455 }, 456 clientGetRoot: getRootCAsForClient, 457 clientVerifyFunc: clientVerifyFuncGood, 458 clientVType: CertVerification, 459 serverMutualTLS: true, 460 serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { 461 return []*tls.Certificate{&cs.ServerCert1}, nil 462 }, 463 serverGetRoot: getRootCAsForServer, 464 serverVerifyFunc: serverVerifyFunc, 465 serverVType: CertVerification, 466 }, 467 // Client: set everything but with the wrong peer cert not trusted by 468 // server 469 // Server: set serverGetRoot and serverGetCert with mutual TLS on 470 // Expected Behavior: server side returns failure because of 471 // certificate mismatch 472 { 473 desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", 474 clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { 475 return &cs.ServerCert1, nil 476 }, 477 clientGetRoot: getRootCAsForClient, 478 clientVerifyFunc: clientVerifyFuncGood, 479 clientVType: CertVerification, 480 serverMutualTLS: true, 481 serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { 482 return []*tls.Certificate{&cs.ServerCert1}, nil 483 }, 484 serverGetRoot: getRootCAsForServer, 485 serverVerifyFunc: serverVerifyFunc, 486 serverVType: CertVerification, 487 serverExpectError: true, 488 }, 489 // Client: set everything but with the wrong trust cert not trusting server 490 // Server: set serverGetRoot and serverGetCert with mutual TLS on 491 // Expected Behavior: server side and client side return failure due to 492 // certificate mismatch and handshake failure 493 { 494 desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", 495 clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { 496 return &cs.ClientCert1, nil 497 }, 498 clientGetRoot: getRootCAsForServer, 499 clientVerifyFunc: clientVerifyFuncGood, 500 clientVType: CertVerification, 501 clientExpectHandshakeError: true, 502 serverMutualTLS: true, 503 serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { 504 return []*tls.Certificate{&cs.ServerCert1}, nil 505 }, 506 serverGetRoot: getRootCAsForServer, 507 serverVerifyFunc: serverVerifyFunc, 508 serverVType: CertVerification, 509 serverExpectError: true, 510 }, 511 // Client: set clientGetRoot, clientVerifyFunc and clientCert 512 // Server: set everything but with the wrong peer cert not trusted by 513 // client 514 // Expected Behavior: server side and client side return failure due to 515 // certificate mismatch and handshake failure 516 { 517 desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS", 518 clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { 519 return &cs.ClientCert1, nil 520 }, 521 clientGetRoot: getRootCAsForClient, 522 clientVerifyFunc: clientVerifyFuncGood, 523 clientVType: CertVerification, 524 serverMutualTLS: true, 525 serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { 526 return []*tls.Certificate{&cs.ClientCert1}, nil 527 }, 528 serverGetRoot: getRootCAsForServer, 529 serverVerifyFunc: serverVerifyFunc, 530 serverVType: CertVerification, 531 serverExpectError: true, 532 }, 533 // Client: set clientGetRoot, clientVerifyFunc and clientCert 534 // Server: set everything but with the wrong trust cert not trusting client 535 // Expected Behavior: server side and client side return failure due to 536 // certificate mismatch and handshake failure 537 { 538 desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS", 539 clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { 540 return &cs.ClientCert1, nil 541 }, 542 clientGetRoot: getRootCAsForClient, 543 clientVerifyFunc: clientVerifyFuncGood, 544 clientVType: CertVerification, 545 clientExpectHandshakeError: true, 546 serverMutualTLS: true, 547 serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { 548 return []*tls.Certificate{&cs.ServerCert1}, nil 549 }, 550 serverGetRoot: getRootCAsForClient, 551 serverVerifyFunc: serverVerifyFunc, 552 serverVType: CertVerification, 553 serverExpectError: true, 554 }, 555 // Client: set clientGetRoot, clientVerifyFunc and clientCert 556 // Server: set serverGetRoot and serverCert, but with bad verifyFunc 557 // Expected Behavior: server side and client side return failure due to 558 // server custom check fails 559 { 560 desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS", 561 clientCert: []tls.Certificate{cs.ClientCert1}, 562 clientGetRoot: getRootCAsForClient, 563 clientVerifyFunc: clientVerifyFuncGood, 564 clientVType: CertVerification, 565 clientExpectHandshakeError: true, 566 serverMutualTLS: true, 567 serverCert: []tls.Certificate{cs.ServerCert1}, 568 serverGetRoot: getRootCAsForServer, 569 serverVerifyFunc: verifyFuncBad, 570 serverVType: CertVerification, 571 serverExpectError: true, 572 }, 573 // Client: set a clientIdentityProvider which will get multiple cert chains 574 // Server: set serverIdentityProvider and serverRootProvider with mutual TLS on 575 // Expected Behavior: server side failure due to multiple cert chains in 576 // clientIdentityProvider 577 { 578 desc: "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS", 579 clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true}, 580 clientRootProvider: fakeProvider{isClient: true}, 581 clientVerifyFunc: clientVerifyFuncGood, 582 clientVType: CertVerification, 583 serverMutualTLS: true, 584 serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, 585 serverRootProvider: fakeProvider{isClient: false}, 586 serverVType: CertVerification, 587 serverExpectError: true, 588 }, 589 // Client: set a bad clientIdentityProvider 590 // Server: set serverIdentityProvider and serverRootProvider with mutual TLS on 591 // Expected Behavior: server side failure due to bad clientIdentityProvider 592 { 593 desc: "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS", 594 clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true}, 595 clientRootProvider: fakeProvider{isClient: true}, 596 clientVerifyFunc: clientVerifyFuncGood, 597 clientVType: CertVerification, 598 serverMutualTLS: true, 599 serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, 600 serverRootProvider: fakeProvider{isClient: false}, 601 serverVType: CertVerification, 602 serverExpectError: true, 603 }, 604 // Client: set clientIdentityProvider and clientRootProvider 605 // Server: set bad serverRootProvider with mutual TLS on 606 // Expected Behavior: server side failure due to bad serverRootProvider 607 { 608 desc: "Client sets root and identity provider; Server sets bad root provider; mutualTLS", 609 clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true}, 610 clientRootProvider: fakeProvider{isClient: true}, 611 clientVerifyFunc: clientVerifyFuncGood, 612 clientVType: CertVerification, 613 serverMutualTLS: true, 614 serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, 615 serverRootProvider: fakeProvider{isClient: false, wantError: true}, 616 serverVType: CertVerification, 617 serverExpectError: true, 618 }, 619 // Client: set clientIdentityProvider and clientRootProvider 620 // Server: set serverIdentityProvider and serverRootProvider with mutual TLS on 621 // Expected Behavior: success 622 { 623 desc: "Client sets root and identity provider; Server sets root and identity provider; mutualTLS", 624 clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true}, 625 clientRootProvider: fakeProvider{isClient: true}, 626 clientVerifyFunc: clientVerifyFuncGood, 627 clientVType: CertVerification, 628 serverMutualTLS: true, 629 serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false}, 630 serverRootProvider: fakeProvider{isClient: false}, 631 serverVType: CertVerification, 632 }, 633 // Client: set clientIdentityProvider and clientRootProvider 634 // Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on 635 // Expected Behavior: success, because server side has SNI 636 { 637 desc: "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS", 638 clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true}, 639 clientRootProvider: fakeProvider{isClient: true}, 640 clientVerifyFunc: clientVerifyFuncGood, 641 clientVType: CertVerification, 642 serverMutualTLS: true, 643 serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true}, 644 serverRootProvider: fakeProvider{isClient: false}, 645 serverVType: CertVerification, 646 }, 647 } { 648 test := test 649 t.Run(test.desc, func(t *testing.T) { 650 done := make(chan credentials.AuthInfo, 1) 651 lis, err := net.Listen("tcp", "localhost:0") 652 if err != nil { 653 t.Fatalf("Failed to listen: %v", err) 654 } 655 // Start a server using ServerOptions in another goroutine. 656 serverOptions := &ServerOptions{ 657 IdentityOptions: IdentityCertificateOptions{ 658 Certificates: test.serverCert, 659 GetIdentityCertificatesForServer: test.serverGetCert, 660 IdentityProvider: test.serverIdentityProvider, 661 }, 662 RootOptions: RootCertificateOptions{ 663 RootCACerts: test.serverRoot, 664 GetRootCertificates: test.serverGetRoot, 665 RootProvider: test.serverRootProvider, 666 }, 667 RequireClientCert: test.serverMutualTLS, 668 VerifyPeer: test.serverVerifyFunc, 669 VType: test.serverVType, 670 } 671 go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { 672 serverRawConn, err := lis.Accept() 673 if err != nil { 674 close(done) 675 return 676 } 677 serverTLS, err := NewServerCreds(serverOptions) 678 if err != nil { 679 serverRawConn.Close() 680 close(done) 681 return 682 } 683 _, serverAuthInfo, err := serverTLS.ServerHandshake(serverRawConn) 684 if err != nil { 685 serverRawConn.Close() 686 close(done) 687 return 688 } 689 done <- serverAuthInfo 690 }(done, lis, serverOptions) 691 defer lis.Close() 692 // Start a client using ClientOptions and connects to the server. 693 lisAddr := lis.Addr().String() 694 conn, err := net.Dial("tcp", lisAddr) 695 if err != nil { 696 t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) 697 } 698 defer conn.Close() 699 clientOptions := &ClientOptions{ 700 IdentityOptions: IdentityCertificateOptions{ 701 Certificates: test.clientCert, 702 GetIdentityCertificatesForClient: test.clientGetCert, 703 IdentityProvider: test.clientIdentityProvider, 704 }, 705 VerifyPeer: test.clientVerifyFunc, 706 RootOptions: RootCertificateOptions{ 707 RootCACerts: test.clientRoot, 708 GetRootCertificates: test.clientGetRoot, 709 RootProvider: test.clientRootProvider, 710 }, 711 VType: test.clientVType, 712 } 713 clientTLS, err := NewClientCreds(clientOptions) 714 if err != nil { 715 t.Fatalf("NewClientCreds failed: %v", err) 716 } 717 _, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(), 718 lisAddr, conn) 719 // wait until server sends serverAuthInfo or fails. 720 serverAuthInfo, ok := <-done 721 if !ok && test.serverExpectError { 722 return 723 } 724 if ok && test.serverExpectError || !ok && !test.serverExpectError { 725 t.Fatalf("Server side error mismatch, got %v, want %v", !ok, test.serverExpectError) 726 } 727 if handshakeErr != nil && test.clientExpectHandshakeError { 728 return 729 } 730 if handshakeErr != nil && !test.clientExpectHandshakeError || 731 handshakeErr == nil && test.clientExpectHandshakeError { 732 t.Fatalf("Expect error: %v, but err is %v", 733 test.clientExpectHandshakeError, handshakeErr) 734 } 735 if !compare(clientAuthInfo, serverAuthInfo) { 736 t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, 737 clientAuthInfo, serverAuthInfo) 738 } 739 }) 740 } 741} 742 743func compare(a1, a2 credentials.AuthInfo) bool { 744 if a1.AuthType() != a2.AuthType() { 745 return false 746 } 747 switch a1.AuthType() { 748 case "tls": 749 state1 := a1.(credentials.TLSInfo).State 750 state2 := a2.(credentials.TLSInfo).State 751 if state1.Version == state2.Version && 752 state1.HandshakeComplete == state2.HandshakeComplete && 753 state1.CipherSuite == state2.CipherSuite && 754 state1.NegotiatedProtocol == state2.NegotiatedProtocol { 755 return true 756 } 757 return false 758 default: 759 return false 760 } 761} 762 763func (s) TestAdvancedTLSOverrideServerName(t *testing.T) { 764 expectedServerName := "server.name" 765 cs := &testutils.CertStore{} 766 if err := cs.LoadCerts(); err != nil { 767 t.Fatalf("cs.LoadCerts() failed, err: %v", err) 768 } 769 clientOptions := &ClientOptions{ 770 RootOptions: RootCertificateOptions{ 771 RootCACerts: cs.ClientTrust1, 772 }, 773 ServerNameOverride: expectedServerName, 774 } 775 c, err := NewClientCreds(clientOptions) 776 if err != nil { 777 t.Fatalf("Client is unable to create credentials. Error: %v", err) 778 } 779 c.OverrideServerName(expectedServerName) 780 if c.Info().ServerName != expectedServerName { 781 t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 782 } 783} 784 785func (s) TestGetCertificatesSNI(t *testing.T) { 786 cs := &testutils.CertStore{} 787 if err := cs.LoadCerts(); err != nil { 788 t.Fatalf("cs.LoadCerts() failed, err: %v", err) 789 } 790 tests := []struct { 791 desc string 792 serverName string 793 // Use Common Name on the certificate to differentiate if we choose the right cert. The common name on all of the three certs are different. 794 wantCommonName string 795 }{ 796 { 797 desc: "Select ServerCert1", 798 // "foo.bar.com" is the common name on server certificate server_cert_1.pem. 799 serverName: "foo.bar.com", 800 wantCommonName: "foo.bar.com", 801 }, 802 { 803 desc: "Select serverCert3", 804 // "foo.bar.server3.com" is the common name on server certificate server_cert_3.pem. 805 // "google.com" is one of the DNS names on server certificate server_cert_3.pem. 806 serverName: "google.com", 807 wantCommonName: "foo.bar.server3.com", 808 }, 809 } 810 for _, test := range tests { 811 test := test 812 t.Run(test.desc, func(t *testing.T) { 813 serverOptions := &ServerOptions{ 814 IdentityOptions: IdentityCertificateOptions{ 815 GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { 816 return []*tls.Certificate{&cs.ServerCert1, &cs.ServerCert2, &cs.ServerPeer3}, nil 817 }, 818 }, 819 } 820 serverConfig, err := serverOptions.config() 821 if err != nil { 822 t.Fatalf("serverOptions.config() failed: %v", err) 823 } 824 pointFormatUncompressed := uint8(0) 825 clientHello := &tls.ClientHelloInfo{ 826 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA}, 827 ServerName: test.serverName, 828 SupportedCurves: []tls.CurveID{tls.CurveP256}, 829 SupportedPoints: []uint8{pointFormatUncompressed}, 830 SupportedVersions: []uint16{tls.VersionTLS10}, 831 } 832 gotCertificate, err := serverConfig.GetCertificate(clientHello) 833 if err != nil { 834 t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err) 835 } 836 if gotCertificate == nil || len(gotCertificate.Certificate) == 0 { 837 t.Fatalf("Got nil or empty Certificate after calling serverConfig.GetCertificate.") 838 } 839 parsedCert, err := x509.ParseCertificate(gotCertificate.Certificate[0]) 840 if err != nil { 841 t.Fatalf("x509.ParseCertificate(%v) failed: %v", gotCertificate.Certificate[0], err) 842 } 843 if parsedCert == nil { 844 t.Fatalf("Got nil Certificate after calling x509.ParseCertificate.") 845 } 846 if parsedCert.Subject.CommonName != test.wantCommonName { 847 t.Errorf("Common name mismatch, got %v, want %v", parsedCert.Subject.CommonName, test.wantCommonName) 848 } 849 }) 850 } 851} 852