1package auth 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "math/rand" 8 "net/http" 9 "time" 10 11 "github.com/hashicorp/go-hclog" 12 "github.com/hashicorp/vault/api" 13 "github.com/hashicorp/vault/sdk/helper/jsonutil" 14) 15 16const ( 17 initialBackoff = 1 * time.Second 18 defaultMaxBackoff = 5 * time.Minute 19) 20 21// AuthMethod is the interface that auto-auth methods implement for the agent 22// to use. 23type AuthMethod interface { 24 // Authenticate returns a mount path, header, request body, and error. 25 // The header may be nil if no special header is needed. 26 Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) 27 NewCreds() chan struct{} 28 CredSuccess() 29 Shutdown() 30} 31 32// AuthMethodWithClient is an extended interface that can return an API client 33// for use during the authentication call. 34type AuthMethodWithClient interface { 35 AuthMethod 36 AuthClient(client *api.Client) (*api.Client, error) 37} 38 39type AuthConfig struct { 40 Logger hclog.Logger 41 MountPath string 42 WrapTTL time.Duration 43 Config map[string]interface{} 44} 45 46// AuthHandler is responsible for keeping a token alive and renewed and passing 47// new tokens to the sink server 48type AuthHandler struct { 49 OutputCh chan string 50 TemplateTokenCh chan string 51 token string 52 logger hclog.Logger 53 client *api.Client 54 random *rand.Rand 55 wrapTTL time.Duration 56 maxBackoff time.Duration 57 enableReauthOnNewCredentials bool 58 enableTemplateTokenCh bool 59} 60 61type AuthHandlerConfig struct { 62 Logger hclog.Logger 63 Client *api.Client 64 WrapTTL time.Duration 65 MaxBackoff time.Duration 66 Token string 67 EnableReauthOnNewCredentials bool 68 EnableTemplateTokenCh bool 69} 70 71func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler { 72 ah := &AuthHandler{ 73 // This is buffered so that if we try to output after the sink server 74 // has been shut down, during agent shutdown, we won't block 75 OutputCh: make(chan string, 1), 76 TemplateTokenCh: make(chan string, 1), 77 token: conf.Token, 78 logger: conf.Logger, 79 client: conf.Client, 80 random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), 81 wrapTTL: conf.WrapTTL, 82 maxBackoff: conf.MaxBackoff, 83 enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials, 84 enableTemplateTokenCh: conf.EnableTemplateTokenCh, 85 } 86 87 return ah 88} 89 90func backoffOrQuit(ctx context.Context, backoff *agentBackoff) { 91 select { 92 case <-time.After(backoff.current): 93 case <-ctx.Done(): 94 } 95 96 // Increase exponential backoff for the next time if we don't 97 // successfully auth/renew/etc. 98 backoff.next() 99} 100 101func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { 102 if am == nil { 103 return errors.New("auth handler: nil auth method") 104 } 105 106 backoff := newAgentBackoff(ah.maxBackoff) 107 108 ah.logger.Info("starting auth handler") 109 defer func() { 110 am.Shutdown() 111 close(ah.OutputCh) 112 close(ah.TemplateTokenCh) 113 ah.logger.Info("auth handler stopped") 114 }() 115 116 credCh := am.NewCreds() 117 if !ah.enableReauthOnNewCredentials { 118 realCredCh := credCh 119 credCh = nil 120 if realCredCh != nil { 121 go func() { 122 for { 123 select { 124 case <-ctx.Done(): 125 return 126 case <-realCredCh: 127 } 128 } 129 }() 130 } 131 } 132 if credCh == nil { 133 credCh = make(chan struct{}) 134 } 135 136 var watcher *api.LifetimeWatcher 137 first := true 138 139 for { 140 select { 141 case <-ctx.Done(): 142 return nil 143 144 default: 145 } 146 147 var clientToUse *api.Client 148 var err error 149 var path string 150 var data map[string]interface{} 151 var header http.Header 152 153 switch am.(type) { 154 case AuthMethodWithClient: 155 clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client) 156 if err != nil { 157 ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff) 158 backoffOrQuit(ctx, backoff) 159 continue 160 } 161 default: 162 clientToUse = ah.client 163 } 164 165 var secret *api.Secret = new(api.Secret) 166 if first && ah.token != "" { 167 ah.logger.Debug("using preloaded token") 168 169 first = false 170 ah.logger.Debug("lookup-self with preloaded token") 171 clientToUse.SetToken(ah.token) 172 173 secret, err = clientToUse.Logical().Read("auth/token/lookup-self") 174 if err != nil { 175 ah.logger.Error("could not look up token", "err", err, "backoff", backoff) 176 backoffOrQuit(ctx, backoff) 177 continue 178 } 179 180 duration, _ := secret.Data["ttl"].(json.Number).Int64() 181 secret.Auth = &api.SecretAuth{ 182 ClientToken: secret.Data["id"].(string), 183 LeaseDuration: int(duration), 184 Renewable: secret.Data["renewable"].(bool), 185 } 186 } else { 187 ah.logger.Info("authenticating") 188 189 path, header, data, err = am.Authenticate(ctx, ah.client) 190 if err != nil { 191 ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff) 192 backoffOrQuit(ctx, backoff) 193 continue 194 } 195 } 196 197 if ah.wrapTTL > 0 { 198 wrapClient, err := clientToUse.Clone() 199 if err != nil { 200 ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff) 201 backoffOrQuit(ctx, backoff) 202 continue 203 } 204 wrapClient.SetWrappingLookupFunc(func(string, string) string { 205 return ah.wrapTTL.String() 206 }) 207 clientToUse = wrapClient 208 } 209 for key, values := range header { 210 for _, value := range values { 211 clientToUse.AddHeader(key, value) 212 } 213 } 214 215 // This should only happen if there's no preloaded token (regular auto-auth login) 216 // or if a preloaded token has expired and is now switching to auto-auth. 217 if secret.Auth == nil { 218 secret, err = clientToUse.Logical().Write(path, data) 219 // Check errors/sanity 220 if err != nil { 221 ah.logger.Error("error authenticating", "error", err, "backoff", backoff) 222 backoffOrQuit(ctx, backoff) 223 continue 224 } 225 } 226 227 switch { 228 case ah.wrapTTL > 0: 229 if secret.WrapInfo == nil { 230 ah.logger.Error("authentication returned nil wrap info", "backoff", backoff) 231 backoffOrQuit(ctx, backoff) 232 continue 233 } 234 if secret.WrapInfo.Token == "" { 235 ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff) 236 backoffOrQuit(ctx, backoff) 237 continue 238 } 239 wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo) 240 if err != nil { 241 ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff) 242 backoffOrQuit(ctx, backoff) 243 continue 244 } 245 ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing") 246 ah.OutputCh <- string(wrappedResp) 247 if ah.enableTemplateTokenCh { 248 ah.TemplateTokenCh <- string(wrappedResp) 249 } 250 251 am.CredSuccess() 252 backoff.reset() 253 254 select { 255 case <-ctx.Done(): 256 ah.logger.Info("shutdown triggered") 257 continue 258 259 case <-credCh: 260 ah.logger.Info("auth method found new credentials, re-authenticating") 261 continue 262 } 263 264 default: 265 if secret == nil || secret.Auth == nil { 266 ah.logger.Error("authentication returned nil auth info", "backoff", backoff) 267 backoffOrQuit(ctx, backoff) 268 continue 269 } 270 if secret.Auth.ClientToken == "" { 271 ah.logger.Error("authentication returned empty client token", "backoff", backoff) 272 backoffOrQuit(ctx, backoff) 273 continue 274 } 275 ah.logger.Info("authentication successful, sending token to sinks") 276 ah.OutputCh <- secret.Auth.ClientToken 277 if ah.enableTemplateTokenCh { 278 ah.TemplateTokenCh <- secret.Auth.ClientToken 279 } 280 281 am.CredSuccess() 282 backoff.reset() 283 } 284 285 if watcher != nil { 286 watcher.Stop() 287 } 288 289 watcher, err = clientToUse.NewLifetimeWatcher(&api.LifetimeWatcherInput{ 290 Secret: secret, 291 }) 292 if err != nil { 293 ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff) 294 backoffOrQuit(ctx, backoff) 295 continue 296 } 297 298 // Start the renewal process 299 ah.logger.Info("starting renewal process") 300 go watcher.Renew() 301 302 LifetimeWatcherLoop: 303 for { 304 select { 305 case <-ctx.Done(): 306 ah.logger.Info("shutdown triggered, stopping lifetime watcher") 307 watcher.Stop() 308 break LifetimeWatcherLoop 309 310 case err := <-watcher.DoneCh(): 311 ah.logger.Info("lifetime watcher done channel triggered") 312 if err != nil { 313 ah.logger.Error("error renewing token", "error", err) 314 } 315 break LifetimeWatcherLoop 316 317 case <-watcher.RenewCh(): 318 ah.logger.Info("renewed auth token") 319 320 case <-credCh: 321 ah.logger.Info("auth method found new credentials, re-authenticating") 322 break LifetimeWatcherLoop 323 } 324 } 325 } 326} 327 328// agentBackoff tracks exponential backoff state. 329type agentBackoff struct { 330 max time.Duration 331 current time.Duration 332} 333 334func newAgentBackoff(max time.Duration) *agentBackoff { 335 if max <= 0 { 336 max = defaultMaxBackoff 337 } 338 339 return &agentBackoff{ 340 max: max, 341 current: initialBackoff, 342 } 343} 344 345// next determines the next backoff duration that is roughly twice 346// the current value, capped to a max value, with a measure of randomness. 347func (b *agentBackoff) next() { 348 maxBackoff := 2 * b.current 349 350 if maxBackoff > b.max { 351 maxBackoff = b.max 352 } 353 354 // Trim a random amount (0-25%) off the doubled duration 355 trim := rand.Int63n(int64(maxBackoff) / 4) 356 b.current = maxBackoff - time.Duration(trim) 357} 358 359func (b *agentBackoff) reset() { 360 b.current = initialBackoff 361} 362 363func (b agentBackoff) String() string { 364 return b.current.Truncate(10 * time.Millisecond).String() 365} 366