1package auth
2
3import (
4	"context"
5	"encoding/json"
6	"errors"
7	"math/rand"
8	"net/http"
9	"time"
10
11	"github.com/hashicorp/go-hclog"
12	"github.com/hashicorp/vault/api"
13	"github.com/hashicorp/vault/sdk/helper/jsonutil"
14)
15
16const (
17	initialBackoff    = 1 * time.Second
18	defaultMaxBackoff = 5 * time.Minute
19)
20
21// AuthMethod is the interface that auto-auth methods implement for the agent
22// to use.
23type AuthMethod interface {
24	// Authenticate returns a mount path, header, request body, and error.
25	// The header may be nil if no special header is needed.
26	Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error)
27	NewCreds() chan struct{}
28	CredSuccess()
29	Shutdown()
30}
31
32// AuthMethodWithClient is an extended interface that can return an API client
33// for use during the authentication call.
34type AuthMethodWithClient interface {
35	AuthMethod
36	AuthClient(client *api.Client) (*api.Client, error)
37}
38
39type AuthConfig struct {
40	Logger    hclog.Logger
41	MountPath string
42	WrapTTL   time.Duration
43	Config    map[string]interface{}
44}
45
46// AuthHandler is responsible for keeping a token alive and renewed and passing
47// new tokens to the sink server
48type AuthHandler struct {
49	OutputCh                     chan string
50	TemplateTokenCh              chan string
51	token                        string
52	logger                       hclog.Logger
53	client                       *api.Client
54	random                       *rand.Rand
55	wrapTTL                      time.Duration
56	maxBackoff                   time.Duration
57	enableReauthOnNewCredentials bool
58	enableTemplateTokenCh        bool
59}
60
61type AuthHandlerConfig struct {
62	Logger                       hclog.Logger
63	Client                       *api.Client
64	WrapTTL                      time.Duration
65	MaxBackoff                   time.Duration
66	Token                        string
67	EnableReauthOnNewCredentials bool
68	EnableTemplateTokenCh        bool
69}
70
71func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
72	ah := &AuthHandler{
73		// This is buffered so that if we try to output after the sink server
74		// has been shut down, during agent shutdown, we won't block
75		OutputCh:                     make(chan string, 1),
76		TemplateTokenCh:              make(chan string, 1),
77		token:                        conf.Token,
78		logger:                       conf.Logger,
79		client:                       conf.Client,
80		random:                       rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
81		wrapTTL:                      conf.WrapTTL,
82		maxBackoff:                   conf.MaxBackoff,
83		enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
84		enableTemplateTokenCh:        conf.EnableTemplateTokenCh,
85	}
86
87	return ah
88}
89
90func backoffOrQuit(ctx context.Context, backoff *agentBackoff) {
91	select {
92	case <-time.After(backoff.current):
93	case <-ctx.Done():
94	}
95
96	// Increase exponential backoff for the next time if we don't
97	// successfully auth/renew/etc.
98	backoff.next()
99}
100
101func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
102	if am == nil {
103		return errors.New("auth handler: nil auth method")
104	}
105
106	backoff := newAgentBackoff(ah.maxBackoff)
107
108	ah.logger.Info("starting auth handler")
109	defer func() {
110		am.Shutdown()
111		close(ah.OutputCh)
112		close(ah.TemplateTokenCh)
113		ah.logger.Info("auth handler stopped")
114	}()
115
116	credCh := am.NewCreds()
117	if !ah.enableReauthOnNewCredentials {
118		realCredCh := credCh
119		credCh = nil
120		if realCredCh != nil {
121			go func() {
122				for {
123					select {
124					case <-ctx.Done():
125						return
126					case <-realCredCh:
127					}
128				}
129			}()
130		}
131	}
132	if credCh == nil {
133		credCh = make(chan struct{})
134	}
135
136	var watcher *api.LifetimeWatcher
137	first := true
138
139	for {
140		select {
141		case <-ctx.Done():
142			return nil
143
144		default:
145		}
146
147		var clientToUse *api.Client
148		var err error
149		var path string
150		var data map[string]interface{}
151		var header http.Header
152
153		switch am.(type) {
154		case AuthMethodWithClient:
155			clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
156			if err != nil {
157				ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff)
158				backoffOrQuit(ctx, backoff)
159				continue
160			}
161		default:
162			clientToUse = ah.client
163		}
164
165		var secret *api.Secret = new(api.Secret)
166		if first && ah.token != "" {
167			ah.logger.Debug("using preloaded token")
168
169			first = false
170			ah.logger.Debug("lookup-self with preloaded token")
171			clientToUse.SetToken(ah.token)
172
173			secret, err = clientToUse.Logical().Read("auth/token/lookup-self")
174			if err != nil {
175				ah.logger.Error("could not look up token", "err", err, "backoff", backoff)
176				backoffOrQuit(ctx, backoff)
177				continue
178			}
179
180			duration, _ := secret.Data["ttl"].(json.Number).Int64()
181			secret.Auth = &api.SecretAuth{
182				ClientToken:   secret.Data["id"].(string),
183				LeaseDuration: int(duration),
184				Renewable:     secret.Data["renewable"].(bool),
185			}
186		} else {
187			ah.logger.Info("authenticating")
188
189			path, header, data, err = am.Authenticate(ctx, ah.client)
190			if err != nil {
191				ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff)
192				backoffOrQuit(ctx, backoff)
193				continue
194			}
195		}
196
197		if ah.wrapTTL > 0 {
198			wrapClient, err := clientToUse.Clone()
199			if err != nil {
200				ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff)
201				backoffOrQuit(ctx, backoff)
202				continue
203			}
204			wrapClient.SetWrappingLookupFunc(func(string, string) string {
205				return ah.wrapTTL.String()
206			})
207			clientToUse = wrapClient
208		}
209		for key, values := range header {
210			for _, value := range values {
211				clientToUse.AddHeader(key, value)
212			}
213		}
214
215		// This should only happen if there's no preloaded token (regular auto-auth login)
216		//  or if a preloaded token has expired and is now switching to auto-auth.
217		if secret.Auth == nil {
218			secret, err = clientToUse.Logical().Write(path, data)
219			// Check errors/sanity
220			if err != nil {
221				ah.logger.Error("error authenticating", "error", err, "backoff", backoff)
222				backoffOrQuit(ctx, backoff)
223				continue
224			}
225		}
226
227		switch {
228		case ah.wrapTTL > 0:
229			if secret.WrapInfo == nil {
230				ah.logger.Error("authentication returned nil wrap info", "backoff", backoff)
231				backoffOrQuit(ctx, backoff)
232				continue
233			}
234			if secret.WrapInfo.Token == "" {
235				ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff)
236				backoffOrQuit(ctx, backoff)
237				continue
238			}
239			wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
240			if err != nil {
241				ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff)
242				backoffOrQuit(ctx, backoff)
243				continue
244			}
245			ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
246			ah.OutputCh <- string(wrappedResp)
247			if ah.enableTemplateTokenCh {
248				ah.TemplateTokenCh <- string(wrappedResp)
249			}
250
251			am.CredSuccess()
252			backoff.reset()
253
254			select {
255			case <-ctx.Done():
256				ah.logger.Info("shutdown triggered")
257				continue
258
259			case <-credCh:
260				ah.logger.Info("auth method found new credentials, re-authenticating")
261				continue
262			}
263
264		default:
265			if secret == nil || secret.Auth == nil {
266				ah.logger.Error("authentication returned nil auth info", "backoff", backoff)
267				backoffOrQuit(ctx, backoff)
268				continue
269			}
270			if secret.Auth.ClientToken == "" {
271				ah.logger.Error("authentication returned empty client token", "backoff", backoff)
272				backoffOrQuit(ctx, backoff)
273				continue
274			}
275			ah.logger.Info("authentication successful, sending token to sinks")
276			ah.OutputCh <- secret.Auth.ClientToken
277			if ah.enableTemplateTokenCh {
278				ah.TemplateTokenCh <- secret.Auth.ClientToken
279			}
280
281			am.CredSuccess()
282			backoff.reset()
283		}
284
285		if watcher != nil {
286			watcher.Stop()
287		}
288
289		watcher, err = clientToUse.NewLifetimeWatcher(&api.LifetimeWatcherInput{
290			Secret: secret,
291		})
292		if err != nil {
293			ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff)
294			backoffOrQuit(ctx, backoff)
295			continue
296		}
297
298		// Start the renewal process
299		ah.logger.Info("starting renewal process")
300		go watcher.Renew()
301
302	LifetimeWatcherLoop:
303		for {
304			select {
305			case <-ctx.Done():
306				ah.logger.Info("shutdown triggered, stopping lifetime watcher")
307				watcher.Stop()
308				break LifetimeWatcherLoop
309
310			case err := <-watcher.DoneCh():
311				ah.logger.Info("lifetime watcher done channel triggered")
312				if err != nil {
313					ah.logger.Error("error renewing token", "error", err)
314				}
315				break LifetimeWatcherLoop
316
317			case <-watcher.RenewCh():
318				ah.logger.Info("renewed auth token")
319
320			case <-credCh:
321				ah.logger.Info("auth method found new credentials, re-authenticating")
322				break LifetimeWatcherLoop
323			}
324		}
325	}
326}
327
328// agentBackoff tracks exponential backoff state.
329type agentBackoff struct {
330	max     time.Duration
331	current time.Duration
332}
333
334func newAgentBackoff(max time.Duration) *agentBackoff {
335	if max <= 0 {
336		max = defaultMaxBackoff
337	}
338
339	return &agentBackoff{
340		max:     max,
341		current: initialBackoff,
342	}
343}
344
345// next determines the next backoff duration that is roughly twice
346// the current value, capped to a max value, with a measure of randomness.
347func (b *agentBackoff) next() {
348	maxBackoff := 2 * b.current
349
350	if maxBackoff > b.max {
351		maxBackoff = b.max
352	}
353
354	// Trim a random amount (0-25%) off the doubled duration
355	trim := rand.Int63n(int64(maxBackoff) / 4)
356	b.current = maxBackoff - time.Duration(trim)
357}
358
359func (b *agentBackoff) reset() {
360	b.current = initialBackoff
361}
362
363func (b agentBackoff) String() string {
364	return b.current.Truncate(10 * time.Millisecond).String()
365}
366