1/*
2 *
3 * Copyright 2019 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package advancedtls is a utility library containing functions to construct
20// credentials.TransportCredentials that can perform credential reloading and
21// custom verification check.
22package advancedtls
23
24import (
25	"context"
26	"crypto/tls"
27	"crypto/x509"
28	"fmt"
29	"net"
30	"reflect"
31	"time"
32
33	"google.golang.org/grpc/credentials"
34	"google.golang.org/grpc/credentials/tls/certprovider"
35	credinternal "google.golang.org/grpc/internal/credentials"
36)
37
38// VerificationFuncParams contains parameters available to users when
39// implementing CustomVerificationFunc.
40// The fields in this struct are read-only.
41type VerificationFuncParams struct {
42	// The target server name that the client connects to when establishing the
43	// connection. This field is only meaningful for client side. On server side,
44	// this field would be an empty string.
45	ServerName string
46	// The raw certificates sent from peer.
47	RawCerts [][]byte
48	// The verification chain obtained by checking peer RawCerts against the
49	// trust certificate bundle(s), if applicable.
50	VerifiedChains [][]*x509.Certificate
51	// The leaf certificate sent from peer, if choosing to verify the peer
52	// certificate(s) and that verification passed. This field would be nil if
53	// either user chose not to verify or the verification failed.
54	Leaf *x509.Certificate
55}
56
57// VerificationResults contains the information about results of
58// CustomVerificationFunc.
59// VerificationResults is an empty struct for now. It may be extended in the
60// future to include more information.
61type VerificationResults struct{}
62
63// CustomVerificationFunc is the function defined by users to perform custom
64// verification check.
65// CustomVerificationFunc returns nil if the authorization fails; otherwise
66// returns an empty struct.
67type CustomVerificationFunc func(params *VerificationFuncParams) (*VerificationResults, error)
68
69// GetRootCAsParams contains the parameters available to users when
70// implementing GetRootCAs.
71type GetRootCAsParams struct {
72	RawConn  net.Conn
73	RawCerts [][]byte
74}
75
76// GetRootCAsResults contains the results of GetRootCAs.
77// If users want to reload the root trust certificate, it is required to return
78// the proper TrustCerts in GetRootCAs.
79type GetRootCAsResults struct {
80	TrustCerts *x509.CertPool
81}
82
83// RootCertificateOptions contains options to obtain root trust certificates
84// for both the client and the server.
85// At most one option could be set. If none of them are set, we
86// use the system default trust certificates.
87type RootCertificateOptions struct {
88	// If RootCACerts is set, it will be used every time when verifying
89	// the peer certificates, without performing root certificate reloading.
90	RootCACerts *x509.CertPool
91	// If GetRootCertificates is set, it will be invoked to obtain root certs for
92	// every new connection.
93	GetRootCertificates func(params *GetRootCAsParams) (*GetRootCAsResults, error)
94	// If RootProvider is set, we will use the root certs from the Provider's
95	// KeyMaterial() call in the new connections. The Provider must have initial
96	// credentials if specified. Otherwise, KeyMaterial() will block forever.
97	RootProvider certprovider.Provider
98}
99
100// nonNilFieldCount returns the number of set fields in RootCertificateOptions.
101func (o RootCertificateOptions) nonNilFieldCount() int {
102	cnt := 0
103	rv := reflect.ValueOf(o)
104	for i := 0; i < rv.NumField(); i++ {
105		if !rv.Field(i).IsNil() {
106			cnt++
107		}
108	}
109	return cnt
110}
111
112// IdentityCertificateOptions contains options to obtain identity certificates
113// for both the client and the server.
114// At most one option could be set.
115type IdentityCertificateOptions struct {
116	// If Certificates is set, it will be used every time when needed to present
117	//identity certificates, without performing identity certificate reloading.
118	Certificates []tls.Certificate
119	// If GetIdentityCertificatesForClient is set, it will be invoked to obtain
120	// identity certs for every new connection.
121	// This field MUST be set on client side.
122	GetIdentityCertificatesForClient func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
123	// If GetIdentityCertificatesForServer is set, it will be invoked to obtain
124	// identity certs for every new connection.
125	// This field MUST be set on server side.
126	GetIdentityCertificatesForServer func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
127	// If IdentityProvider is set, we will use the identity certs from the
128	// Provider's KeyMaterial() call in the new connections. The Provider must
129	// have initial credentials if specified. Otherwise, KeyMaterial() will block
130	// forever.
131	IdentityProvider certprovider.Provider
132}
133
134// nonNilFieldCount returns the number of set fields in IdentityCertificateOptions.
135func (o IdentityCertificateOptions) nonNilFieldCount() int {
136	cnt := 0
137	rv := reflect.ValueOf(o)
138	for i := 0; i < rv.NumField(); i++ {
139		if !rv.Field(i).IsNil() {
140			cnt++
141		}
142	}
143	return cnt
144}
145
146// VerificationType is the enum type that represents different levels of
147// verification users could set, both on client side and on server side.
148type VerificationType int
149
150const (
151	// CertAndHostVerification indicates doing both certificate signature check
152	// and hostname check.
153	CertAndHostVerification VerificationType = iota
154	// CertVerification indicates doing certificate signature check only. Setting
155	// this field without proper custom verification check would leave the
156	// application susceptible to the MITM attack.
157	CertVerification
158	// SkipVerification indicates skipping both certificate signature check and
159	// hostname check. If setting this field, proper custom verification needs to
160	// be implemented in order to complete the authentication. Setting this field
161	// with a nil custom verification would raise an error.
162	SkipVerification
163)
164
165// ClientOptions contains the fields needed to be filled by the client.
166type ClientOptions struct {
167	// IdentityOptions is OPTIONAL on client side. This field only needs to be
168	// set if mutual authentication is required on server side.
169	IdentityOptions IdentityCertificateOptions
170	// VerifyPeer is a custom verification check after certificate signature
171	// check.
172	// If this is set, we will perform this customized check after doing the
173	// normal check(s) indicated by setting VType.
174	VerifyPeer CustomVerificationFunc
175	// ServerNameOverride is for testing only. If set to a non-empty string,
176	// it will override the virtual host name of authority (e.g. :authority
177	// header field) in requests.
178	ServerNameOverride string
179	// RootOptions is OPTIONAL on client side. If not set, we will try to use the
180	// default trust certificates in users' OS system.
181	RootOptions RootCertificateOptions
182	// VType is the verification type on the client side.
183	VType VerificationType
184}
185
186// ServerOptions contains the fields needed to be filled by the server.
187type ServerOptions struct {
188	// IdentityOptions is REQUIRED on server side.
189	IdentityOptions IdentityCertificateOptions
190	// VerifyPeer is a custom verification check after certificate signature
191	// check.
192	// If this is set, we will perform this customized check after doing the
193	// normal check(s) indicated by setting VType.
194	VerifyPeer CustomVerificationFunc
195	// RootOptions is OPTIONAL on server side. This field only needs to be set if
196	// mutual authentication is required(RequireClientCert is true).
197	RootOptions RootCertificateOptions
198	// If the server want the client to send certificates.
199	RequireClientCert bool
200	// VType is the verification type on the server side.
201	VType VerificationType
202}
203
204func (o *ClientOptions) config() (*tls.Config, error) {
205	if o.VType == SkipVerification && o.VerifyPeer == nil {
206		return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
207	}
208	// Make sure users didn't specify more than one fields in
209	// RootCertificateOptions and IdentityCertificateOptions.
210	if num := o.RootOptions.nonNilFieldCount(); num > 1 {
211		return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified")
212	}
213	if num := o.IdentityOptions.nonNilFieldCount(); num > 1 {
214		return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified")
215	}
216	if o.IdentityOptions.GetIdentityCertificatesForServer != nil {
217		return nil, fmt.Errorf("GetIdentityCertificatesForServer cannot be specified on the client side")
218	}
219	config := &tls.Config{
220		ServerName: o.ServerNameOverride,
221		// We have to set InsecureSkipVerify to true to skip the default checks and
222		// use the verification function we built from buildVerifyFunc.
223		InsecureSkipVerify: true,
224	}
225	// Propagate root-certificate-related fields in tls.Config.
226	switch {
227	case o.RootOptions.RootCACerts != nil:
228		config.RootCAs = o.RootOptions.RootCACerts
229	case o.RootOptions.GetRootCertificates != nil:
230		// In cases when users provide GetRootCertificates callback, since this
231		// callback is not contained in tls.Config, we have nothing to set here.
232		// We will invoke the callback in ClientHandshake.
233	case o.RootOptions.RootProvider != nil:
234		o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
235			km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
236			if err != nil {
237				return nil, err
238			}
239			return &GetRootCAsResults{TrustCerts: km.Roots}, nil
240		}
241	default:
242		// No root certificate options specified by user. Use the certificates
243		// stored in system default path as the last resort.
244		if o.VType != SkipVerification {
245			systemRootCAs, err := x509.SystemCertPool()
246			if err != nil {
247				return nil, err
248			}
249			config.RootCAs = systemRootCAs
250		}
251	}
252	// Propagate identity-certificate-related fields in tls.Config.
253	switch {
254	case o.IdentityOptions.Certificates != nil:
255		config.Certificates = o.IdentityOptions.Certificates
256	case o.IdentityOptions.GetIdentityCertificatesForClient != nil:
257		config.GetClientCertificate = o.IdentityOptions.GetIdentityCertificatesForClient
258	case o.IdentityOptions.IdentityProvider != nil:
259		config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
260			km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background())
261			if err != nil {
262				return nil, err
263			}
264			if len(km.Certs) != 1 {
265				return nil, fmt.Errorf("there should always be only one identity cert chain on the client side in IdentityProvider")
266			}
267			return &km.Certs[0], nil
268		}
269	default:
270		// It's fine for users to not specify identity certificate options here.
271	}
272	return config, nil
273}
274
275func (o *ServerOptions) config() (*tls.Config, error) {
276	if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
277		return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
278	}
279	// Make sure users didn't specify more than one fields in
280	// RootCertificateOptions and IdentityCertificateOptions.
281	if num := o.RootOptions.nonNilFieldCount(); num > 1 {
282		return nil, fmt.Errorf("at most one field in RootCertificateOptions could be specified")
283	}
284	if num := o.IdentityOptions.nonNilFieldCount(); num > 1 {
285		return nil, fmt.Errorf("at most one field in IdentityCertificateOptions could be specified")
286	}
287	if o.IdentityOptions.GetIdentityCertificatesForClient != nil {
288		return nil, fmt.Errorf("GetIdentityCertificatesForClient cannot be specified on the server side")
289	}
290	clientAuth := tls.NoClientCert
291	if o.RequireClientCert {
292		// We have to set clientAuth to RequireAnyClientCert to force underlying
293		// TLS package to use the verification function we built from
294		// buildVerifyFunc.
295		clientAuth = tls.RequireAnyClientCert
296	}
297	config := &tls.Config{
298		ClientAuth: clientAuth,
299	}
300	// Propagate root-certificate-related fields in tls.Config.
301	switch {
302	case o.RootOptions.RootCACerts != nil:
303		config.ClientCAs = o.RootOptions.RootCACerts
304	case o.RootOptions.GetRootCertificates != nil:
305		// In cases when users provide GetRootCertificates callback, since this
306		// callback is not contained in tls.Config, we have nothing to set here.
307		// We will invoke the callback in ServerHandshake.
308	case o.RootOptions.RootProvider != nil:
309		o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
310			km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
311			if err != nil {
312				return nil, err
313			}
314			return &GetRootCAsResults{TrustCerts: km.Roots}, nil
315		}
316	default:
317		// No root certificate options specified by user. Use the certificates
318		// stored in system default path as the last resort.
319		if o.VType != SkipVerification && o.RequireClientCert {
320			systemRootCAs, err := x509.SystemCertPool()
321			if err != nil {
322				return nil, err
323			}
324			config.ClientCAs = systemRootCAs
325		}
326	}
327	// Propagate identity-certificate-related fields in tls.Config.
328	switch {
329	case o.IdentityOptions.Certificates != nil:
330		config.Certificates = o.IdentityOptions.Certificates
331	case o.IdentityOptions.GetIdentityCertificatesForServer != nil:
332		config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
333			return buildGetCertificates(clientHello, o)
334		}
335	case o.IdentityOptions.IdentityProvider != nil:
336		o.IdentityOptions.GetIdentityCertificatesForServer = func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
337			km, err := o.IdentityOptions.IdentityProvider.KeyMaterial(context.Background())
338			if err != nil {
339				return nil, err
340			}
341			var certChains []*tls.Certificate
342			for i := 0; i < len(km.Certs); i++ {
343				certChains = append(certChains, &km.Certs[i])
344			}
345			return certChains, nil
346		}
347		config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
348			return buildGetCertificates(clientHello, o)
349		}
350	default:
351		return nil, fmt.Errorf("needs to specify at least one field in IdentityCertificateOptions")
352	}
353	return config, nil
354}
355
356// advancedTLSCreds is the credentials required for authenticating a connection
357// using TLS.
358type advancedTLSCreds struct {
359	config     *tls.Config
360	verifyFunc CustomVerificationFunc
361	getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
362	isClient   bool
363	vType      VerificationType
364}
365
366func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
367	return credentials.ProtocolInfo{
368		SecurityProtocol: "tls",
369		SecurityVersion:  "1.2",
370		ServerName:       c.config.ServerName,
371	}
372}
373
374func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
375	// Use local cfg to avoid clobbering ServerName if using multiple endpoints.
376	cfg := credinternal.CloneTLSConfig(c.config)
377	// We return the full authority name to users if ServerName is empty without
378	// stripping the trailing port.
379	if cfg.ServerName == "" {
380		cfg.ServerName = authority
381	}
382	cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn)
383	conn := tls.Client(rawConn, cfg)
384	errChannel := make(chan error, 1)
385	go func() {
386		errChannel <- conn.Handshake()
387		close(errChannel)
388	}()
389	select {
390	case err := <-errChannel:
391		if err != nil {
392			conn.Close()
393			return nil, nil, err
394		}
395	case <-ctx.Done():
396		conn.Close()
397		return nil, nil, ctx.Err()
398	}
399	info := credentials.TLSInfo{
400		State: conn.ConnectionState(),
401		CommonAuthInfo: credentials.CommonAuthInfo{
402			SecurityLevel: credentials.PrivacyAndIntegrity,
403		},
404	}
405	info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
406	return credinternal.WrapSyscallConn(rawConn, conn), info, nil
407}
408
409func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
410	cfg := credinternal.CloneTLSConfig(c.config)
411	cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
412	conn := tls.Server(rawConn, cfg)
413	if err := conn.Handshake(); err != nil {
414		conn.Close()
415		return nil, nil, err
416	}
417	info := credentials.TLSInfo{
418		State: conn.ConnectionState(),
419		CommonAuthInfo: credentials.CommonAuthInfo{
420			SecurityLevel: credentials.PrivacyAndIntegrity,
421		},
422	}
423	info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
424	return credinternal.WrapSyscallConn(rawConn, conn), info, nil
425}
426
427func (c *advancedTLSCreds) Clone() credentials.TransportCredentials {
428	return &advancedTLSCreds{
429		config:     credinternal.CloneTLSConfig(c.config),
430		verifyFunc: c.verifyFunc,
431		getRootCAs: c.getRootCAs,
432		isClient:   c.isClient,
433	}
434}
435
436func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
437	c.config.ServerName = serverNameOverride
438	return nil
439}
440
441// The function buildVerifyFunc is used when users want root cert reloading,
442// and possibly custom verification check.
443// We have to build our own verification function here because current
444// tls module:
445//   1. does not have a good support on root cert reloading.
446//   2. will ignore basic certificate check when setting InsecureSkipVerify
447//   to true.
448func buildVerifyFunc(c *advancedTLSCreds,
449	serverName string,
450	rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
451	return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
452		chains := verifiedChains
453		var leafCert *x509.Certificate
454		if c.vType == CertAndHostVerification || c.vType == CertVerification {
455			// perform possible trust credential reloading and certificate check
456			rootCAs := c.config.RootCAs
457			if !c.isClient {
458				rootCAs = c.config.ClientCAs
459			}
460			// Reload root CA certs.
461			if rootCAs == nil && c.getRootCAs != nil {
462				results, err := c.getRootCAs(&GetRootCAsParams{
463					RawConn:  rawConn,
464					RawCerts: rawCerts,
465				})
466				if err != nil {
467					return err
468				}
469				rootCAs = results.TrustCerts
470			}
471			// Verify peers' certificates against RootCAs and get verifiedChains.
472			certs := make([]*x509.Certificate, len(rawCerts))
473			for i, asn1Data := range rawCerts {
474				cert, err := x509.ParseCertificate(asn1Data)
475				if err != nil {
476					return err
477				}
478				certs[i] = cert
479			}
480			keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
481			if !c.isClient {
482				keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
483			}
484			opts := x509.VerifyOptions{
485				Roots:         rootCAs,
486				CurrentTime:   time.Now(),
487				Intermediates: x509.NewCertPool(),
488				KeyUsages:     keyUsages,
489			}
490			for _, cert := range certs[1:] {
491				opts.Intermediates.AddCert(cert)
492			}
493			// Perform default hostname check if specified.
494			if c.isClient && c.vType == CertAndHostVerification && serverName != "" {
495				parsedName, _, err := net.SplitHostPort(serverName)
496				if err != nil {
497					// If the serverName had no host port or if the serverName cannot be
498					// parsed, use it as-is.
499					parsedName = serverName
500				}
501				opts.DNSName = parsedName
502			}
503			var err error
504			chains, err = certs[0].Verify(opts)
505			if err != nil {
506				return err
507			}
508			leafCert = certs[0]
509		}
510		// Perform custom verification check if specified.
511		if c.verifyFunc != nil {
512			_, err := c.verifyFunc(&VerificationFuncParams{
513				ServerName:     serverName,
514				RawCerts:       rawCerts,
515				VerifiedChains: chains,
516				Leaf:           leafCert,
517			})
518			return err
519		}
520		return nil
521	}
522}
523
524// NewClientCreds uses ClientOptions to construct a TransportCredentials based
525// on TLS.
526func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) {
527	conf, err := o.config()
528	if err != nil {
529		return nil, err
530	}
531	tc := &advancedTLSCreds{
532		config:     conf,
533		isClient:   true,
534		getRootCAs: o.RootOptions.GetRootCertificates,
535		verifyFunc: o.VerifyPeer,
536		vType:      o.VType,
537	}
538	tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
539	return tc, nil
540}
541
542// NewServerCreds uses ServerOptions to construct a TransportCredentials based
543// on TLS.
544func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) {
545	conf, err := o.config()
546	if err != nil {
547		return nil, err
548	}
549	tc := &advancedTLSCreds{
550		config:     conf,
551		isClient:   false,
552		getRootCAs: o.RootOptions.GetRootCertificates,
553		verifyFunc: o.VerifyPeer,
554		vType:      o.VType,
555	}
556	tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
557	return tc, nil
558}
559