1// +build go1.13
2
3/*
4 *
5 * Copyright 2020 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21// Package meshca provides an implementation of the Provider interface which
22// communicates with MeshCA to get certificates signed.
23package meshca
24
25import (
26	"context"
27	"crypto"
28	"crypto/rand"
29	"crypto/rsa"
30	"crypto/tls"
31	"crypto/x509"
32	"encoding/pem"
33	"fmt"
34	"time"
35
36	durationpb "github.com/golang/protobuf/ptypes/duration"
37	"github.com/google/uuid"
38
39	"google.golang.org/grpc"
40	"google.golang.org/grpc/credentials/tls/certprovider"
41	meshgrpc "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1"
42	meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1"
43	"google.golang.org/grpc/internal/grpclog"
44	"google.golang.org/grpc/metadata"
45)
46
47// In requests sent to the MeshCA, we add a metadata header with this key and
48// the value being the GCE zone in which the workload is running in.
49const locationMetadataKey = "x-goog-request-params"
50
51// For overriding from unit tests.
52var newDistributorFunc = func() distributor { return certprovider.NewDistributor() }
53
54// distributor wraps the methods on certprovider.Distributor which are used by
55// the plugin. This is very useful in tests which need to know exactly when the
56// plugin updates its key material.
57type distributor interface {
58	KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error)
59	Set(km *certprovider.KeyMaterial, err error)
60	Stop()
61}
62
63// providerPlugin is an implementation of the certprovider.Provider interface,
64// which gets certificates signed by communicating with the MeshCA.
65type providerPlugin struct {
66	distributor // Holds the key material.
67	cancel      context.CancelFunc
68	cc          *grpc.ClientConn          // Connection to MeshCA server.
69	cfg         *pluginConfig             // Plugin configuration.
70	opts        certprovider.BuildOptions // Key material options.
71	logger      *grpclog.PrefixLogger     // Plugin instance specific prefix.
72	backoff     func(int) time.Duration   // Exponential backoff.
73	doneFunc    func()                    // Notify the builder when done.
74}
75
76// providerParams wraps params passed to the provider plugin at creation time.
77type providerParams struct {
78	// This ClientConn to the MeshCA server is owned by the builder.
79	cc       *grpc.ClientConn
80	cfg      *pluginConfig
81	opts     certprovider.BuildOptions
82	backoff  func(int) time.Duration
83	doneFunc func()
84}
85
86func newProviderPlugin(params providerParams) *providerPlugin {
87	ctx, cancel := context.WithCancel(context.Background())
88	p := &providerPlugin{
89		cancel:      cancel,
90		cc:          params.cc,
91		cfg:         params.cfg,
92		opts:        params.opts,
93		backoff:     params.backoff,
94		doneFunc:    params.doneFunc,
95		distributor: newDistributorFunc(),
96	}
97	p.logger = prefixLogger((p))
98	p.logger.Infof("plugin created")
99	go p.run(ctx)
100	return p
101}
102
103func (p *providerPlugin) Close() {
104	p.logger.Infof("plugin closed")
105	p.Stop() // Stop the embedded distributor.
106	p.cancel()
107	p.doneFunc()
108}
109
110// run is a long running goroutine which periodically sends out CSRs to the
111// MeshCA, and updates the underlying Distributor with the new key material.
112func (p *providerPlugin) run(ctx context.Context) {
113	// We need to start fetching key material right away. The next attempt will
114	// be triggered by the timer firing.
115	for {
116		certValidity, err := p.updateKeyMaterial(ctx)
117		if err != nil {
118			return
119		}
120
121		// We request a certificate with the configured validity duration (which
122		// is usually twice as much as the grace period). But the server is free
123		// to return a certificate with whatever validity time it deems right.
124		refreshAfter := p.cfg.certGraceTime
125		if refreshAfter > certValidity {
126			// The default value of cert grace time is half that of the default
127			// cert validity time. So here, when we have to use a non-default
128			// cert life time, we will set the grace time again to half that of
129			// the validity time.
130			refreshAfter = certValidity / 2
131		}
132		timer := time.NewTimer(refreshAfter)
133		select {
134		case <-ctx.Done():
135			return
136		case <-timer.C:
137		}
138	}
139}
140
141// updateKeyMaterial generates a CSR and attempts to get it signed from the
142// MeshCA. It retries with an exponential backoff till it succeeds or the
143// deadline specified in ctx expires. Once it gets the CSR signed from the
144// MeshCA, it updates the Distributor with the new key material.
145//
146// It returns the amount of time the new certificate is valid for.
147func (p *providerPlugin) updateKeyMaterial(ctx context.Context) (time.Duration, error) {
148	client := meshgrpc.NewMeshCertificateServiceClient(p.cc)
149	retries := 0
150	for {
151		if ctx.Err() != nil {
152			return 0, ctx.Err()
153		}
154
155		if retries != 0 {
156			bi := p.backoff(retries)
157			p.logger.Warningf("Backing off for %s before attempting the next CreateCertificate() request", bi)
158			timer := time.NewTimer(bi)
159			select {
160			case <-timer.C:
161			case <-ctx.Done():
162				return 0, ctx.Err()
163			}
164		}
165		retries++
166
167		privKey, err := rsa.GenerateKey(rand.Reader, p.cfg.keySize)
168		if err != nil {
169			p.logger.Warningf("RSA key generation failed: %v", err)
170			continue
171		}
172		// We do not set any fields in the CSR (we use an empty
173		// x509.CertificateRequest as the template) because the MeshCA discards
174		// them anyways, and uses the workload identity from the access token
175		// that we present (as part of the STS call creds).
176		csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, crypto.PrivateKey(privKey))
177		if err != nil {
178			p.logger.Warningf("CSR creation failed: %v", err)
179			continue
180		}
181		csrPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes})
182
183		// Send out the CSR with a call timeout and location metadata, as
184		// specified in the plugin configuration.
185		req := &meshpb.MeshCertificateRequest{
186			RequestId: uuid.New().String(),
187			Csr:       string(csrPEM),
188			Validity:  &durationpb.Duration{Seconds: int64(p.cfg.certLifetime / time.Second)},
189		}
190		p.logger.Debugf("Sending CreateCertificate() request: %v", req)
191
192		callCtx, ctxCancel := context.WithTimeout(context.Background(), p.cfg.callTimeout)
193		callCtx = metadata.NewOutgoingContext(callCtx, metadata.Pairs(locationMetadataKey, p.cfg.location))
194		resp, err := client.CreateCertificate(callCtx, req)
195		if err != nil {
196			p.logger.Warningf("CreateCertificate request failed: %v", err)
197			ctxCancel()
198			continue
199		}
200		ctxCancel()
201
202		// The returned cert chain must contain more than one cert. Leaf cert is
203		// element '0', while root cert is element 'n', and the intermediate
204		// entries form the chain from the root to the leaf.
205		certChain := resp.GetCertChain()
206		if l := len(certChain); l <= 1 {
207			p.logger.Errorf("Received certificate chain contains %d certificates, need more than one", l)
208			continue
209		}
210
211		// We need to explicitly parse the PEM cert contents as an
212		// x509.Certificate to read the certificate validity period. We use this
213		// to decide when to refresh the cert. Even though the call to
214		// tls.X509KeyPair actually parses the PEM contents into an
215		// x509.Certificate, it does not store that in the `Leaf` field. See:
216		// https://golang.org/pkg/crypto/tls/#X509KeyPair.
217		identity, intermediates, roots, err := parseCertChain(certChain)
218		if err != nil {
219			p.logger.Errorf(err.Error())
220			continue
221		}
222		_, err = identity.Verify(x509.VerifyOptions{
223			Intermediates: intermediates,
224			Roots:         roots,
225		})
226		if err != nil {
227			p.logger.Errorf("Certificate verification failed for return certChain: %v", err)
228			continue
229		}
230
231		key := x509.MarshalPKCS1PrivateKey(privKey)
232		keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: key})
233		certPair, err := tls.X509KeyPair([]byte(certChain[0]), keyPEM)
234		if err != nil {
235			p.logger.Errorf("Failed to create x509 key pair: %v", err)
236			continue
237		}
238
239		// At this point, the received response has been deemed good.
240		retries = 0
241
242		// All certs signed by the MeshCA roll up to the same root. And treating
243		// the last element of the returned chain as the root is the only
244		// supported option to get the root certificate. So, we ignore the
245		// options specified in the call to Build(), which contain certificate
246		// name and whether the caller is interested in identity or root cert.
247		p.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{certPair}, Roots: roots}, nil)
248		return time.Until(identity.NotAfter), nil
249	}
250}
251
252// ParseCertChain parses the result returned by the MeshCA which consists of a
253// list of PEM encoded certs. The first element in the list is the leaf or
254// identity cert, while the last element is the root, and everything in between
255// form the chain of trust.
256//
257// Caller needs to make sure that certChain has at least two elements.
258func parseCertChain(certChain []string) (*x509.Certificate, *x509.CertPool, *x509.CertPool, error) {
259	identity, err := parseCert([]byte(certChain[0]))
260	if err != nil {
261		return nil, nil, nil, err
262	}
263
264	intermediates := x509.NewCertPool()
265	for _, cert := range certChain[1 : len(certChain)-1] {
266		i, err := parseCert([]byte(cert))
267		if err != nil {
268			return nil, nil, nil, err
269		}
270		intermediates.AddCert(i)
271	}
272
273	roots := x509.NewCertPool()
274	root, err := parseCert([]byte(certChain[len(certChain)-1]))
275	if err != nil {
276		return nil, nil, nil, err
277	}
278	roots.AddCert(root)
279
280	return identity, intermediates, roots, nil
281}
282
283func parseCert(certPEM []byte) (*x509.Certificate, error) {
284	block, _ := pem.Decode(certPEM)
285	if block == nil {
286		return nil, fmt.Errorf("failed to decode received PEM data: %v", certPEM)
287	}
288	return x509.ParseCertificate(block.Bytes)
289}
290