1// +build go1.12
2
3/*
4 *
5 * Copyright 2019 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package advancedtls
22
23import (
24	"context"
25	"crypto/tls"
26	"crypto/x509"
27	"errors"
28	"fmt"
29	"net"
30	"testing"
31
32	"google.golang.org/grpc/credentials"
33	"google.golang.org/grpc/credentials/tls/certprovider"
34	"google.golang.org/grpc/internal/grpctest"
35	"google.golang.org/grpc/security/advancedtls/internal/testutils"
36)
37
38type s struct {
39	grpctest.Tester
40}
41
42func Test(t *testing.T) {
43	grpctest.RunSubTests(t, s{})
44}
45
46type provType int
47
48const (
49	provTypeRoot provType = iota
50	provTypeIdentity
51)
52
53type fakeProvider struct {
54	pt            provType
55	isClient      bool
56	wantMultiCert bool
57	wantError     bool
58}
59
60func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
61	if f.wantError {
62		return nil, fmt.Errorf("bad fakeProvider")
63	}
64	cs := &testutils.CertStore{}
65	if err := cs.LoadCerts(); err != nil {
66		return nil, fmt.Errorf("cs.LoadCerts() failed, err: %v", err)
67	}
68	if f.pt == provTypeRoot && f.isClient {
69		return &certprovider.KeyMaterial{Roots: cs.ClientTrust1}, nil
70	}
71	if f.pt == provTypeRoot && !f.isClient {
72		return &certprovider.KeyMaterial{Roots: cs.ServerTrust1}, nil
73	}
74	if f.pt == provTypeIdentity && f.isClient {
75		if f.wantMultiCert {
76			return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1, cs.ClientCert2}}, nil
77		}
78		return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, nil
79	}
80	if f.wantMultiCert {
81		return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1, cs.ServerCert2}}, nil
82	}
83	return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1}}, nil
84}
85
86func (f fakeProvider) Close() {}
87
88func (s) TestClientOptionsConfigErrorCases(t *testing.T) {
89	tests := []struct {
90		desc            string
91		clientVType     VerificationType
92		IdentityOptions IdentityCertificateOptions
93		RootOptions     RootCertificateOptions
94	}{
95		{
96			desc:        "Skip default verification and provide no root credentials",
97			clientVType: SkipVerification,
98		},
99		{
100			desc:        "More than one fields in RootCertificateOptions is specified",
101			clientVType: CertVerification,
102			RootOptions: RootCertificateOptions{
103				RootCACerts:  x509.NewCertPool(),
104				RootProvider: fakeProvider{},
105			},
106		},
107		{
108			desc:        "More than one fields in IdentityCertificateOptions is specified",
109			clientVType: CertVerification,
110			IdentityOptions: IdentityCertificateOptions{
111				GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
112					return nil, nil
113				},
114				IdentityProvider: fakeProvider{pt: provTypeIdentity},
115			},
116		},
117		{
118			desc: "Specify GetIdentityCertificatesForServer",
119			IdentityOptions: IdentityCertificateOptions{
120				GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
121					return nil, nil
122				},
123			},
124		},
125	}
126	for _, test := range tests {
127		test := test
128		t.Run(test.desc, func(t *testing.T) {
129			clientOptions := &ClientOptions{
130				VType:           test.clientVType,
131				IdentityOptions: test.IdentityOptions,
132				RootOptions:     test.RootOptions,
133			}
134			_, err := clientOptions.config()
135			if err == nil {
136				t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions)
137			}
138		})
139	}
140}
141
142func (s) TestClientOptionsConfigSuccessCases(t *testing.T) {
143	tests := []struct {
144		desc            string
145		clientVType     VerificationType
146		IdentityOptions IdentityCertificateOptions
147		RootOptions     RootCertificateOptions
148	}{
149		{
150			desc:        "Use system default if no fields in RootCertificateOptions is specified",
151			clientVType: CertVerification,
152		},
153		{
154			desc:        "Good case with mutual TLS",
155			clientVType: CertVerification,
156			RootOptions: RootCertificateOptions{
157				RootProvider: fakeProvider{},
158			},
159			IdentityOptions: IdentityCertificateOptions{
160				IdentityProvider: fakeProvider{pt: provTypeIdentity},
161			},
162		},
163	}
164	for _, test := range tests {
165		test := test
166		t.Run(test.desc, func(t *testing.T) {
167			clientOptions := &ClientOptions{
168				VType:           test.clientVType,
169				IdentityOptions: test.IdentityOptions,
170				RootOptions:     test.RootOptions,
171			}
172			clientConfig, err := clientOptions.config()
173			if err != nil {
174				t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err)
175			}
176			// Verify that the system-provided certificates would be used
177			// when no verification method was set in clientOptions.
178			if clientOptions.RootOptions.RootCACerts == nil &&
179				clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil {
180				if clientConfig.RootCAs == nil {
181					t.Fatalf("Failed to assign system-provided certificates on the client side.")
182				}
183			}
184		})
185	}
186}
187
188func (s) TestServerOptionsConfigErrorCases(t *testing.T) {
189	tests := []struct {
190		desc              string
191		requireClientCert bool
192		serverVType       VerificationType
193		IdentityOptions   IdentityCertificateOptions
194		RootOptions       RootCertificateOptions
195	}{
196		{
197			desc:              "Skip default verification and provide no root credentials",
198			requireClientCert: true,
199			serverVType:       SkipVerification,
200		},
201		{
202			desc:              "More than one fields in RootCertificateOptions is specified",
203			requireClientCert: true,
204			serverVType:       CertVerification,
205			RootOptions: RootCertificateOptions{
206				RootCACerts: x509.NewCertPool(),
207				GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) {
208					return nil, nil
209				},
210			},
211		},
212		{
213			desc:        "More than one fields in IdentityCertificateOptions is specified",
214			serverVType: CertVerification,
215			IdentityOptions: IdentityCertificateOptions{
216				Certificates:     []tls.Certificate{},
217				IdentityProvider: fakeProvider{pt: provTypeIdentity},
218			},
219		},
220		{
221			desc:        "no field in IdentityCertificateOptions is specified",
222			serverVType: CertVerification,
223		},
224		{
225			desc: "Specify GetIdentityCertificatesForClient",
226			IdentityOptions: IdentityCertificateOptions{
227				GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
228					return nil, nil
229				},
230			},
231		},
232	}
233	for _, test := range tests {
234		test := test
235		t.Run(test.desc, func(t *testing.T) {
236			serverOptions := &ServerOptions{
237				VType:             test.serverVType,
238				RequireClientCert: test.requireClientCert,
239				IdentityOptions:   test.IdentityOptions,
240				RootOptions:       test.RootOptions,
241			}
242			_, err := serverOptions.config()
243			if err == nil {
244				t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions)
245			}
246		})
247	}
248}
249
250func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
251	tests := []struct {
252		desc              string
253		requireClientCert bool
254		serverVType       VerificationType
255		IdentityOptions   IdentityCertificateOptions
256		RootOptions       RootCertificateOptions
257	}{
258		{
259			desc:              "Use system default if no fields in RootCertificateOptions is specified",
260			requireClientCert: true,
261			serverVType:       CertVerification,
262			IdentityOptions: IdentityCertificateOptions{
263				Certificates: []tls.Certificate{},
264			},
265		},
266		{
267			desc:              "Good case with mutual TLS",
268			requireClientCert: true,
269			serverVType:       CertVerification,
270			RootOptions: RootCertificateOptions{
271				RootProvider: fakeProvider{},
272			},
273			IdentityOptions: IdentityCertificateOptions{
274				GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
275					return nil, nil
276				},
277			},
278		},
279	}
280	for _, test := range tests {
281		test := test
282		t.Run(test.desc, func(t *testing.T) {
283			serverOptions := &ServerOptions{
284				VType:             test.serverVType,
285				RequireClientCert: test.requireClientCert,
286				IdentityOptions:   test.IdentityOptions,
287				RootOptions:       test.RootOptions,
288			}
289			serverConfig, err := serverOptions.config()
290			if err != nil {
291				t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err)
292			}
293			// Verify that the system-provided certificates would be used
294			// when no verification method was set in serverOptions.
295			if serverOptions.RootOptions.RootCACerts == nil &&
296				serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil {
297				if serverConfig.ClientCAs == nil {
298					t.Fatalf("Failed to assign system-provided certificates on the server side.")
299				}
300			}
301		})
302	}
303}
304
305func (s) TestClientServerHandshake(t *testing.T) {
306	cs := &testutils.CertStore{}
307	if err := cs.LoadCerts(); err != nil {
308		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
309	}
310	getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
311		return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
312	}
313	clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
314		if params.ServerName == "" {
315			return nil, errors.New("client side server name should have a value")
316		}
317		// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
318		if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.com") {
319			return nil, errors.New("client side params parsing error")
320		}
321
322		return &VerificationResults{}, nil
323	}
324	verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
325		return nil, fmt.Errorf("custom verification function failed")
326	}
327	getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
328		return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
329	}
330	serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
331		if params.ServerName != "" {
332			return nil, errors.New("server side server name should not have a value")
333		}
334		// "foo.bar.hoo.com" is the common name on client certificate client_cert_1.pem.
335		if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.hoo.com") {
336			return nil, errors.New("server side params parsing error")
337		}
338
339		return &VerificationResults{}, nil
340	}
341	getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
342		return nil, fmt.Errorf("bad root certificate reloading")
343	}
344	for _, test := range []struct {
345		desc                       string
346		clientCert                 []tls.Certificate
347		clientGetCert              func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
348		clientRoot                 *x509.CertPool
349		clientGetRoot              func(params *GetRootCAsParams) (*GetRootCAsResults, error)
350		clientVerifyFunc           CustomVerificationFunc
351		clientVType                VerificationType
352		clientRootProvider         certprovider.Provider
353		clientIdentityProvider     certprovider.Provider
354		clientExpectHandshakeError bool
355		serverMutualTLS            bool
356		serverCert                 []tls.Certificate
357		serverGetCert              func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
358		serverRoot                 *x509.CertPool
359		serverGetRoot              func(params *GetRootCAsParams) (*GetRootCAsResults, error)
360		serverVerifyFunc           CustomVerificationFunc
361		serverVType                VerificationType
362		serverRootProvider         certprovider.Provider
363		serverIdentityProvider     certprovider.Provider
364		serverExpectError          bool
365	}{
366		// Client: nil setting except verifyFuncGood
367		// Server: only set serverCert with mutual TLS off
368		// Expected Behavior: success
369		// Reason: we will use verifyFuncGood to verify the server,
370		// if either clientCert or clientGetCert is not set
371		{
372			desc:             "Client has no trust cert with verifyFuncGood; server sends peer cert",
373			clientVerifyFunc: clientVerifyFuncGood,
374			clientVType:      SkipVerification,
375			serverCert:       []tls.Certificate{cs.ServerCert1},
376			serverVType:      CertAndHostVerification,
377		},
378		// Client: set clientGetRoot and clientVerifyFunc
379		// Server: only set serverCert with mutual TLS off
380		// Expected Behavior: success
381		{
382			desc:             "Client sets reload root function with verifyFuncGood; server sends peer cert",
383			clientGetRoot:    getRootCAsForClient,
384			clientVerifyFunc: clientVerifyFuncGood,
385			clientVType:      CertVerification,
386			serverCert:       []tls.Certificate{cs.ServerCert1},
387			serverVType:      CertAndHostVerification,
388		},
389		// Client: set clientGetRoot and bad clientVerifyFunc function
390		// Server: only set serverCert with mutual TLS off
391		// Expected Behavior: server side failure and client handshake failure
392		// Reason: custom verification function is bad
393		{
394			desc:                       "Client sets reload root function with verifyFuncBad; server sends peer cert",
395			clientGetRoot:              getRootCAsForClient,
396			clientVerifyFunc:           verifyFuncBad,
397			clientVType:                CertVerification,
398			clientExpectHandshakeError: true,
399			serverCert:                 []tls.Certificate{cs.ServerCert1},
400			serverVType:                CertVerification,
401			serverExpectError:          true,
402		},
403		// Client: set clientGetRoot, clientVerifyFunc and clientCert
404		// Server: set serverRoot and serverCert with mutual TLS on
405		// Expected Behavior: success
406		{
407			desc:             "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
408			clientCert:       []tls.Certificate{cs.ClientCert1},
409			clientGetRoot:    getRootCAsForClient,
410			clientVerifyFunc: clientVerifyFuncGood,
411			clientVType:      CertVerification,
412			serverMutualTLS:  true,
413			serverCert:       []tls.Certificate{cs.ServerCert1},
414			serverRoot:       cs.ServerTrust1,
415			serverVType:      CertVerification,
416		},
417		// Client: set clientGetRoot, clientVerifyFunc and clientCert
418		// Server: set serverGetRoot and serverCert with mutual TLS on
419		// Expected Behavior: success
420		{
421			desc:             "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
422			clientCert:       []tls.Certificate{cs.ClientCert1},
423			clientGetRoot:    getRootCAsForClient,
424			clientVerifyFunc: clientVerifyFuncGood,
425			clientVType:      CertVerification,
426			serverMutualTLS:  true,
427			serverCert:       []tls.Certificate{cs.ServerCert1},
428			serverGetRoot:    getRootCAsForServer,
429			serverVType:      CertVerification,
430		},
431		// Client: set clientGetRoot, clientVerifyFunc and clientCert
432		// Server: set serverGetRoot returning error and serverCert with mutual
433		// TLS on
434		// Expected Behavior: server side failure
435		// Reason: server side reloading returns failure
436		{
437			desc:              "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
438			clientCert:        []tls.Certificate{cs.ClientCert1},
439			clientGetRoot:     getRootCAsForClient,
440			clientVerifyFunc:  clientVerifyFuncGood,
441			clientVType:       CertVerification,
442			serverMutualTLS:   true,
443			serverCert:        []tls.Certificate{cs.ServerCert1},
444			serverGetRoot:     getRootCAsForServerBad,
445			serverVType:       CertVerification,
446			serverExpectError: true,
447		},
448		// Client: set clientGetRoot, clientVerifyFunc and clientGetCert
449		// Server: set serverGetRoot and serverGetCert with mutual TLS on
450		// Expected Behavior: success
451		{
452			desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
453			clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
454				return &cs.ClientCert1, nil
455			},
456			clientGetRoot:    getRootCAsForClient,
457			clientVerifyFunc: clientVerifyFuncGood,
458			clientVType:      CertVerification,
459			serverMutualTLS:  true,
460			serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
461				return []*tls.Certificate{&cs.ServerCert1}, nil
462			},
463			serverGetRoot:    getRootCAsForServer,
464			serverVerifyFunc: serverVerifyFunc,
465			serverVType:      CertVerification,
466		},
467		// Client: set everything but with the wrong peer cert not trusted by
468		// server
469		// Server: set serverGetRoot and serverGetCert with mutual TLS on
470		// Expected Behavior: server side returns failure because of
471		// certificate mismatch
472		{
473			desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
474			clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
475				return &cs.ServerCert1, nil
476			},
477			clientGetRoot:    getRootCAsForClient,
478			clientVerifyFunc: clientVerifyFuncGood,
479			clientVType:      CertVerification,
480			serverMutualTLS:  true,
481			serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
482				return []*tls.Certificate{&cs.ServerCert1}, nil
483			},
484			serverGetRoot:     getRootCAsForServer,
485			serverVerifyFunc:  serverVerifyFunc,
486			serverVType:       CertVerification,
487			serverExpectError: true,
488		},
489		// Client: set everything but with the wrong trust cert not trusting server
490		// Server: set serverGetRoot and serverGetCert with mutual TLS on
491		// Expected Behavior: server side and client side return failure due to
492		// certificate mismatch and handshake failure
493		{
494			desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
495			clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
496				return &cs.ClientCert1, nil
497			},
498			clientGetRoot:              getRootCAsForServer,
499			clientVerifyFunc:           clientVerifyFuncGood,
500			clientVType:                CertVerification,
501			clientExpectHandshakeError: true,
502			serverMutualTLS:            true,
503			serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
504				return []*tls.Certificate{&cs.ServerCert1}, nil
505			},
506			serverGetRoot:     getRootCAsForServer,
507			serverVerifyFunc:  serverVerifyFunc,
508			serverVType:       CertVerification,
509			serverExpectError: true,
510		},
511		// Client: set clientGetRoot, clientVerifyFunc and clientCert
512		// Server: set everything but with the wrong peer cert not trusted by
513		// client
514		// Expected Behavior: server side and client side return failure due to
515		// certificate mismatch and handshake failure
516		{
517			desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
518			clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
519				return &cs.ClientCert1, nil
520			},
521			clientGetRoot:    getRootCAsForClient,
522			clientVerifyFunc: clientVerifyFuncGood,
523			clientVType:      CertVerification,
524			serverMutualTLS:  true,
525			serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
526				return []*tls.Certificate{&cs.ClientCert1}, nil
527			},
528			serverGetRoot:     getRootCAsForServer,
529			serverVerifyFunc:  serverVerifyFunc,
530			serverVType:       CertVerification,
531			serverExpectError: true,
532		},
533		// Client: set clientGetRoot, clientVerifyFunc and clientCert
534		// Server: set everything but with the wrong trust cert not trusting client
535		// Expected Behavior: server side and client side return failure due to
536		// certificate mismatch and handshake failure
537		{
538			desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
539			clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
540				return &cs.ClientCert1, nil
541			},
542			clientGetRoot:              getRootCAsForClient,
543			clientVerifyFunc:           clientVerifyFuncGood,
544			clientVType:                CertVerification,
545			clientExpectHandshakeError: true,
546			serverMutualTLS:            true,
547			serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
548				return []*tls.Certificate{&cs.ServerCert1}, nil
549			},
550			serverGetRoot:     getRootCAsForClient,
551			serverVerifyFunc:  serverVerifyFunc,
552			serverVType:       CertVerification,
553			serverExpectError: true,
554		},
555		// Client: set clientGetRoot, clientVerifyFunc and clientCert
556		// Server: set serverGetRoot and serverCert, but with bad verifyFunc
557		// Expected Behavior: server side and client side return failure due to
558		// server custom check fails
559		{
560			desc:                       "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
561			clientCert:                 []tls.Certificate{cs.ClientCert1},
562			clientGetRoot:              getRootCAsForClient,
563			clientVerifyFunc:           clientVerifyFuncGood,
564			clientVType:                CertVerification,
565			clientExpectHandshakeError: true,
566			serverMutualTLS:            true,
567			serverCert:                 []tls.Certificate{cs.ServerCert1},
568			serverGetRoot:              getRootCAsForServer,
569			serverVerifyFunc:           verifyFuncBad,
570			serverVType:                CertVerification,
571			serverExpectError:          true,
572		},
573		// Client: set a clientIdentityProvider which will get multiple cert chains
574		// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
575		// Expected Behavior: server side failure due to multiple cert chains in
576		// clientIdentityProvider
577		{
578			desc:                   "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS",
579			clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true},
580			clientRootProvider:     fakeProvider{isClient: true},
581			clientVerifyFunc:       clientVerifyFuncGood,
582			clientVType:            CertVerification,
583			serverMutualTLS:        true,
584			serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
585			serverRootProvider:     fakeProvider{isClient: false},
586			serverVType:            CertVerification,
587			serverExpectError:      true,
588		},
589		// Client: set a bad clientIdentityProvider
590		// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
591		// Expected Behavior: server side failure due to bad clientIdentityProvider
592		{
593			desc:                   "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS",
594			clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true},
595			clientRootProvider:     fakeProvider{isClient: true},
596			clientVerifyFunc:       clientVerifyFuncGood,
597			clientVType:            CertVerification,
598			serverMutualTLS:        true,
599			serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
600			serverRootProvider:     fakeProvider{isClient: false},
601			serverVType:            CertVerification,
602			serverExpectError:      true,
603		},
604		// Client: set clientIdentityProvider and clientRootProvider
605		// Server: set bad serverRootProvider with mutual TLS on
606		// Expected Behavior: server side failure due to bad serverRootProvider
607		{
608			desc:                   "Client sets root and identity provider; Server sets bad root provider; mutualTLS",
609			clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
610			clientRootProvider:     fakeProvider{isClient: true},
611			clientVerifyFunc:       clientVerifyFuncGood,
612			clientVType:            CertVerification,
613			serverMutualTLS:        true,
614			serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
615			serverRootProvider:     fakeProvider{isClient: false, wantError: true},
616			serverVType:            CertVerification,
617			serverExpectError:      true,
618		},
619		// Client: set clientIdentityProvider and clientRootProvider
620		// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
621		// Expected Behavior: success
622		{
623			desc:                   "Client sets root and identity provider; Server sets root and identity provider; mutualTLS",
624			clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
625			clientRootProvider:     fakeProvider{isClient: true},
626			clientVerifyFunc:       clientVerifyFuncGood,
627			clientVType:            CertVerification,
628			serverMutualTLS:        true,
629			serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
630			serverRootProvider:     fakeProvider{isClient: false},
631			serverVType:            CertVerification,
632		},
633		// Client: set clientIdentityProvider and clientRootProvider
634		// Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on
635		// Expected Behavior: success, because server side has SNI
636		{
637			desc:                   "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS",
638			clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
639			clientRootProvider:     fakeProvider{isClient: true},
640			clientVerifyFunc:       clientVerifyFuncGood,
641			clientVType:            CertVerification,
642			serverMutualTLS:        true,
643			serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true},
644			serverRootProvider:     fakeProvider{isClient: false},
645			serverVType:            CertVerification,
646		},
647	} {
648		test := test
649		t.Run(test.desc, func(t *testing.T) {
650			done := make(chan credentials.AuthInfo, 1)
651			lis, err := net.Listen("tcp", "localhost:0")
652			if err != nil {
653				t.Fatalf("Failed to listen: %v", err)
654			}
655			// Start a server using ServerOptions in another goroutine.
656			serverOptions := &ServerOptions{
657				IdentityOptions: IdentityCertificateOptions{
658					Certificates:                     test.serverCert,
659					GetIdentityCertificatesForServer: test.serverGetCert,
660					IdentityProvider:                 test.serverIdentityProvider,
661				},
662				RootOptions: RootCertificateOptions{
663					RootCACerts:         test.serverRoot,
664					GetRootCertificates: test.serverGetRoot,
665					RootProvider:        test.serverRootProvider,
666				},
667				RequireClientCert: test.serverMutualTLS,
668				VerifyPeer:        test.serverVerifyFunc,
669				VType:             test.serverVType,
670			}
671			go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
672				serverRawConn, err := lis.Accept()
673				if err != nil {
674					close(done)
675					return
676				}
677				serverTLS, err := NewServerCreds(serverOptions)
678				if err != nil {
679					serverRawConn.Close()
680					close(done)
681					return
682				}
683				_, serverAuthInfo, err := serverTLS.ServerHandshake(serverRawConn)
684				if err != nil {
685					serverRawConn.Close()
686					close(done)
687					return
688				}
689				done <- serverAuthInfo
690			}(done, lis, serverOptions)
691			defer lis.Close()
692			// Start a client using ClientOptions and connects to the server.
693			lisAddr := lis.Addr().String()
694			conn, err := net.Dial("tcp", lisAddr)
695			if err != nil {
696				t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
697			}
698			defer conn.Close()
699			clientOptions := &ClientOptions{
700				IdentityOptions: IdentityCertificateOptions{
701					Certificates:                     test.clientCert,
702					GetIdentityCertificatesForClient: test.clientGetCert,
703					IdentityProvider:                 test.clientIdentityProvider,
704				},
705				VerifyPeer: test.clientVerifyFunc,
706				RootOptions: RootCertificateOptions{
707					RootCACerts:         test.clientRoot,
708					GetRootCertificates: test.clientGetRoot,
709					RootProvider:        test.clientRootProvider,
710				},
711				VType: test.clientVType,
712			}
713			clientTLS, err := NewClientCreds(clientOptions)
714			if err != nil {
715				t.Fatalf("NewClientCreds failed: %v", err)
716			}
717			_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(),
718				lisAddr, conn)
719			// wait until server sends serverAuthInfo or fails.
720			serverAuthInfo, ok := <-done
721			if !ok && test.serverExpectError {
722				return
723			}
724			if ok && test.serverExpectError || !ok && !test.serverExpectError {
725				t.Fatalf("Server side error mismatch, got %v, want %v", !ok, test.serverExpectError)
726			}
727			if handshakeErr != nil && test.clientExpectHandshakeError {
728				return
729			}
730			if handshakeErr != nil && !test.clientExpectHandshakeError ||
731				handshakeErr == nil && test.clientExpectHandshakeError {
732				t.Fatalf("Expect error: %v, but err is %v",
733					test.clientExpectHandshakeError, handshakeErr)
734			}
735			if !compare(clientAuthInfo, serverAuthInfo) {
736				t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
737					clientAuthInfo, serverAuthInfo)
738			}
739		})
740	}
741}
742
743func compare(a1, a2 credentials.AuthInfo) bool {
744	if a1.AuthType() != a2.AuthType() {
745		return false
746	}
747	switch a1.AuthType() {
748	case "tls":
749		state1 := a1.(credentials.TLSInfo).State
750		state2 := a2.(credentials.TLSInfo).State
751		if state1.Version == state2.Version &&
752			state1.HandshakeComplete == state2.HandshakeComplete &&
753			state1.CipherSuite == state2.CipherSuite &&
754			state1.NegotiatedProtocol == state2.NegotiatedProtocol {
755			return true
756		}
757		return false
758	default:
759		return false
760	}
761}
762
763func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
764	expectedServerName := "server.name"
765	cs := &testutils.CertStore{}
766	if err := cs.LoadCerts(); err != nil {
767		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
768	}
769	clientOptions := &ClientOptions{
770		RootOptions: RootCertificateOptions{
771			RootCACerts: cs.ClientTrust1,
772		},
773		ServerNameOverride: expectedServerName,
774	}
775	c, err := NewClientCreds(clientOptions)
776	if err != nil {
777		t.Fatalf("Client is unable to create credentials. Error: %v", err)
778	}
779	c.OverrideServerName(expectedServerName)
780	if c.Info().ServerName != expectedServerName {
781		t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
782	}
783}
784
785func (s) TestGetCertificatesSNI(t *testing.T) {
786	cs := &testutils.CertStore{}
787	if err := cs.LoadCerts(); err != nil {
788		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
789	}
790	tests := []struct {
791		desc       string
792		serverName string
793		// Use Common Name on the certificate to differentiate if we choose the right cert. The common name on all of the three certs are different.
794		wantCommonName string
795	}{
796		{
797			desc: "Select ServerCert1",
798			// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
799			serverName:     "foo.bar.com",
800			wantCommonName: "foo.bar.com",
801		},
802		{
803			desc: "Select serverCert3",
804			// "foo.bar.server3.com" is the common name on server certificate server_cert_3.pem.
805			// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
806			serverName:     "google.com",
807			wantCommonName: "foo.bar.server3.com",
808		},
809	}
810	for _, test := range tests {
811		test := test
812		t.Run(test.desc, func(t *testing.T) {
813			serverOptions := &ServerOptions{
814				IdentityOptions: IdentityCertificateOptions{
815					GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
816						return []*tls.Certificate{&cs.ServerCert1, &cs.ServerCert2, &cs.ServerPeer3}, nil
817					},
818				},
819			}
820			serverConfig, err := serverOptions.config()
821			if err != nil {
822				t.Fatalf("serverOptions.config() failed: %v", err)
823			}
824			pointFormatUncompressed := uint8(0)
825			clientHello := &tls.ClientHelloInfo{
826				CipherSuites:      []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA},
827				ServerName:        test.serverName,
828				SupportedCurves:   []tls.CurveID{tls.CurveP256},
829				SupportedPoints:   []uint8{pointFormatUncompressed},
830				SupportedVersions: []uint16{tls.VersionTLS10},
831			}
832			gotCertificate, err := serverConfig.GetCertificate(clientHello)
833			if err != nil {
834				t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
835			}
836			if gotCertificate == nil || len(gotCertificate.Certificate) == 0 {
837				t.Fatalf("Got nil or empty Certificate after calling serverConfig.GetCertificate.")
838			}
839			parsedCert, err := x509.ParseCertificate(gotCertificate.Certificate[0])
840			if err != nil {
841				t.Fatalf("x509.ParseCertificate(%v) failed: %v", gotCertificate.Certificate[0], err)
842			}
843			if parsedCert == nil {
844				t.Fatalf("Got nil Certificate after calling x509.ParseCertificate.")
845			}
846			if parsedCert.Subject.CommonName != test.wantCommonName {
847				t.Errorf("Common name mismatch, got %v, want %v", parsedCert.Subject.CommonName, test.wantCommonName)
848			}
849		})
850	}
851}
852