1package transit
2
3import (
4	"context"
5	"errors"
6	"strings"
7	"sync/atomic"
8
9	"github.com/hashicorp/go-hclog"
10	wrapping "github.com/hashicorp/go-kms-wrapping"
11)
12
13// Wrapper is a wrapper that leverages Vault's Transit secret
14// engine
15type Wrapper struct {
16	logger       hclog.Logger
17	client       transitClientEncryptor
18	currentKeyID *atomic.Value
19}
20
21var _ wrapping.Wrapper = (*Wrapper)(nil)
22
23// NewWrapper creates a new transit wrapper
24func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper {
25	if opts == nil {
26		opts = new(wrapping.WrapperOptions)
27	}
28	s := &Wrapper{
29		logger:       opts.Logger,
30		currentKeyID: new(atomic.Value),
31	}
32	s.currentKeyID.Store("")
33	return s
34}
35
36// SetConfig processes the config info from the server config
37func (s *Wrapper) SetConfig(config map[string]string) (map[string]string, error) {
38	client, wrapperInfo, err := newTransitClient(s.logger, config)
39	if err != nil {
40		return nil, err
41	}
42	s.client = client
43
44	// Send a value to test the wrapper and to set the current key id
45	if _, err := s.Encrypt(context.Background(), []byte("a"), nil); err != nil {
46		client.Close()
47		return nil, err
48	}
49
50	return wrapperInfo, nil
51}
52
53// Init is called during core.Initialize
54func (s *Wrapper) Init(_ context.Context) error {
55	return nil
56}
57
58// Finalize is called during shutdown
59func (s *Wrapper) Finalize(_ context.Context) error {
60	s.client.Close()
61	return nil
62}
63
64// Type returns the type for this particular Wrapper implementation
65func (s *Wrapper) Type() string {
66	return wrapping.Transit
67}
68
69// KeyID returns the last known key id
70func (s *Wrapper) KeyID() string {
71	return s.currentKeyID.Load().(string)
72}
73
74// HMACKeyID returns the last known HMAC key id
75func (s *Wrapper) HMACKeyID() string {
76	return ""
77}
78
79// Encrypt is used to encrypt using Vault's Transit engine
80func (s *Wrapper) Encrypt(_ context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) {
81	ciphertext, err := s.client.Encrypt(plaintext)
82	if err != nil {
83		return nil, err
84	}
85
86	splitKey := strings.Split(string(ciphertext), ":")
87	if len(splitKey) != 3 {
88		return nil, errors.New("invalid ciphertext returned")
89	}
90	keyID := splitKey[1]
91	s.currentKeyID.Store(keyID)
92
93	ret := &wrapping.EncryptedBlobInfo{
94		Ciphertext: ciphertext,
95		KeyInfo: &wrapping.KeyInfo{
96			KeyID: keyID,
97		},
98	}
99	return ret, nil
100}
101
102// Decrypt is used to decrypt the ciphertext
103func (s *Wrapper) Decrypt(_ context.Context, in *wrapping.EncryptedBlobInfo, _ []byte) (pt []byte, err error) {
104	plaintext, err := s.client.Decrypt(in.Ciphertext)
105	if err != nil {
106		return nil, err
107	}
108	return plaintext, nil
109}
110
111// GetClient returns the transit Wrapper's transitClientEncryptor
112func (s *Wrapper) GetClient() transitClientEncryptor {
113	return s.client
114}
115