1package transit 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "strings" 9 "testing" 10 11 log "github.com/hashicorp/go-hclog" 12 "github.com/hashicorp/vault/sdk/helper/logging" 13 "github.com/hashicorp/vault/sdk/physical" 14 "github.com/hashicorp/vault/vault/seal" 15) 16 17type testTransitClient struct { 18 keyID string 19 seal seal.Access 20} 21 22func newTestTransitClient(keyID string) *testTransitClient { 23 return &testTransitClient{ 24 keyID: keyID, 25 seal: seal.NewTestSeal(nil), 26 } 27} 28 29func (m *testTransitClient) Close() {} 30 31func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) { 32 v, err := m.seal.Encrypt(context.Background(), plaintext) 33 if err != nil { 34 return nil, err 35 } 36 37 return []byte(fmt.Sprintf("v1:%s:%s", m.keyID, string(v.Ciphertext))), nil 38} 39 40func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) { 41 splitKey := strings.Split(string(ciphertext), ":") 42 if len(splitKey) != 3 { 43 return nil, errors.New("invalid ciphertext returned") 44 } 45 46 data := &physical.EncryptedBlobInfo{ 47 Ciphertext: []byte(splitKey[2]), 48 } 49 v, err := m.seal.Decrypt(context.Background(), data) 50 if err != nil { 51 return nil, err 52 } 53 54 return v, nil 55} 56 57func TestTransitSeal_Lifecycle(t *testing.T) { 58 s := NewSeal(logging.NewVaultLogger(log.Trace)) 59 60 keyID := "test-key" 61 s.client = newTestTransitClient(keyID) 62 63 // Test Encrypt and Decrypt calls 64 input := []byte("foo") 65 swi, err := s.Encrypt(context.Background(), input) 66 if err != nil { 67 t.Fatalf("err: %s", err.Error()) 68 } 69 70 pt, err := s.Decrypt(context.Background(), swi) 71 if err != nil { 72 t.Fatalf("err: %s", err.Error()) 73 } 74 75 if !reflect.DeepEqual(input, pt) { 76 t.Fatalf("expected %s, got %s", input, pt) 77 } 78 79 if s.KeyID() != keyID { 80 t.Fatalf("key id does not match: expected %s, got %s", keyID, s.KeyID()) 81 } 82} 83