1package api 2 3import ( 4 "crypto/tls" 5 "crypto/x509" 6 "encoding/base64" 7 "errors" 8 "flag" 9 "net/url" 10 "os" 11 12 squarejwt "gopkg.in/square/go-jose.v2/jwt" 13 14 "github.com/hashicorp/errwrap" 15) 16 17var ( 18 // PluginMetadataModeEnv is an ENV name used to disable TLS communication 19 // to bootstrap mounting plugins. 20 PluginMetadataModeEnv = "VAULT_PLUGIN_METADATA_MODE" 21 22 // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the 23 // plugin. 24 PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" 25) 26 27// PluginAPIClientMeta is a helper that plugins can use to configure TLS connections 28// back to Vault. 29type PluginAPIClientMeta struct { 30 // These are set by the command line flags. 31 flagCACert string 32 flagCAPath string 33 flagClientCert string 34 flagClientKey string 35 flagInsecure bool 36} 37 38// FlagSet returns the flag set for configuring the TLS connection 39func (f *PluginAPIClientMeta) FlagSet() *flag.FlagSet { 40 fs := flag.NewFlagSet("vault plugin settings", flag.ContinueOnError) 41 42 fs.StringVar(&f.flagCACert, "ca-cert", "", "") 43 fs.StringVar(&f.flagCAPath, "ca-path", "", "") 44 fs.StringVar(&f.flagClientCert, "client-cert", "", "") 45 fs.StringVar(&f.flagClientKey, "client-key", "", "") 46 fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "") 47 48 return fs 49} 50 51// GetTLSConfig will return a TLSConfig based off the values from the flags 52func (f *PluginAPIClientMeta) GetTLSConfig() *TLSConfig { 53 // If we need custom TLS configuration, then set it 54 if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure { 55 t := &TLSConfig{ 56 CACert: f.flagCACert, 57 CAPath: f.flagCAPath, 58 ClientCert: f.flagClientCert, 59 ClientKey: f.flagClientKey, 60 TLSServerName: "", 61 Insecure: f.flagInsecure, 62 } 63 64 return t 65 } 66 67 return nil 68} 69 70// VaultPluginTLSProvider is run inside a plugin and retrieves the response 71// wrapped TLS certificate from vault. It returns a configured TLS Config. 72func VaultPluginTLSProvider(apiTLSConfig *TLSConfig) func() (*tls.Config, error) { 73 if os.Getenv(PluginMetadataModeEnv) == "true" { 74 return nil 75 } 76 77 return func() (*tls.Config, error) { 78 unwrapToken := os.Getenv(PluginUnwrapTokenEnv) 79 80 parsedJWT, err := squarejwt.ParseSigned(unwrapToken) 81 if err != nil { 82 return nil, errwrap.Wrapf("error parsing wrapping token: {{err}}", err) 83 } 84 85 var allClaims = make(map[string]interface{}) 86 if err = parsedJWT.UnsafeClaimsWithoutVerification(&allClaims); err != nil { 87 return nil, errwrap.Wrapf("error parsing claims from wrapping token: {{err}}", err) 88 } 89 90 addrClaimRaw, ok := allClaims["addr"] 91 if !ok { 92 return nil, errors.New("could not validate addr claim") 93 } 94 vaultAddr, ok := addrClaimRaw.(string) 95 if !ok { 96 return nil, errors.New("could not parse addr claim") 97 } 98 if vaultAddr == "" { 99 return nil, errors.New(`no vault api_addr found`) 100 } 101 102 // Sanity check the value 103 if _, err := url.Parse(vaultAddr); err != nil { 104 return nil, errwrap.Wrapf("error parsing the vault api_addr: {{err}}", err) 105 } 106 107 // Unwrap the token 108 clientConf := DefaultConfig() 109 clientConf.Address = vaultAddr 110 if apiTLSConfig != nil { 111 err := clientConf.ConfigureTLS(apiTLSConfig) 112 if err != nil { 113 return nil, errwrap.Wrapf("error configuring api client {{err}}", err) 114 } 115 } 116 client, err := NewClient(clientConf) 117 if err != nil { 118 return nil, errwrap.Wrapf("error during api client creation: {{err}}", err) 119 } 120 121 // Reset token value to make sure nothing has been set by default 122 client.ClearToken() 123 124 secret, err := client.Logical().Unwrap(unwrapToken) 125 if err != nil { 126 return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) 127 } 128 if secret == nil { 129 return nil, errors.New("error during token unwrap request: secret is nil") 130 } 131 132 // Retrieve and parse the server's certificate 133 serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) 134 if !ok { 135 return nil, errors.New("error unmarshalling certificate") 136 } 137 138 serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) 139 if err != nil { 140 return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err) 141 } 142 143 serverCert, err := x509.ParseCertificate(serverCertBytes) 144 if err != nil { 145 return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err) 146 } 147 148 // Retrieve and parse the server's private key 149 serverKeyB64, ok := secret.Data["ServerKey"].(string) 150 if !ok { 151 return nil, errors.New("error unmarshalling certificate") 152 } 153 154 serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) 155 if err != nil { 156 return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err) 157 } 158 159 serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) 160 if err != nil { 161 return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err) 162 } 163 164 // Add CA cert to the cert pool 165 caCertPool := x509.NewCertPool() 166 caCertPool.AddCert(serverCert) 167 168 // Build a certificate object out of the server's cert and private key. 169 cert := tls.Certificate{ 170 Certificate: [][]byte{serverCertBytes}, 171 PrivateKey: serverKey, 172 Leaf: serverCert, 173 } 174 175 // Setup TLS config 176 tlsConfig := &tls.Config{ 177 ClientCAs: caCertPool, 178 RootCAs: caCertPool, 179 ClientAuth: tls.RequireAndVerifyClientCert, 180 // TLS 1.2 minimum 181 MinVersion: tls.VersionTLS12, 182 Certificates: []tls.Certificate{cert}, 183 ServerName: serverCert.Subject.CommonName, 184 } 185 tlsConfig.BuildNameToCertificate() 186 187 return tlsConfig, nil 188 } 189} 190