1package cert
2
3import (
4	"bytes"
5	"context"
6	"crypto/tls"
7	"crypto/x509"
8	"encoding/asn1"
9	"encoding/base64"
10	"encoding/pem"
11	"errors"
12	"fmt"
13	"strings"
14
15	"github.com/hashicorp/vault/sdk/framework"
16	"github.com/hashicorp/vault/sdk/helper/certutil"
17	"github.com/hashicorp/vault/sdk/helper/policyutil"
18	"github.com/hashicorp/vault/sdk/logical"
19
20	"github.com/hashicorp/vault/sdk/helper/cidrutil"
21	glob "github.com/ryanuber/go-glob"
22)
23
24// ParsedCert is a certificate that has been configured as trusted
25type ParsedCert struct {
26	Entry        *CertEntry
27	Certificates []*x509.Certificate
28}
29
30func pathLogin(b *backend) *framework.Path {
31	return &framework.Path{
32		Pattern: "login",
33		Fields: map[string]*framework.FieldSchema{
34			"name": &framework.FieldSchema{
35				Type:        framework.TypeString,
36				Description: "The name of the certificate role to authenticate against.",
37			},
38		},
39		Callbacks: map[logical.Operation]framework.OperationFunc{
40			logical.UpdateOperation:         b.pathLogin,
41			logical.AliasLookaheadOperation: b.pathLoginAliasLookahead,
42		},
43	}
44}
45
46func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
47	clientCerts := req.Connection.ConnState.PeerCertificates
48	if len(clientCerts) == 0 {
49		return nil, fmt.Errorf("no client certificate found")
50	}
51
52	return &logical.Response{
53		Auth: &logical.Auth{
54			Alias: &logical.Alias{
55				Name: clientCerts[0].Subject.CommonName,
56			},
57		},
58	}, nil
59}
60
61func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
62	var matched *ParsedCert
63	if verifyResp, resp, err := b.verifyCredentials(ctx, req, data); err != nil {
64		return nil, err
65	} else if resp != nil {
66		return resp, nil
67	} else {
68		matched = verifyResp
69	}
70
71	if matched == nil {
72		return nil, nil
73	}
74
75	if len(matched.Entry.TokenBoundCIDRs) > 0 {
76		if req.Connection == nil {
77			b.Logger().Warn("token bound CIDRs found but no connection information available for validation")
78			return nil, logical.ErrPermissionDenied
79		}
80		if !cidrutil.RemoteAddrIsOk(req.Connection.RemoteAddr, matched.Entry.TokenBoundCIDRs) {
81			return nil, logical.ErrPermissionDenied
82		}
83	}
84
85	clientCerts := req.Connection.ConnState.PeerCertificates
86	if len(clientCerts) == 0 {
87		return logical.ErrorResponse("no client certificate found"), nil
88	}
89	skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
90	akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
91
92	auth := &logical.Auth{
93		InternalData: map[string]interface{}{
94			"subject_key_id":   skid,
95			"authority_key_id": akid,
96		},
97		DisplayName: matched.Entry.DisplayName,
98		Metadata: map[string]string{
99			"cert_name":        matched.Entry.Name,
100			"common_name":      clientCerts[0].Subject.CommonName,
101			"serial_number":    clientCerts[0].SerialNumber.String(),
102			"subject_key_id":   certutil.GetHexFormatted(clientCerts[0].SubjectKeyId, ":"),
103			"authority_key_id": certutil.GetHexFormatted(clientCerts[0].AuthorityKeyId, ":"),
104		},
105		Alias: &logical.Alias{
106			Name: clientCerts[0].Subject.CommonName,
107		},
108	}
109	matched.Entry.PopulateTokenAuth(auth)
110
111	return &logical.Response{
112		Auth: auth,
113	}, nil
114}
115
116func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
117	config, err := b.Config(ctx, req.Storage)
118	if err != nil {
119		return nil, err
120	}
121
122	if !config.DisableBinding {
123		var matched *ParsedCert
124		if verifyResp, resp, err := b.verifyCredentials(ctx, req, d); err != nil {
125			return nil, err
126		} else if resp != nil {
127			return resp, nil
128		} else {
129			matched = verifyResp
130		}
131
132		if matched == nil {
133			return nil, nil
134		}
135
136		clientCerts := req.Connection.ConnState.PeerCertificates
137		if len(clientCerts) == 0 {
138			return logical.ErrorResponse("no client certificate found"), nil
139		}
140		skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
141		akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
142
143		// Certificate should not only match a registered certificate policy.
144		// Also, the identity of the certificate presented should match the identity of the certificate used during login
145		if req.Auth.InternalData["subject_key_id"] != skid && req.Auth.InternalData["authority_key_id"] != akid {
146			return nil, fmt.Errorf("client identity during renewal not matching client identity used during login")
147		}
148
149	}
150	// Get the cert and use its TTL
151	cert, err := b.Cert(ctx, req.Storage, req.Auth.Metadata["cert_name"])
152	if err != nil {
153		return nil, err
154	}
155	if cert == nil {
156		// User no longer exists, do not renew
157		return nil, nil
158	}
159
160	if !policyutil.EquivalentPolicies(cert.TokenPolicies, req.Auth.TokenPolicies) {
161		return nil, fmt.Errorf("policies have changed, not renewing")
162	}
163
164	resp := &logical.Response{Auth: req.Auth}
165	resp.Auth.TTL = cert.TokenTTL
166	resp.Auth.MaxTTL = cert.TokenMaxTTL
167	resp.Auth.Period = cert.TokenPeriod
168	return resp, nil
169}
170
171func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
172	// Get the connection state
173	if req.Connection == nil || req.Connection.ConnState == nil {
174		return nil, logical.ErrorResponse("tls connection required"), nil
175	}
176	connState := req.Connection.ConnState
177
178	if connState.PeerCertificates == nil || len(connState.PeerCertificates) == 0 {
179		return nil, logical.ErrorResponse("client certificate must be supplied"), nil
180	}
181	clientCert := connState.PeerCertificates[0]
182
183	// Allow constraining the login request to a single CertEntry
184	var certName string
185	if req.Auth != nil { // It's a renewal, use the saved certName
186		certName = req.Auth.Metadata["cert_name"]
187	} else {
188		certName = d.Get("name").(string)
189	}
190
191	// Load the trusted certificates
192	roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName)
193
194	// Get the list of full chains matching the connection and validates the
195	// certificate itself
196	trustedChains, err := validateConnState(roots, connState)
197	if err != nil {
198		return nil, nil, err
199	}
200
201	// If trustedNonCAs is not empty it means that client had registered a non-CA cert
202	// with the backend.
203	if len(trustedNonCAs) != 0 {
204		for _, trustedNonCA := range trustedNonCAs {
205			tCert := trustedNonCA.Certificates[0]
206			// Check for client cert being explicitly listed in the config (and matching other constraints)
207			if tCert.SerialNumber.Cmp(clientCert.SerialNumber) == 0 &&
208				bytes.Equal(tCert.AuthorityKeyId, clientCert.AuthorityKeyId) &&
209				b.matchesConstraints(clientCert, trustedNonCA.Certificates, trustedNonCA) {
210				return trustedNonCA, nil, nil
211			}
212		}
213	}
214
215	// If no trusted chain was found, client is not authenticated
216	// This check happens after checking for a matching configured non-CA certs
217	if len(trustedChains) == 0 {
218		return nil, logical.ErrorResponse("invalid certificate or no client certificate supplied"), nil
219	}
220
221	// Search for a ParsedCert that intersects with the validated chains and any additional constraints
222	matches := make([]*ParsedCert, 0)
223	for _, trust := range trusted { // For each ParsedCert in the config
224		for _, tCert := range trust.Certificates { // For each certificate in the entry
225			for _, chain := range trustedChains { // For each root chain that we matched
226				for _, cCert := range chain { // For each cert in the matched chain
227					if tCert.Equal(cCert) && // ParsedCert intersects with matched chain
228						b.matchesConstraints(clientCert, chain, trust) { // validate client cert + matched chain against the config
229						// Add the match to the list
230						matches = append(matches, trust)
231					}
232				}
233			}
234		}
235	}
236
237	// Fail on no matches
238	if len(matches) == 0 {
239		return nil, logical.ErrorResponse("no chain matching all constraints could be found for this login certificate"), nil
240	}
241
242	// Return the first matching entry (for backwards compatibility, we continue to just pick one if multiple match)
243	return matches[0], nil, nil
244}
245
246func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool {
247	return !b.checkForChainInCRLs(trustedChain) &&
248		b.matchesNames(clientCert, config) &&
249		b.matchesCommonName(clientCert, config) &&
250		b.matchesDNSSANs(clientCert, config) &&
251		b.matchesEmailSANs(clientCert, config) &&
252		b.matchesURISANs(clientCert, config) &&
253		b.matchesOrganizationalUnits(clientCert, config) &&
254		b.matchesCertificateExtensions(clientCert, config)
255}
256
257// matchesNames verifies that the certificate matches at least one configured
258// allowed name
259func (b *backend) matchesNames(clientCert *x509.Certificate, config *ParsedCert) bool {
260	// Default behavior (no names) is to allow all names
261	if len(config.Entry.AllowedNames) == 0 {
262		return true
263	}
264	// At least one pattern must match at least one name if any patterns are specified
265	for _, allowedName := range config.Entry.AllowedNames {
266		if glob.Glob(allowedName, clientCert.Subject.CommonName) {
267			return true
268		}
269
270		for _, name := range clientCert.DNSNames {
271			if glob.Glob(allowedName, name) {
272				return true
273			}
274		}
275
276		for _, name := range clientCert.EmailAddresses {
277			if glob.Glob(allowedName, name) {
278				return true
279			}
280		}
281
282	}
283	return false
284}
285
286// matchesCommonName verifies that the certificate matches at least one configured
287// allowed common name
288func (b *backend) matchesCommonName(clientCert *x509.Certificate, config *ParsedCert) bool {
289	// Default behavior (no names) is to allow all names
290	if len(config.Entry.AllowedCommonNames) == 0 {
291		return true
292	}
293	// At least one pattern must match at least one name if any patterns are specified
294	for _, allowedCommonName := range config.Entry.AllowedCommonNames {
295		if glob.Glob(allowedCommonName, clientCert.Subject.CommonName) {
296			return true
297		}
298	}
299
300	return false
301}
302
303// matchesDNSSANs verifies that the certificate matches at least one configured
304// allowed dns entry in the subject alternate name extension
305func (b *backend) matchesDNSSANs(clientCert *x509.Certificate, config *ParsedCert) bool {
306	// Default behavior (no names) is to allow all names
307	if len(config.Entry.AllowedDNSSANs) == 0 {
308		return true
309	}
310	// At least one pattern must match at least one name if any patterns are specified
311	for _, allowedDNS := range config.Entry.AllowedDNSSANs {
312		for _, name := range clientCert.DNSNames {
313			if glob.Glob(allowedDNS, name) {
314				return true
315			}
316		}
317	}
318
319	return false
320}
321
322// matchesEmailSANs verifies that the certificate matches at least one configured
323// allowed email in the subject alternate name extension
324func (b *backend) matchesEmailSANs(clientCert *x509.Certificate, config *ParsedCert) bool {
325	// Default behavior (no names) is to allow all names
326	if len(config.Entry.AllowedEmailSANs) == 0 {
327		return true
328	}
329	// At least one pattern must match at least one name if any patterns are specified
330	for _, allowedEmail := range config.Entry.AllowedEmailSANs {
331		for _, email := range clientCert.EmailAddresses {
332			if glob.Glob(allowedEmail, email) {
333				return true
334			}
335		}
336	}
337
338	return false
339}
340
341// matchesURISANs verifies that the certificate matches at least one configured
342// allowed uri in the subject alternate name extension
343func (b *backend) matchesURISANs(clientCert *x509.Certificate, config *ParsedCert) bool {
344	// Default behavior (no names) is to allow all names
345	if len(config.Entry.AllowedURISANs) == 0 {
346		return true
347	}
348	// At least one pattern must match at least one name if any patterns are specified
349	for _, allowedURI := range config.Entry.AllowedURISANs {
350		for _, name := range clientCert.URIs {
351			if glob.Glob(allowedURI, name.String()) {
352				return true
353			}
354		}
355	}
356
357	return false
358}
359
360// matchesOrganizationalUnits verifies that the certificate matches at least one configurd allowed OU
361func (b *backend) matchesOrganizationalUnits(clientCert *x509.Certificate, config *ParsedCert) bool {
362	// Default behavior (no OUs) is to allow all OUs
363	if len(config.Entry.AllowedOrganizationalUnits) == 0 {
364		return true
365	}
366
367	// At least one pattern must match at least one name if any patterns are specified
368	for _, allowedOrganizationalUnits := range config.Entry.AllowedOrganizationalUnits {
369		for _, ou := range clientCert.Subject.OrganizationalUnit {
370			if glob.Glob(allowedOrganizationalUnits, ou) {
371				return true
372			}
373		}
374	}
375
376	return false
377}
378
379// matchesCertificateExtensions verifies that the certificate matches configured
380// required extensions
381func (b *backend) matchesCertificateExtensions(clientCert *x509.Certificate, config *ParsedCert) bool {
382	// If no required extensions, nothing to check here
383	if len(config.Entry.RequiredExtensions) == 0 {
384		return true
385	}
386	// Fail fast if we have required extensions but no extensions on the cert
387	if len(clientCert.Extensions) == 0 {
388		return false
389	}
390
391	// Build Client Extensions Map for Constraint Matching
392	// x509 Writes Extensions in ASN1 with a bitstring tag, which results in the field
393	// including its ASN.1 type tag bytes. For the sake of simplicity, assume string type
394	// and drop the tag bytes. And get the number of bytes from the tag.
395	clientExtMap := make(map[string]string, len(clientCert.Extensions))
396	for _, ext := range clientCert.Extensions {
397		var parsedValue string
398		asn1.Unmarshal(ext.Value, &parsedValue)
399		clientExtMap[ext.Id.String()] = parsedValue
400	}
401	// If any of the required extensions don't match the constraint fails
402	for _, requiredExt := range config.Entry.RequiredExtensions {
403		reqExt := strings.SplitN(requiredExt, ":", 2)
404		clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]]
405		if !clientExtValueOk || !glob.Glob(reqExt[1], clientExtValue) {
406			return false
407		}
408	}
409	return true
410}
411
412// loadTrustedCerts is used to load all the trusted certificates from the backend
413func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
414	pool = x509.NewCertPool()
415	trusted = make([]*ParsedCert, 0)
416	trustedNonCAs = make([]*ParsedCert, 0)
417	names, err := storage.List(ctx, "cert/")
418	if err != nil {
419		b.Logger().Error("failed to list trusted certs", "error", err)
420		return
421	}
422	for _, name := range names {
423		// If we are trying to select a single CertEntry and this isn't it
424		if certName != "" && name != certName {
425			continue
426		}
427		entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/"))
428		if err != nil {
429			b.Logger().Error("failed to load trusted cert", "name", name, "error", err)
430			continue
431		}
432		parsed := parsePEM([]byte(entry.Certificate))
433		if len(parsed) == 0 {
434			b.Logger().Error("failed to parse certificate", "name", name)
435			continue
436		}
437		if !parsed[0].IsCA {
438			trustedNonCAs = append(trustedNonCAs, &ParsedCert{
439				Entry:        entry,
440				Certificates: parsed,
441			})
442		} else {
443			for _, p := range parsed {
444				pool.AddCert(p)
445			}
446
447			// Create a ParsedCert entry
448			trusted = append(trusted, &ParsedCert{
449				Entry:        entry,
450				Certificates: parsed,
451			})
452		}
453	}
454	return
455}
456
457func (b *backend) checkForChainInCRLs(chain []*x509.Certificate) bool {
458	badChain := false
459	for _, cert := range chain {
460		badCRLs := b.findSerialInCRLs(cert.SerialNumber)
461		if len(badCRLs) != 0 {
462			badChain = true
463			break
464		}
465	}
466	return badChain
467}
468
469func (b *backend) checkForValidChain(chains [][]*x509.Certificate) bool {
470	for _, chain := range chains {
471		if !b.checkForChainInCRLs(chain) {
472			return true
473		}
474	}
475	return false
476}
477
478// parsePEM parses a PEM encoded x509 certificate
479func parsePEM(raw []byte) (certs []*x509.Certificate) {
480	for len(raw) > 0 {
481		var block *pem.Block
482		block, raw = pem.Decode(raw)
483		if block == nil {
484			break
485		}
486		if (block.Type != "CERTIFICATE" && block.Type != "TRUSTED CERTIFICATE") || len(block.Headers) != 0 {
487			continue
488		}
489
490		cert, err := x509.ParseCertificate(block.Bytes)
491		if err != nil {
492			continue
493		}
494		certs = append(certs, cert)
495	}
496	return
497}
498
499// validateConnState is used to validate that the TLS client is authorized
500// by at trusted certificate. Most of this logic is lifted from the client
501// verification logic here:  http://golang.org/src/crypto/tls/handshake_server.go
502// The trusted chains are returned.
503func validateConnState(roots *x509.CertPool, cs *tls.ConnectionState) ([][]*x509.Certificate, error) {
504	certs := cs.PeerCertificates
505	if len(certs) == 0 {
506		return nil, nil
507	}
508
509	opts := x509.VerifyOptions{
510		Roots:         roots,
511		Intermediates: x509.NewCertPool(),
512		KeyUsages:     []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
513	}
514
515	if len(certs) > 1 {
516		for _, cert := range certs[1:] {
517			opts.Intermediates.AddCert(cert)
518		}
519	}
520
521	chains, err := certs[0].Verify(opts)
522	if err != nil {
523		if _, ok := err.(x509.UnknownAuthorityError); ok {
524			return nil, nil
525		}
526		return nil, errors.New("failed to verify client's certificate: " + err.Error())
527	}
528
529	return chains, nil
530}
531