1package auth
2
3import (
4	"context"
5	"math/rand"
6	"time"
7
8	hclog "github.com/hashicorp/go-hclog"
9	"github.com/hashicorp/vault/api"
10	"github.com/hashicorp/vault/sdk/helper/jsonutil"
11)
12
13type AuthMethod interface {
14	Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error)
15	NewCreds() chan struct{}
16	CredSuccess()
17	Shutdown()
18}
19
20type AuthConfig struct {
21	Logger    hclog.Logger
22	MountPath string
23	WrapTTL   time.Duration
24	Config    map[string]interface{}
25}
26
27// AuthHandler is responsible for keeping a token alive and renewed and passing
28// new tokens to the sink server
29type AuthHandler struct {
30	DoneCh                       chan struct{}
31	OutputCh                     chan string
32	TemplateTokenCh              chan string
33	logger                       hclog.Logger
34	client                       *api.Client
35	random                       *rand.Rand
36	wrapTTL                      time.Duration
37	enableReauthOnNewCredentials bool
38	enableTemplateTokenCh        bool
39}
40
41type AuthHandlerConfig struct {
42	Logger                       hclog.Logger
43	Client                       *api.Client
44	WrapTTL                      time.Duration
45	EnableReauthOnNewCredentials bool
46	EnableTemplateTokenCh        bool
47}
48
49func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
50	ah := &AuthHandler{
51		DoneCh: make(chan struct{}),
52		// This is buffered so that if we try to output after the sink server
53		// has been shut down, during agent shutdown, we won't block
54		OutputCh:                     make(chan string, 1),
55		TemplateTokenCh:              make(chan string, 1),
56		logger:                       conf.Logger,
57		client:                       conf.Client,
58		random:                       rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
59		wrapTTL:                      conf.WrapTTL,
60		enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
61		enableTemplateTokenCh:        conf.EnableTemplateTokenCh,
62	}
63
64	return ah
65}
66
67func backoffOrQuit(ctx context.Context, backoff time.Duration) {
68	select {
69	case <-time.After(backoff):
70	case <-ctx.Done():
71	}
72}
73
74func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
75	if am == nil {
76		panic("nil auth method")
77	}
78
79	ah.logger.Info("starting auth handler")
80	defer func() {
81		am.Shutdown()
82		close(ah.OutputCh)
83		close(ah.DoneCh)
84		close(ah.TemplateTokenCh)
85		ah.logger.Info("auth handler stopped")
86	}()
87
88	credCh := am.NewCreds()
89	if !ah.enableReauthOnNewCredentials {
90		realCredCh := credCh
91		credCh = nil
92		if realCredCh != nil {
93			go func() {
94				for {
95					select {
96					case <-ctx.Done():
97						return
98					case <-realCredCh:
99					}
100				}
101			}()
102		}
103	}
104	if credCh == nil {
105		credCh = make(chan struct{})
106	}
107
108	var watcher *api.LifetimeWatcher
109
110	for {
111		select {
112		case <-ctx.Done():
113			return
114
115		default:
116		}
117
118		// Create a fresh backoff value
119		backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
120
121		ah.logger.Info("authenticating")
122		path, data, err := am.Authenticate(ctx, ah.client)
123		if err != nil {
124			ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
125			backoffOrQuit(ctx, backoff)
126			continue
127		}
128
129		clientToUse := ah.client
130		if ah.wrapTTL > 0 {
131			wrapClient, err := ah.client.Clone()
132			if err != nil {
133				ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
134				backoffOrQuit(ctx, backoff)
135				continue
136			}
137			wrapClient.SetWrappingLookupFunc(func(string, string) string {
138				return ah.wrapTTL.String()
139			})
140			clientToUse = wrapClient
141		}
142
143		secret, err := clientToUse.Logical().Write(path, data)
144		// Check errors/sanity
145		if err != nil {
146			ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds())
147			backoffOrQuit(ctx, backoff)
148			continue
149		}
150
151		switch {
152		case ah.wrapTTL > 0:
153			if secret.WrapInfo == nil {
154				ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds())
155				backoffOrQuit(ctx, backoff)
156				continue
157			}
158			if secret.WrapInfo.Token == "" {
159				ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds())
160				backoffOrQuit(ctx, backoff)
161				continue
162			}
163			wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
164			if err != nil {
165				ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds())
166				backoffOrQuit(ctx, backoff)
167				continue
168			}
169			ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
170			ah.OutputCh <- string(wrappedResp)
171			if ah.enableTemplateTokenCh {
172				ah.TemplateTokenCh <- string(wrappedResp)
173			}
174
175			am.CredSuccess()
176
177			select {
178			case <-ctx.Done():
179				ah.logger.Info("shutdown triggered")
180				continue
181
182			case <-credCh:
183				ah.logger.Info("auth method found new credentials, re-authenticating")
184				continue
185			}
186
187		default:
188			if secret == nil || secret.Auth == nil {
189				ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds())
190				backoffOrQuit(ctx, backoff)
191				continue
192			}
193			if secret.Auth.ClientToken == "" {
194				ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds())
195				backoffOrQuit(ctx, backoff)
196				continue
197			}
198			ah.logger.Info("authentication successful, sending token to sinks")
199			ah.OutputCh <- secret.Auth.ClientToken
200			if ah.enableTemplateTokenCh {
201				ah.TemplateTokenCh <- secret.Auth.ClientToken
202			}
203
204			am.CredSuccess()
205		}
206
207		if watcher != nil {
208			watcher.Stop()
209		}
210
211		watcher, err = ah.client.NewLifetimeWatcher(&api.LifetimeWatcherInput{
212			Secret: secret,
213		})
214		if err != nil {
215			ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff.Seconds())
216			backoffOrQuit(ctx, backoff)
217			continue
218		}
219
220		// Start the renewal process
221		ah.logger.Info("starting renewal process")
222		go watcher.Renew()
223
224	LifetimeWatcherLoop:
225		for {
226			select {
227			case <-ctx.Done():
228				ah.logger.Info("shutdown triggered, stopping lifetime watcher")
229				watcher.Stop()
230				break LifetimeWatcherLoop
231
232			case err := <-watcher.DoneCh():
233				ah.logger.Info("lifetime watcher done channel triggered")
234				if err != nil {
235					ah.logger.Error("error renewing token", "error", err)
236				}
237				break LifetimeWatcherLoop
238
239			case <-watcher.RenewCh():
240				ah.logger.Info("renewed auth token")
241
242			case <-credCh:
243				ah.logger.Info("auth method found new credentials, re-authenticating")
244				break LifetimeWatcherLoop
245			}
246		}
247	}
248}
249