1package api
2
3import (
4	"crypto/tls"
5	"crypto/x509"
6	"encoding/base64"
7	"errors"
8	"flag"
9	"net/url"
10	"os"
11
12	squarejwt "gopkg.in/square/go-jose.v2/jwt"
13
14	"github.com/hashicorp/errwrap"
15)
16
17var (
18	// PluginMetadataModeEnv is an ENV name used to disable TLS communication
19	// to bootstrap mounting plugins.
20	PluginMetadataModeEnv = "VAULT_PLUGIN_METADATA_MODE"
21
22	// PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the
23	// plugin.
24	PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN"
25)
26
27// PluginAPIClientMeta is a helper that plugins can use to configure TLS connections
28// back to Vault.
29type PluginAPIClientMeta struct {
30	// These are set by the command line flags.
31	flagCACert     string
32	flagCAPath     string
33	flagClientCert string
34	flagClientKey  string
35	flagInsecure   bool
36}
37
38// FlagSet returns the flag set for configuring the TLS connection
39func (f *PluginAPIClientMeta) FlagSet() *flag.FlagSet {
40	fs := flag.NewFlagSet("vault plugin settings", flag.ContinueOnError)
41
42	fs.StringVar(&f.flagCACert, "ca-cert", "", "")
43	fs.StringVar(&f.flagCAPath, "ca-path", "", "")
44	fs.StringVar(&f.flagClientCert, "client-cert", "", "")
45	fs.StringVar(&f.flagClientKey, "client-key", "", "")
46	fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "")
47
48	return fs
49}
50
51// GetTLSConfig will return a TLSConfig based off the values from the flags
52func (f *PluginAPIClientMeta) GetTLSConfig() *TLSConfig {
53	// If we need custom TLS configuration, then set it
54	if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure {
55		t := &TLSConfig{
56			CACert:        f.flagCACert,
57			CAPath:        f.flagCAPath,
58			ClientCert:    f.flagClientCert,
59			ClientKey:     f.flagClientKey,
60			TLSServerName: "",
61			Insecure:      f.flagInsecure,
62		}
63
64		return t
65	}
66
67	return nil
68}
69
70// VaultPluginTLSProvider is run inside a plugin and retrieves the response
71// wrapped TLS certificate from vault. It returns a configured TLS Config.
72func VaultPluginTLSProvider(apiTLSConfig *TLSConfig) func() (*tls.Config, error) {
73	if os.Getenv(PluginMetadataModeEnv) == "true" {
74		return nil
75	}
76
77	return func() (*tls.Config, error) {
78		unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
79
80		parsedJWT, err := squarejwt.ParseSigned(unwrapToken)
81		if err != nil {
82			return nil, errwrap.Wrapf("error parsing wrapping token: {{err}}", err)
83		}
84
85		var allClaims = make(map[string]interface{})
86		if err = parsedJWT.UnsafeClaimsWithoutVerification(&allClaims); err != nil {
87			return nil, errwrap.Wrapf("error parsing claims from wrapping token: {{err}}", err)
88		}
89
90		addrClaimRaw, ok := allClaims["addr"]
91		if !ok {
92			return nil, errors.New("could not validate addr claim")
93		}
94		vaultAddr, ok := addrClaimRaw.(string)
95		if !ok {
96			return nil, errors.New("could not parse addr claim")
97		}
98		if vaultAddr == "" {
99			return nil, errors.New(`no vault api_addr found`)
100		}
101
102		// Sanity check the value
103		if _, err := url.Parse(vaultAddr); err != nil {
104			return nil, errwrap.Wrapf("error parsing the vault api_addr: {{err}}", err)
105		}
106
107		// Unwrap the token
108		clientConf := DefaultConfig()
109		clientConf.Address = vaultAddr
110		if apiTLSConfig != nil {
111			err := clientConf.ConfigureTLS(apiTLSConfig)
112			if err != nil {
113				return nil, errwrap.Wrapf("error configuring api client {{err}}", err)
114			}
115		}
116		client, err := NewClient(clientConf)
117		if err != nil {
118			return nil, errwrap.Wrapf("error during api client creation: {{err}}", err)
119		}
120
121		// Reset token value to make sure nothing has been set by default
122		client.ClearToken()
123
124		secret, err := client.Logical().Unwrap(unwrapToken)
125		if err != nil {
126			return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
127		}
128		if secret == nil {
129			return nil, errors.New("error during token unwrap request: secret is nil")
130		}
131
132		// Retrieve and parse the server's certificate
133		serverCertBytesRaw, ok := secret.Data["ServerCert"].(string)
134		if !ok {
135			return nil, errors.New("error unmarshalling certificate")
136		}
137
138		serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw)
139		if err != nil {
140			return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err)
141		}
142
143		serverCert, err := x509.ParseCertificate(serverCertBytes)
144		if err != nil {
145			return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err)
146		}
147
148		// Retrieve and parse the server's private key
149		serverKeyB64, ok := secret.Data["ServerKey"].(string)
150		if !ok {
151			return nil, errors.New("error unmarshalling certificate")
152		}
153
154		serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64)
155		if err != nil {
156			return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err)
157		}
158
159		serverKey, err := x509.ParseECPrivateKey(serverKeyRaw)
160		if err != nil {
161			return nil, errwrap.Wrapf("error parsing certificate: {{err}}", err)
162		}
163
164		// Add CA cert to the cert pool
165		caCertPool := x509.NewCertPool()
166		caCertPool.AddCert(serverCert)
167
168		// Build a certificate object out of the server's cert and private key.
169		cert := tls.Certificate{
170			Certificate: [][]byte{serverCertBytes},
171			PrivateKey:  serverKey,
172			Leaf:        serverCert,
173		}
174
175		// Setup TLS config
176		tlsConfig := &tls.Config{
177			ClientCAs:  caCertPool,
178			RootCAs:    caCertPool,
179			ClientAuth: tls.RequireAndVerifyClientCert,
180			// TLS 1.2 minimum
181			MinVersion:   tls.VersionTLS12,
182			Certificates: []tls.Certificate{cert},
183			ServerName:   serverCert.Subject.CommonName,
184		}
185		tlsConfig.BuildNameToCertificate()
186
187		return tlsConfig, nil
188	}
189}
190