1package awskms
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"os"
8	"sync/atomic"
9
10	"github.com/aws/aws-sdk-go/aws"
11	"github.com/aws/aws-sdk-go/aws/session"
12	"github.com/aws/aws-sdk-go/service/kms"
13	"github.com/aws/aws-sdk-go/service/kms/kmsiface"
14	cleanhttp "github.com/hashicorp/go-cleanhttp"
15	"github.com/hashicorp/go-hclog"
16	wrapping "github.com/hashicorp/go-kms-wrapping"
17	"github.com/hashicorp/vault/sdk/helper/awsutil"
18)
19
20// These constants contain the accepted env vars; the Vault one is for backwards compat
21const (
22	EnvAWSKMSWrapperKeyID   = "AWSKMS_WRAPPER_KEY_ID"
23	EnvVaultAWSKMSSealKeyID = "VAULT_AWSKMS_SEAL_KEY_ID"
24)
25
26const (
27	// AWSKMSEncrypt is used to directly encrypt the data with KMS
28	AWSKMSEncrypt = iota
29	// AWSKMSEnvelopeAESGCMEncrypt is when a data encryption key is generated and
30	// the data is encrypted with AESGCM and the key is encrypted with KMS
31	AWSKMSEnvelopeAESGCMEncrypt
32)
33
34// Wrapper represents credentials and Key information for the KMS Key used to
35// encryption and decryption
36type Wrapper struct {
37	accessKey    string
38	secretKey    string
39	sessionToken string
40	region       string
41	keyID        string
42	endpoint     string
43
44	currentKeyID *atomic.Value
45
46	client kmsiface.KMSAPI
47
48	logger hclog.Logger
49}
50
51// Ensure that we are implementing Wrapper
52var _ wrapping.Wrapper = (*Wrapper)(nil)
53
54// NewWrapper creates a new AWSKMS wrapper with the provided options
55func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper {
56	if opts == nil {
57		opts = new(wrapping.WrapperOptions)
58	}
59	k := &Wrapper{
60		currentKeyID: new(atomic.Value),
61		logger:       opts.Logger,
62	}
63	k.currentKeyID.Store("")
64	return k
65}
66
67// SetConfig sets the fields on the Wrapper object based on
68// values from the config parameter.
69//
70// Order of precedence AWS values:
71// * Environment variable
72// * Value from Vault configuration file
73// * Instance metadata role (access key and secret key)
74// * Default values
75func (k *Wrapper) SetConfig(config map[string]string) (map[string]string, error) {
76	if config == nil {
77		config = map[string]string{}
78	}
79
80	// Check and set KeyID
81	switch {
82	case os.Getenv(EnvAWSKMSWrapperKeyID) != "":
83		k.keyID = os.Getenv(EnvAWSKMSWrapperKeyID)
84	case os.Getenv(EnvVaultAWSKMSSealKeyID) != "":
85		k.keyID = os.Getenv(EnvVaultAWSKMSSealKeyID)
86	case config["kms_key_id"] != "":
87		k.keyID = config["kms_key_id"]
88	default:
89		return nil, fmt.Errorf("'kms_key_id' not found for AWS KMS wrapper configuration")
90	}
91
92	// Please see GetRegion for an explanation of the order in which region is parsed.
93	var err error
94	k.region, err = awsutil.GetRegion(config["region"])
95	if err != nil {
96		return nil, err
97	}
98
99	// Check and set AWS access key, secret key, and session token
100	k.accessKey = config["access_key"]
101	k.secretKey = config["secret_key"]
102	k.sessionToken = config["session_token"]
103
104	k.endpoint = os.Getenv("AWS_KMS_ENDPOINT")
105	if k.endpoint == "" {
106		if endpoint, ok := config["endpoint"]; ok {
107			k.endpoint = endpoint
108		}
109	}
110
111	// Check and set k.client
112	if k.client == nil {
113		client, err := k.GetAWSKMSClient()
114		if err != nil {
115			return nil, fmt.Errorf("error initializing AWS KMS wrapping client: %w", err)
116		}
117
118		// Test the client connection using provided key ID
119		keyInfo, err := client.DescribeKey(&kms.DescribeKeyInput{
120			KeyId: aws.String(k.keyID),
121		})
122		if err != nil {
123			return nil, fmt.Errorf("error fetching AWS KMS wrapping key information: %w", err)
124		}
125		if keyInfo == nil || keyInfo.KeyMetadata == nil || keyInfo.KeyMetadata.KeyId == nil {
126			return nil, errors.New("no key information returned")
127		}
128		k.currentKeyID.Store(aws.StringValue(keyInfo.KeyMetadata.KeyId))
129
130		k.client = client
131	}
132
133	// Map that holds non-sensitive configuration info
134	wrappingInfo := make(map[string]string)
135	wrappingInfo["region"] = k.region
136	wrappingInfo["kms_key_id"] = k.keyID
137	if k.endpoint != "" {
138		wrappingInfo["endpoint"] = k.endpoint
139	}
140
141	return wrappingInfo, nil
142}
143
144// Init is called during core.Initialize. No-op at the moment.
145func (k *Wrapper) Init(_ context.Context) error {
146	return nil
147}
148
149// Finalize is called during shutdown. This is a no-op since
150// Wrapper doesn't require any cleanup.
151func (k *Wrapper) Finalize(_ context.Context) error {
152	return nil
153}
154
155// Type returns the wrapping type for this particular Wrapper implementation
156func (k *Wrapper) Type() string {
157	return wrapping.AWSKMS
158}
159
160// KeyID returns the last known key id
161func (k *Wrapper) KeyID() string {
162	return k.currentKeyID.Load().(string)
163}
164
165// HMACKeyID returns the last known HMAC key id
166func (k *Wrapper) HMACKeyID() string {
167	return ""
168}
169
170// Encrypt is used to encrypt the master key using the the AWS CMK.
171// This returns the ciphertext, and/or any errors from this
172// call. This should be called after the KMS client has been instantiated.
173func (k *Wrapper) Encrypt(_ context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) {
174	if plaintext == nil {
175		return nil, fmt.Errorf("given plaintext for encryption is nil")
176	}
177
178	env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad)
179	if err != nil {
180		return nil, fmt.Errorf("error wrapping data: %w", err)
181	}
182
183	if k.client == nil {
184		return nil, fmt.Errorf("nil client")
185	}
186
187	input := &kms.EncryptInput{
188		KeyId:     aws.String(k.keyID),
189		Plaintext: env.Key,
190	}
191	output, err := k.client.Encrypt(input)
192	if err != nil {
193		return nil, fmt.Errorf("error encrypting data: %w", err)
194	}
195
196	// Store the current key id
197	//
198	// When using a key alias, this will return the actual underlying key id
199	// used for encryption.  This is helpful if you are looking to reencyrpt
200	// your data when it is not using the latest key id. See these docs relating
201	// to key rotation https://docs.aws.amazon.com/kms/latest/developerguide/rotate-keys.html
202	keyID := aws.StringValue(output.KeyId)
203	k.currentKeyID.Store(keyID)
204
205	ret := &wrapping.EncryptedBlobInfo{
206		Ciphertext: env.Ciphertext,
207		IV:         env.IV,
208		KeyInfo: &wrapping.KeyInfo{
209			Mechanism: AWSKMSEnvelopeAESGCMEncrypt,
210			// Even though we do not use the key id during decryption, store it
211			// to know exactly the specific key used in encryption in case we
212			// want to rewrap older entries
213			KeyID:      keyID,
214			WrappedKey: output.CiphertextBlob,
215		},
216	}
217
218	return ret, nil
219}
220
221// Decrypt is used to decrypt the ciphertext. This should be called after Init.
222func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) (pt []byte, err error) {
223	if in == nil {
224		return nil, fmt.Errorf("given input for decryption is nil")
225	}
226
227	// Default to mechanism used before key info was stored
228	if in.KeyInfo == nil {
229		in.KeyInfo = &wrapping.KeyInfo{
230			Mechanism: AWSKMSEncrypt,
231		}
232	}
233
234	var plaintext []byte
235	switch in.KeyInfo.Mechanism {
236	case AWSKMSEncrypt:
237		input := &kms.DecryptInput{
238			CiphertextBlob: in.Ciphertext,
239		}
240
241		output, err := k.client.Decrypt(input)
242		if err != nil {
243			return nil, fmt.Errorf("error decrypting data: %w", err)
244		}
245		plaintext = output.Plaintext
246
247	case AWSKMSEnvelopeAESGCMEncrypt:
248		// KeyID is not passed to this call because AWS handles this
249		// internally based on the metadata stored with the encrypted data
250		input := &kms.DecryptInput{
251			CiphertextBlob: in.KeyInfo.WrappedKey,
252		}
253		output, err := k.client.Decrypt(input)
254		if err != nil {
255			return nil, fmt.Errorf("error decrypting data encryption key: %w", err)
256		}
257
258		envInfo := &wrapping.EnvelopeInfo{
259			Key:        output.Plaintext,
260			IV:         in.IV,
261			Ciphertext: in.Ciphertext,
262		}
263		plaintext, err = wrapping.NewEnvelope(nil).Decrypt(envInfo, aad)
264		if err != nil {
265			return nil, fmt.Errorf("error decrypting data: %w", err)
266		}
267
268	default:
269		return nil, fmt.Errorf("invalid mechanism: %d", in.KeyInfo.Mechanism)
270	}
271
272	return plaintext, nil
273}
274
275// GetAWSKMSClient returns an instance of the KMS client.
276func (k *Wrapper) GetAWSKMSClient() (*kms.KMS, error) {
277	credsConfig := &awsutil.CredentialsConfig{}
278
279	credsConfig.AccessKey = k.accessKey
280	credsConfig.SecretKey = k.secretKey
281	credsConfig.SessionToken = k.sessionToken
282	credsConfig.Region = k.region
283	credsConfig.Logger = k.logger
284
285	credsConfig.HTTPClient = cleanhttp.DefaultClient()
286
287	creds, err := credsConfig.GenerateCredentialChain()
288	if err != nil {
289		return nil, err
290	}
291
292	awsConfig := &aws.Config{
293		Credentials: creds,
294		Region:      aws.String(credsConfig.Region),
295		HTTPClient:  cleanhttp.DefaultClient(),
296	}
297
298	if k.endpoint != "" {
299		awsConfig.Endpoint = aws.String(k.endpoint)
300	}
301
302	sess, err := session.NewSession(awsConfig)
303	if err != nil {
304		return nil, err
305	}
306
307	client := kms.New(sess)
308
309	return client, nil
310}
311