1package auth
2
3import (
4	"context"
5	"math/rand"
6	"time"
7
8	hclog "github.com/hashicorp/go-hclog"
9	"github.com/hashicorp/vault/api"
10	"github.com/hashicorp/vault/sdk/helper/jsonutil"
11)
12
13type AuthMethod interface {
14	Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error)
15	NewCreds() chan struct{}
16	CredSuccess()
17	Shutdown()
18}
19
20type AuthConfig struct {
21	Logger    hclog.Logger
22	MountPath string
23	WrapTTL   time.Duration
24	Config    map[string]interface{}
25}
26
27// AuthHandler is responsible for keeping a token alive and renewed and passing
28// new tokens to the sink server
29type AuthHandler struct {
30	DoneCh                       chan struct{}
31	OutputCh                     chan string
32	logger                       hclog.Logger
33	client                       *api.Client
34	random                       *rand.Rand
35	wrapTTL                      time.Duration
36	enableReauthOnNewCredentials bool
37}
38
39type AuthHandlerConfig struct {
40	Logger                       hclog.Logger
41	Client                       *api.Client
42	WrapTTL                      time.Duration
43	EnableReauthOnNewCredentials bool
44}
45
46func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
47	ah := &AuthHandler{
48		DoneCh: make(chan struct{}),
49		// This is buffered so that if we try to output after the sink server
50		// has been shut down, during agent shutdown, we won't block
51		OutputCh:                     make(chan string, 1),
52		logger:                       conf.Logger,
53		client:                       conf.Client,
54		random:                       rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
55		wrapTTL:                      conf.WrapTTL,
56		enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
57	}
58
59	return ah
60}
61
62func backoffOrQuit(ctx context.Context, backoff time.Duration) {
63	select {
64	case <-time.After(backoff):
65	case <-ctx.Done():
66	}
67}
68
69func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
70	if am == nil {
71		panic("nil auth method")
72	}
73
74	ah.logger.Info("starting auth handler")
75	defer func() {
76		am.Shutdown()
77		close(ah.OutputCh)
78		close(ah.DoneCh)
79		ah.logger.Info("auth handler stopped")
80	}()
81
82	credCh := am.NewCreds()
83	if !ah.enableReauthOnNewCredentials {
84		realCredCh := credCh
85		credCh = nil
86		if realCredCh != nil {
87			go func() {
88				for {
89					select {
90					case <-ctx.Done():
91						return
92					case <-realCredCh:
93					}
94				}
95			}()
96		}
97	}
98	if credCh == nil {
99		credCh = make(chan struct{})
100	}
101
102	var renewer *api.Renewer
103
104	for {
105		select {
106		case <-ctx.Done():
107			return
108
109		default:
110		}
111
112		// Create a fresh backoff value
113		backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
114
115		ah.logger.Info("authenticating")
116		path, data, err := am.Authenticate(ctx, ah.client)
117		if err != nil {
118			ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
119			backoffOrQuit(ctx, backoff)
120			continue
121		}
122
123		clientToUse := ah.client
124		if ah.wrapTTL > 0 {
125			wrapClient, err := ah.client.Clone()
126			if err != nil {
127				ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
128				backoffOrQuit(ctx, backoff)
129				continue
130			}
131			wrapClient.SetWrappingLookupFunc(func(string, string) string {
132				return ah.wrapTTL.String()
133			})
134			clientToUse = wrapClient
135		}
136
137		secret, err := clientToUse.Logical().Write(path, data)
138		// Check errors/sanity
139		if err != nil {
140			ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds())
141			backoffOrQuit(ctx, backoff)
142			continue
143		}
144
145		switch {
146		case ah.wrapTTL > 0:
147			if secret.WrapInfo == nil {
148				ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds())
149				backoffOrQuit(ctx, backoff)
150				continue
151			}
152			if secret.WrapInfo.Token == "" {
153				ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds())
154				backoffOrQuit(ctx, backoff)
155				continue
156			}
157			wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
158			if err != nil {
159				ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds())
160				backoffOrQuit(ctx, backoff)
161				continue
162			}
163			ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
164			ah.OutputCh <- string(wrappedResp)
165
166			am.CredSuccess()
167
168			select {
169			case <-ctx.Done():
170				ah.logger.Info("shutdown triggered")
171				continue
172
173			case <-credCh:
174				ah.logger.Info("auth method found new credentials, re-authenticating")
175				continue
176			}
177
178		default:
179			if secret == nil || secret.Auth == nil {
180				ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds())
181				backoffOrQuit(ctx, backoff)
182				continue
183			}
184			if secret.Auth.ClientToken == "" {
185				ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds())
186				backoffOrQuit(ctx, backoff)
187				continue
188			}
189			ah.logger.Info("authentication successful, sending token to sinks")
190			ah.OutputCh <- secret.Auth.ClientToken
191
192			am.CredSuccess()
193		}
194
195		if renewer != nil {
196			renewer.Stop()
197		}
198
199		renewer, err = ah.client.NewRenewer(&api.RenewerInput{
200			Secret: secret,
201		})
202		if err != nil {
203			ah.logger.Error("error creating renewer, backing off and retrying", "error", err, "backoff", backoff.Seconds())
204			backoffOrQuit(ctx, backoff)
205			continue
206		}
207
208		// Start the renewal process
209		ah.logger.Info("starting renewal process")
210		go renewer.Renew()
211
212	RenewerLoop:
213		for {
214			select {
215			case <-ctx.Done():
216				ah.logger.Info("shutdown triggered, stopping renewer")
217				renewer.Stop()
218				break RenewerLoop
219
220			case err := <-renewer.DoneCh():
221				ah.logger.Info("renewer done channel triggered")
222				if err != nil {
223					ah.logger.Error("error renewing token", "error", err)
224				}
225				break RenewerLoop
226
227			case <-renewer.RenewCh():
228				ah.logger.Info("renewed auth token")
229
230			case <-credCh:
231				ah.logger.Info("auth method found new credentials, re-authenticating")
232				break RenewerLoop
233			}
234		}
235	}
236}
237