1package azurekeyvault 2 3import ( 4 "context" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "os" 9 "strings" 10 "sync/atomic" 11 12 "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault" 13 "github.com/Azure/go-autorest/autorest" 14 "github.com/Azure/go-autorest/autorest/azure" 15 "github.com/Azure/go-autorest/autorest/azure/auth" 16 "github.com/Azure/go-autorest/autorest/to" 17 18 wrapping "github.com/hashicorp/go-kms-wrapping" 19) 20 21const ( 22 EnvAzureKeyVaultWrapperVaultName = "AZUREKEYVAULT_WRAPPER_VAULT_NAME" 23 EnvVaultAzureKeyVaultVaultName = "VAULT_AZUREKEYVAULT_VAULT_NAME" 24 25 EnvAzureKeyVaultWrapperKeyName = "AZUREKEYVAULT_WRAPPER_KEY_NAME" 26 EnvVaultAzureKeyVaultKeyName = "VAULT_AZUREKEYVAULT_KEY_NAME" 27) 28 29// Wrapper is an Wrapper that uses Azure Key Vault 30// for crypto operations. Azure Key Vault currently does not support 31// keys that can encrypt long data (RSA keys). Due to this fact, we generate 32// and AES key and wrap the key using Key Vault and store it with the 33// data 34type Wrapper struct { 35 tenantID string 36 clientID string 37 clientSecret string 38 vaultName string 39 keyName string 40 41 currentKeyID *atomic.Value 42 43 environment azure.Environment 44 client *keyvault.BaseClient 45} 46 47// Ensure that we are implementing Wrapper 48var _ wrapping.Wrapper = (*Wrapper)(nil) 49 50// NewWrapper creates a new wrapper with the given options 51func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper { 52 if opts == nil { 53 opts = new(wrapping.WrapperOptions) 54 } 55 v := &Wrapper{ 56 currentKeyID: new(atomic.Value), 57 } 58 v.currentKeyID.Store("") 59 return v 60} 61 62// SetConfig sets the fields on the Wrapper object based on 63// values from the config parameter. 64// 65// Order of precedence: 66// * Environment variable 67// * Value from Vault configuration file 68// * Managed Service Identity for instance 69func (v *Wrapper) SetConfig(config map[string]string) (map[string]string, error) { 70 if config == nil { 71 config = map[string]string{} 72 } 73 74 switch { 75 case os.Getenv("AZURE_TENANT_ID") != "": 76 v.tenantID = os.Getenv("AZURE_TENANT_ID") 77 case config["tenant_id"] != "": 78 v.tenantID = config["tenant_id"] 79 } 80 81 switch { 82 case os.Getenv("AZURE_CLIENT_ID") != "": 83 v.clientID = os.Getenv("AZURE_CLIENT_ID") 84 case config["client_id"] != "": 85 v.clientID = config["client_id"] 86 } 87 88 switch { 89 case os.Getenv("AZURE_CLIENT_SECRET") != "": 90 v.clientSecret = os.Getenv("AZURE_CLIENT_SECRET") 91 case config["client_secret"] != "": 92 v.clientSecret = config["client_secret"] 93 } 94 95 envName := os.Getenv("AZURE_ENVIRONMENT") 96 if envName == "" { 97 envName = config["environment"] 98 } 99 if envName == "" { 100 v.environment = azure.PublicCloud 101 } else { 102 var err error 103 v.environment, err = azure.EnvironmentFromName(envName) 104 if err != nil { 105 return nil, err 106 } 107 } 108 109 switch { 110 case os.Getenv(EnvAzureKeyVaultWrapperVaultName) != "": 111 v.vaultName = os.Getenv(EnvAzureKeyVaultWrapperVaultName) 112 case os.Getenv(EnvVaultAzureKeyVaultVaultName) != "": 113 v.vaultName = os.Getenv(EnvVaultAzureKeyVaultVaultName) 114 case config["vault_name"] != "": 115 v.vaultName = config["vault_name"] 116 default: 117 return nil, errors.New("vault name is required") 118 } 119 120 switch { 121 case os.Getenv(EnvAzureKeyVaultWrapperKeyName) != "": 122 v.keyName = os.Getenv(EnvAzureKeyVaultWrapperKeyName) 123 case os.Getenv(EnvVaultAzureKeyVaultKeyName) != "": 124 v.keyName = os.Getenv(EnvVaultAzureKeyVaultKeyName) 125 case config["key_name"] != "": 126 v.keyName = config["key_name"] 127 default: 128 return nil, errors.New("key name is required") 129 } 130 131 if v.client == nil { 132 client, err := v.getKeyVaultClient() 133 if err != nil { 134 return nil, fmt.Errorf("error initializing Azure Key Vault wrapper client: %w", err) 135 } 136 137 // Test the client connection using provided key ID 138 keyInfo, err := client.GetKey(context.Background(), v.buildBaseURL(), v.keyName, "") 139 if err != nil { 140 return nil, fmt.Errorf("error fetching Azure Key Vault wrapper key information: %w", err) 141 } 142 if keyInfo.Key == nil { 143 return nil, errors.New("no key information returned") 144 } 145 v.currentKeyID.Store(parseKeyVersion(to.String(keyInfo.Key.Kid))) 146 147 v.client = client 148 } 149 150 // Map that holds non-sensitive configuration info 151 wrapperInfo := make(map[string]string) 152 wrapperInfo["environment"] = v.environment.Name 153 wrapperInfo["vault_name"] = v.vaultName 154 wrapperInfo["key_name"] = v.keyName 155 156 return wrapperInfo, nil 157} 158 159// Init is called during core.Initialize. This is a no-op. 160func (v *Wrapper) Init(context.Context) error { 161 return nil 162} 163 164// Finalize is called during shutdown. This is a no-op. 165func (v *Wrapper) Finalize(context.Context) error { 166 return nil 167} 168 169// Type returns the type for this particular Wrapper implementation 170func (v *Wrapper) Type() string { 171 return wrapping.AzureKeyVault 172} 173 174// KeyID returns the last known key id 175func (v *Wrapper) KeyID() string { 176 return v.currentKeyID.Load().(string) 177} 178 179// HMACKeyID returns the last known HMAC key id 180func (v *Wrapper) HMACKeyID() string { 181 return "" 182} 183 184// Encrypt is used to encrypt using Azure Key Vault. 185// This returns the ciphertext, and/or any errors from this 186// call. 187func (v *Wrapper) Encrypt(ctx context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) { 188 if plaintext == nil { 189 return nil, errors.New("given plaintext for encryption is nil") 190 } 191 192 env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad) 193 if err != nil { 194 return nil, fmt.Errorf("error wrapping dat: %w", err) 195 } 196 197 // Encrypt the DEK using Key Vault 198 params := keyvault.KeyOperationsParameters{ 199 Algorithm: keyvault.RSAOAEP256, 200 Value: to.StringPtr(base64.URLEncoding.EncodeToString(env.Key)), 201 } 202 // Wrap key with the latest version for the key name 203 resp, err := v.client.WrapKey(ctx, v.buildBaseURL(), v.keyName, "", params) 204 if err != nil { 205 return nil, err 206 } 207 208 // Store the current key version 209 keyVersion := parseKeyVersion(to.String(resp.Kid)) 210 v.currentKeyID.Store(keyVersion) 211 212 ret := &wrapping.EncryptedBlobInfo{ 213 Ciphertext: env.Ciphertext, 214 IV: env.IV, 215 KeyInfo: &wrapping.KeyInfo{ 216 KeyID: keyVersion, 217 WrappedKey: []byte(to.String(resp.Result)), 218 }, 219 } 220 221 return ret, nil 222} 223 224// Decrypt is used to decrypt the ciphertext 225func (v *Wrapper) Decrypt(ctx context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) (pt []byte, err error) { 226 if in == nil { 227 return nil, errors.New("given input for decryption is nil") 228 } 229 230 if in.KeyInfo == nil { 231 return nil, errors.New("key info is nil") 232 } 233 234 // Unwrap the key 235 params := keyvault.KeyOperationsParameters{ 236 Algorithm: keyvault.RSAOAEP256, 237 Value: to.StringPtr(string(in.KeyInfo.WrappedKey)), 238 } 239 resp, err := v.client.UnwrapKey(ctx, v.buildBaseURL(), v.keyName, in.KeyInfo.KeyID, params) 240 if err != nil { 241 return nil, err 242 } 243 244 keyBytes, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(to.String(resp.Result)) 245 if err != nil { 246 return nil, err 247 } 248 envInfo := &wrapping.EnvelopeInfo{ 249 Key: keyBytes, 250 IV: in.IV, 251 Ciphertext: in.Ciphertext, 252 } 253 return wrapping.NewEnvelope(nil).Decrypt(envInfo, aad) 254} 255 256func (v *Wrapper) buildBaseURL() string { 257 return fmt.Sprintf("https://%s.%s/", v.vaultName, v.environment.KeyVaultDNSSuffix) 258} 259 260func (v *Wrapper) getKeyVaultClient() (*keyvault.BaseClient, error) { 261 var authorizer autorest.Authorizer 262 var err error 263 264 switch { 265 case v.clientID != "" && v.clientSecret != "": 266 config := auth.NewClientCredentialsConfig(v.clientID, v.clientSecret, v.tenantID) 267 config.AADEndpoint = v.environment.ActiveDirectoryEndpoint 268 config.Resource = strings.TrimSuffix(v.environment.KeyVaultEndpoint, "/") 269 authorizer, err = config.Authorizer() 270 if err != nil { 271 return nil, err 272 } 273 // By default use MSI 274 default: 275 config := auth.NewMSIConfig() 276 config.Resource = strings.TrimSuffix(v.environment.KeyVaultEndpoint, "/") 277 authorizer, err = config.Authorizer() 278 if err != nil { 279 return nil, err 280 } 281 } 282 283 client := keyvault.New() 284 client.Authorizer = authorizer 285 return &client, nil 286} 287 288// Kid gets returned as a full URL, get the last bit which is just 289// the version 290func parseKeyVersion(kid string) string { 291 keyVersionParts := strings.Split(kid, "/") 292 return keyVersionParts[len(keyVersionParts)-1] 293} 294