1package transit
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"reflect"
8	"strings"
9	"testing"
10
11	log "github.com/hashicorp/go-hclog"
12	"github.com/hashicorp/vault/sdk/helper/logging"
13	"github.com/hashicorp/vault/sdk/physical"
14	"github.com/hashicorp/vault/vault/seal"
15)
16
17type testTransitClient struct {
18	keyID string
19	seal  seal.Access
20}
21
22func newTestTransitClient(keyID string) *testTransitClient {
23	return &testTransitClient{
24		keyID: keyID,
25		seal:  seal.NewTestSeal(nil),
26	}
27}
28
29func (m *testTransitClient) Close() {}
30
31func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) {
32	v, err := m.seal.Encrypt(context.Background(), plaintext)
33	if err != nil {
34		return nil, err
35	}
36
37	return []byte(fmt.Sprintf("v1:%s:%s", m.keyID, string(v.Ciphertext))), nil
38}
39
40func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) {
41	splitKey := strings.Split(string(ciphertext), ":")
42	if len(splitKey) != 3 {
43		return nil, errors.New("invalid ciphertext returned")
44	}
45
46	data := &physical.EncryptedBlobInfo{
47		Ciphertext: []byte(splitKey[2]),
48	}
49	v, err := m.seal.Decrypt(context.Background(), data)
50	if err != nil {
51		return nil, err
52	}
53
54	return v, nil
55}
56
57func TestTransitSeal_Lifecycle(t *testing.T) {
58	s := NewSeal(logging.NewVaultLogger(log.Trace))
59
60	keyID := "test-key"
61	s.client = newTestTransitClient(keyID)
62
63	// Test Encrypt and Decrypt calls
64	input := []byte("foo")
65	swi, err := s.Encrypt(context.Background(), input)
66	if err != nil {
67		t.Fatalf("err: %s", err.Error())
68	}
69
70	pt, err := s.Decrypt(context.Background(), swi)
71	if err != nil {
72		t.Fatalf("err: %s", err.Error())
73	}
74
75	if !reflect.DeepEqual(input, pt) {
76		t.Fatalf("expected %s, got %s", input, pt)
77	}
78
79	if s.KeyID() != keyID {
80		t.Fatalf("key id does not match: expected %s, got %s", keyID, s.KeyID())
81	}
82}
83