1package transit 2 3import ( 4 "context" 5 "errors" 6 "strings" 7 "sync/atomic" 8 9 "github.com/hashicorp/go-hclog" 10 wrapping "github.com/hashicorp/go-kms-wrapping" 11) 12 13// Wrapper is a wrapper that leverages Vault's Transit secret 14// engine 15type Wrapper struct { 16 logger hclog.Logger 17 client transitClientEncryptor 18 currentKeyID *atomic.Value 19} 20 21var _ wrapping.Wrapper = (*Wrapper)(nil) 22 23// NewWrapper creates a new transit wrapper 24func NewWrapper(opts *wrapping.WrapperOptions) *Wrapper { 25 if opts == nil { 26 opts = new(wrapping.WrapperOptions) 27 } 28 s := &Wrapper{ 29 logger: opts.Logger, 30 currentKeyID: new(atomic.Value), 31 } 32 s.currentKeyID.Store("") 33 return s 34} 35 36// SetConfig processes the config info from the server config 37func (s *Wrapper) SetConfig(config map[string]string) (map[string]string, error) { 38 client, wrapperInfo, err := newTransitClient(s.logger, config) 39 if err != nil { 40 return nil, err 41 } 42 s.client = client 43 44 // Send a value to test the wrapper and to set the current key id 45 if _, err := s.Encrypt(context.Background(), []byte("a"), nil); err != nil { 46 client.Close() 47 return nil, err 48 } 49 50 return wrapperInfo, nil 51} 52 53// Init is called during core.Initialize 54func (s *Wrapper) Init(_ context.Context) error { 55 return nil 56} 57 58// Finalize is called during shutdown 59func (s *Wrapper) Finalize(_ context.Context) error { 60 s.client.Close() 61 return nil 62} 63 64// Type returns the type for this particular Wrapper implementation 65func (s *Wrapper) Type() string { 66 return wrapping.Transit 67} 68 69// KeyID returns the last known key id 70func (s *Wrapper) KeyID() string { 71 return s.currentKeyID.Load().(string) 72} 73 74// HMACKeyID returns the last known HMAC key id 75func (s *Wrapper) HMACKeyID() string { 76 return "" 77} 78 79// Encrypt is used to encrypt using Vault's Transit engine 80func (s *Wrapper) Encrypt(_ context.Context, plaintext, aad []byte) (blob *wrapping.EncryptedBlobInfo, err error) { 81 ciphertext, err := s.client.Encrypt(plaintext) 82 if err != nil { 83 return nil, err 84 } 85 86 splitKey := strings.Split(string(ciphertext), ":") 87 if len(splitKey) != 3 { 88 return nil, errors.New("invalid ciphertext returned") 89 } 90 keyID := splitKey[1] 91 s.currentKeyID.Store(keyID) 92 93 ret := &wrapping.EncryptedBlobInfo{ 94 Ciphertext: ciphertext, 95 KeyInfo: &wrapping.KeyInfo{ 96 KeyID: keyID, 97 }, 98 } 99 return ret, nil 100} 101 102// Decrypt is used to decrypt the ciphertext 103func (s *Wrapper) Decrypt(_ context.Context, in *wrapping.EncryptedBlobInfo, _ []byte) (pt []byte, err error) { 104 plaintext, err := s.client.Decrypt(in.Ciphertext) 105 if err != nil { 106 return nil, err 107 } 108 return plaintext, nil 109} 110 111// GetClient returns the transit Wrapper's transitClientEncryptor 112func (s *Wrapper) GetClient() transitClientEncryptor { 113 return s.client 114} 115