1// +build go1.12
2
3/*
4 *
5 * Copyright 2020 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package advancedtls
22
23import (
24	"context"
25	"crypto/tls"
26	"crypto/x509"
27	"fmt"
28	"io/ioutil"
29	"net"
30	"os"
31	"sync"
32	"testing"
33	"time"
34
35	"google.golang.org/grpc"
36	"google.golang.org/grpc/credentials"
37	"google.golang.org/grpc/credentials/tls/certprovider"
38	"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
39	pb "google.golang.org/grpc/examples/helloworld/helloworld"
40	"google.golang.org/grpc/security/advancedtls/internal/testutils"
41	"google.golang.org/grpc/security/advancedtls/testdata"
42)
43
44const (
45	// Default timeout for normal connections.
46	defaultTestTimeout = 5 * time.Second
47	// Default timeout for failed connections.
48	defaultTestShortTimeout = 10 * time.Millisecond
49	// Intervals that set to monitor the credential updates.
50	credRefreshingInterval = 200 * time.Millisecond
51	// Time we wait for the credential updates to be picked up.
52	sleepInterval = 400 * time.Millisecond
53)
54
55// stageInfo contains a stage number indicating the current phase of each
56// integration test, and a mutex.
57// Based on the stage number of current test, we will use different
58// certificates and custom verification functions to check if our tests behave
59// as expected.
60type stageInfo struct {
61	mutex sync.Mutex
62	stage int
63}
64
65func (s *stageInfo) increase() {
66	s.mutex.Lock()
67	defer s.mutex.Unlock()
68	s.stage = s.stage + 1
69}
70
71func (s *stageInfo) read() int {
72	s.mutex.Lock()
73	defer s.mutex.Unlock()
74	return s.stage
75}
76
77func (s *stageInfo) reset() {
78	s.mutex.Lock()
79	defer s.mutex.Unlock()
80	s.stage = 0
81}
82
83type greeterServer struct {
84	pb.UnimplementedGreeterServer
85}
86
87// sayHello is a simple implementation of the pb.GreeterServer SayHello method.
88func (greeterServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
89	return &pb.HelloReply{Message: "Hello " + in.Name}, nil
90}
91
92// TODO(ZhenLian): remove shouldFail to the function signature to provider
93// tests.
94func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error {
95	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
96	defer cancel()
97	_, err := client.SayHello(ctx, &pb.HelloRequest{Name: msg})
98	if want, got := shouldFail == true, err != nil; got != want {
99		return fmt.Errorf("want and got mismatch,  want shouldFail=%v, got fail=%v, rpc error: %v", want, got, err)
100	}
101	return nil
102}
103
104// TODO(ZhenLian): remove shouldFail and add ...DialOption to the function
105// signature to provider cleaner tests.
106func callAndVerifyWithClientConn(connCtx context.Context, address string, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) {
107	var conn *grpc.ClientConn
108	var err error
109	// If we want the test to fail, we establish a non-blocking connection to
110	// avoid it hangs and killed by the context.
111	if shouldFail {
112		conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds))
113		if err != nil {
114			return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err)
115		}
116	} else {
117		conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds), grpc.WithBlock())
118		if err != nil {
119			return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err)
120		}
121	}
122	greetClient := pb.NewGreeterClient(conn)
123	err = callAndVerify(msg, greetClient, shouldFail)
124	if err != nil {
125		return nil, nil, err
126	}
127	return conn, greetClient, nil
128}
129
130// The advanced TLS features are tested in different stages.
131// At stage 0, we establish a good connection between client and server.
132// At stage 1, we change one factor(it could be we change the server's
133// certificate, or custom verification function, etc), and test if the
134// following connections would be dropped.
135// At stage 2, we re-establish the connection by changing the counterpart of
136// the factor we modified in stage 1.
137// (could be change the client's trust certificate, or change custom
138// verification function, etc)
139func (s) TestEnd2End(t *testing.T) {
140	cs := &testutils.CertStore{}
141	if err := cs.LoadCerts(); err != nil {
142		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
143	}
144	stage := &stageInfo{}
145	for _, test := range []struct {
146		desc             string
147		clientCert       []tls.Certificate
148		clientGetCert    func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
149		clientRoot       *x509.CertPool
150		clientGetRoot    func(params *GetRootCAsParams) (*GetRootCAsResults, error)
151		clientVerifyFunc CustomVerificationFunc
152		clientVType      VerificationType
153		serverCert       []tls.Certificate
154		serverGetCert    func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
155		serverRoot       *x509.CertPool
156		serverGetRoot    func(params *GetRootCAsParams) (*GetRootCAsResults, error)
157		serverVerifyFunc CustomVerificationFunc
158		serverVType      VerificationType
159	}{
160		// Test Scenarios:
161		// At initialization(stage = 0), client will be initialized with cert
162		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
163		// The mutual authentication works at the beginning, since ClientCert1 is
164		// trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
165		// At stage 1, client changes ClientCert1 to ClientCert2. Since ClientCert2
166		// is not trusted by ServerTrust1, following rpc calls are expected to
167		// fail, while the previous rpc calls are still good because those are
168		// already authenticated.
169		// At stage 2, the server changes ServerTrust1 to ServerTrust2, and we
170		// should see it again accepts the connection, since ClientCert2 is trusted
171		// by ServerTrust2.
172		{
173			desc: "test the reloading feature for client identity callback and server trust callback",
174			clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
175				switch stage.read() {
176				case 0:
177					return &cs.ClientCert1, nil
178				default:
179					return &cs.ClientCert2, nil
180				}
181			},
182			clientRoot: cs.ClientTrust1,
183			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
184				return &VerificationResults{}, nil
185			},
186			clientVType: CertVerification,
187			serverCert:  []tls.Certificate{cs.ServerCert1},
188			serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
189				switch stage.read() {
190				case 0, 1:
191					return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
192				default:
193					return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
194				}
195			},
196			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
197				return &VerificationResults{}, nil
198			},
199			serverVType: CertVerification,
200		},
201		// Test Scenarios:
202		// At initialization(stage = 0), client will be initialized with cert
203		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
204		// The mutual authentication works at the beginning, since ClientCert1 is
205		// trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
206		// At stage 1, server changes ServerCert1 to ServerCert2. Since ServerCert2
207		// is not trusted by ClientTrust1, following rpc calls are expected to
208		// fail, while the previous rpc calls are still good because those are
209		// already authenticated.
210		// At stage 2, the client changes ClientTrust1 to ClientTrust2, and we
211		// should see it again accepts the connection, since ServerCert2 is trusted
212		// by ClientTrust2.
213		{
214			desc:       "test the reloading feature for server identity callback and client trust callback",
215			clientCert: []tls.Certificate{cs.ClientCert1},
216			clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
217				switch stage.read() {
218				case 0, 1:
219					return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
220				default:
221					return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
222				}
223			},
224			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
225				return &VerificationResults{}, nil
226			},
227			clientVType: CertVerification,
228			serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
229				switch stage.read() {
230				case 0:
231					return []*tls.Certificate{&cs.ServerCert1}, nil
232				default:
233					return []*tls.Certificate{&cs.ServerCert2}, nil
234				}
235			},
236			serverRoot: cs.ServerTrust1,
237			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
238				return &VerificationResults{}, nil
239			},
240			serverVType: CertVerification,
241		},
242		// Test Scenarios:
243		// At initialization(stage = 0), client will be initialized with cert
244		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
245		// The mutual authentication works at the beginning, since ClientCert1
246		// trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
247		// custom verification check allows the CommonName on ServerCert1.
248		// At stage 1, server changes ServerCert1 to ServerCert2, and client
249		// changes ClientTrust1 to ClientTrust2. Although ServerCert2 is trusted by
250		// ClientTrust2, our authorization check only accepts ServerCert1, and
251		// hence the following calls should fail. Previous connections should
252		// not be affected.
253		// At stage 2, the client changes authorization check to only accept
254		// ServerCert2. Now we should see the connection becomes normal again.
255		{
256			desc:       "test client custom verification",
257			clientCert: []tls.Certificate{cs.ClientCert1},
258			clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
259				switch stage.read() {
260				case 0:
261					return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
262				default:
263					return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
264				}
265			},
266			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
267				if len(params.RawCerts) == 0 {
268					return nil, fmt.Errorf("no peer certs")
269				}
270				cert, err := x509.ParseCertificate(params.RawCerts[0])
271				if err != nil || cert == nil {
272					return nil, fmt.Errorf("failed to parse certificate: " + err.Error())
273				}
274				authzCheck := false
275				switch stage.read() {
276				case 0, 1:
277					// foo.bar.com is the common name on ServerCert1
278					if cert.Subject.CommonName == "foo.bar.com" {
279						authzCheck = true
280					}
281				default:
282					// foo.bar.server2.com is the common name on ServerCert2
283					if cert.Subject.CommonName == "foo.bar.server2.com" {
284						authzCheck = true
285					}
286				}
287				if authzCheck {
288					return &VerificationResults{}, nil
289				}
290				return nil, fmt.Errorf("custom authz check fails")
291			},
292			clientVType: CertVerification,
293			serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
294				switch stage.read() {
295				case 0:
296					return []*tls.Certificate{&cs.ServerCert1}, nil
297				default:
298					return []*tls.Certificate{&cs.ServerCert2}, nil
299				}
300			},
301			serverRoot: cs.ServerTrust1,
302			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
303				return &VerificationResults{}, nil
304			},
305			serverVType: CertVerification,
306		},
307		// Test Scenarios:
308		// At initialization(stage = 0), client will be initialized with cert
309		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
310		// The mutual authentication works at the beginning, since ClientCert1
311		// trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
312		// custom verification check on server side allows all connections.
313		// At stage 1, server disallows the the connections by setting custom
314		// verification check. The following calls should fail. Previous
315		// connections should not be affected.
316		// At stage 2, server allows all the connections again and the
317		// authentications should go back to normal.
318		{
319			desc:       "TestServerCustomVerification",
320			clientCert: []tls.Certificate{cs.ClientCert1},
321			clientRoot: cs.ClientTrust1,
322			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
323				return &VerificationResults{}, nil
324			},
325			clientVType: CertVerification,
326			serverCert:  []tls.Certificate{cs.ServerCert1},
327			serverRoot:  cs.ServerTrust1,
328			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
329				switch stage.read() {
330				case 0, 2:
331					return &VerificationResults{}, nil
332				case 1:
333					return nil, fmt.Errorf("custom authz check fails")
334				default:
335					return nil, fmt.Errorf("custom authz check fails")
336				}
337			},
338			serverVType: CertVerification,
339		},
340	} {
341		test := test
342		t.Run(test.desc, func(t *testing.T) {
343			// Start a server using ServerOptions in another goroutine.
344			serverOptions := &ServerOptions{
345				IdentityOptions: IdentityCertificateOptions{
346					Certificates:                     test.serverCert,
347					GetIdentityCertificatesForServer: test.serverGetCert,
348				},
349				RootOptions: RootCertificateOptions{
350					RootCACerts:         test.serverRoot,
351					GetRootCertificates: test.serverGetRoot,
352				},
353				RequireClientCert: true,
354				VerifyPeer:        test.serverVerifyFunc,
355				VType:             test.serverVType,
356			}
357			serverTLSCreds, err := NewServerCreds(serverOptions)
358			if err != nil {
359				t.Fatalf("failed to create server creds: %v", err)
360			}
361			s := grpc.NewServer(grpc.Creds(serverTLSCreds))
362			defer s.Stop()
363			lis, err := net.Listen("tcp", "localhost:0")
364			if err != nil {
365				t.Fatalf("failed to listen: %v", err)
366			}
367			defer lis.Close()
368			addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port)
369			pb.RegisterGreeterServer(s, greeterServer{})
370			go s.Serve(lis)
371			clientOptions := &ClientOptions{
372				IdentityOptions: IdentityCertificateOptions{
373					Certificates:                     test.clientCert,
374					GetIdentityCertificatesForClient: test.clientGetCert,
375				},
376				VerifyPeer: test.clientVerifyFunc,
377				RootOptions: RootCertificateOptions{
378					RootCACerts:         test.clientRoot,
379					GetRootCertificates: test.clientGetRoot,
380				},
381				VType: test.clientVType,
382			}
383			clientTLSCreds, err := NewClientCreds(clientOptions)
384			if err != nil {
385				t.Fatalf("clientTLSCreds failed to create")
386			}
387			// ------------------------Scenario 1------------------------------------
388			// stage = 0, initial connection should succeed
389			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
390			defer cancel()
391			conn, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, false)
392			if err != nil {
393				t.Fatal(err)
394			}
395			defer conn.Close()
396			// ----------------------------------------------------------------------
397			stage.increase()
398			// ------------------------Scenario 2------------------------------------
399			// stage = 1, previous connection should still succeed
400			err = callAndVerify("rpc call 2", greetClient, false)
401			if err != nil {
402				t.Fatal(err)
403			}
404			// ------------------------Scenario 3------------------------------------
405			// stage = 1, new connection should fail
406			shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
407			defer shortCancel()
408			conn2, greetClient, err := callAndVerifyWithClientConn(shortCtx, addr, "rpc call 3", clientTLSCreds, true)
409			if err != nil {
410				t.Fatal(err)
411			}
412			defer conn2.Close()
413			// ----------------------------------------------------------------------
414			stage.increase()
415			// ------------------------Scenario 4------------------------------------
416			// stage = 2,  new connection should succeed
417			conn3, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 4", clientTLSCreds, false)
418			if err != nil {
419				t.Fatal(err)
420			}
421			defer conn3.Close()
422			// ----------------------------------------------------------------------
423			stage.reset()
424		})
425	}
426}
427
428type tmpCredsFiles struct {
429	clientCertTmp  *os.File
430	clientKeyTmp   *os.File
431	clientTrustTmp *os.File
432	serverCertTmp  *os.File
433	serverKeyTmp   *os.File
434	serverTrustTmp *os.File
435}
436
437// Create temp files that are used to hold credentials.
438func createTmpFiles() (*tmpCredsFiles, error) {
439	tmpFiles := &tmpCredsFiles{}
440	var err error
441	tmpFiles.clientCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
442	if err != nil {
443		return nil, err
444	}
445	tmpFiles.clientKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
446	if err != nil {
447		return nil, err
448	}
449	tmpFiles.clientTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
450	if err != nil {
451		return nil, err
452	}
453	tmpFiles.serverCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
454	if err != nil {
455		return nil, err
456	}
457	tmpFiles.serverKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
458	if err != nil {
459		return nil, err
460	}
461	tmpFiles.serverTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
462	if err != nil {
463		return nil, err
464	}
465	return tmpFiles, nil
466}
467
468// Copy the credential contents to the temporary files.
469func (tmpFiles *tmpCredsFiles) copyCredsToTmpFiles() error {
470	if err := copyFileContents(testdata.Path("client_cert_1.pem"), tmpFiles.clientCertTmp.Name()); err != nil {
471		return err
472	}
473	if err := copyFileContents(testdata.Path("client_key_1.pem"), tmpFiles.clientKeyTmp.Name()); err != nil {
474		return err
475	}
476	if err := copyFileContents(testdata.Path("client_trust_cert_1.pem"), tmpFiles.clientTrustTmp.Name()); err != nil {
477		return err
478	}
479	if err := copyFileContents(testdata.Path("server_cert_1.pem"), tmpFiles.serverCertTmp.Name()); err != nil {
480		return err
481	}
482	if err := copyFileContents(testdata.Path("server_key_1.pem"), tmpFiles.serverKeyTmp.Name()); err != nil {
483		return err
484	}
485	if err := copyFileContents(testdata.Path("server_trust_cert_1.pem"), tmpFiles.serverTrustTmp.Name()); err != nil {
486		return err
487	}
488	return nil
489}
490
491func (tmpFiles *tmpCredsFiles) removeFiles() {
492	os.Remove(tmpFiles.clientCertTmp.Name())
493	os.Remove(tmpFiles.clientKeyTmp.Name())
494	os.Remove(tmpFiles.clientTrustTmp.Name())
495	os.Remove(tmpFiles.serverCertTmp.Name())
496	os.Remove(tmpFiles.serverKeyTmp.Name())
497	os.Remove(tmpFiles.serverTrustTmp.Name())
498}
499
500func copyFileContents(sourceFile, destinationFile string) error {
501	input, err := ioutil.ReadFile(sourceFile)
502	if err != nil {
503		return err
504	}
505	err = ioutil.WriteFile(destinationFile, input, 0644)
506	if err != nil {
507		return err
508	}
509	return nil
510}
511
512// Create PEMFileProvider(s) watching the content changes of temporary
513// files.
514func createProviders(tmpFiles *tmpCredsFiles) (certprovider.Provider, certprovider.Provider, certprovider.Provider, certprovider.Provider, error) {
515	clientIdentityOptions := pemfile.Options{
516		CertFile:        tmpFiles.clientCertTmp.Name(),
517		KeyFile:         tmpFiles.clientKeyTmp.Name(),
518		RefreshDuration: credRefreshingInterval,
519	}
520	clientIdentityProvider, err := pemfile.NewProvider(clientIdentityOptions)
521	if err != nil {
522		return nil, nil, nil, nil, err
523	}
524	clientRootOptions := pemfile.Options{
525		RootFile:        tmpFiles.clientTrustTmp.Name(),
526		RefreshDuration: credRefreshingInterval,
527	}
528	clientRootProvider, err := pemfile.NewProvider(clientRootOptions)
529	if err != nil {
530		return nil, nil, nil, nil, err
531	}
532	serverIdentityOptions := pemfile.Options{
533		CertFile:        tmpFiles.serverCertTmp.Name(),
534		KeyFile:         tmpFiles.serverKeyTmp.Name(),
535		RefreshDuration: credRefreshingInterval,
536	}
537	serverIdentityProvider, err := pemfile.NewProvider(serverIdentityOptions)
538	if err != nil {
539		return nil, nil, nil, nil, err
540	}
541	serverRootOptions := pemfile.Options{
542		RootFile:        tmpFiles.serverTrustTmp.Name(),
543		RefreshDuration: credRefreshingInterval,
544	}
545	serverRootProvider, err := pemfile.NewProvider(serverRootOptions)
546	if err != nil {
547		return nil, nil, nil, nil, err
548	}
549	return clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, nil
550}
551
552// In order to test advanced TLS provider features, we used temporary files to
553// hold credential data, and copy the contents under testdata/ to these tmp
554// files.
555// Initially, we establish a good connection with providers watching contents
556// from tmp files.
557// Next, we change the identity certs that IdentityProvider is watching. Since
558// the identity key is not changed, the IdentityProvider should ignore the
559// update, and the connection should still be good.
560// Then the the identity key is changed. This time IdentityProvider should pick
561// up the update, and the connection should fail, due to the trust certs on the
562// other side is not changed.
563// Finally, the trust certs that other-side's RootProvider is watching get
564// changed. The connection should go back to normal again.
565func (s) TestPEMFileProviderEnd2End(t *testing.T) {
566	tmpFiles, err := createTmpFiles()
567	if err != nil {
568		t.Fatalf("createTmpFiles() failed, error: %v", err)
569	}
570	defer tmpFiles.removeFiles()
571	for _, test := range []struct {
572		desc                string
573		certUpdateFunc      func()
574		keyUpdateFunc       func()
575		trustCertUpdateFunc func()
576	}{
577		{
578			desc: "test the reloading feature for clientIdentityProvider and serverTrustProvider",
579			certUpdateFunc: func() {
580				err = copyFileContents(testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name())
581				if err != nil {
582					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name(), err)
583				}
584			},
585			keyUpdateFunc: func() {
586				err = copyFileContents(testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name())
587				if err != nil {
588					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name(), err)
589				}
590			},
591			trustCertUpdateFunc: func() {
592				err = copyFileContents(testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name())
593				if err != nil {
594					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name(), err)
595				}
596			},
597		},
598		{
599			desc: "test the reloading feature for serverIdentityProvider and clientTrustProvider",
600			certUpdateFunc: func() {
601				err = copyFileContents(testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name())
602				if err != nil {
603					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name(), err)
604				}
605			},
606			keyUpdateFunc: func() {
607				err = copyFileContents(testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name())
608				if err != nil {
609					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name(), err)
610				}
611			},
612			trustCertUpdateFunc: func() {
613				err = copyFileContents(testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name())
614				if err != nil {
615					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name(), err)
616				}
617			},
618		},
619	} {
620		test := test
621		t.Run(test.desc, func(t *testing.T) {
622			if err := tmpFiles.copyCredsToTmpFiles(); err != nil {
623				t.Fatalf("tmpFiles.copyCredsToTmpFiles() failed, error: %v", err)
624			}
625			clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, err := createProviders(tmpFiles)
626			if err != nil {
627				t.Fatalf("createProviders(%v) failed, error: %v", tmpFiles, err)
628			}
629			defer clientIdentityProvider.Close()
630			defer clientRootProvider.Close()
631			defer serverIdentityProvider.Close()
632			defer serverRootProvider.Close()
633			// Start a server and create a client using advancedtls API with Provider.
634			serverOptions := &ServerOptions{
635				IdentityOptions: IdentityCertificateOptions{
636					IdentityProvider: serverIdentityProvider,
637				},
638				RootOptions: RootCertificateOptions{
639					RootProvider: serverRootProvider,
640				},
641				RequireClientCert: true,
642				VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
643					return &VerificationResults{}, nil
644				},
645				VType: CertVerification,
646			}
647			serverTLSCreds, err := NewServerCreds(serverOptions)
648			if err != nil {
649				t.Fatalf("failed to create server creds: %v", err)
650			}
651			s := grpc.NewServer(grpc.Creds(serverTLSCreds))
652			defer s.Stop()
653			lis, err := net.Listen("tcp", "localhost:0")
654			if err != nil {
655				t.Fatalf("failed to listen: %v", err)
656			}
657			defer lis.Close()
658			addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port)
659			pb.RegisterGreeterServer(s, greeterServer{})
660			go s.Serve(lis)
661			clientOptions := &ClientOptions{
662				IdentityOptions: IdentityCertificateOptions{
663					IdentityProvider: clientIdentityProvider,
664				},
665				VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
666					return &VerificationResults{}, nil
667				},
668				RootOptions: RootCertificateOptions{
669					RootProvider: clientRootProvider,
670				},
671				VType: CertVerification,
672			}
673			clientTLSCreds, err := NewClientCreds(clientOptions)
674			if err != nil {
675				t.Fatalf("clientTLSCreds failed to create, error: %v", err)
676			}
677
678			// At initialization, the connection should be good.
679			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
680			defer cancel()
681			conn, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, false)
682			if err != nil {
683				t.Fatal(err)
684			}
685			defer conn.Close()
686			// Make the identity cert change, and wait 1 second for the provider to
687			// pick up the change.
688			test.certUpdateFunc()
689			time.Sleep(sleepInterval)
690			// The already-established connection should not be affected.
691			err = callAndVerify("rpc call 2", greetClient, false)
692			if err != nil {
693				t.Fatal(err)
694			}
695			// New connections should still be good, because the Provider didn't pick
696			// up the changes due to key-cert mismatch.
697			conn2, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 3", clientTLSCreds, false)
698			if err != nil {
699				t.Fatal(err)
700			}
701			defer conn2.Close()
702			// Make the identity key change, and wait 1 second for the provider to
703			// pick up the change.
704			test.keyUpdateFunc()
705			time.Sleep(sleepInterval)
706			// New connections should fail now, because the Provider picked the
707			// change, and *_cert_2.pem is not trusted by *_trust_cert_1.pem on the
708			// other side.
709			shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
710			defer shortCancel()
711			conn3, greetClient, err := callAndVerifyWithClientConn(shortCtx, addr, "rpc call 4", clientTLSCreds, true)
712			if err != nil {
713				t.Fatal(err)
714			}
715			defer conn3.Close()
716			// Make the trust cert change on the other side, and wait 1 second for
717			// the provider to pick up the change.
718			test.trustCertUpdateFunc()
719			time.Sleep(sleepInterval)
720			// New connections should be good, because the other side is using
721			// *_trust_cert_2.pem now.
722			conn4, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 5", clientTLSCreds, false)
723			if err != nil {
724				t.Fatal(err)
725			}
726			defer conn4.Close()
727		})
728	}
729}
730
731func (s) TestDefaultHostNameCheck(t *testing.T) {
732	cs := &testutils.CertStore{}
733	if err := cs.LoadCerts(); err != nil {
734		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
735	}
736	for _, test := range []struct {
737		desc             string
738		clientRoot       *x509.CertPool
739		clientVerifyFunc CustomVerificationFunc
740		clientVType      VerificationType
741		serverCert       []tls.Certificate
742		serverVType      VerificationType
743		expectError      bool
744	}{
745		// Client side sets vType to CertAndHostVerification, and will do
746		// default hostname check. Server uses a cert without "localhost" or
747		// "127.0.0.1" as common name or SAN names, and will hence fail.
748		{
749			desc:        "Bad default hostname check",
750			clientRoot:  cs.ClientTrust1,
751			clientVType: CertAndHostVerification,
752			serverCert:  []tls.Certificate{cs.ServerCert1},
753			serverVType: CertAndHostVerification,
754			expectError: true,
755		},
756		// Client side sets vType to CertAndHostVerification, and will do
757		// default hostname check. Server uses a certificate with "localhost" as
758		// common name, and will hence pass the default hostname check.
759		{
760			desc:        "Good default hostname check",
761			clientRoot:  cs.ClientTrust1,
762			clientVType: CertAndHostVerification,
763			serverCert:  []tls.Certificate{cs.ServerPeerLocalhost1},
764			serverVType: CertAndHostVerification,
765			expectError: false,
766		},
767	} {
768		test := test
769		t.Run(test.desc, func(t *testing.T) {
770			// Start a server using ServerOptions in another goroutine.
771			serverOptions := &ServerOptions{
772				IdentityOptions: IdentityCertificateOptions{
773					Certificates: test.serverCert,
774				},
775				RequireClientCert: false,
776				VType:             test.serverVType,
777			}
778			serverTLSCreds, err := NewServerCreds(serverOptions)
779			if err != nil {
780				t.Fatalf("failed to create server creds: %v", err)
781			}
782			s := grpc.NewServer(grpc.Creds(serverTLSCreds))
783			defer s.Stop()
784			lis, err := net.Listen("tcp", "localhost:0")
785			if err != nil {
786				t.Fatalf("failed to listen: %v", err)
787			}
788			defer lis.Close()
789			addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port)
790			pb.RegisterGreeterServer(s, greeterServer{})
791			go s.Serve(lis)
792			clientOptions := &ClientOptions{
793				VerifyPeer: test.clientVerifyFunc,
794				RootOptions: RootCertificateOptions{
795					RootCACerts: test.clientRoot,
796				},
797				VType: test.clientVType,
798			}
799			clientTLSCreds, err := NewClientCreds(clientOptions)
800			if err != nil {
801				t.Fatalf("clientTLSCreds failed to create")
802			}
803			shouldFail := false
804			if test.expectError {
805				shouldFail = true
806			}
807			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
808			defer cancel()
809			conn, _, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, shouldFail)
810			if err != nil {
811				t.Fatal(err)
812			}
813			defer conn.Close()
814		})
815	}
816}
817