1// Copyright © 2019, Oracle and/or its affiliates.
2package ocikms
3
4import (
5	"context"
6	"encoding/base64"
7	"errors"
8	"fmt"
9	"math"
10	"os"
11	"strconv"
12	"sync/atomic"
13	"time"
14
15	wrapping "github.com/hashicorp/go-kms-wrapping"
16	"github.com/oracle/oci-go-sdk/common"
17	"github.com/oracle/oci-go-sdk/common/auth"
18	"github.com/oracle/oci-go-sdk/keymanagement"
19)
20
21const (
22	// OCI KMS key ID to use for encryption and decryption
23	EnvOCIKMSWrapperKeyID   = "OCIKMS_WRAPPER_KEY_ID"
24	EnvVaultOCIKMSSealKeyID = "VAULT_OCIKMS_SEAL_KEY_ID"
25	// OCI KMS crypto endpoint to use for encryption and decryption
26	EnvOCIKMSWrapperCryptoEndpoint   = "OCIKMS_WRAPPER_CRYPTO_ENDPOINT"
27	EnvVaultOCIKMSSealCryptoEndpoint = "VAULT_OCIKMS_CRYPTO_ENDPOINT"
28	// OCI KMS management endpoint to manage keys
29	EnvOCIKMSWrapperManagementEndpoint   = "OCIKMS_WRAPPER_MANAGEMENT_ENDPOINT"
30	EnvVaultOCIKMSSealManagementEndpoint = "VAULT_OCIKMS_MANAGEMENT_ENDPOINT"
31	// Maximum number of retries
32	KMSMaximumNumberOfRetries = 5
33	// keyID config
34	KMSConfigKeyID = "key_id"
35	// cryptoEndpoint config
36	KMSConfigCryptoEndpoint = "crypto_endpoint"
37	// managementEndpoint config
38	KMSConfigManagementEndpoint = "management_endpoint"
39	// authTypeAPIKey config
40	KMSConfigAuthTypeAPIKey = "auth_type_api_key"
41)
42
43type Wrapper struct {
44	authTypeAPIKey bool   // true for user principal, false for instance principal, default is false
45	keyID          string // OCI KMS keyID
46
47	cryptoEndpoint     string // OCI KMS crypto endpoint
48	managementEndpoint string // OCI KMS management endpoint
49
50	cryptoClient     *keymanagement.KmsCryptoClient     // OCI KMS crypto client
51	managementClient *keymanagement.KmsManagementClient // OCI KMS management client
52
53	currentKeyID *atomic.Value // Current key version which is used for encryption/decryption
54}
55
56var _ wrapping.Wrapper = (*Wrapper)(nil)
57
58// NewWrapper creates a new Wrapper seal with the provided logger
59func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper {
60	if opts == nil {
61		opts = new(wrapping.WrapperOptions)
62	}
63	k := &Wrapper{
64		currentKeyID: new(atomic.Value),
65	}
66	k.currentKeyID.Store("")
67	return k
68}
69
70func (k *Wrapper) SetConfig(config map[string]string) (map[string]string, error) {
71	if config == nil {
72		config = map[string]string{}
73	}
74
75	// Check and set KeyID
76	switch {
77	case os.Getenv(EnvOCIKMSWrapperKeyID) != "":
78		k.keyID = os.Getenv(EnvOCIKMSWrapperKeyID)
79	case os.Getenv(EnvVaultOCIKMSSealKeyID) != "":
80		k.keyID = os.Getenv(EnvVaultOCIKMSSealKeyID)
81	case config[KMSConfigKeyID] != "":
82		k.keyID = config[KMSConfigKeyID]
83	default:
84		return nil, fmt.Errorf("'%s' not found for OCI KMS seal configuration", KMSConfigKeyID)
85	}
86	// Check and set cryptoEndpoint
87	switch {
88	case os.Getenv(EnvOCIKMSWrapperCryptoEndpoint) != "":
89		k.cryptoEndpoint = os.Getenv(EnvOCIKMSWrapperCryptoEndpoint)
90	case os.Getenv(EnvVaultOCIKMSSealCryptoEndpoint) != "":
91		k.cryptoEndpoint = os.Getenv(EnvVaultOCIKMSSealCryptoEndpoint)
92	case config[KMSConfigCryptoEndpoint] != "":
93		k.cryptoEndpoint = config[KMSConfigCryptoEndpoint]
94	default:
95		return nil, fmt.Errorf("'%s' not found for OCI KMS seal configuration", KMSConfigCryptoEndpoint)
96	}
97
98	// Check and set managementEndpoint
99	switch {
100	case os.Getenv(EnvOCIKMSWrapperManagementEndpoint) != "":
101		k.managementEndpoint = os.Getenv(EnvOCIKMSWrapperManagementEndpoint)
102	case os.Getenv(EnvVaultOCIKMSSealManagementEndpoint) != "":
103		k.managementEndpoint = os.Getenv(EnvVaultOCIKMSSealManagementEndpoint)
104	case config[KMSConfigManagementEndpoint] != "":
105		k.managementEndpoint = config[KMSConfigManagementEndpoint]
106	default:
107		return nil, fmt.Errorf("'%s' not found for OCI KMS seal configuration", KMSConfigManagementEndpoint)
108	}
109
110	// Check and set authTypeAPIKey
111	var err error
112	k.authTypeAPIKey = false
113	authTypeAPIKeyStr := config[KMSConfigAuthTypeAPIKey]
114	if authTypeAPIKeyStr != "" {
115		k.authTypeAPIKey, err = strconv.ParseBool(authTypeAPIKeyStr)
116		if err != nil {
117			return nil, fmt.Errorf("failed parsing "+KMSConfigAuthTypeAPIKey+" parameter: %w", err)
118		}
119	}
120
121	// Check and set OCI KMS crypto client
122	if k.cryptoClient == nil {
123		kmsCryptoClient, err := k.getOCIKMSCryptoClient()
124		if err != nil {
125			return nil, fmt.Errorf("error initializing OCI KMS client: %w", err)
126		}
127		k.cryptoClient = kmsCryptoClient
128	}
129
130	// Check and set OCI KMS management client
131	if k.managementClient == nil {
132		kmsManagementClient, err := k.getOCIKMSManagementClient()
133		if err != nil {
134			return nil, fmt.Errorf("error initializing OCI KMS client: %w", err)
135		}
136		k.managementClient = kmsManagementClient
137	}
138
139	// Calling Encrypt method with empty string just to validate keyId access and store current keyVersion
140	encryptedBlobInfo, err := k.Encrypt(context.Background(), []byte(""), nil)
141	if err != nil || encryptedBlobInfo == nil {
142		return nil, fmt.Errorf("failed "+KMSConfigKeyID+" validation: %w", err)
143	}
144
145	// Map that holds non-sensitive configuration info
146	wrapperInfo := make(map[string]string)
147	wrapperInfo[KMSConfigKeyID] = k.keyID
148	wrapperInfo[KMSConfigCryptoEndpoint] = k.cryptoEndpoint
149	wrapperInfo[KMSConfigManagementEndpoint] = k.managementEndpoint
150	if k.authTypeAPIKey {
151		wrapperInfo["principal_type"] = "user"
152	} else {
153		wrapperInfo["principal_type"] = "instance"
154	}
155
156	return wrapperInfo, nil
157}
158
159func (k *Wrapper) Type() string {
160	return wrapping.OCIKMS
161}
162
163func (k *Wrapper) KeyID() string {
164	return k.currentKeyID.Load().(string)
165}
166
167func (k *Wrapper) HMACKeyID() string {
168	return ""
169}
170
171func (k *Wrapper) Init(context.Context) error {
172	return nil
173}
174
175func (k *Wrapper) Finalize(context.Context) error {
176	return nil
177}
178
179func (k *Wrapper) Encrypt(ctx context.Context, plaintext, aad []byte) (*wrapping.EncryptedBlobInfo, error) {
180	if plaintext == nil {
181		return nil, errors.New("given plaintext for encryption is nil")
182	}
183
184	env, err := wrapping.NewEnvelope(nil).Encrypt(plaintext, aad)
185	if err != nil {
186		return nil, fmt.Errorf("error wrapping data: %w", err)
187	}
188
189	if k.cryptoClient == nil {
190		return nil, errors.New("nil client")
191	}
192
193	// OCI KMS required base64 encrypted plain text before sending to the service
194	encodedKey := base64.StdEncoding.EncodeToString(env.Key)
195
196	// Build Encrypt Request
197	requestMetadata := k.getRequestMetadata()
198	encryptedDataDetails := keymanagement.EncryptDataDetails{
199		KeyId:     &k.keyID,
200		Plaintext: &encodedKey,
201	}
202
203	input := keymanagement.EncryptRequest{
204		EncryptDataDetails: encryptedDataDetails,
205		RequestMetadata:    requestMetadata,
206	}
207	output, err := k.cryptoClient.Encrypt(ctx, input)
208	if err != nil {
209		return nil, fmt.Errorf("error encrypting data: %w", err)
210	}
211
212	// Note: It is potential a timing issue if the key gets rotated between this
213	// getCurrentKeyVersion operation and above Encrypt operation
214	keyVersion, err := k.getCurrentKeyVersion()
215	if err != nil {
216		return nil, fmt.Errorf("error getting current key version: %w", err)
217	}
218	// Update key version
219	k.currentKeyID.Store(keyVersion)
220
221	ret := &wrapping.EncryptedBlobInfo{
222		Ciphertext: env.Ciphertext,
223		IV:         env.IV,
224		KeyInfo: &wrapping.KeyInfo{
225			// Storing current key version in case we want to re-wrap older entries
226			KeyID:      keyVersion,
227			WrappedKey: []byte(*output.Ciphertext),
228		},
229	}
230
231	return ret, nil
232}
233
234func (k *Wrapper) Decrypt(ctx context.Context, in *wrapping.EncryptedBlobInfo, aad []byte) ([]byte, error) {
235	if in == nil {
236		return nil, fmt.Errorf("given input for decryption is nil")
237	}
238
239	requestMetadata := k.getRequestMetadata()
240	cipherTextBlob := string(in.KeyInfo.WrappedKey)
241	decryptedDataDetails := keymanagement.DecryptDataDetails{
242		KeyId:      &k.keyID,
243		Ciphertext: &cipherTextBlob,
244	}
245	input := keymanagement.DecryptRequest{
246		DecryptDataDetails: decryptedDataDetails,
247		RequestMetadata:    requestMetadata,
248	}
249	output, err := k.cryptoClient.Decrypt(ctx, input)
250	if err != nil {
251		return nil, fmt.Errorf("error decrypting data: %w", err)
252	}
253	envelopeKey, err := base64.StdEncoding.DecodeString(*output.Plaintext)
254	if err != nil {
255		return nil, fmt.Errorf("error base64 decrypting data: %w", err)
256	}
257	envInfo := &wrapping.EnvelopeInfo{
258		Key:        envelopeKey,
259		IV:         in.IV,
260		Ciphertext: in.Ciphertext,
261	}
262
263	plaintext, err := wrapping.NewEnvelope(nil).Decrypt(envInfo, aad)
264	if err != nil {
265		return nil, fmt.Errorf("error decrypting data: %w", err)
266	}
267
268	return plaintext, nil
269}
270
271func (k *Wrapper) getConfigProvider() (common.ConfigurationProvider, error) {
272	var cp common.ConfigurationProvider
273	var err error
274	if k.authTypeAPIKey {
275		cp = common.DefaultConfigProvider()
276	} else {
277		cp, err = auth.InstancePrincipalConfigurationProvider()
278		if err != nil {
279			return nil, fmt.Errorf("failed creating InstancePrincipalConfigurationProvider: %w", err)
280		}
281	}
282	return cp, nil
283}
284
285// Build OCI KMS crypto client
286func (k *Wrapper) getOCIKMSCryptoClient() (*keymanagement.KmsCryptoClient, error) {
287	cp, err := k.getConfigProvider()
288	if err != nil {
289		return nil, fmt.Errorf("failed creating configuration provider: %w", err)
290	}
291
292	// Build crypto client
293	kmsCryptoClient, err := keymanagement.NewKmsCryptoClientWithConfigurationProvider(cp, k.cryptoEndpoint)
294	if err != nil {
295		return nil, fmt.Errorf("failed creating NewKmsCryptoClientWithConfigurationProvider: %w", err)
296	}
297
298	return &kmsCryptoClient, nil
299}
300
301// Build OCI KMS management client
302func (k *Wrapper) getOCIKMSManagementClient() (*keymanagement.KmsManagementClient, error) {
303	cp, err := k.getConfigProvider()
304	if err != nil {
305		return nil, fmt.Errorf("failed creating configuration provider: %w", err)
306	}
307
308	// Build crypto client
309	kmsManagementClient, err := keymanagement.NewKmsManagementClientWithConfigurationProvider(cp, k.managementEndpoint)
310	if err != nil {
311		return nil, fmt.Errorf("failed creating NewKmsCryptoClientWithConfigurationProvider: %w", err)
312	}
313
314	return &kmsManagementClient, nil
315}
316
317// Request metadata includes retry policy
318func (k *Wrapper) getRequestMetadata() common.RequestMetadata {
319	// Only retry for 5xx errors
320	retryOn5xxFunc := func(r common.OCIOperationResponse) bool {
321		return r.Error != nil && r.Response.HTTPResponse().StatusCode >= 500
322	}
323	return getRequestMetadataWithCustomizedRetryPolicy(retryOn5xxFunc)
324}
325
326func getRequestMetadataWithCustomizedRetryPolicy(fn func(r common.OCIOperationResponse) bool) common.RequestMetadata {
327	return common.RequestMetadata{
328		RetryPolicy: getExponentialBackoffRetryPolicy(uint(KMSMaximumNumberOfRetries), fn),
329	}
330}
331
332func getExponentialBackoffRetryPolicy(n uint, fn func(r common.OCIOperationResponse) bool) *common.RetryPolicy {
333	// The duration between each retry operation, you might want to wait longer each time the retry fails
334	exponentialBackoff := func(r common.OCIOperationResponse) time.Duration {
335		return time.Duration(math.Pow(float64(2), float64(r.AttemptNumber-1))) * time.Second
336	}
337	policy := common.NewRetryPolicy(n, fn, exponentialBackoff)
338	return &policy
339}
340
341func (k *Wrapper) getCurrentKeyVersion() (string, error) {
342	if k.managementClient == nil {
343		return "", fmt.Errorf("managementClient has not yet initialized")
344	}
345	requestMetadata := k.getRequestMetadata()
346	getKeyInput := keymanagement.GetKeyRequest{
347		KeyId:           &k.keyID,
348		RequestMetadata: requestMetadata,
349	}
350	getKeyResponse, err := k.managementClient.GetKey(context.Background(), getKeyInput)
351	if err != nil || getKeyResponse.CurrentKeyVersion == nil {
352		return "", fmt.Errorf("failed getting current key version: %w", err)
353	}
354
355	return *getKeyResponse.CurrentKeyVersion, nil
356}
357