1package transit 2 3import ( 4 "encoding/base64" 5 "fmt" 6 "os" 7 "path" 8 "strconv" 9 10 log "github.com/hashicorp/go-hclog" 11 "github.com/hashicorp/vault/api" 12) 13 14type transitClientEncryptor interface { 15 Close() 16 Encrypt(plaintext []byte) (ciphertext []byte, err error) 17 Decrypt(ciphertext []byte) (plaintext []byte, err error) 18} 19 20type transitClient struct { 21 client *api.Client 22 renewer *api.Renewer 23 24 mountPath string 25 keyName string 26} 27 28func newTransitClient(logger log.Logger, config map[string]string) (*transitClient, map[string]string, error) { 29 if config == nil { 30 config = map[string]string{} 31 } 32 33 var mountPath, keyName string 34 switch { 35 case os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") != "": 36 mountPath = os.Getenv("VAULT_TRANSIT_SEAL_MOUNT_PATH") 37 case config["mount_path"] != "": 38 mountPath = config["mount_path"] 39 default: 40 return nil, nil, fmt.Errorf("mount_path is required") 41 } 42 43 switch { 44 case os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") != "": 45 keyName = os.Getenv("VAULT_TRANSIT_SEAL_KEY_NAME") 46 case config["key_name"] != "": 47 keyName = config["key_name"] 48 default: 49 return nil, nil, fmt.Errorf("key_name is required") 50 } 51 52 var disableRenewal bool 53 var disableRenewalRaw string 54 switch { 55 case os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") != "": 56 disableRenewalRaw = os.Getenv("VAULT_TRANSIT_SEAL_DISABLE_RENEWAL") 57 case config["disable_renewal"] != "": 58 disableRenewalRaw = config["disable_renewal"] 59 } 60 if disableRenewalRaw != "" { 61 var err error 62 disableRenewal, err = strconv.ParseBool(disableRenewalRaw) 63 if err != nil { 64 return nil, nil, err 65 } 66 } 67 68 var namespace string 69 switch { 70 case os.Getenv("VAULT_NAMESPACE") != "": 71 namespace = os.Getenv("VAULT_NAMESPACE") 72 case config["namespace"] != "": 73 namespace = config["namespace"] 74 } 75 76 apiConfig := api.DefaultConfig() 77 if config["address"] != "" { 78 apiConfig.Address = config["address"] 79 } 80 if config["tls_ca_cert"] != "" || config["tls_ca_path"] != "" || config["tls_client_cert"] != "" || config["tls_client_key"] != "" || 81 config["tls_server_name"] != "" || config["tls_skip_verify"] != "" { 82 var tlsSkipVerify bool 83 if config["tls_skip_verify"] != "" { 84 var err error 85 tlsSkipVerify, err = strconv.ParseBool(config["tls_skip_verify"]) 86 if err != nil { 87 return nil, nil, err 88 } 89 } 90 91 tlsConfig := &api.TLSConfig{ 92 CACert: config["tls_ca_cert"], 93 CAPath: config["tls_ca_path"], 94 ClientCert: config["tls_client_cert"], 95 ClientKey: config["tls_client_key"], 96 TLSServerName: config["tls_server_name"], 97 Insecure: tlsSkipVerify, 98 } 99 if err := apiConfig.ConfigureTLS(tlsConfig); err != nil { 100 return nil, nil, err 101 } 102 } 103 104 apiClient, err := api.NewClient(apiConfig) 105 if err != nil { 106 return nil, nil, err 107 } 108 if config["token"] != "" { 109 apiClient.SetToken(config["token"]) 110 } 111 if namespace != "" { 112 apiClient.SetNamespace(namespace) 113 } 114 if apiClient.Token() == "" { 115 logger.Info("no token provided to transit auto-seal") 116 } 117 118 client := &transitClient{ 119 client: apiClient, 120 mountPath: mountPath, 121 keyName: keyName, 122 } 123 124 if !disableRenewal && apiClient.Token() != "" { 125 // Renew the token immediately to get a secret to pass to renewer 126 secret, err := apiClient.Auth().Token().RenewTokenAsSelf(apiClient.Token(), 0) 127 // If we don't get an error renewing, set up a renewer. The token may not be renewable or not have 128 // permission to renew-self. 129 if err == nil { 130 renewer, err := apiClient.NewRenewer(&api.RenewerInput{ 131 Secret: secret, 132 }) 133 if err != nil { 134 return nil, nil, err 135 } 136 client.renewer = renewer 137 138 go func() { 139 for { 140 select { 141 case err := <-renewer.DoneCh(): 142 logger.Info("shutting down token renewal") 143 if err != nil { 144 logger.Error("error renewing token", "error", err) 145 } 146 return 147 case <-renewer.RenewCh(): 148 logger.Trace("successfully renewed token") 149 } 150 } 151 }() 152 go renewer.Renew() 153 } else { 154 logger.Info("unable to renew token, disabling renewal", "err", err) 155 } 156 } 157 158 sealInfo := make(map[string]string) 159 sealInfo["address"] = apiClient.Address() 160 sealInfo["mount_path"] = mountPath 161 sealInfo["key_name"] = keyName 162 if namespace != "" { 163 sealInfo["namespace"] = namespace 164 } 165 166 return client, sealInfo, nil 167} 168 169func (c *transitClient) Close() { 170 if c.renewer != nil { 171 c.renewer.Stop() 172 } 173} 174 175func (c *transitClient) Encrypt(plaintext []byte) ([]byte, error) { 176 encPlaintext := base64.StdEncoding.EncodeToString(plaintext) 177 path := path.Join(c.mountPath, "encrypt", c.keyName) 178 secret, err := c.client.Logical().Write(path, map[string]interface{}{ 179 "plaintext": encPlaintext, 180 }) 181 if err != nil { 182 return nil, err 183 } 184 185 return []byte(secret.Data["ciphertext"].(string)), nil 186} 187 188func (c *transitClient) Decrypt(ciphertext []byte) ([]byte, error) { 189 path := path.Join(c.mountPath, "decrypt", c.keyName) 190 secret, err := c.client.Logical().Write(path, map[string]interface{}{ 191 "ciphertext": string(ciphertext), 192 }) 193 if err != nil { 194 return nil, err 195 } 196 197 plaintext, err := base64.StdEncoding.DecodeString(secret.Data["plaintext"].(string)) 198 if err != nil { 199 return nil, err 200 } 201 return plaintext, nil 202} 203