1package awskms 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "os" 8 "sync/atomic" 9 10 "github.com/aws/aws-sdk-go/aws" 11 "github.com/aws/aws-sdk-go/aws/session" 12 "github.com/aws/aws-sdk-go/service/kms" 13 "github.com/aws/aws-sdk-go/service/kms/kmsiface" 14 cleanhttp "github.com/hashicorp/go-cleanhttp" 15 wrapping "github.com/hashicorp/go-kms-wrapping" 16 "github.com/hashicorp/vault/sdk/helper/awsutil" 17) 18 19// These constants contain the accepted env vars; the Vault one is for backwards compat 20const ( 21 EnvAWSKMSWrapperKeyID = "AWSKMS_WRAPPER_KEY_ID" 22 EnvVaultAWSKMSSealKeyID = "VAULT_AWSKMS_SEAL_KEY_ID" 23) 24 25const ( 26 // AWSKMSEncrypt is used to directly encrypt the data with KMS 27 AWSKMSEncrypt = iota 28 // AWSKMSEnvelopeAESGCMEncrypt is when a data encryption key is generated and 29 // the data is encrypted with AESGCM and the key is encrypted with KMS 30 AWSKMSEnvelopeAESGCMEncrypt 31) 32 33// Wrapper represents credentials and Key information for the KMS Key used to 34// encryption and decryption 35type Wrapper struct { 36 accessKey string 37 secretKey string 38 sessionToken string 39 region string 40 keyID string 41 endpoint string 42 43 currentKeyID *atomic.Value 44 45 client kmsiface.KMSAPI 46} 47 48// Ensure that we are implementing Wrapper 49var _ wrapping.Wrapper = (*Wrapper)(nil) 50 51// NewWrapper creates a new AWSKMS wrapper with the provided options 52func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper { 53 if opts == nil { 54 opts = new(wrapping.WrapperOptions) 55 } 56 k := &Wrapper{ 57 currentKeyID: new(atomic.Value), 58 } 59 k.currentKeyID.Store("") 60 return k 61} 62 63// SetConfig sets the fields on the Wrapper object based on 64// values from the config parameter. 65// 66// Order of precedence AWS values: 67// * Environment variable 68// * Value from Vault configuration file 69// * Instance metadata role (access key and secret key) 70// * Default values 71func (k *Wrapper) SetConfig(config map[string]string) (map[string]string, error) { 72 if config == nil { 73 config = map[string]string{} 74 } 75 76 // Check and set KeyID 77 switch { 78 case os.Getenv(EnvAWSKMSWrapperKeyID) != "": 79 k.keyID = os.Getenv(EnvAWSKMSWrapperKeyID) 80 case os.Getenv(EnvVaultAWSKMSSealKeyID) != "": 81 k.keyID = os.Getenv(EnvVaultAWSKMSSealKeyID) 82 case config["kms_key_id"] != "": 83 k.keyID = config["kms_key_id"] 84 default: 85 return nil, fmt.Errorf("'kms_key_id' not found for AWS KMS wrapper configuration") 86 } 87 88 // Please see GetRegion for an explanation of the order in which region is parsed. 89 var err error 90 k.region, err = awsutil.GetRegion(config["region"]) 91 if err != nil { 92 return nil, err 93 } 94 95 // Check and set AWS access key, secret key, and session token 96 k.accessKey = config["access_key"] 97 k.secretKey = config["secret_key"] 98 k.sessionToken = config["session_token"] 99 100 k.endpoint = os.Getenv("AWS_KMS_ENDPOINT") 101 if k.endpoint == "" { 102 if endpoint, ok := config["endpoint"]; ok { 103 k.endpoint = endpoint 104 } 105 } 106 107 // Check and set k.client 108 if k.client == nil { 109 client, err := k.GetAWSKMSClient() 110 if err != nil { 111 return nil, fmt.Errorf("error initializing AWS KMS wrapping client: %w", err) 112 } 113 114 // Test the client connection using provided key ID 115 keyInfo, err := client.DescribeKey(&kms.DescribeKeyInput{ 116 KeyId: aws.String(k.keyID), 117 }) 118 if err != nil { 119 return nil, fmt.Errorf("error fetching AWS KMS wrapping key information: %w", err) 120 } 121 if keyInfo == nil || keyInfo.KeyMetadata == nil || keyInfo.KeyMetadata.KeyId == nil { 122 return nil, errors.New("no key information returned") 123 } 124 k.currentKeyID.Store(aws.StringValue(keyInfo.KeyMetadata.KeyId)) 125 126 k.client = client 127 } 128 129 // Map that holds non-sensitive configuration info 130 wrappingInfo := make(map[string]string) 131 wrappingInfo["region"] = k.region 132 wrappingInfo["kms_key_id"] = k.keyID 133 if k.endpoint != "" { 134 wrappingInfo["endpoint"] = k.endpoint 135 } 136 137 return wrappingInfo, nil 138} 139 140// Init is called during core.Initialize. No-op at the moment. 141func (k *Wrapper) Init(_ context.Context) error { 142 return nil 143} 144 145// Finalize is called during shutdown. This is a no-op since 146// Wrapper doesn't require any cleanup. 147func (k *Wrapper) Finalize(_ context.Context) error { 148 return nil 149} 150 151// Type returns the wrapping type for this particular Wrapper implementation 152func (k *Wrapper) Type() string { 153 return wrapping.AWSKMS 154} 155 156// KeyID returns the last known key id 157func (k *Wrapper) KeyID() string { 158 return k.currentKeyID.Load().(string) 159} 160 161// HMACKeyID returns the last known HMAC key id 162func (k *Wrapper) HMACKeyID() string { 163 return "" 164} 165 166// Encrypt is used to encrypt the master key using the the AWS CMK. 167// This returns the ciphertext, and/or any errors from this 168// call. This should be called after the KMS client has been instantiated. 169func (k *Wrapper) Encrypt(_ context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) { 170 if plaintext == nil { 171 return nil, fmt.Errorf("given plaintext for encryption is nil") 172 } 173 174 env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad) 175 if err != nil { 176 return nil, fmt.Errorf("error wrapping data: %w", err) 177 } 178 179 if k.client == nil { 180 return nil, fmt.Errorf("nil client") 181 } 182 183 input := &kms.EncryptInput{ 184 KeyId: aws.String(k.keyID), 185 Plaintext: env.Key, 186 } 187 output, err := k.client.Encrypt(input) 188 if err != nil { 189 return nil, fmt.Errorf("error encrypting data: %w", err) 190 } 191 192 // store the current key id 193 keyID := aws.StringValue(output.KeyId) 194 k.currentKeyID.Store(keyID) 195 196 ret := &wrapping.EncryptedBlobInfo{ 197 Ciphertext: env.Ciphertext, 198 IV: env.IV, 199 KeyInfo: &wrapping.KeyInfo{ 200 Mechanism: AWSKMSEnvelopeAESGCMEncrypt, 201 // Even though we do not use the key id during decryption, store it 202 // to know exactly the specific key used in encryption in case we 203 // want to rewrap older entries 204 KeyID: keyID, 205 WrappedKey: output.CiphertextBlob, 206 }, 207 } 208 209 return ret, nil 210} 211 212// Decrypt is used to decrypt the ciphertext. This should be called after Init. 213func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) (pt []byte, err error) { 214 if in == nil { 215 return nil, fmt.Errorf("given input for decryption is nil") 216 } 217 218 // Default to mechanism used before key info was stored 219 if in.KeyInfo == nil { 220 in.KeyInfo = &wrapping.KeyInfo{ 221 Mechanism: AWSKMSEncrypt, 222 } 223 } 224 225 var plaintext []byte 226 switch in.KeyInfo.Mechanism { 227 case AWSKMSEncrypt: 228 input := &kms.DecryptInput{ 229 CiphertextBlob: in.Ciphertext, 230 } 231 232 output, err := k.client.Decrypt(input) 233 if err != nil { 234 return nil, fmt.Errorf("error decrypting data: %w", err) 235 } 236 plaintext = output.Plaintext 237 238 case AWSKMSEnvelopeAESGCMEncrypt: 239 // KeyID is not passed to this call because AWS handles this 240 // internally based on the metadata stored with the encrypted data 241 input := &kms.DecryptInput{ 242 CiphertextBlob: in.KeyInfo.WrappedKey, 243 } 244 output, err := k.client.Decrypt(input) 245 if err != nil { 246 return nil, fmt.Errorf("error decrypting data encryption key: %w", err) 247 } 248 249 envInfo := &wrapping.EnvelopeInfo{ 250 Key: output.Plaintext, 251 IV: in.IV, 252 Ciphertext: in.Ciphertext, 253 } 254 plaintext, err = wrapping.NewEnvelope(nil).Decrypt(envInfo, aad) 255 if err != nil { 256 return nil, fmt.Errorf("error decrypting data: %w", err) 257 } 258 259 default: 260 return nil, fmt.Errorf("invalid mechanism: %d", in.KeyInfo.Mechanism) 261 } 262 263 return plaintext, nil 264} 265 266// GetAWSKMSClient returns an instance of the KMS client. 267func (k *Wrapper) GetAWSKMSClient() (*kms.KMS, error) { 268 credsConfig := &awsutil.CredentialsConfig{} 269 270 credsConfig.AccessKey = k.accessKey 271 credsConfig.SecretKey = k.secretKey 272 credsConfig.SessionToken = k.sessionToken 273 credsConfig.Region = k.region 274 275 credsConfig.HTTPClient = cleanhttp.DefaultClient() 276 277 creds, err := credsConfig.GenerateCredentialChain() 278 if err != nil { 279 return nil, err 280 } 281 282 awsConfig := &aws.Config{ 283 Credentials: creds, 284 Region: aws.String(credsConfig.Region), 285 HTTPClient: cleanhttp.DefaultClient(), 286 } 287 288 if k.endpoint != "" { 289 awsConfig.Endpoint = aws.String(k.endpoint) 290 } 291 292 sess, err := session.NewSession(awsConfig) 293 if err != nil { 294 return nil, err 295 } 296 297 client := kms.New(sess) 298 299 return client, nil 300} 301