1package auth 2 3import ( 4 "context" 5 "math/rand" 6 "time" 7 8 hclog "github.com/hashicorp/go-hclog" 9 "github.com/hashicorp/vault/api" 10 "github.com/hashicorp/vault/sdk/helper/jsonutil" 11) 12 13type AuthMethod interface { 14 Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error) 15 NewCreds() chan struct{} 16 CredSuccess() 17 Shutdown() 18} 19 20type AuthConfig struct { 21 Logger hclog.Logger 22 MountPath string 23 WrapTTL time.Duration 24 Config map[string]interface{} 25} 26 27// AuthHandler is responsible for keeping a token alive and renewed and passing 28// new tokens to the sink server 29type AuthHandler struct { 30 DoneCh chan struct{} 31 OutputCh chan string 32 TemplateTokenCh chan string 33 logger hclog.Logger 34 client *api.Client 35 random *rand.Rand 36 wrapTTL time.Duration 37 enableReauthOnNewCredentials bool 38 enableTemplateTokenCh bool 39} 40 41type AuthHandlerConfig struct { 42 Logger hclog.Logger 43 Client *api.Client 44 WrapTTL time.Duration 45 EnableReauthOnNewCredentials bool 46 EnableTemplateTokenCh bool 47} 48 49func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler { 50 ah := &AuthHandler{ 51 DoneCh: make(chan struct{}), 52 // This is buffered so that if we try to output after the sink server 53 // has been shut down, during agent shutdown, we won't block 54 OutputCh: make(chan string, 1), 55 TemplateTokenCh: make(chan string, 1), 56 logger: conf.Logger, 57 client: conf.Client, 58 random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), 59 wrapTTL: conf.WrapTTL, 60 enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials, 61 enableTemplateTokenCh: conf.EnableTemplateTokenCh, 62 } 63 64 return ah 65} 66 67func backoffOrQuit(ctx context.Context, backoff time.Duration) { 68 select { 69 case <-time.After(backoff): 70 case <-ctx.Done(): 71 } 72} 73 74func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { 75 if am == nil { 76 panic("nil auth method") 77 } 78 79 ah.logger.Info("starting auth handler") 80 defer func() { 81 am.Shutdown() 82 close(ah.OutputCh) 83 close(ah.DoneCh) 84 close(ah.TemplateTokenCh) 85 ah.logger.Info("auth handler stopped") 86 }() 87 88 credCh := am.NewCreds() 89 if !ah.enableReauthOnNewCredentials { 90 realCredCh := credCh 91 credCh = nil 92 if realCredCh != nil { 93 go func() { 94 for { 95 select { 96 case <-ctx.Done(): 97 return 98 case <-realCredCh: 99 } 100 } 101 }() 102 } 103 } 104 if credCh == nil { 105 credCh = make(chan struct{}) 106 } 107 108 var watcher *api.LifetimeWatcher 109 110 for { 111 select { 112 case <-ctx.Done(): 113 return 114 115 default: 116 } 117 118 // Create a fresh backoff value 119 backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second)) 120 121 ah.logger.Info("authenticating") 122 path, data, err := am.Authenticate(ctx, ah.client) 123 if err != nil { 124 ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) 125 backoffOrQuit(ctx, backoff) 126 continue 127 } 128 129 clientToUse := ah.client 130 if ah.wrapTTL > 0 { 131 wrapClient, err := ah.client.Clone() 132 if err != nil { 133 ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds()) 134 backoffOrQuit(ctx, backoff) 135 continue 136 } 137 wrapClient.SetWrappingLookupFunc(func(string, string) string { 138 return ah.wrapTTL.String() 139 }) 140 clientToUse = wrapClient 141 } 142 143 secret, err := clientToUse.Logical().Write(path, data) 144 // Check errors/sanity 145 if err != nil { 146 ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds()) 147 backoffOrQuit(ctx, backoff) 148 continue 149 } 150 151 switch { 152 case ah.wrapTTL > 0: 153 if secret.WrapInfo == nil { 154 ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds()) 155 backoffOrQuit(ctx, backoff) 156 continue 157 } 158 if secret.WrapInfo.Token == "" { 159 ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds()) 160 backoffOrQuit(ctx, backoff) 161 continue 162 } 163 wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo) 164 if err != nil { 165 ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds()) 166 backoffOrQuit(ctx, backoff) 167 continue 168 } 169 ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing") 170 ah.OutputCh <- string(wrappedResp) 171 if ah.enableTemplateTokenCh { 172 ah.TemplateTokenCh <- string(wrappedResp) 173 } 174 175 am.CredSuccess() 176 177 select { 178 case <-ctx.Done(): 179 ah.logger.Info("shutdown triggered") 180 continue 181 182 case <-credCh: 183 ah.logger.Info("auth method found new credentials, re-authenticating") 184 continue 185 } 186 187 default: 188 if secret == nil || secret.Auth == nil { 189 ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds()) 190 backoffOrQuit(ctx, backoff) 191 continue 192 } 193 if secret.Auth.ClientToken == "" { 194 ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds()) 195 backoffOrQuit(ctx, backoff) 196 continue 197 } 198 ah.logger.Info("authentication successful, sending token to sinks") 199 ah.OutputCh <- secret.Auth.ClientToken 200 if ah.enableTemplateTokenCh { 201 ah.TemplateTokenCh <- secret.Auth.ClientToken 202 } 203 204 am.CredSuccess() 205 } 206 207 if watcher != nil { 208 watcher.Stop() 209 } 210 211 watcher, err = ah.client.NewLifetimeWatcher(&api.LifetimeWatcherInput{ 212 Secret: secret, 213 }) 214 if err != nil { 215 ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff.Seconds()) 216 backoffOrQuit(ctx, backoff) 217 continue 218 } 219 220 // Start the renewal process 221 ah.logger.Info("starting renewal process") 222 go watcher.Renew() 223 224 LifetimeWatcherLoop: 225 for { 226 select { 227 case <-ctx.Done(): 228 ah.logger.Info("shutdown triggered, stopping lifetime watcher") 229 watcher.Stop() 230 break LifetimeWatcherLoop 231 232 case err := <-watcher.DoneCh(): 233 ah.logger.Info("lifetime watcher done channel triggered") 234 if err != nil { 235 ah.logger.Error("error renewing token", "error", err) 236 } 237 break LifetimeWatcherLoop 238 239 case <-watcher.RenewCh(): 240 ah.logger.Info("renewed auth token") 241 242 case <-credCh: 243 ah.logger.Info("auth method found new credentials, re-authenticating") 244 break LifetimeWatcherLoop 245 } 246 } 247 } 248} 249