1package kerberos
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net/http"
8
9	"github.com/hashicorp/go-hclog"
10	kerberos "github.com/hashicorp/vault-plugin-auth-kerberos"
11	"github.com/hashicorp/vault/api"
12	"github.com/hashicorp/vault/command/agent/auth"
13	"github.com/hashicorp/vault/sdk/helper/parseutil"
14	"github.com/jcmturner/gokrb5/v8/spnego"
15)
16
17type kerberosMethod struct {
18	logger    hclog.Logger
19	mountPath string
20	loginCfg  *kerberos.LoginCfg
21}
22
23func NewKerberosAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
24	if conf == nil {
25		return nil, errors.New("empty config")
26	}
27	if conf.Config == nil {
28		return nil, errors.New("empty config data")
29	}
30	username, err := read("username", conf.Config)
31	if err != nil {
32		return nil, err
33	}
34	service, err := read("service", conf.Config)
35	if err != nil {
36		return nil, err
37	}
38	realm, err := read("realm", conf.Config)
39	if err != nil {
40		return nil, err
41	}
42	keytabPath, err := read("keytab_path", conf.Config)
43	if err != nil {
44		return nil, err
45	}
46	krb5ConfPath, err := read("krb5conf_path", conf.Config)
47	if err != nil {
48		return nil, err
49	}
50
51	disableFast := false
52	disableFastRaw, ok := conf.Config["disable_fast_negotiation"]
53	if ok {
54		disableFast, err = parseutil.ParseBool(disableFastRaw)
55		if err != nil {
56			return nil, fmt.Errorf("error parsing 'disable_fast_negotiation': %s", err)
57		}
58	}
59
60	return &kerberosMethod{
61		logger:    conf.Logger,
62		mountPath: conf.MountPath,
63		loginCfg: &kerberos.LoginCfg{
64			Username:               username,
65			Service:                service,
66			Realm:                  realm,
67			KeytabPath:             keytabPath,
68			Krb5ConfPath:           krb5ConfPath,
69			DisableFASTNegotiation: disableFast,
70		},
71	}, nil
72}
73
74func (k *kerberosMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) {
75	k.logger.Trace("beginning authentication")
76	authHeaderVal, err := kerberos.GetAuthHeaderVal(k.loginCfg)
77	if err != nil {
78		return "", nil, nil, err
79	}
80	var header http.Header
81	header = make(map[string][]string)
82	header.Set(spnego.HTTPHeaderAuthRequest, authHeaderVal)
83	return k.mountPath + "/login", header, make(map[string]interface{}), nil
84}
85
86// These functions are implemented to meet the AuthHandler interface,
87// but we don't need to take advantage of them.
88func (k *kerberosMethod) NewCreds() chan struct{} { return nil }
89func (k *kerberosMethod) CredSuccess()            {}
90func (k *kerberosMethod) Shutdown()               {}
91
92// read reads a key from a map and convert its value to a string.
93func read(key string, m map[string]interface{}) (string, error) {
94	raw, ok := m[key]
95	if !ok {
96		return "", fmt.Errorf("%q is required", key)
97	}
98	v, ok := raw.(string)
99	if !ok {
100		return "", fmt.Errorf("%q must be a string", key)
101	}
102	return v, nil
103}
104