1package adal 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/sha1" 22 "crypto/x509" 23 "encoding/base64" 24 "encoding/json" 25 "fmt" 26 "io/ioutil" 27 "net" 28 "net/http" 29 "net/url" 30 "strconv" 31 "strings" 32 "sync" 33 "time" 34 35 "github.com/Azure/go-autorest/autorest/date" 36 "github.com/dgrijalva/jwt-go" 37) 38 39const ( 40 defaultRefresh = 5 * time.Minute 41 42 // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow 43 OAuthGrantTypeDeviceCode = "device_code" 44 45 // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows 46 OAuthGrantTypeClientCredentials = "client_credentials" 47 48 // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows 49 OAuthGrantTypeUserPass = "password" 50 51 // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows 52 OAuthGrantTypeRefreshToken = "refresh_token" 53 54 // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows 55 OAuthGrantTypeAuthorizationCode = "authorization_code" 56 57 // metadataHeader is the header required by MSI extension 58 metadataHeader = "Metadata" 59 60 // msiEndpoint is the well known endpoint for getting MSI authentications tokens 61 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" 62) 63 64// OAuthTokenProvider is an interface which should be implemented by an access token retriever 65type OAuthTokenProvider interface { 66 OAuthToken() string 67} 68 69// TokenRefreshError is an interface used by errors returned during token refresh. 70type TokenRefreshError interface { 71 error 72 Response() *http.Response 73} 74 75// Refresher is an interface for token refresh functionality 76type Refresher interface { 77 Refresh() error 78 RefreshExchange(resource string) error 79 EnsureFresh() error 80} 81 82// RefresherWithContext is an interface for token refresh functionality 83type RefresherWithContext interface { 84 RefreshWithContext(ctx context.Context) error 85 RefreshExchangeWithContext(ctx context.Context, resource string) error 86 EnsureFreshWithContext(ctx context.Context) error 87} 88 89// TokenRefreshCallback is the type representing callbacks that will be called after 90// a successful token refresh 91type TokenRefreshCallback func(Token) error 92 93// Token encapsulates the access token used to authorize Azure requests. 94type Token struct { 95 AccessToken string `json:"access_token"` 96 RefreshToken string `json:"refresh_token"` 97 98 ExpiresIn string `json:"expires_in"` 99 ExpiresOn string `json:"expires_on"` 100 NotBefore string `json:"not_before"` 101 102 Resource string `json:"resource"` 103 Type string `json:"token_type"` 104} 105 106// IsZero returns true if the token object is zero-initialized. 107func (t Token) IsZero() bool { 108 return t == Token{} 109} 110 111// Expires returns the time.Time when the Token expires. 112func (t Token) Expires() time.Time { 113 s, err := strconv.Atoi(t.ExpiresOn) 114 if err != nil { 115 s = -3600 116 } 117 118 expiration := date.NewUnixTimeFromSeconds(float64(s)) 119 120 return time.Time(expiration).UTC() 121} 122 123// IsExpired returns true if the Token is expired, false otherwise. 124func (t Token) IsExpired() bool { 125 return t.WillExpireIn(0) 126} 127 128// WillExpireIn returns true if the Token will expire after the passed time.Duration interval 129// from now, false otherwise. 130func (t Token) WillExpireIn(d time.Duration) bool { 131 return !t.Expires().After(time.Now().Add(d)) 132} 133 134//OAuthToken return the current access token 135func (t *Token) OAuthToken() string { 136 return t.AccessToken 137} 138 139// ServicePrincipalNoSecret represents a secret type that contains no secret 140// meaning it is not valid for fetching a fresh token. This is used by Manual 141type ServicePrincipalNoSecret struct { 142} 143 144// SetAuthenticationValues is a method of the interface ServicePrincipalSecret 145// It only returns an error for the ServicePrincipalNoSecret type 146func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 147 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") 148} 149 150// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form 151// that is submitted when acquiring an oAuth token. 152type ServicePrincipalSecret interface { 153 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error 154} 155 156// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. 157type ServicePrincipalTokenSecret struct { 158 ClientSecret string 159} 160 161// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 162// It will populate the form submitted during oAuth Token Acquisition using the client_secret. 163func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 164 v.Set("client_secret", tokenSecret.ClientSecret) 165 return nil 166} 167 168// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. 169type ServicePrincipalCertificateSecret struct { 170 Certificate *x509.Certificate 171 PrivateKey *rsa.PrivateKey 172} 173 174// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. 175type ServicePrincipalMSISecret struct { 176} 177 178// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. 179type ServicePrincipalUsernamePasswordSecret struct { 180 Username string 181 Password string 182} 183 184// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. 185type ServicePrincipalAuthorizationCodeSecret struct { 186 ClientSecret string 187 AuthorizationCode string 188 RedirectURI string 189} 190 191// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 192func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 193 v.Set("code", secret.AuthorizationCode) 194 v.Set("client_secret", secret.ClientSecret) 195 v.Set("redirect_uri", secret.RedirectURI) 196 return nil 197} 198 199// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 200func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 201 v.Set("username", secret.Username) 202 v.Set("password", secret.Password) 203 return nil 204} 205 206// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 207func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 208 return nil 209} 210 211// SignJwt returns the JWT signed with the certificate's private key. 212func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { 213 hasher := sha1.New() 214 _, err := hasher.Write(secret.Certificate.Raw) 215 if err != nil { 216 return "", err 217 } 218 219 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) 220 221 // The jti (JWT ID) claim provides a unique identifier for the JWT. 222 jti := make([]byte, 20) 223 _, err = rand.Read(jti) 224 if err != nil { 225 return "", err 226 } 227 228 token := jwt.New(jwt.SigningMethodRS256) 229 token.Header["x5t"] = thumbprint 230 token.Claims = jwt.MapClaims{ 231 "aud": spt.oauthConfig.TokenEndpoint.String(), 232 "iss": spt.clientID, 233 "sub": spt.clientID, 234 "jti": base64.URLEncoding.EncodeToString(jti), 235 "nbf": time.Now().Unix(), 236 "exp": time.Now().Add(time.Hour * 24).Unix(), 237 } 238 239 signedString, err := token.SignedString(secret.PrivateKey) 240 return signedString, err 241} 242 243// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 244// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate. 245func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 246 jwt, err := secret.SignJwt(spt) 247 if err != nil { 248 return err 249 } 250 251 v.Set("client_assertion", jwt) 252 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 253 return nil 254} 255 256// ServicePrincipalToken encapsulates a Token created for a Service Principal. 257type ServicePrincipalToken struct { 258 token Token 259 secret ServicePrincipalSecret 260 oauthConfig OAuthConfig 261 clientID string 262 resource string 263 autoRefresh bool 264 refreshLock *sync.RWMutex 265 refreshWithin time.Duration 266 sender Sender 267 268 refreshCallbacks []TokenRefreshCallback 269} 270 271func validateOAuthConfig(oac OAuthConfig) error { 272 if oac.IsZero() { 273 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized") 274 } 275 return nil 276} 277 278// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation. 279func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 280 if err := validateOAuthConfig(oauthConfig); err != nil { 281 return nil, err 282 } 283 if err := validateStringParam(id, "id"); err != nil { 284 return nil, err 285 } 286 if err := validateStringParam(resource, "resource"); err != nil { 287 return nil, err 288 } 289 if secret == nil { 290 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 291 } 292 spt := &ServicePrincipalToken{ 293 oauthConfig: oauthConfig, 294 secret: secret, 295 clientID: id, 296 resource: resource, 297 autoRefresh: true, 298 refreshLock: &sync.RWMutex{}, 299 refreshWithin: defaultRefresh, 300 sender: &http.Client{}, 301 refreshCallbacks: callbacks, 302 } 303 return spt, nil 304} 305 306// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token 307func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 308 if err := validateOAuthConfig(oauthConfig); err != nil { 309 return nil, err 310 } 311 if err := validateStringParam(clientID, "clientID"); err != nil { 312 return nil, err 313 } 314 if err := validateStringParam(resource, "resource"); err != nil { 315 return nil, err 316 } 317 if token.IsZero() { 318 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 319 } 320 spt, err := NewServicePrincipalTokenWithSecret( 321 oauthConfig, 322 clientID, 323 resource, 324 &ServicePrincipalNoSecret{}, 325 callbacks...) 326 if err != nil { 327 return nil, err 328 } 329 330 spt.token = token 331 332 return spt, nil 333} 334 335// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal 336// credentials scoped to the named resource. 337func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 338 if err := validateOAuthConfig(oauthConfig); err != nil { 339 return nil, err 340 } 341 if err := validateStringParam(clientID, "clientID"); err != nil { 342 return nil, err 343 } 344 if err := validateStringParam(secret, "secret"); err != nil { 345 return nil, err 346 } 347 if err := validateStringParam(resource, "resource"); err != nil { 348 return nil, err 349 } 350 return NewServicePrincipalTokenWithSecret( 351 oauthConfig, 352 clientID, 353 resource, 354 &ServicePrincipalTokenSecret{ 355 ClientSecret: secret, 356 }, 357 callbacks..., 358 ) 359} 360 361// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes. 362func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 363 if err := validateOAuthConfig(oauthConfig); err != nil { 364 return nil, err 365 } 366 if err := validateStringParam(clientID, "clientID"); err != nil { 367 return nil, err 368 } 369 if err := validateStringParam(resource, "resource"); err != nil { 370 return nil, err 371 } 372 if certificate == nil { 373 return nil, fmt.Errorf("parameter 'certificate' cannot be nil") 374 } 375 if privateKey == nil { 376 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil") 377 } 378 return NewServicePrincipalTokenWithSecret( 379 oauthConfig, 380 clientID, 381 resource, 382 &ServicePrincipalCertificateSecret{ 383 PrivateKey: privateKey, 384 Certificate: certificate, 385 }, 386 callbacks..., 387 ) 388} 389 390// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password. 391func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 392 if err := validateOAuthConfig(oauthConfig); err != nil { 393 return nil, err 394 } 395 if err := validateStringParam(clientID, "clientID"); err != nil { 396 return nil, err 397 } 398 if err := validateStringParam(username, "username"); err != nil { 399 return nil, err 400 } 401 if err := validateStringParam(password, "password"); err != nil { 402 return nil, err 403 } 404 if err := validateStringParam(resource, "resource"); err != nil { 405 return nil, err 406 } 407 return NewServicePrincipalTokenWithSecret( 408 oauthConfig, 409 clientID, 410 resource, 411 &ServicePrincipalUsernamePasswordSecret{ 412 Username: username, 413 Password: password, 414 }, 415 callbacks..., 416 ) 417} 418 419// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the 420func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 421 422 if err := validateOAuthConfig(oauthConfig); err != nil { 423 return nil, err 424 } 425 if err := validateStringParam(clientID, "clientID"); err != nil { 426 return nil, err 427 } 428 if err := validateStringParam(clientSecret, "clientSecret"); err != nil { 429 return nil, err 430 } 431 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil { 432 return nil, err 433 } 434 if err := validateStringParam(redirectURI, "redirectURI"); err != nil { 435 return nil, err 436 } 437 if err := validateStringParam(resource, "resource"); err != nil { 438 return nil, err 439 } 440 441 return NewServicePrincipalTokenWithSecret( 442 oauthConfig, 443 clientID, 444 resource, 445 &ServicePrincipalAuthorizationCodeSecret{ 446 ClientSecret: clientSecret, 447 AuthorizationCode: authorizationCode, 448 RedirectURI: redirectURI, 449 }, 450 callbacks..., 451 ) 452} 453 454// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. 455func GetMSIVMEndpoint() (string, error) { 456 return msiEndpoint, nil 457} 458 459// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. 460// It will use the system assigned identity when creating the token. 461func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 462 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...) 463} 464 465// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension. 466// It will use the specified user assigned identity when creating the token. 467func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 468 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...) 469} 470 471func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 472 if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil { 473 return nil, err 474 } 475 if err := validateStringParam(resource, "resource"); err != nil { 476 return nil, err 477 } 478 if userAssignedID != nil { 479 if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil { 480 return nil, err 481 } 482 } 483 // We set the oauth config token endpoint to be MSI's endpoint 484 msiEndpointURL, err := url.Parse(msiEndpoint) 485 if err != nil { 486 return nil, err 487 } 488 489 v := url.Values{} 490 v.Set("resource", resource) 491 v.Set("api-version", "2018-02-01") 492 if userAssignedID != nil { 493 v.Set("client_id", *userAssignedID) 494 } 495 msiEndpointURL.RawQuery = v.Encode() 496 497 spt := &ServicePrincipalToken{ 498 oauthConfig: OAuthConfig{ 499 TokenEndpoint: *msiEndpointURL, 500 }, 501 secret: &ServicePrincipalMSISecret{}, 502 resource: resource, 503 autoRefresh: true, 504 refreshLock: &sync.RWMutex{}, 505 refreshWithin: defaultRefresh, 506 sender: &http.Client{}, 507 refreshCallbacks: callbacks, 508 } 509 510 if userAssignedID != nil { 511 spt.clientID = *userAssignedID 512 } 513 514 return spt, nil 515} 516 517// internal type that implements TokenRefreshError 518type tokenRefreshError struct { 519 message string 520 resp *http.Response 521} 522 523// Error implements the error interface which is part of the TokenRefreshError interface. 524func (tre tokenRefreshError) Error() string { 525 return tre.message 526} 527 528// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. 529func (tre tokenRefreshError) Response() *http.Response { 530 return tre.resp 531} 532 533func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError { 534 return tokenRefreshError{message: message, resp: resp} 535} 536 537// EnsureFresh will refresh the token if it will expire within the refresh window (as set by 538// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 539func (spt *ServicePrincipalToken) EnsureFresh() error { 540 return spt.EnsureFreshWithContext(context.Background()) 541} 542 543// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by 544// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 545func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { 546 if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) { 547 // take the write lock then check to see if the token was already refreshed 548 spt.refreshLock.Lock() 549 defer spt.refreshLock.Unlock() 550 if spt.token.WillExpireIn(spt.refreshWithin) { 551 return spt.refreshInternal(ctx, spt.resource) 552 } 553 } 554 return nil 555} 556 557// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization 558func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { 559 if spt.refreshCallbacks != nil { 560 for _, callback := range spt.refreshCallbacks { 561 err := callback(spt.token) 562 if err != nil { 563 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) 564 } 565 } 566 } 567 return nil 568} 569 570// Refresh obtains a fresh token for the Service Principal. 571// This method is not safe for concurrent use and should be syncrhonized. 572func (spt *ServicePrincipalToken) Refresh() error { 573 return spt.RefreshWithContext(context.Background()) 574} 575 576// RefreshWithContext obtains a fresh token for the Service Principal. 577// This method is not safe for concurrent use and should be syncrhonized. 578func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { 579 spt.refreshLock.Lock() 580 defer spt.refreshLock.Unlock() 581 return spt.refreshInternal(ctx, spt.resource) 582} 583 584// RefreshExchange refreshes the token, but for a different resource. 585// This method is not safe for concurrent use and should be syncrhonized. 586func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { 587 return spt.RefreshExchangeWithContext(context.Background(), resource) 588} 589 590// RefreshExchangeWithContext refreshes the token, but for a different resource. 591// This method is not safe for concurrent use and should be syncrhonized. 592func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { 593 spt.refreshLock.Lock() 594 defer spt.refreshLock.Unlock() 595 return spt.refreshInternal(ctx, resource) 596} 597 598func (spt *ServicePrincipalToken) getGrantType() string { 599 switch spt.secret.(type) { 600 case *ServicePrincipalUsernamePasswordSecret: 601 return OAuthGrantTypeUserPass 602 case *ServicePrincipalAuthorizationCodeSecret: 603 return OAuthGrantTypeAuthorizationCode 604 default: 605 return OAuthGrantTypeClientCredentials 606 } 607} 608 609func isIMDS(u url.URL) bool { 610 imds, err := url.Parse(msiEndpoint) 611 if err != nil { 612 return false 613 } 614 return u.Host == imds.Host && u.Path == imds.Path 615} 616 617func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { 618 req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil) 619 if err != nil { 620 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) 621 } 622 req = req.WithContext(ctx) 623 if !isIMDS(spt.oauthConfig.TokenEndpoint) { 624 v := url.Values{} 625 v.Set("client_id", spt.clientID) 626 v.Set("resource", resource) 627 628 if spt.token.RefreshToken != "" { 629 v.Set("grant_type", OAuthGrantTypeRefreshToken) 630 v.Set("refresh_token", spt.token.RefreshToken) 631 } else { 632 v.Set("grant_type", spt.getGrantType()) 633 err := spt.secret.SetAuthenticationValues(spt, &v) 634 if err != nil { 635 return err 636 } 637 } 638 639 s := v.Encode() 640 body := ioutil.NopCloser(strings.NewReader(s)) 641 req.ContentLength = int64(len(s)) 642 req.Header.Set(contentType, mimeTypeFormPost) 643 req.Body = body 644 } 645 646 if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok { 647 req.Method = http.MethodGet 648 req.Header.Set(metadataHeader, "true") 649 } 650 651 var resp *http.Response 652 if isIMDS(spt.oauthConfig.TokenEndpoint) { 653 resp, err = retry(spt.sender, req) 654 } else { 655 resp, err = spt.sender.Do(req) 656 } 657 if err != nil { 658 return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil) 659 } 660 661 defer resp.Body.Close() 662 rb, err := ioutil.ReadAll(resp.Body) 663 664 if resp.StatusCode != http.StatusOK { 665 if err != nil { 666 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp) 667 } 668 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) 669 } 670 671 // for the following error cases don't return a TokenRefreshError. the operation succeeded 672 // but some transient failure happened during deserialization. by returning a generic error 673 // the retry logic will kick in (we don't retry on TokenRefreshError). 674 675 if err != nil { 676 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) 677 } 678 if len(strings.Trim(string(rb), " ")) == 0 { 679 return fmt.Errorf("adal: Empty service principal token received during refresh") 680 } 681 var token Token 682 err = json.Unmarshal(rb, &token) 683 if err != nil { 684 return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) 685 } 686 687 spt.token = token 688 689 return spt.InvokeRefreshCallbacks(token) 690} 691 692func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { 693 retries := []int{ 694 http.StatusRequestTimeout, // 408 695 http.StatusTooManyRequests, // 429 696 http.StatusInternalServerError, // 500 697 http.StatusBadGateway, // 502 698 http.StatusServiceUnavailable, // 503 699 http.StatusGatewayTimeout, // 504 700 } 701 // Extra retry status codes requered 702 retries = append(retries, http.StatusNotFound, 703 // all remaining 5xx 704 http.StatusNotImplemented, 705 http.StatusHTTPVersionNotSupported, 706 http.StatusVariantAlsoNegotiates, 707 http.StatusInsufficientStorage, 708 http.StatusLoopDetected, 709 http.StatusNotExtended, 710 http.StatusNetworkAuthenticationRequired) 711 712 attempt := 0 713 maxAttempts := 5 714 715 for attempt < maxAttempts { 716 resp, err = sender.Do(req) 717 // retry on temporary network errors, e.g. transient network failures. 718 if (err != nil && !isTemporaryNetworkError(err)) || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) { 719 return 720 } 721 722 if !delay(resp, req.Context().Done()) { 723 select { 724 case <-time.After(time.Second): 725 attempt++ 726 case <-req.Context().Done(): 727 err = req.Context().Err() 728 return 729 } 730 } 731 } 732 return 733} 734 735func isTemporaryNetworkError(err error) bool { 736 if netErr, ok := err.(net.Error); ok && netErr.Temporary() { 737 return true 738 } 739 return false 740} 741 742func containsInt(ints []int, n int) bool { 743 for _, i := range ints { 744 if i == n { 745 return true 746 } 747 } 748 return false 749} 750 751func delay(resp *http.Response, cancel <-chan struct{}) bool { 752 if resp == nil { 753 return false 754 } 755 retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After")) 756 if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 { 757 select { 758 case <-time.After(time.Duration(retryAfter) * time.Second): 759 return true 760 case <-cancel: 761 return false 762 } 763 } 764 return false 765} 766 767// SetAutoRefresh enables or disables automatic refreshing of stale tokens. 768func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { 769 spt.autoRefresh = autoRefresh 770} 771 772// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will 773// refresh the token. 774func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { 775 spt.refreshWithin = d 776 return 777} 778 779// SetSender sets the http.Client used when obtaining the Service Principal token. An 780// undecorated http.Client is used by default. 781func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } 782 783// OAuthToken implements the OAuthTokenProvider interface. It returns the current access token. 784func (spt *ServicePrincipalToken) OAuthToken() string { 785 spt.refreshLock.RLock() 786 defer spt.refreshLock.RUnlock() 787 return spt.token.OAuthToken() 788} 789 790// Token returns a copy of the current token. 791func (spt *ServicePrincipalToken) Token() Token { 792 spt.refreshLock.RLock() 793 defer spt.refreshLock.RUnlock() 794 return spt.token 795} 796