1package azurekeyvault
2
3import (
4	"context"
5	"encoding/base64"
6	"errors"
7	"fmt"
8	"os"
9	"strings"
10	"sync/atomic"
11
12	"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault"
13	"github.com/Azure/go-autorest/autorest"
14	"github.com/Azure/go-autorest/autorest/azure"
15	"github.com/Azure/go-autorest/autorest/azure/auth"
16	"github.com/Azure/go-autorest/autorest/to"
17
18	wrapping "github.com/hashicorp/go-kms-wrapping"
19)
20
21const (
22	EnvAzureKeyVaultWrapperVaultName = "AZUREKEYVAULT_WRAPPER_VAULT_NAME"
23	EnvVaultAzureKeyVaultVaultName   = "VAULT_AZUREKEYVAULT_VAULT_NAME"
24
25	EnvAzureKeyVaultWrapperKeyName = "AZUREKEYVAULT_WRAPPER_KEY_NAME"
26	EnvVaultAzureKeyVaultKeyName   = "VAULT_AZUREKEYVAULT_KEY_NAME"
27)
28
29// Wrapper is an Wrapper that uses Azure Key Vault
30// for crypto operations.  Azure Key Vault currently does not support
31// keys that can encrypt long data (RSA keys).  Due to this fact, we generate
32// and AES key and wrap the key using Key Vault and store it with the
33// data
34type Wrapper struct {
35	tenantID     string
36	clientID     string
37	clientSecret string
38	vaultName    string
39	keyName      string
40
41	currentKeyID *atomic.Value
42
43	environment azure.Environment
44	client      *keyvault.BaseClient
45}
46
47// Ensure that we are implementing Wrapper
48var _ wrapping.Wrapper = (*Wrapper)(nil)
49
50// NewWrapper creates a new wrapper with the given options
51func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper {
52	if opts == nil {
53		opts = new(wrapping.WrapperOptions)
54	}
55	v := &Wrapper{
56		currentKeyID: new(atomic.Value),
57	}
58	v.currentKeyID.Store("")
59	return v
60}
61
62// SetConfig sets the fields on the Wrapper object based on
63// values from the config parameter.
64//
65// Order of precedence:
66// * Environment variable
67// * Value from Vault configuration file
68// * Managed Service Identity for instance
69func (v *Wrapper) SetConfig(config map[string]string) (map[string]string, error) {
70	if config == nil {
71		config = map[string]string{}
72	}
73
74	switch {
75	case os.Getenv("AZURE_TENANT_ID") != "":
76		v.tenantID = os.Getenv("AZURE_TENANT_ID")
77	case config["tenant_id"] != "":
78		v.tenantID = config["tenant_id"]
79	}
80
81	switch {
82	case os.Getenv("AZURE_CLIENT_ID") != "":
83		v.clientID = os.Getenv("AZURE_CLIENT_ID")
84	case config["client_id"] != "":
85		v.clientID = config["client_id"]
86	}
87
88	switch {
89	case os.Getenv("AZURE_CLIENT_SECRET") != "":
90		v.clientSecret = os.Getenv("AZURE_CLIENT_SECRET")
91	case config["client_secret"] != "":
92		v.clientSecret = config["client_secret"]
93	}
94
95	envName := os.Getenv("AZURE_ENVIRONMENT")
96	if envName == "" {
97		envName = config["environment"]
98	}
99	if envName == "" {
100		v.environment = azure.PublicCloud
101	} else {
102		var err error
103		v.environment, err = azure.EnvironmentFromName(envName)
104		if err != nil {
105			return nil, err
106		}
107	}
108
109	switch {
110	case os.Getenv(EnvAzureKeyVaultWrapperVaultName) != "":
111		v.vaultName = os.Getenv(EnvAzureKeyVaultWrapperVaultName)
112	case os.Getenv(EnvVaultAzureKeyVaultVaultName) != "":
113		v.vaultName = os.Getenv(EnvVaultAzureKeyVaultVaultName)
114	case config["vault_name"] != "":
115		v.vaultName = config["vault_name"]
116	default:
117		return nil, errors.New("vault name is required")
118	}
119
120	switch {
121	case os.Getenv(EnvAzureKeyVaultWrapperKeyName) != "":
122		v.keyName = os.Getenv(EnvAzureKeyVaultWrapperKeyName)
123	case os.Getenv(EnvVaultAzureKeyVaultKeyName) != "":
124		v.keyName = os.Getenv(EnvVaultAzureKeyVaultKeyName)
125	case config["key_name"] != "":
126		v.keyName = config["key_name"]
127	default:
128		return nil, errors.New("key name is required")
129	}
130
131	if v.client == nil {
132		client, err := v.getKeyVaultClient()
133		if err != nil {
134			return nil, fmt.Errorf("error initializing Azure Key Vault wrapper client: %w", err)
135		}
136
137		// Test the client connection using provided key ID
138		keyInfo, err := client.GetKey(context.Background(), v.buildBaseURL(), v.keyName, "")
139		if err != nil {
140			return nil, fmt.Errorf("error fetching Azure Key Vault wrapper key information: %w", err)
141		}
142		if keyInfo.Key == nil {
143			return nil, errors.New("no key information returned")
144		}
145		v.currentKeyID.Store(parseKeyVersion(to.String(keyInfo.Key.Kid)))
146
147		v.client = client
148	}
149
150	// Map that holds non-sensitive configuration info
151	wrapperInfo := make(map[string]string)
152	wrapperInfo["environment"] = v.environment.Name
153	wrapperInfo["vault_name"] = v.vaultName
154	wrapperInfo["key_name"] = v.keyName
155
156	return wrapperInfo, nil
157}
158
159// Init is called during core.Initialize.  This is a no-op.
160func (v *Wrapper) Init(context.Context) error {
161	return nil
162}
163
164// Finalize is called during shutdown. This is a no-op.
165func (v *Wrapper) Finalize(context.Context) error {
166	return nil
167}
168
169// Type returns the type for this particular Wrapper implementation
170func (v *Wrapper) Type() string {
171	return wrapping.AzureKeyVault
172}
173
174// KeyID returns the last known key id
175func (v *Wrapper) KeyID() string {
176	return v.currentKeyID.Load().(string)
177}
178
179// HMACKeyID returns the last known HMAC key id
180func (v *Wrapper) HMACKeyID() string {
181	return ""
182}
183
184// Encrypt is used to encrypt using Azure Key Vault.
185// This returns the ciphertext, and/or any errors from this
186// call.
187func (v *Wrapper) Encrypt(ctx context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) {
188	if plaintext == nil {
189		return nil, errors.New("given plaintext for encryption is nil")
190	}
191
192	env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad)
193	if err != nil {
194		return nil, fmt.Errorf("error wrapping dat: %w", err)
195	}
196
197	// Encrypt the DEK using Key Vault
198	params := keyvault.KeyOperationsParameters{
199		Algorithm: keyvault.RSAOAEP256,
200		Value:     to.StringPtr(base64.URLEncoding.EncodeToString(env.Key)),
201	}
202	// Wrap key with the latest version for the key name
203	resp, err := v.client.WrapKey(ctx, v.buildBaseURL(), v.keyName, "", params)
204	if err != nil {
205		return nil, err
206	}
207
208	// Store the current key version
209	keyVersion := parseKeyVersion(to.String(resp.Kid))
210	v.currentKeyID.Store(keyVersion)
211
212	ret := &wrapping.EncryptedBlobInfo{
213		Ciphertext: env.Ciphertext,
214		IV:         env.IV,
215		KeyInfo: &wrapping.KeyInfo{
216			KeyID:      keyVersion,
217			WrappedKey: []byte(to.String(resp.Result)),
218		},
219	}
220
221	return ret, nil
222}
223
224// Decrypt is used to decrypt the ciphertext
225func (v *Wrapper) Decrypt(ctx context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) (pt []byte, err error) {
226	if in == nil {
227		return nil, errors.New("given input for decryption is nil")
228	}
229
230	if in.KeyInfo == nil {
231		return nil, errors.New("key info is nil")
232	}
233
234	// Unwrap the key
235	params := keyvault.KeyOperationsParameters{
236		Algorithm: keyvault.RSAOAEP256,
237		Value:     to.StringPtr(string(in.KeyInfo.WrappedKey)),
238	}
239	resp, err := v.client.UnwrapKey(ctx, v.buildBaseURL(), v.keyName, in.KeyInfo.KeyID, params)
240	if err != nil {
241		return nil, err
242	}
243
244	keyBytes, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(to.String(resp.Result))
245	if err != nil {
246		return nil, err
247	}
248	envInfo := &wrapping.EnvelopeInfo{
249		Key:        keyBytes,
250		IV:         in.IV,
251		Ciphertext: in.Ciphertext,
252	}
253	return wrapping.NewEnvelope(nil).Decrypt(envInfo, aad)
254}
255
256func (v *Wrapper) buildBaseURL() string {
257	return fmt.Sprintf("https://%s.%s/", v.vaultName, v.environment.KeyVaultDNSSuffix)
258}
259
260func (v *Wrapper) getKeyVaultClient() (*keyvault.BaseClient, error) {
261	var authorizer autorest.Authorizer
262	var err error
263
264	switch {
265	case v.clientID != "" && v.clientSecret != "":
266		config := auth.NewClientCredentialsConfig(v.clientID, v.clientSecret, v.tenantID)
267		config.AADEndpoint = v.environment.ActiveDirectoryEndpoint
268		config.Resource = strings.TrimSuffix(v.environment.KeyVaultEndpoint, "/")
269		authorizer, err = config.Authorizer()
270		if err != nil {
271			return nil, err
272		}
273	// By default use MSI
274	default:
275		config := auth.NewMSIConfig()
276		config.Resource = strings.TrimSuffix(v.environment.KeyVaultEndpoint, "/")
277		authorizer, err = config.Authorizer()
278		if err != nil {
279			return nil, err
280		}
281	}
282
283	client := keyvault.New()
284	client.Authorizer = authorizer
285	return &client, nil
286}
287
288// Kid gets returned as a full URL, get the last bit which is just
289// the version
290func parseKeyVersion(kid string) string {
291	keyVersionParts := strings.Split(kid, "/")
292	return keyVersionParts[len(keyVersionParts)-1]
293}
294