1package auth 2 3import ( 4 "context" 5 "math/rand" 6 "time" 7 8 hclog "github.com/hashicorp/go-hclog" 9 "github.com/hashicorp/vault/api" 10 "github.com/hashicorp/vault/sdk/helper/jsonutil" 11) 12 13type AuthMethod interface { 14 Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error) 15 NewCreds() chan struct{} 16 CredSuccess() 17 Shutdown() 18} 19 20type AuthConfig struct { 21 Logger hclog.Logger 22 MountPath string 23 WrapTTL time.Duration 24 Config map[string]interface{} 25} 26 27// AuthHandler is responsible for keeping a token alive and renewed and passing 28// new tokens to the sink server 29type AuthHandler struct { 30 DoneCh chan struct{} 31 OutputCh chan string 32 logger hclog.Logger 33 client *api.Client 34 random *rand.Rand 35 wrapTTL time.Duration 36 enableReauthOnNewCredentials bool 37} 38 39type AuthHandlerConfig struct { 40 Logger hclog.Logger 41 Client *api.Client 42 WrapTTL time.Duration 43 EnableReauthOnNewCredentials bool 44} 45 46func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler { 47 ah := &AuthHandler{ 48 DoneCh: make(chan struct{}), 49 // This is buffered so that if we try to output after the sink server 50 // has been shut down, during agent shutdown, we won't block 51 OutputCh: make(chan string, 1), 52 logger: conf.Logger, 53 client: conf.Client, 54 random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), 55 wrapTTL: conf.WrapTTL, 56 enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials, 57 } 58 59 return ah 60} 61 62func backoffOrQuit(ctx context.Context, backoff time.Duration) { 63 select { 64 case <-time.After(backoff): 65 case <-ctx.Done(): 66 } 67} 68 69func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { 70 if am == nil { 71 panic("nil auth method") 72 } 73 74 ah.logger.Info("starting auth handler") 75 defer func() { 76 am.Shutdown() 77 close(ah.OutputCh) 78 close(ah.DoneCh) 79 ah.logger.Info("auth handler stopped") 80 }() 81 82 credCh := am.NewCreds() 83 if !ah.enableReauthOnNewCredentials { 84 realCredCh := credCh 85 credCh = nil 86 if realCredCh != nil { 87 go func() { 88 for { 89 select { 90 case <-ctx.Done(): 91 return 92 case <-realCredCh: 93 } 94 } 95 }() 96 } 97 } 98 if credCh == nil { 99 credCh = make(chan struct{}) 100 } 101 102 var renewer *api.Renewer 103 104 for { 105 select { 106 case <-ctx.Done(): 107 return 108 109 default: 110 } 111 112 // Create a fresh backoff value 113 backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second)) 114 115 ah.logger.Info("authenticating") 116 path, data, err := am.Authenticate(ctx, ah.client) 117 if err != nil { 118 ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) 119 backoffOrQuit(ctx, backoff) 120 continue 121 } 122 123 clientToUse := ah.client 124 if ah.wrapTTL > 0 { 125 wrapClient, err := ah.client.Clone() 126 if err != nil { 127 ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds()) 128 backoffOrQuit(ctx, backoff) 129 continue 130 } 131 wrapClient.SetWrappingLookupFunc(func(string, string) string { 132 return ah.wrapTTL.String() 133 }) 134 clientToUse = wrapClient 135 } 136 137 secret, err := clientToUse.Logical().Write(path, data) 138 // Check errors/sanity 139 if err != nil { 140 ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds()) 141 backoffOrQuit(ctx, backoff) 142 continue 143 } 144 145 switch { 146 case ah.wrapTTL > 0: 147 if secret.WrapInfo == nil { 148 ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds()) 149 backoffOrQuit(ctx, backoff) 150 continue 151 } 152 if secret.WrapInfo.Token == "" { 153 ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds()) 154 backoffOrQuit(ctx, backoff) 155 continue 156 } 157 wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo) 158 if err != nil { 159 ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds()) 160 backoffOrQuit(ctx, backoff) 161 continue 162 } 163 ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing") 164 ah.OutputCh <- string(wrappedResp) 165 166 am.CredSuccess() 167 168 select { 169 case <-ctx.Done(): 170 ah.logger.Info("shutdown triggered") 171 continue 172 173 case <-credCh: 174 ah.logger.Info("auth method found new credentials, re-authenticating") 175 continue 176 } 177 178 default: 179 if secret == nil || secret.Auth == nil { 180 ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds()) 181 backoffOrQuit(ctx, backoff) 182 continue 183 } 184 if secret.Auth.ClientToken == "" { 185 ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds()) 186 backoffOrQuit(ctx, backoff) 187 continue 188 } 189 ah.logger.Info("authentication successful, sending token to sinks") 190 ah.OutputCh <- secret.Auth.ClientToken 191 192 am.CredSuccess() 193 } 194 195 if renewer != nil { 196 renewer.Stop() 197 } 198 199 renewer, err = ah.client.NewRenewer(&api.RenewerInput{ 200 Secret: secret, 201 }) 202 if err != nil { 203 ah.logger.Error("error creating renewer, backing off and retrying", "error", err, "backoff", backoff.Seconds()) 204 backoffOrQuit(ctx, backoff) 205 continue 206 } 207 208 // Start the renewal process 209 ah.logger.Info("starting renewal process") 210 go renewer.Renew() 211 212 RenewerLoop: 213 for { 214 select { 215 case <-ctx.Done(): 216 ah.logger.Info("shutdown triggered, stopping renewer") 217 renewer.Stop() 218 break RenewerLoop 219 220 case err := <-renewer.DoneCh(): 221 ah.logger.Info("renewer done channel triggered") 222 if err != nil { 223 ah.logger.Error("error renewing token", "error", err) 224 } 225 break RenewerLoop 226 227 case <-renewer.RenewCh(): 228 ah.logger.Info("renewed auth token") 229 230 case <-credCh: 231 ah.logger.Info("auth method found new credentials, re-authenticating") 232 break RenewerLoop 233 } 234 } 235 } 236} 237