1package transit 2 3import ( 4 "encoding/base64" 5 "errors" 6 "fmt" 7 "os" 8 "path" 9 "strconv" 10 11 log "github.com/hashicorp/go-hclog" 12 "github.com/hashicorp/vault/api" 13) 14 15type transitClientEncryptor interface { 16 Close() 17 Encrypt(plaintext []byte) (ciphertext []byte, err error) 18 Decrypt(ciphertext []byte) (plaintext []byte, err error) 19} 20 21type transitClient struct { 22 client *api.Client 23 renewer *api.Renewer 24 25 mountPath string 26 keyName string 27} 28 29func newTransitClient(logger log.Logger, config map[string]string) (*transitClient, map[string]string, error) { 30 if config == nil { 31 config = map[string]string{} 32 } 33 34 var mountPath, keyName string 35 switch { 36 case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "": 37 mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") 38 case config["mount_path"] != "": 39 mountPath = config["mount_path"] 40 default: 41 return nil, nil, fmt.Errorf("mount_path is required") 42 } 43 44 switch { 45 case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "": 46 keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") 47 case config["key_name"] != "": 48 keyName = config["key_name"] 49 default: 50 return nil, nil, fmt.Errorf("key_name is required") 51 } 52 53 var disableRenewal bool 54 var disableRenewalRaw string 55 switch { 56 case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "": 57 disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") 58 case config["disable_renewal"] != "": 59 disableRenewalRaw = config["disable_renewal"] 60 } 61 if disableRenewalRaw != "" { 62 var err error 63 disableRenewal, err = strconv.ParseBool(disableRenewalRaw) 64 if err != nil { 65 return nil, nil, err 66 } 67 } 68 69 var namespace string 70 switch { 71 case os.Getenv("VAULT_NAMESPACE") != "": 72 namespace = os.Getenv("VAULT_NAMESPACE") 73 case config["namespace"] != "": 74 namespace = config["namespace"] 75 } 76 77 apiConfig := api.DefaultConfig() 78 if config["address"] != "" { 79 apiConfig.Address = config["address"] 80 } 81 if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" || 82 config["tls_server_name"] != "" || config["tls_skip_verify"] != "" { 83 var tlsSkipVerify bool 84 if config["tls_skip_verify"] != "" { 85 var err error 86 tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"]) 87 if err != nil { 88 return nil, nil, err 89 } 90 } 91 92 tlsConfig := &api.TLSConfig{ 93 CACert: config["tls_ca_cert"], 94 CAPath: config["tls_ca_path"], 95 ClientCert: config["tls_client_cert"], 96 ClientKey: config["tls_client_key"], 97 TLSServerName: config["tls_server_name"], 98 Insecure: tlsSkipVerify, 99 } 100 if err := apiConfig.ConfigureTLS(tlsConfig); err != nil { 101 return nil, nil, err 102 } 103 } 104 105 apiClient, err := api.NewClient(apiConfig) 106 if err != nil { 107 return nil, nil, err 108 } 109 if config["token"] != "" { 110 apiClient.SetToken(config["token"]) 111 } 112 if namespace != "" { 113 apiClient.SetNamespace(namespace) 114 } 115 if apiClient.Token() == "" { 116 return nil, nil, errors.New("missing token") 117 } 118 119 client := &transitClient{ 120 client: apiClient, 121 mountPath: mountPath, 122 keyName: keyName, 123 } 124 125 if !disableRenewal { 126 // Renew the token immediately to get a secret to pass to renewer 127 secret, err := apiClient.Auth().Token().RenewTokenAsSelf(apiClient.Token(), 0) 128 // If we don't get an error renewing, set up a renewer. The token may not be renewable or not have 129 // permission to renew-self. 130 if err == nil { 131 renewer, err := apiClient.NewRenewer(&api.RenewerInput{ 132 Secret: secret, 133 }) 134 if err != nil { 135 return nil, nil, err 136 } 137 client.renewer = renewer 138 139 go func() { 140 for { 141 select { 142 case err := <-renewer.DoneCh(): 143 logger.Info("shutting down token renewal") 144 if err != nil { 145 logger.Error("error renewing token", "error", err) 146 } 147 return 148 case <-renewer.RenewCh(): 149 logger.Trace("successfully renewed token") 150 } 151 } 152 }() 153 go renewer.Renew() 154 } else { 155 logger.Info("unable to renew token, disabling renewal", "err", err) 156 } 157 } 158 159 sealInfo := make(map[string]string) 160 sealInfo["address"] = apiClient.Address() 161 sealInfo["mount_path"] = mountPath 162 sealInfo["key_name"] = keyName 163 if namespace != "" { 164 sealInfo["namespace"] = namespace 165 } 166 167 return client, sealInfo, nil 168} 169 170func (c *transitClient) Close() { 171 if c.renewer != nil { 172 c.renewer.Stop() 173 } 174} 175 176func (c *transitClient) Encrypt(plaintext []byte) ([]byte, error) { 177 encPlaintext := base64.StdEncoding.EncodeToString(plaintext) 178 path := path.Join(c.mountPath, "encrypt", c.keyName) 179 secret, err := c.client.Logical().Write(path, map[string]interface{}{ 180 "plaintext": encPlaintext, 181 }) 182 if err != nil { 183 return nil, err 184 } 185 186 return []byte(secret.Data["ciphertext"].(string)), nil 187} 188 189func (c *transitClient) Decrypt(ciphertext []byte) ([]byte, error) { 190 path := path.Join(c.mountPath, "decrypt", c.keyName) 191 secret, err := c.client.Logical().Write(path, map[string]interface{}{ 192 "ciphertext": string(ciphertext), 193 }) 194 if err != nil { 195 return nil, err 196 } 197 198 plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string)) 199 if err != nil { 200 return nil, err 201 } 202 return plaintext, nil 203} 204