1package gcpkms
2
3import (
4	"context"
5	"sync"
6	"time"
7
8	"github.com/hashicorp/errwrap"
9	"github.com/hashicorp/vault/sdk/framework"
10	"github.com/hashicorp/vault/sdk/helper/useragent"
11	"github.com/hashicorp/vault/sdk/logical"
12	"golang.org/x/oauth2/google"
13	"google.golang.org/api/option"
14
15	kmsapi "cloud.google.com/go/kms/apiv1"
16)
17
18var (
19	// defaultClientLifetime is the amount of time to cache the KMS client. This
20	// has to be less than 60 minutes or the oauth token will expire and
21	// subsequent requests will fail. The reason we cache the client is because
22	// the process for looking up credentials is not performant and the overhead
23	// is too significant for a plugin that will receive this much traffic.
24	defaultClientLifetime = 30 * time.Minute
25)
26
27type backend struct {
28	*framework.Backend
29
30	// kmsClient is the actual client for connecting to KMS. It is cached on
31	// the backend for efficiency.
32	kmsClient           *kmsapi.KeyManagementClient
33	kmsClientCreateTime time.Time
34	kmsClientLifetime   time.Duration
35	kmsClientLock       sync.RWMutex
36
37	// ctx and ctxCancel are used to control overall plugin shutdown. These
38	// contexts are given to any client libraries or requests that should be
39	// terminated during plugin termination.
40	ctx       context.Context
41	ctxCancel context.CancelFunc
42	ctxLock   sync.Mutex
43}
44
45// Factory returns a configured instance of the backend.
46func Factory(ctx context.Context, c *logical.BackendConfig) (logical.Backend, error) {
47	b := Backend()
48	if err := b.Setup(ctx, c); err != nil {
49		return nil, err
50	}
51	return b, nil
52}
53
54// Backend returns a configured instance of the backend.
55func Backend() *backend {
56	var b backend
57
58	b.kmsClientLifetime = defaultClientLifetime
59	b.ctx, b.ctxCancel = context.WithCancel(context.Background())
60
61	b.Backend = &framework.Backend{
62		BackendType: logical.TypeLogical,
63		Help: "The GCP KMS secrets engine provides pass-through encryption and " +
64			"decryption to Google Cloud KMS keys.",
65
66		Paths: []*framework.Path{
67			b.pathConfig(),
68
69			b.pathKeys(),
70			b.pathKeysCRUD(),
71			b.pathKeysConfigCRUD(),
72			b.pathKeysDeregister(),
73			b.pathKeysRegister(),
74			b.pathKeysRotate(),
75			b.pathKeysTrim(),
76
77			b.pathDecrypt(),
78			b.pathEncrypt(),
79			b.pathPubkey(),
80			b.pathReencrypt(),
81			b.pathSign(),
82			b.pathVerify(),
83		},
84
85		Invalidate: b.invalidate,
86		Clean:      b.clean,
87	}
88
89	return &b
90}
91
92// clean cancels the shared contexts. This is called just before unmounting
93// the plugin.
94func (b *backend) clean(_ context.Context) {
95	b.ctxLock.Lock()
96	b.ctxCancel()
97	b.ctxLock.Unlock()
98}
99
100// invalidate resets the plugin. This is called when a key is updated via
101// replication.
102func (b *backend) invalidate(ctx context.Context, key string) {
103	switch key {
104	case "config":
105		b.ResetClient()
106	}
107}
108
109// ResetClient closes any connected clients.
110func (b *backend) ResetClient() {
111	b.kmsClientLock.Lock()
112	b.resetClient()
113	b.kmsClientLock.Unlock()
114}
115
116// resetClient rests the underlying client. The caller is responsible for
117// acquiring and releasing locks. This method is not safe to call concurrently.
118func (b *backend) resetClient() {
119	if b.kmsClient != nil {
120		b.kmsClient.Close()
121		b.kmsClient = nil
122	}
123
124	b.kmsClientCreateTime = time.Unix(0, 0).UTC()
125}
126
127// KMSClient creates a new client for talking to the GCP KMS service.
128func (b *backend) KMSClient(s logical.Storage) (*kmsapi.KeyManagementClient, func(), error) {
129	// If the client already exists and is valid, return it
130	b.kmsClientLock.RLock()
131	if b.kmsClient != nil && time.Now().UTC().Sub(b.kmsClientCreateTime) < b.kmsClientLifetime {
132		closer := func() { b.kmsClientLock.RUnlock() }
133		return b.kmsClient, closer, nil
134	}
135	b.kmsClientLock.RUnlock()
136
137	// Acquire a full lock. Since all invocations acquire a read lock and defer
138	// the release of that lock, this will block until all clients are no longer
139	// in use. At that point, we can acquire a globally exclusive lock to close
140	// any connections and create a new client.
141	b.kmsClientLock.Lock()
142
143	b.Logger().Debug("creating new KMS client")
144
145	// Attempt to close an existing client if we have one.
146	b.resetClient()
147
148	// Get the config
149	config, err := b.Config(b.ctx, s)
150	if err != nil {
151		b.kmsClientLock.Unlock()
152		return nil, nil, err
153	}
154
155	// If credentials were provided, use those. Otherwise fall back to the
156	// default application credentials.
157	var creds *google.Credentials
158	if config.Credentials != "" {
159		creds, err = google.CredentialsFromJSON(b.ctx, []byte(config.Credentials), config.Scopes...)
160		if err != nil {
161			b.kmsClientLock.Unlock()
162			return nil, nil, errwrap.Wrapf("failed to parse credentials: {{err}}", err)
163		}
164	} else {
165		creds, err = google.FindDefaultCredentials(b.ctx, config.Scopes...)
166		if err != nil {
167			b.kmsClientLock.Unlock()
168			return nil, nil, errwrap.Wrapf("failed to get default token source: {{err}}", err)
169		}
170	}
171
172	// Create and return the KMS client with a custom user agent.
173	client, err := kmsapi.NewKeyManagementClient(b.ctx,
174		option.WithCredentials(creds),
175		option.WithScopes(config.Scopes...),
176		option.WithUserAgent(useragent.String()),
177	)
178	if err != nil {
179		b.kmsClientLock.Unlock()
180		return nil, nil, errwrap.Wrapf("failed to create KMS client: {{err}}", err)
181	}
182
183	// Cache the client
184	b.kmsClient = client
185	b.kmsClientCreateTime = time.Now().UTC()
186	b.kmsClientLock.Unlock()
187
188	b.kmsClientLock.RLock()
189	closer := func() { b.kmsClientLock.RUnlock() }
190	return client, closer, nil
191}
192
193// Config parses and returns the configuration data from the storage backend.
194// Even when no user-defined data exists in storage, a Config is returned with
195// the default values.
196func (b *backend) Config(ctx context.Context, s logical.Storage) (*Config, error) {
197	c := DefaultConfig()
198
199	entry, err := s.Get(ctx, "config")
200	if err != nil {
201		return nil, errwrap.Wrapf("failed to get configuration from storage: {{err}}", err)
202	}
203	if entry == nil || len(entry.Value) == 0 {
204		return c, nil
205	}
206
207	if err := entry.DecodeJSON(&c); err != nil {
208		return nil, errwrap.Wrapf("failed to decode configuration: {{err}}", err)
209	}
210	return c, nil
211}
212