1package adal 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/sha1" 22 "crypto/x509" 23 "encoding/base64" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io/ioutil" 28 "math" 29 "net" 30 "net/http" 31 "net/url" 32 "strings" 33 "sync" 34 "time" 35 36 "github.com/Azure/go-autorest/autorest/date" 37 "github.com/Azure/go-autorest/tracing" 38 "github.com/dgrijalva/jwt-go" 39) 40 41const ( 42 defaultRefresh = 5 * time.Minute 43 44 // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow 45 OAuthGrantTypeDeviceCode = "device_code" 46 47 // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows 48 OAuthGrantTypeClientCredentials = "client_credentials" 49 50 // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows 51 OAuthGrantTypeUserPass = "password" 52 53 // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows 54 OAuthGrantTypeRefreshToken = "refresh_token" 55 56 // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows 57 OAuthGrantTypeAuthorizationCode = "authorization_code" 58 59 // metadataHeader is the header required by MSI extension 60 metadataHeader = "Metadata" 61 62 // msiEndpoint is the well known endpoint for getting MSI authentications tokens 63 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" 64 65 // the default number of attempts to refresh an MSI authentication token 66 defaultMaxMSIRefreshAttempts = 5 67) 68 69// OAuthTokenProvider is an interface which should be implemented by an access token retriever 70type OAuthTokenProvider interface { 71 OAuthToken() string 72} 73 74// TokenRefreshError is an interface used by errors returned during token refresh. 75type TokenRefreshError interface { 76 error 77 Response() *http.Response 78} 79 80// Refresher is an interface for token refresh functionality 81type Refresher interface { 82 Refresh() error 83 RefreshExchange(resource string) error 84 EnsureFresh() error 85} 86 87// RefresherWithContext is an interface for token refresh functionality 88type RefresherWithContext interface { 89 RefreshWithContext(ctx context.Context) error 90 RefreshExchangeWithContext(ctx context.Context, resource string) error 91 EnsureFreshWithContext(ctx context.Context) error 92} 93 94// TokenRefreshCallback is the type representing callbacks that will be called after 95// a successful token refresh 96type TokenRefreshCallback func(Token) error 97 98// Token encapsulates the access token used to authorize Azure requests. 99// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response 100type Token struct { 101 AccessToken string `json:"access_token"` 102 RefreshToken string `json:"refresh_token"` 103 104 ExpiresIn json.Number `json:"expires_in"` 105 ExpiresOn json.Number `json:"expires_on"` 106 NotBefore json.Number `json:"not_before"` 107 108 Resource string `json:"resource"` 109 Type string `json:"token_type"` 110} 111 112func newToken() Token { 113 return Token{ 114 ExpiresIn: "0", 115 ExpiresOn: "0", 116 NotBefore: "0", 117 } 118} 119 120// IsZero returns true if the token object is zero-initialized. 121func (t Token) IsZero() bool { 122 return t == Token{} 123} 124 125// Expires returns the time.Time when the Token expires. 126func (t Token) Expires() time.Time { 127 s, err := t.ExpiresOn.Float64() 128 if err != nil { 129 s = -3600 130 } 131 132 expiration := date.NewUnixTimeFromSeconds(s) 133 134 return time.Time(expiration).UTC() 135} 136 137// IsExpired returns true if the Token is expired, false otherwise. 138func (t Token) IsExpired() bool { 139 return t.WillExpireIn(0) 140} 141 142// WillExpireIn returns true if the Token will expire after the passed time.Duration interval 143// from now, false otherwise. 144func (t Token) WillExpireIn(d time.Duration) bool { 145 return !t.Expires().After(time.Now().Add(d)) 146} 147 148//OAuthToken return the current access token 149func (t *Token) OAuthToken() string { 150 return t.AccessToken 151} 152 153// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form 154// that is submitted when acquiring an oAuth token. 155type ServicePrincipalSecret interface { 156 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error 157} 158 159// ServicePrincipalNoSecret represents a secret type that contains no secret 160// meaning it is not valid for fetching a fresh token. This is used by Manual 161type ServicePrincipalNoSecret struct { 162} 163 164// SetAuthenticationValues is a method of the interface ServicePrincipalSecret 165// It only returns an error for the ServicePrincipalNoSecret type 166func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 167 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") 168} 169 170// MarshalJSON implements the json.Marshaler interface. 171func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) { 172 type tokenType struct { 173 Type string `json:"type"` 174 } 175 return json.Marshal(tokenType{ 176 Type: "ServicePrincipalNoSecret", 177 }) 178} 179 180// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. 181type ServicePrincipalTokenSecret struct { 182 ClientSecret string `json:"value"` 183} 184 185// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 186// It will populate the form submitted during oAuth Token Acquisition using the client_secret. 187func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 188 v.Set("client_secret", tokenSecret.ClientSecret) 189 return nil 190} 191 192// MarshalJSON implements the json.Marshaler interface. 193func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) { 194 type tokenType struct { 195 Type string `json:"type"` 196 Value string `json:"value"` 197 } 198 return json.Marshal(tokenType{ 199 Type: "ServicePrincipalTokenSecret", 200 Value: tokenSecret.ClientSecret, 201 }) 202} 203 204// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. 205type ServicePrincipalCertificateSecret struct { 206 Certificate *x509.Certificate 207 PrivateKey *rsa.PrivateKey 208} 209 210// SignJwt returns the JWT signed with the certificate's private key. 211func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { 212 hasher := sha1.New() 213 _, err := hasher.Write(secret.Certificate.Raw) 214 if err != nil { 215 return "", err 216 } 217 218 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) 219 220 // The jti (JWT ID) claim provides a unique identifier for the JWT. 221 jti := make([]byte, 20) 222 _, err = rand.Read(jti) 223 if err != nil { 224 return "", err 225 } 226 227 token := jwt.New(jwt.SigningMethodRS256) 228 token.Header["x5t"] = thumbprint 229 x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)} 230 token.Header["x5c"] = x5c 231 token.Claims = jwt.MapClaims{ 232 "aud": spt.inner.OauthConfig.TokenEndpoint.String(), 233 "iss": spt.inner.ClientID, 234 "sub": spt.inner.ClientID, 235 "jti": base64.URLEncoding.EncodeToString(jti), 236 "nbf": time.Now().Unix(), 237 "exp": time.Now().Add(time.Hour * 24).Unix(), 238 } 239 240 signedString, err := token.SignedString(secret.PrivateKey) 241 return signedString, err 242} 243 244// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 245// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate. 246func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 247 jwt, err := secret.SignJwt(spt) 248 if err != nil { 249 return err 250 } 251 252 v.Set("client_assertion", jwt) 253 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 254 return nil 255} 256 257// MarshalJSON implements the json.Marshaler interface. 258func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) { 259 return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported") 260} 261 262// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. 263type ServicePrincipalMSISecret struct { 264} 265 266// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 267func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 268 return nil 269} 270 271// MarshalJSON implements the json.Marshaler interface. 272func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) { 273 return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported") 274} 275 276// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. 277type ServicePrincipalUsernamePasswordSecret struct { 278 Username string `json:"username"` 279 Password string `json:"password"` 280} 281 282// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 283func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 284 v.Set("username", secret.Username) 285 v.Set("password", secret.Password) 286 return nil 287} 288 289// MarshalJSON implements the json.Marshaler interface. 290func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) { 291 type tokenType struct { 292 Type string `json:"type"` 293 Username string `json:"username"` 294 Password string `json:"password"` 295 } 296 return json.Marshal(tokenType{ 297 Type: "ServicePrincipalUsernamePasswordSecret", 298 Username: secret.Username, 299 Password: secret.Password, 300 }) 301} 302 303// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. 304type ServicePrincipalAuthorizationCodeSecret struct { 305 ClientSecret string `json:"value"` 306 AuthorizationCode string `json:"authCode"` 307 RedirectURI string `json:"redirect"` 308} 309 310// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 311func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 312 v.Set("code", secret.AuthorizationCode) 313 v.Set("client_secret", secret.ClientSecret) 314 v.Set("redirect_uri", secret.RedirectURI) 315 return nil 316} 317 318// MarshalJSON implements the json.Marshaler interface. 319func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) { 320 type tokenType struct { 321 Type string `json:"type"` 322 Value string `json:"value"` 323 AuthCode string `json:"authCode"` 324 Redirect string `json:"redirect"` 325 } 326 return json.Marshal(tokenType{ 327 Type: "ServicePrincipalAuthorizationCodeSecret", 328 Value: secret.ClientSecret, 329 AuthCode: secret.AuthorizationCode, 330 Redirect: secret.RedirectURI, 331 }) 332} 333 334// ServicePrincipalToken encapsulates a Token created for a Service Principal. 335type ServicePrincipalToken struct { 336 inner servicePrincipalToken 337 refreshLock *sync.RWMutex 338 sender Sender 339 refreshCallbacks []TokenRefreshCallback 340 // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. 341 MaxMSIRefreshAttempts int 342} 343 344// MarshalTokenJSON returns the marshalled inner token. 345func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) { 346 return json.Marshal(spt.inner.Token) 347} 348 349// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks. 350func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) { 351 spt.refreshCallbacks = callbacks 352} 353 354// MarshalJSON implements the json.Marshaler interface. 355func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { 356 return json.Marshal(spt.inner) 357} 358 359// UnmarshalJSON implements the json.Unmarshaler interface. 360func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error { 361 // need to determine the token type 362 raw := map[string]interface{}{} 363 err := json.Unmarshal(data, &raw) 364 if err != nil { 365 return err 366 } 367 secret := raw["secret"].(map[string]interface{}) 368 switch secret["type"] { 369 case "ServicePrincipalNoSecret": 370 spt.inner.Secret = &ServicePrincipalNoSecret{} 371 case "ServicePrincipalTokenSecret": 372 spt.inner.Secret = &ServicePrincipalTokenSecret{} 373 case "ServicePrincipalCertificateSecret": 374 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported") 375 case "ServicePrincipalMSISecret": 376 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported") 377 case "ServicePrincipalUsernamePasswordSecret": 378 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{} 379 case "ServicePrincipalAuthorizationCodeSecret": 380 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{} 381 default: 382 return fmt.Errorf("unrecognized token type '%s'", secret["type"]) 383 } 384 err = json.Unmarshal(data, &spt.inner) 385 if err != nil { 386 return err 387 } 388 // Don't override the refreshLock or the sender if those have been already set. 389 if spt.refreshLock == nil { 390 spt.refreshLock = &sync.RWMutex{} 391 } 392 if spt.sender == nil { 393 spt.sender = &http.Client{Transport: tracing.Transport} 394 } 395 return nil 396} 397 398// internal type used for marshalling/unmarshalling 399type servicePrincipalToken struct { 400 Token Token `json:"token"` 401 Secret ServicePrincipalSecret `json:"secret"` 402 OauthConfig OAuthConfig `json:"oauth"` 403 ClientID string `json:"clientID"` 404 Resource string `json:"resource"` 405 AutoRefresh bool `json:"autoRefresh"` 406 RefreshWithin time.Duration `json:"refreshWithin"` 407} 408 409func validateOAuthConfig(oac OAuthConfig) error { 410 if oac.IsZero() { 411 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized") 412 } 413 return nil 414} 415 416// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation. 417func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 418 if err := validateOAuthConfig(oauthConfig); err != nil { 419 return nil, err 420 } 421 if err := validateStringParam(id, "id"); err != nil { 422 return nil, err 423 } 424 if err := validateStringParam(resource, "resource"); err != nil { 425 return nil, err 426 } 427 if secret == nil { 428 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 429 } 430 spt := &ServicePrincipalToken{ 431 inner: servicePrincipalToken{ 432 Token: newToken(), 433 OauthConfig: oauthConfig, 434 Secret: secret, 435 ClientID: id, 436 Resource: resource, 437 AutoRefresh: true, 438 RefreshWithin: defaultRefresh, 439 }, 440 refreshLock: &sync.RWMutex{}, 441 sender: &http.Client{Transport: tracing.Transport}, 442 refreshCallbacks: callbacks, 443 } 444 return spt, nil 445} 446 447// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token 448func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 449 if err := validateOAuthConfig(oauthConfig); err != nil { 450 return nil, err 451 } 452 if err := validateStringParam(clientID, "clientID"); err != nil { 453 return nil, err 454 } 455 if err := validateStringParam(resource, "resource"); err != nil { 456 return nil, err 457 } 458 if token.IsZero() { 459 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 460 } 461 spt, err := NewServicePrincipalTokenWithSecret( 462 oauthConfig, 463 clientID, 464 resource, 465 &ServicePrincipalNoSecret{}, 466 callbacks...) 467 if err != nil { 468 return nil, err 469 } 470 471 spt.inner.Token = token 472 473 return spt, nil 474} 475 476// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret 477func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 478 if err := validateOAuthConfig(oauthConfig); err != nil { 479 return nil, err 480 } 481 if err := validateStringParam(clientID, "clientID"); err != nil { 482 return nil, err 483 } 484 if err := validateStringParam(resource, "resource"); err != nil { 485 return nil, err 486 } 487 if secret == nil { 488 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 489 } 490 if token.IsZero() { 491 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 492 } 493 spt, err := NewServicePrincipalTokenWithSecret( 494 oauthConfig, 495 clientID, 496 resource, 497 secret, 498 callbacks...) 499 if err != nil { 500 return nil, err 501 } 502 503 spt.inner.Token = token 504 505 return spt, nil 506} 507 508// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal 509// credentials scoped to the named resource. 510func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 511 if err := validateOAuthConfig(oauthConfig); err != nil { 512 return nil, err 513 } 514 if err := validateStringParam(clientID, "clientID"); err != nil { 515 return nil, err 516 } 517 if err := validateStringParam(secret, "secret"); err != nil { 518 return nil, err 519 } 520 if err := validateStringParam(resource, "resource"); err != nil { 521 return nil, err 522 } 523 return NewServicePrincipalTokenWithSecret( 524 oauthConfig, 525 clientID, 526 resource, 527 &ServicePrincipalTokenSecret{ 528 ClientSecret: secret, 529 }, 530 callbacks..., 531 ) 532} 533 534// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes. 535func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 536 if err := validateOAuthConfig(oauthConfig); err != nil { 537 return nil, err 538 } 539 if err := validateStringParam(clientID, "clientID"); err != nil { 540 return nil, err 541 } 542 if err := validateStringParam(resource, "resource"); err != nil { 543 return nil, err 544 } 545 if certificate == nil { 546 return nil, fmt.Errorf("parameter 'certificate' cannot be nil") 547 } 548 if privateKey == nil { 549 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil") 550 } 551 return NewServicePrincipalTokenWithSecret( 552 oauthConfig, 553 clientID, 554 resource, 555 &ServicePrincipalCertificateSecret{ 556 PrivateKey: privateKey, 557 Certificate: certificate, 558 }, 559 callbacks..., 560 ) 561} 562 563// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password. 564func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 565 if err := validateOAuthConfig(oauthConfig); err != nil { 566 return nil, err 567 } 568 if err := validateStringParam(clientID, "clientID"); err != nil { 569 return nil, err 570 } 571 if err := validateStringParam(username, "username"); err != nil { 572 return nil, err 573 } 574 if err := validateStringParam(password, "password"); err != nil { 575 return nil, err 576 } 577 if err := validateStringParam(resource, "resource"); err != nil { 578 return nil, err 579 } 580 return NewServicePrincipalTokenWithSecret( 581 oauthConfig, 582 clientID, 583 resource, 584 &ServicePrincipalUsernamePasswordSecret{ 585 Username: username, 586 Password: password, 587 }, 588 callbacks..., 589 ) 590} 591 592// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the 593func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 594 595 if err := validateOAuthConfig(oauthConfig); err != nil { 596 return nil, err 597 } 598 if err := validateStringParam(clientID, "clientID"); err != nil { 599 return nil, err 600 } 601 if err := validateStringParam(clientSecret, "clientSecret"); err != nil { 602 return nil, err 603 } 604 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil { 605 return nil, err 606 } 607 if err := validateStringParam(redirectURI, "redirectURI"); err != nil { 608 return nil, err 609 } 610 if err := validateStringParam(resource, "resource"); err != nil { 611 return nil, err 612 } 613 614 return NewServicePrincipalTokenWithSecret( 615 oauthConfig, 616 clientID, 617 resource, 618 &ServicePrincipalAuthorizationCodeSecret{ 619 ClientSecret: clientSecret, 620 AuthorizationCode: authorizationCode, 621 RedirectURI: redirectURI, 622 }, 623 callbacks..., 624 ) 625} 626 627// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. 628func GetMSIVMEndpoint() (string, error) { 629 return msiEndpoint, nil 630} 631 632// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. 633// It will use the system assigned identity when creating the token. 634func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 635 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...) 636} 637 638// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension. 639// It will use the specified user assigned identity when creating the token. 640func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 641 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...) 642} 643 644func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 645 if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil { 646 return nil, err 647 } 648 if err := validateStringParam(resource, "resource"); err != nil { 649 return nil, err 650 } 651 if userAssignedID != nil { 652 if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil { 653 return nil, err 654 } 655 } 656 // We set the oauth config token endpoint to be MSI's endpoint 657 msiEndpointURL, err := url.Parse(msiEndpoint) 658 if err != nil { 659 return nil, err 660 } 661 662 v := url.Values{} 663 v.Set("resource", resource) 664 v.Set("api-version", "2018-02-01") 665 if userAssignedID != nil { 666 v.Set("client_id", *userAssignedID) 667 } 668 msiEndpointURL.RawQuery = v.Encode() 669 670 spt := &ServicePrincipalToken{ 671 inner: servicePrincipalToken{ 672 Token: newToken(), 673 OauthConfig: OAuthConfig{ 674 TokenEndpoint: *msiEndpointURL, 675 }, 676 Secret: &ServicePrincipalMSISecret{}, 677 Resource: resource, 678 AutoRefresh: true, 679 RefreshWithin: defaultRefresh, 680 }, 681 refreshLock: &sync.RWMutex{}, 682 sender: &http.Client{Transport: tracing.Transport}, 683 refreshCallbacks: callbacks, 684 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts, 685 } 686 687 if userAssignedID != nil { 688 spt.inner.ClientID = *userAssignedID 689 } 690 691 return spt, nil 692} 693 694// internal type that implements TokenRefreshError 695type tokenRefreshError struct { 696 message string 697 resp *http.Response 698} 699 700// Error implements the error interface which is part of the TokenRefreshError interface. 701func (tre tokenRefreshError) Error() string { 702 return tre.message 703} 704 705// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. 706func (tre tokenRefreshError) Response() *http.Response { 707 return tre.resp 708} 709 710func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError { 711 return tokenRefreshError{message: message, resp: resp} 712} 713 714// EnsureFresh will refresh the token if it will expire within the refresh window (as set by 715// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 716func (spt *ServicePrincipalToken) EnsureFresh() error { 717 return spt.EnsureFreshWithContext(context.Background()) 718} 719 720// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by 721// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 722func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { 723 if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { 724 // take the write lock then check to see if the token was already refreshed 725 spt.refreshLock.Lock() 726 defer spt.refreshLock.Unlock() 727 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { 728 return spt.refreshInternal(ctx, spt.inner.Resource) 729 } 730 } 731 return nil 732} 733 734// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization 735func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { 736 if spt.refreshCallbacks != nil { 737 for _, callback := range spt.refreshCallbacks { 738 err := callback(spt.inner.Token) 739 if err != nil { 740 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) 741 } 742 } 743 } 744 return nil 745} 746 747// Refresh obtains a fresh token for the Service Principal. 748// This method is not safe for concurrent use and should be syncrhonized. 749func (spt *ServicePrincipalToken) Refresh() error { 750 return spt.RefreshWithContext(context.Background()) 751} 752 753// RefreshWithContext obtains a fresh token for the Service Principal. 754// This method is not safe for concurrent use and should be syncrhonized. 755func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { 756 spt.refreshLock.Lock() 757 defer spt.refreshLock.Unlock() 758 return spt.refreshInternal(ctx, spt.inner.Resource) 759} 760 761// RefreshExchange refreshes the token, but for a different resource. 762// This method is not safe for concurrent use and should be syncrhonized. 763func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { 764 return spt.RefreshExchangeWithContext(context.Background(), resource) 765} 766 767// RefreshExchangeWithContext refreshes the token, but for a different resource. 768// This method is not safe for concurrent use and should be syncrhonized. 769func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { 770 spt.refreshLock.Lock() 771 defer spt.refreshLock.Unlock() 772 return spt.refreshInternal(ctx, resource) 773} 774 775func (spt *ServicePrincipalToken) getGrantType() string { 776 switch spt.inner.Secret.(type) { 777 case *ServicePrincipalUsernamePasswordSecret: 778 return OAuthGrantTypeUserPass 779 case *ServicePrincipalAuthorizationCodeSecret: 780 return OAuthGrantTypeAuthorizationCode 781 default: 782 return OAuthGrantTypeClientCredentials 783 } 784} 785 786func isIMDS(u url.URL) bool { 787 imds, err := url.Parse(msiEndpoint) 788 if err != nil { 789 return false 790 } 791 return u.Host == imds.Host && u.Path == imds.Path 792} 793 794func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { 795 req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) 796 if err != nil { 797 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) 798 } 799 req.Header.Add("User-Agent", UserAgent()) 800 req = req.WithContext(ctx) 801 if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) { 802 v := url.Values{} 803 v.Set("client_id", spt.inner.ClientID) 804 v.Set("resource", resource) 805 806 if spt.inner.Token.RefreshToken != "" { 807 v.Set("grant_type", OAuthGrantTypeRefreshToken) 808 v.Set("refresh_token", spt.inner.Token.RefreshToken) 809 // web apps must specify client_secret when refreshing tokens 810 // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens 811 if spt.getGrantType() == OAuthGrantTypeAuthorizationCode { 812 err := spt.inner.Secret.SetAuthenticationValues(spt, &v) 813 if err != nil { 814 return err 815 } 816 } 817 } else { 818 v.Set("grant_type", spt.getGrantType()) 819 err := spt.inner.Secret.SetAuthenticationValues(spt, &v) 820 if err != nil { 821 return err 822 } 823 } 824 825 s := v.Encode() 826 body := ioutil.NopCloser(strings.NewReader(s)) 827 req.ContentLength = int64(len(s)) 828 req.Header.Set(contentType, mimeTypeFormPost) 829 req.Body = body 830 } 831 832 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok { 833 req.Method = http.MethodGet 834 req.Header.Set(metadataHeader, "true") 835 } 836 837 var resp *http.Response 838 if isIMDS(spt.inner.OauthConfig.TokenEndpoint) { 839 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts) 840 } else { 841 resp, err = spt.sender.Do(req) 842 } 843 if err != nil { 844 return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil) 845 } 846 847 defer resp.Body.Close() 848 rb, err := ioutil.ReadAll(resp.Body) 849 850 if resp.StatusCode != http.StatusOK { 851 if err != nil { 852 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp) 853 } 854 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) 855 } 856 857 // for the following error cases don't return a TokenRefreshError. the operation succeeded 858 // but some transient failure happened during deserialization. by returning a generic error 859 // the retry logic will kick in (we don't retry on TokenRefreshError). 860 861 if err != nil { 862 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) 863 } 864 if len(strings.Trim(string(rb), " ")) == 0 { 865 return fmt.Errorf("adal: Empty service principal token received during refresh") 866 } 867 var token Token 868 err = json.Unmarshal(rb, &token) 869 if err != nil { 870 return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) 871 } 872 873 spt.inner.Token = token 874 875 return spt.InvokeRefreshCallbacks(token) 876} 877 878// retry logic specific to retrieving a token from the IMDS endpoint 879func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) { 880 // copied from client.go due to circular dependency 881 retries := []int{ 882 http.StatusRequestTimeout, // 408 883 http.StatusTooManyRequests, // 429 884 http.StatusInternalServerError, // 500 885 http.StatusBadGateway, // 502 886 http.StatusServiceUnavailable, // 503 887 http.StatusGatewayTimeout, // 504 888 } 889 // extra retry status codes specific to IMDS 890 retries = append(retries, 891 http.StatusNotFound, 892 http.StatusGone, 893 // all remaining 5xx 894 http.StatusNotImplemented, 895 http.StatusHTTPVersionNotSupported, 896 http.StatusVariantAlsoNegotiates, 897 http.StatusInsufficientStorage, 898 http.StatusLoopDetected, 899 http.StatusNotExtended, 900 http.StatusNetworkAuthenticationRequired) 901 902 // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance 903 904 const maxDelay time.Duration = 60 * time.Second 905 906 attempt := 0 907 delay := time.Duration(0) 908 909 for attempt < maxAttempts { 910 resp, err = sender.Do(req) 911 // retry on temporary network errors, e.g. transient network failures. 912 // if we don't receive a response then assume we can't connect to the 913 // endpoint so we're likely not running on an Azure VM so don't retry. 914 if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) { 915 return 916 } 917 918 // perform exponential backoff with a cap. 919 // must increment attempt before calculating delay. 920 attempt++ 921 // the base value of 2 is the "delta backoff" as specified in the guidance doc 922 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second) 923 if delay > maxDelay { 924 delay = maxDelay 925 } 926 927 select { 928 case <-time.After(delay): 929 // intentionally left blank 930 case <-req.Context().Done(): 931 err = req.Context().Err() 932 return 933 } 934 } 935 return 936} 937 938// returns true if the specified error is a temporary network error or false if it's not. 939// if the error doesn't implement the net.Error interface the return value is true. 940func isTemporaryNetworkError(err error) bool { 941 if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) { 942 return true 943 } 944 return false 945} 946 947// returns true if slice ints contains the value n 948func containsInt(ints []int, n int) bool { 949 for _, i := range ints { 950 if i == n { 951 return true 952 } 953 } 954 return false 955} 956 957// SetAutoRefresh enables or disables automatic refreshing of stale tokens. 958func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { 959 spt.inner.AutoRefresh = autoRefresh 960} 961 962// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will 963// refresh the token. 964func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { 965 spt.inner.RefreshWithin = d 966 return 967} 968 969// SetSender sets the http.Client used when obtaining the Service Principal token. An 970// undecorated http.Client is used by default. 971func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } 972 973// OAuthToken implements the OAuthTokenProvider interface. It returns the current access token. 974func (spt *ServicePrincipalToken) OAuthToken() string { 975 spt.refreshLock.RLock() 976 defer spt.refreshLock.RUnlock() 977 return spt.inner.Token.OAuthToken() 978} 979 980// Token returns a copy of the current token. 981func (spt *ServicePrincipalToken) Token() Token { 982 spt.refreshLock.RLock() 983 defer spt.refreshLock.RUnlock() 984 return spt.inner.Token 985} 986