1package awskms 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "os" 8 "sync/atomic" 9 10 "github.com/aws/aws-sdk-go/aws" 11 "github.com/aws/aws-sdk-go/aws/session" 12 "github.com/aws/aws-sdk-go/service/kms" 13 "github.com/aws/aws-sdk-go/service/kms/kmsiface" 14 cleanhttp "github.com/hashicorp/go-cleanhttp" 15 "github.com/hashicorp/go-hclog" 16 wrapping "github.com/hashicorp/go-kms-wrapping" 17 "github.com/hashicorp/vault/sdk/helper/awsutil" 18) 19 20// These constants contain the accepted env vars; the Vault one is for backwards compat 21const ( 22 EnvAWSKMSWrapperKeyID = "AWSKMS_WRAPPER_KEY_ID" 23 EnvVaultAWSKMSSealKeyID = "VAULT_AWSKMS_SEAL_KEY_ID" 24) 25 26const ( 27 // AWSKMSEncrypt is used to directly encrypt the data with KMS 28 AWSKMSEncrypt = iota 29 // AWSKMSEnvelopeAESGCMEncrypt is when a data encryption key is generated and 30 // the data is encrypted with AESGCM and the key is encrypted with KMS 31 AWSKMSEnvelopeAESGCMEncrypt 32) 33 34// Wrapper represents credentials and Key information for the KMS Key used to 35// encryption and decryption 36type Wrapper struct { 37 accessKey string 38 secretKey string 39 sessionToken string 40 region string 41 keyID string 42 endpoint string 43 44 currentKeyID *atomic.Value 45 46 client kmsiface.KMSAPI 47 48 logger hclog.Logger 49} 50 51// Ensure that we are implementing Wrapper 52var _ wrapping.Wrapper = (*Wrapper)(nil) 53 54// NewWrapper creates a new AWSKMS wrapper with the provided options 55func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper { 56 if opts == nil { 57 opts = new(wrapping.WrapperOptions) 58 } 59 k := &Wrapper{ 60 currentKeyID: new(atomic.Value), 61 logger: opts.Logger, 62 } 63 k.currentKeyID.Store("") 64 return k 65} 66 67// SetConfig sets the fields on the Wrapper object based on 68// values from the config parameter. 69// 70// Order of precedence AWS values: 71// * Environment variable 72// * Value from Vault configuration file 73// * Instance metadata role (access key and secret key) 74// * Default values 75func (k *Wrapper) SetConfig(config map[string]string) (map[string]string, error) { 76 if config == nil { 77 config = map[string]string{} 78 } 79 80 // Check and set KeyID 81 switch { 82 case os.Getenv(EnvAWSKMSWrapperKeyID) != "": 83 k.keyID = os.Getenv(EnvAWSKMSWrapperKeyID) 84 case os.Getenv(EnvVaultAWSKMSSealKeyID) != "": 85 k.keyID = os.Getenv(EnvVaultAWSKMSSealKeyID) 86 case config["kms_key_id"] != "": 87 k.keyID = config["kms_key_id"] 88 default: 89 return nil, fmt.Errorf("'kms_key_id' not found for AWS KMS wrapper configuration") 90 } 91 92 // Please see GetRegion for an explanation of the order in which region is parsed. 93 var err error 94 k.region, err = awsutil.GetRegion(config["region"]) 95 if err != nil { 96 return nil, err 97 } 98 99 // Check and set AWS access key, secret key, and session token 100 k.accessKey = config["access_key"] 101 k.secretKey = config["secret_key"] 102 k.sessionToken = config["session_token"] 103 104 k.endpoint = os.Getenv("AWS_KMS_ENDPOINT") 105 if k.endpoint == "" { 106 if endpoint, ok := config["endpoint"]; ok { 107 k.endpoint = endpoint 108 } 109 } 110 111 // Check and set k.client 112 if k.client == nil { 113 client, err := k.GetAWSKMSClient() 114 if err != nil { 115 return nil, fmt.Errorf("error initializing AWS KMS wrapping client: %w", err) 116 } 117 118 // Test the client connection using provided key ID 119 keyInfo, err := client.DescribeKey(&kms.DescribeKeyInput{ 120 KeyId: aws.String(k.keyID), 121 }) 122 if err != nil { 123 return nil, fmt.Errorf("error fetching AWS KMS wrapping key information: %w", err) 124 } 125 if keyInfo == nil || keyInfo.KeyMetadata == nil || keyInfo.KeyMetadata.KeyId == nil { 126 return nil, errors.New("no key information returned") 127 } 128 k.currentKeyID.Store(aws.StringValue(keyInfo.KeyMetadata.KeyId)) 129 130 k.client = client 131 } 132 133 // Map that holds non-sensitive configuration info 134 wrappingInfo := make(map[string]string) 135 wrappingInfo["region"] = k.region 136 wrappingInfo["kms_key_id"] = k.keyID 137 if k.endpoint != "" { 138 wrappingInfo["endpoint"] = k.endpoint 139 } 140 141 return wrappingInfo, nil 142} 143 144// Init is called during core.Initialize. No-op at the moment. 145func (k *Wrapper) Init(_ context.Context) error { 146 return nil 147} 148 149// Finalize is called during shutdown. This is a no-op since 150// Wrapper doesn't require any cleanup. 151func (k *Wrapper) Finalize(_ context.Context) error { 152 return nil 153} 154 155// Type returns the wrapping type for this particular Wrapper implementation 156func (k *Wrapper) Type() string { 157 return wrapping.AWSKMS 158} 159 160// KeyID returns the last known key id 161func (k *Wrapper) KeyID() string { 162 return k.currentKeyID.Load().(string) 163} 164 165// HMACKeyID returns the last known HMAC key id 166func (k *Wrapper) HMACKeyID() string { 167 return "" 168} 169 170// Encrypt is used to encrypt the master key using the the AWS CMK. 171// This returns the ciphertext, and/or any errors from this 172// call. This should be called after the KMS client has been instantiated. 173func (k *Wrapper) Encrypt(_ context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) { 174 if plaintext == nil { 175 return nil, fmt.Errorf("given plaintext for encryption is nil") 176 } 177 178 env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad) 179 if err != nil { 180 return nil, fmt.Errorf("error wrapping data: %w", err) 181 } 182 183 if k.client == nil { 184 return nil, fmt.Errorf("nil client") 185 } 186 187 input := &kms.EncryptInput{ 188 KeyId: aws.String(k.keyID), 189 Plaintext: env.Key, 190 } 191 output, err := k.client.Encrypt(input) 192 if err != nil { 193 return nil, fmt.Errorf("error encrypting data: %w", err) 194 } 195 196 // Store the current key id 197 // 198 // When using a key alias, this will return the actual underlying key id 199 // used for encryption. This is helpful if you are looking to reencyrpt 200 // your data when it is not using the latest key id. See these docs relating 201 // to key rotation https://docs.aws.amazon.com/kms/latest/developerguide/rotate-keys.html 202 keyID := aws.StringValue(output.KeyId) 203 k.currentKeyID.Store(keyID) 204 205 ret := &wrapping.EncryptedBlobInfo{ 206 Ciphertext: env.Ciphertext, 207 IV: env.IV, 208 KeyInfo: &wrapping.KeyInfo{ 209 Mechanism: AWSKMSEnvelopeAESGCMEncrypt, 210 // Even though we do not use the key id during decryption, store it 211 // to know exactly the specific key used in encryption in case we 212 // want to rewrap older entries 213 KeyID: keyID, 214 WrappedKey: output.CiphertextBlob, 215 }, 216 } 217 218 return ret, nil 219} 220 221// Decrypt is used to decrypt the ciphertext. This should be called after Init. 222func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) (pt []byte, err error) { 223 if in == nil { 224 return nil, fmt.Errorf("given input for decryption is nil") 225 } 226 227 // Default to mechanism used before key info was stored 228 if in.KeyInfo == nil { 229 in.KeyInfo = &wrapping.KeyInfo{ 230 Mechanism: AWSKMSEncrypt, 231 } 232 } 233 234 var plaintext []byte 235 switch in.KeyInfo.Mechanism { 236 case AWSKMSEncrypt: 237 input := &kms.DecryptInput{ 238 CiphertextBlob: in.Ciphertext, 239 } 240 241 output, err := k.client.Decrypt(input) 242 if err != nil { 243 return nil, fmt.Errorf("error decrypting data: %w", err) 244 } 245 plaintext = output.Plaintext 246 247 case AWSKMSEnvelopeAESGCMEncrypt: 248 // KeyID is not passed to this call because AWS handles this 249 // internally based on the metadata stored with the encrypted data 250 input := &kms.DecryptInput{ 251 CiphertextBlob: in.KeyInfo.WrappedKey, 252 } 253 output, err := k.client.Decrypt(input) 254 if err != nil { 255 return nil, fmt.Errorf("error decrypting data encryption key: %w", err) 256 } 257 258 envInfo := &wrapping.EnvelopeInfo{ 259 Key: output.Plaintext, 260 IV: in.IV, 261 Ciphertext: in.Ciphertext, 262 } 263 plaintext, err = wrapping.NewEnvelope(nil).Decrypt(envInfo, aad) 264 if err != nil { 265 return nil, fmt.Errorf("error decrypting data: %w", err) 266 } 267 268 default: 269 return nil, fmt.Errorf("invalid mechanism: %d", in.KeyInfo.Mechanism) 270 } 271 272 return plaintext, nil 273} 274 275// GetAWSKMSClient returns an instance of the KMS client. 276func (k *Wrapper) GetAWSKMSClient() (*kms.KMS, error) { 277 credsConfig := &awsutil.CredentialsConfig{} 278 279 credsConfig.AccessKey = k.accessKey 280 credsConfig.SecretKey = k.secretKey 281 credsConfig.SessionToken = k.sessionToken 282 credsConfig.Region = k.region 283 credsConfig.Logger = k.logger 284 285 credsConfig.HTTPClient = cleanhttp.DefaultClient() 286 287 creds, err := credsConfig.GenerateCredentialChain() 288 if err != nil { 289 return nil, err 290 } 291 292 awsConfig := &aws.Config{ 293 Credentials: creds, 294 Region: aws.String(credsConfig.Region), 295 HTTPClient: cleanhttp.DefaultClient(), 296 } 297 298 if k.endpoint != "" { 299 awsConfig.Endpoint = aws.String(k.endpoint) 300 } 301 302 sess, err := session.NewSession(awsConfig) 303 if err != nil { 304 return nil, err 305 } 306 307 client := kms.New(sess) 308 309 return client, nil 310} 311