1package awskms
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"os"
8	"sync/atomic"
9
10	"github.com/aws/aws-sdk-go/aws"
11	"github.com/aws/aws-sdk-go/aws/session"
12	"github.com/aws/aws-sdk-go/service/kms"
13	"github.com/aws/aws-sdk-go/service/kms/kmsiface"
14	cleanhttp "github.com/hashicorp/go-cleanhttp"
15	wrapping "github.com/hashicorp/go-kms-wrapping"
16	"github.com/hashicorp/vault/sdk/helper/awsutil"
17)
18
19// These constants contain the accepted env vars; the Vault one is for backwards compat
20const (
21	EnvAWSKMSWrapperKeyID   = "AWSKMS_WRAPPER_KEY_ID"
22	EnvVaultAWSKMSSealKeyID = "VAULT_AWSKMS_SEAL_KEY_ID"
23)
24
25const (
26	// AWSKMSEncrypt is used to directly encrypt the data with KMS
27	AWSKMSEncrypt = iota
28	// AWSKMSEnvelopeAESGCMEncrypt is when a data encryption key is generated and
29	// the data is encrypted with AESGCM and the key is encrypted with KMS
30	AWSKMSEnvelopeAESGCMEncrypt
31)
32
33// Wrapper represents credentials and Key information for the KMS Key used to
34// encryption and decryption
35type Wrapper struct {
36	accessKey    string
37	secretKey    string
38	sessionToken string
39	region       string
40	keyID        string
41	endpoint     string
42
43	currentKeyID *atomic.Value
44
45	client kmsiface.KMSAPI
46}
47
48// Ensure that we are implementing Wrapper
49var _ wrapping.Wrapper = (*Wrapper)(nil)
50
51// NewWrapper creates a new AWSKMS wrapper with the provided options
52func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper {
53	if opts == nil {
54		opts = new(wrapping.WrapperOptions)
55	}
56	k := &Wrapper{
57		currentKeyID: new(atomic.Value),
58	}
59	k.currentKeyID.Store("")
60	return k
61}
62
63// SetConfig sets the fields on the Wrapper object based on
64// values from the config parameter.
65//
66// Order of precedence AWS values:
67// * Environment variable
68// * Value from Vault configuration file
69// * Instance metadata role (access key and secret key)
70// * Default values
71func (k *Wrapper) SetConfig(config map[string]string) (map[string]string, error) {
72	if config == nil {
73		config = map[string]string{}
74	}
75
76	// Check and set KeyID
77	switch {
78	case os.Getenv(EnvAWSKMSWrapperKeyID) != "":
79		k.keyID = os.Getenv(EnvAWSKMSWrapperKeyID)
80	case os.Getenv(EnvVaultAWSKMSSealKeyID) != "":
81		k.keyID = os.Getenv(EnvVaultAWSKMSSealKeyID)
82	case config["kms_key_id"] != "":
83		k.keyID = config["kms_key_id"]
84	default:
85		return nil, fmt.Errorf("'kms_key_id' not found for AWS KMS wrapper configuration")
86	}
87
88	// Please see GetRegion for an explanation of the order in which region is parsed.
89	var err error
90	k.region, err = awsutil.GetRegion(config["region"])
91	if err != nil {
92		return nil, err
93	}
94
95	// Check and set AWS access key, secret key, and session token
96	k.accessKey = config["access_key"]
97	k.secretKey = config["secret_key"]
98	k.sessionToken = config["session_token"]
99
100	k.endpoint = os.Getenv("AWS_KMS_ENDPOINT")
101	if k.endpoint == "" {
102		if endpoint, ok := config["endpoint"]; ok {
103			k.endpoint = endpoint
104		}
105	}
106
107	// Check and set k.client
108	if k.client == nil {
109		client, err := k.GetAWSKMSClient()
110		if err != nil {
111			return nil, fmt.Errorf("error initializing AWS KMS wrapping client: %w", err)
112		}
113
114		// Test the client connection using provided key ID
115		keyInfo, err := client.DescribeKey(&kms.DescribeKeyInput{
116			KeyId: aws.String(k.keyID),
117		})
118		if err != nil {
119			return nil, fmt.Errorf("error fetching AWS KMS wrapping key information: %w", err)
120		}
121		if keyInfo == nil || keyInfo.KeyMetadata == nil || keyInfo.KeyMetadata.KeyId == nil {
122			return nil, errors.New("no key information returned")
123		}
124		k.currentKeyID.Store(aws.StringValue(keyInfo.KeyMetadata.KeyId))
125
126		k.client = client
127	}
128
129	// Map that holds non-sensitive configuration info
130	wrappingInfo := make(map[string]string)
131	wrappingInfo["region"] = k.region
132	wrappingInfo["kms_key_id"] = k.keyID
133	if k.endpoint != "" {
134		wrappingInfo["endpoint"] = k.endpoint
135	}
136
137	return wrappingInfo, nil
138}
139
140// Init is called during core.Initialize. No-op at the moment.
141func (k *Wrapper) Init(_ context.Context) error {
142	return nil
143}
144
145// Finalize is called during shutdown. This is a no-op since
146// Wrapper doesn't require any cleanup.
147func (k *Wrapper) Finalize(_ context.Context) error {
148	return nil
149}
150
151// Type returns the wrapping type for this particular Wrapper implementation
152func (k *Wrapper) Type() string {
153	return wrapping.AWSKMS
154}
155
156// KeyID returns the last known key id
157func (k *Wrapper) KeyID() string {
158	return k.currentKeyID.Load().(string)
159}
160
161// HMACKeyID returns the last known HMAC key id
162func (k *Wrapper) HMACKeyID() string {
163	return ""
164}
165
166// Encrypt is used to encrypt the master key using the the AWS CMK.
167// This returns the ciphertext, and/or any errors from this
168// call. This should be called after the KMS client has been instantiated.
169func (k *Wrapper) Encrypt(_ context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) {
170	if plaintext == nil {
171		return nil, fmt.Errorf("given plaintext for encryption is nil")
172	}
173
174	env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad)
175	if err != nil {
176		return nil, fmt.Errorf("error wrapping data: %w", err)
177	}
178
179	if k.client == nil {
180		return nil, fmt.Errorf("nil client")
181	}
182
183	input := &kms.EncryptInput{
184		KeyId:     aws.String(k.keyID),
185		Plaintext: env.Key,
186	}
187	output, err := k.client.Encrypt(input)
188	if err != nil {
189		return nil, fmt.Errorf("error encrypting data: %w", err)
190	}
191
192	// store the current key id
193	keyID := aws.StringValue(output.KeyId)
194	k.currentKeyID.Store(keyID)
195
196	ret := &wrapping.EncryptedBlobInfo{
197		Ciphertext: env.Ciphertext,
198		IV:         env.IV,
199		KeyInfo: &wrapping.KeyInfo{
200			Mechanism: AWSKMSEnvelopeAESGCMEncrypt,
201			// Even though we do not use the key id during decryption, store it
202			// to know exactly the specific key used in encryption in case we
203			// want to rewrap older entries
204			KeyID:      keyID,
205			WrappedKey: output.CiphertextBlob,
206		},
207	}
208
209	return ret, nil
210}
211
212// Decrypt is used to decrypt the ciphertext. This should be called after Init.
213func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) (pt []byte, err error) {
214	if in == nil {
215		return nil, fmt.Errorf("given input for decryption is nil")
216	}
217
218	// Default to mechanism used before key info was stored
219	if in.KeyInfo == nil {
220		in.KeyInfo = &wrapping.KeyInfo{
221			Mechanism: AWSKMSEncrypt,
222		}
223	}
224
225	var plaintext []byte
226	switch in.KeyInfo.Mechanism {
227	case AWSKMSEncrypt:
228		input := &kms.DecryptInput{
229			CiphertextBlob: in.Ciphertext,
230		}
231
232		output, err := k.client.Decrypt(input)
233		if err != nil {
234			return nil, fmt.Errorf("error decrypting data: %w", err)
235		}
236		plaintext = output.Plaintext
237
238	case AWSKMSEnvelopeAESGCMEncrypt:
239		// KeyID is not passed to this call because AWS handles this
240		// internally based on the metadata stored with the encrypted data
241		input := &kms.DecryptInput{
242			CiphertextBlob: in.KeyInfo.WrappedKey,
243		}
244		output, err := k.client.Decrypt(input)
245		if err != nil {
246			return nil, fmt.Errorf("error decrypting data encryption key: %w", err)
247		}
248
249		envInfo := &wrapping.EnvelopeInfo{
250			Key:        output.Plaintext,
251			IV:         in.IV,
252			Ciphertext: in.Ciphertext,
253		}
254		plaintext, err = wrapping.NewEnvelope(nil).Decrypt(envInfo, aad)
255		if err != nil {
256			return nil, fmt.Errorf("error decrypting data: %w", err)
257		}
258
259	default:
260		return nil, fmt.Errorf("invalid mechanism: %d", in.KeyInfo.Mechanism)
261	}
262
263	return plaintext, nil
264}
265
266// GetAWSKMSClient returns an instance of the KMS client.
267func (k *Wrapper) GetAWSKMSClient() (*kms.KMS, error) {
268	credsConfig := &awsutil.CredentialsConfig{}
269
270	credsConfig.AccessKey = k.accessKey
271	credsConfig.SecretKey = k.secretKey
272	credsConfig.SessionToken = k.sessionToken
273	credsConfig.Region = k.region
274
275	credsConfig.HTTPClient = cleanhttp.DefaultClient()
276
277	creds, err := credsConfig.GenerateCredentialChain()
278	if err != nil {
279		return nil, err
280	}
281
282	awsConfig := &aws.Config{
283		Credentials: creds,
284		Region:      aws.String(credsConfig.Region),
285		HTTPClient:  cleanhttp.DefaultClient(),
286	}
287
288	if k.endpoint != "" {
289		awsConfig.Endpoint = aws.String(k.endpoint)
290	}
291
292	sess, err := session.NewSession(awsConfig)
293	if err != nil {
294		return nil, err
295	}
296
297	client := kms.New(sess)
298
299	return client, nil
300}
301