1package transit
2
3import (
4	"encoding/base64"
5	"errors"
6	"fmt"
7	"os"
8	"path"
9	"strconv"
10
11	log "github.com/hashicorp/go-hclog"
12	"github.com/hashicorp/vault/api"
13)
14
15type transitClientEncryptor interface {
16	Close()
17	Encrypt(plaintext []byte) (ciphertext []byte, err error)
18	Decrypt(ciphertext []byte) (plaintext []byte, err error)
19}
20
21type transitClient struct {
22	client  *api.Client
23	renewer *api.Renewer
24
25	mountPath string
26	keyName   string
27}
28
29func newTransitClient(logger log.Logger, config map[string]string) (*transitClient, map[string]string, error) {
30	if config == nil {
31		config = map[string]string{}
32	}
33
34	var mountPath, keyName string
35	switch {
36	case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "":
37		mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH")
38	case config["mount_path"] != "":
39		mountPath = config["mount_path"]
40	default:
41		return nil, nil, fmt.Errorf("mount_path is required")
42	}
43
44	switch {
45	case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "":
46		keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME")
47	case config["key_name"] != "":
48		keyName = config["key_name"]
49	default:
50		return nil, nil, fmt.Errorf("key_name is required")
51	}
52
53	var disableRenewal bool
54	var disableRenewalRaw string
55	switch {
56	case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "":
57		disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL")
58	case config["disable_renewal"] != "":
59		disableRenewalRaw = config["disable_renewal"]
60	}
61	if disableRenewalRaw != "" {
62		var err error
63		disableRenewal, err = strconv.ParseBool(disableRenewalRaw)
64		if err != nil {
65			return nil, nil, err
66		}
67	}
68
69	var namespace string
70	switch {
71	case os.Getenv("VAULT_NAMESPACE") != "":
72		namespace = os.Getenv("VAULT_NAMESPACE")
73	case config["namespace"] != "":
74		namespace = config["namespace"]
75	}
76
77	apiConfig := api.DefaultConfig()
78	if config["address"] != "" {
79		apiConfig.Address = config["address"]
80	}
81	if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" ||
82		config["tls_server_name"] != "" || config["tls_skip_verify"] != "" {
83		var tlsSkipVerify bool
84		if config["tls_skip_verify"] != "" {
85			var err error
86			tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"])
87			if err != nil {
88				return nil, nil, err
89			}
90		}
91
92		tlsConfig := &api.TLSConfig{
93			CACert:        config["tls_ca_cert"],
94			CAPath:        config["tls_ca_path"],
95			ClientCert:    config["tls_client_cert"],
96			ClientKey:     config["tls_client_key"],
97			TLSServerName: config["tls_server_name"],
98			Insecure:      tlsSkipVerify,
99		}
100		if err := apiConfig.ConfigureTLS(tlsConfig); err != nil {
101			return nil, nil, err
102		}
103	}
104
105	apiClient, err := api.NewClient(apiConfig)
106	if err != nil {
107		return nil, nil, err
108	}
109	if config["token"] != "" {
110		apiClient.SetToken(config["token"])
111	}
112	if namespace != "" {
113		apiClient.SetNamespace(namespace)
114	}
115	if apiClient.Token() == "" {
116		return nil, nil, errors.New("missing token")
117	}
118
119	client := &transitClient{
120		client:    apiClient,
121		mountPath: mountPath,
122		keyName:   keyName,
123	}
124
125	if !disableRenewal {
126		// Renew the token immediately to get a secret to pass to renewer
127		secret, err := apiClient.Auth().Token().RenewTokenAsSelf(apiClient.Token(), 0)
128		// If we don't get an error renewing, set up a renewer.  The token may not be renewable or not have
129		// permission to renew-self.
130		if err == nil {
131			renewer, err := apiClient.NewRenewer(&api.RenewerInput{
132				Secret: secret,
133			})
134			if err != nil {
135				return nil, nil, err
136			}
137			client.renewer = renewer
138
139			go func() {
140				for {
141					select {
142					case err := <-renewer.DoneCh():
143						logger.Info("shutting down token renewal")
144						if err != nil {
145							logger.Error("error renewing token", "error", err)
146						}
147						return
148					case <-renewer.RenewCh():
149						logger.Trace("successfully renewed token")
150					}
151				}
152			}()
153			go renewer.Renew()
154		} else {
155			logger.Info("unable to renew token, disabling renewal", "err", err)
156		}
157	}
158
159	sealInfo := make(map[string]string)
160	sealInfo["address"] = apiClient.Address()
161	sealInfo["mount_path"] = mountPath
162	sealInfo["key_name"] = keyName
163	if namespace != "" {
164		sealInfo["namespace"] = namespace
165	}
166
167	return client, sealInfo, nil
168}
169
170func (c *transitClient) Close() {
171	if c.renewer != nil {
172		c.renewer.Stop()
173	}
174}
175
176func (c *transitClient) Encrypt(plaintext []byte) ([]byte, error) {
177	encPlaintext := base64.StdEncoding.EncodeToString(plaintext)
178	path := path.Join(c.mountPath, "encrypt", c.keyName)
179	secret, err := c.client.Logical().Write(path, map[string]interface{}{
180		"plaintext": encPlaintext,
181	})
182	if err != nil {
183		return nil, err
184	}
185
186	return []byte(secret.Data["ciphertext"].(string)), nil
187}
188
189func (c *transitClient) Decrypt(ciphertext []byte) ([]byte, error) {
190	path := path.Join(c.mountPath, "decrypt", c.keyName)
191	secret, err := c.client.Logical().Write(path, map[string]interface{}{
192		"ciphertext": string(ciphertext),
193	})
194	if err != nil {
195		return nil, err
196	}
197
198	plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string))
199	if err != nil {
200		return nil, err
201	}
202	return plaintext, nil
203}
204