1package transit
2
3import (
4	"encoding/base64"
5	"fmt"
6	"os"
7	"path"
8	"strconv"
9
10	log "github.com/hashicorp/go-hclog"
11	"github.com/hashicorp/vault/api"
12)
13
14type transitClientEncryptor interface {
15	Close()
16	Encrypt(plaintext []byte) (ciphertext []byte, err error)
17	Decrypt(ciphertext []byte) (plaintext []byte, err error)
18}
19
20type transitClient struct {
21	client  *api.Client
22	renewer *api.Renewer
23
24	mountPath string
25	keyName   string
26}
27
28func newTransitClient(logger log.Logger, config map[string]string) (*transitClient, map[string]string, error) {
29	if config == nil {
30		config = map[string]string{}
31	}
32
33	var mountPath, keyName string
34	switch {
35	case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "":
36		mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH")
37	case config["mount_path"] != "":
38		mountPath = config["mount_path"]
39	default:
40		return nil, nil, fmt.Errorf("mount_path is required")
41	}
42
43	switch {
44	case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "":
45		keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME")
46	case config["key_name"] != "":
47		keyName = config["key_name"]
48	default:
49		return nil, nil, fmt.Errorf("key_name is required")
50	}
51
52	var disableRenewal bool
53	var disableRenewalRaw string
54	switch {
55	case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "":
56		disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL")
57	case config["disable_renewal"] != "":
58		disableRenewalRaw = config["disable_renewal"]
59	}
60	if disableRenewalRaw != "" {
61		var err error
62		disableRenewal, err = strconv.ParseBool(disableRenewalRaw)
63		if err != nil {
64			return nil, nil, err
65		}
66	}
67
68	var namespace string
69	switch {
70	case os.Getenv("VAULT_NAMESPACE") != "":
71		namespace = os.Getenv("VAULT_NAMESPACE")
72	case config["namespace"] != "":
73		namespace = config["namespace"]
74	}
75
76	apiConfig := api.DefaultConfig()
77	if config["address"] != "" {
78		apiConfig.Address = config["address"]
79	}
80	if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" ||
81		config["tls_server_name"] != "" || config["tls_skip_verify"] != "" {
82		var tlsSkipVerify bool
83		if config["tls_skip_verify"] != "" {
84			var err error
85			tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"])
86			if err != nil {
87				return nil, nil, err
88			}
89		}
90
91		tlsConfig := &api.TLSConfig{
92			CACert:        config["tls_ca_cert"],
93			CAPath:        config["tls_ca_path"],
94			ClientCert:    config["tls_client_cert"],
95			ClientKey:     config["tls_client_key"],
96			TLSServerName: config["tls_server_name"],
97			Insecure:      tlsSkipVerify,
98		}
99		if err := apiConfig.ConfigureTLS(tlsConfig); err != nil {
100			return nil, nil, err
101		}
102	}
103
104	apiClient, err := api.NewClient(apiConfig)
105	if err != nil {
106		return nil, nil, err
107	}
108	if config["token"] != "" {
109		apiClient.SetToken(config["token"])
110	}
111	if namespace != "" {
112		apiClient.SetNamespace(namespace)
113	}
114	if apiClient.Token() == "" {
115		logger.Info("no token provided to transit auto-seal")
116	}
117
118	client := &transitClient{
119		client:    apiClient,
120		mountPath: mountPath,
121		keyName:   keyName,
122	}
123
124	if !disableRenewal && apiClient.Token() != "" {
125		// Renew the token immediately to get a secret to pass to renewer
126		secret, err := apiClient.Auth().Token().RenewTokenAsSelf(apiClient.Token(), 0)
127		// If we don't get an error renewing, set up a renewer.  The token may not be renewable or not have
128		// permission to renew-self.
129		if err == nil {
130			renewer, err := apiClient.NewRenewer(&api.RenewerInput{
131				Secret: secret,
132			})
133			if err != nil {
134				return nil, nil, err
135			}
136			client.renewer = renewer
137
138			go func() {
139				for {
140					select {
141					case err := <-renewer.DoneCh():
142						logger.Info("shutting down token renewal")
143						if err != nil {
144							logger.Error("error renewing token", "error", err)
145						}
146						return
147					case <-renewer.RenewCh():
148						logger.Trace("successfully renewed token")
149					}
150				}
151			}()
152			go renewer.Renew()
153		} else {
154			logger.Info("unable to renew token, disabling renewal", "err", err)
155		}
156	}
157
158	sealInfo := make(map[string]string)
159	sealInfo["address"] = apiClient.Address()
160	sealInfo["mount_path"] = mountPath
161	sealInfo["key_name"] = keyName
162	if namespace != "" {
163		sealInfo["namespace"] = namespace
164	}
165
166	return client, sealInfo, nil
167}
168
169func (c *transitClient) Close() {
170	if c.renewer != nil {
171		c.renewer.Stop()
172	}
173}
174
175func (c *transitClient) Encrypt(plaintext []byte) ([]byte, error) {
176	encPlaintext := base64.StdEncoding.EncodeToString(plaintext)
177	path := path.Join(c.mountPath, "encrypt", c.keyName)
178	secret, err := c.client.Logical().Write(path, map[string]interface{}{
179		"plaintext": encPlaintext,
180	})
181	if err != nil {
182		return nil, err
183	}
184
185	return []byte(secret.Data["ciphertext"].(string)), nil
186}
187
188func (c *transitClient) Decrypt(ciphertext []byte) ([]byte, error) {
189	path := path.Join(c.mountPath, "decrypt", c.keyName)
190	secret, err := c.client.Logical().Write(path, map[string]interface{}{
191		"ciphertext": string(ciphertext),
192	})
193	if err != nil {
194		return nil, err
195	}
196
197	plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string))
198	if err != nil {
199		return nil, err
200	}
201	return plaintext, nil
202}
203