1package adal 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/sha1" 22 "crypto/x509" 23 "encoding/base64" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "io/ioutil" 29 "math" 30 "net/http" 31 "net/url" 32 "os" 33 "strconv" 34 "strings" 35 "sync" 36 "time" 37 38 "github.com/Azure/go-autorest/autorest/date" 39 "github.com/Azure/go-autorest/logger" 40 "github.com/form3tech-oss/jwt-go" 41) 42 43const ( 44 defaultRefresh = 5 * time.Minute 45 46 // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow 47 OAuthGrantTypeDeviceCode = "device_code" 48 49 // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows 50 OAuthGrantTypeClientCredentials = "client_credentials" 51 52 // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows 53 OAuthGrantTypeUserPass = "password" 54 55 // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows 56 OAuthGrantTypeRefreshToken = "refresh_token" 57 58 // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows 59 OAuthGrantTypeAuthorizationCode = "authorization_code" 60 61 // metadataHeader is the header required by MSI extension 62 metadataHeader = "Metadata" 63 64 // msiEndpoint is the well known endpoint for getting MSI authentications tokens 65 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" 66 67 // the API version to use for the MSI endpoint 68 msiAPIVersion = "2018-02-01" 69 70 // the default number of attempts to refresh an MSI authentication token 71 defaultMaxMSIRefreshAttempts = 5 72 73 // asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions 74 msiEndpointEnv = "MSI_ENDPOINT" 75 76 // asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions 77 msiSecretEnv = "MSI_SECRET" 78 79 // the API version to use for the legacy App Service MSI endpoint 80 appServiceAPIVersion2017 = "2017-09-01" 81 82 // secret header used when authenticating against app service MSI endpoint 83 secretHeader = "Secret" 84 85 // the format for expires_on in UTC with AM/PM 86 expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00" 87 88 // the format for expires_on in UTC without AM/PM 89 expiresOnDateFormat = "1/2/2006 15:04:05 +00:00" 90) 91 92// OAuthTokenProvider is an interface which should be implemented by an access token retriever 93type OAuthTokenProvider interface { 94 OAuthToken() string 95} 96 97// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization. 98type MultitenantOAuthTokenProvider interface { 99 PrimaryOAuthToken() string 100 AuxiliaryOAuthTokens() []string 101} 102 103// TokenRefreshError is an interface used by errors returned during token refresh. 104type TokenRefreshError interface { 105 error 106 Response() *http.Response 107} 108 109// Refresher is an interface for token refresh functionality 110type Refresher interface { 111 Refresh() error 112 RefreshExchange(resource string) error 113 EnsureFresh() error 114} 115 116// RefresherWithContext is an interface for token refresh functionality 117type RefresherWithContext interface { 118 RefreshWithContext(ctx context.Context) error 119 RefreshExchangeWithContext(ctx context.Context, resource string) error 120 EnsureFreshWithContext(ctx context.Context) error 121} 122 123// TokenRefreshCallback is the type representing callbacks that will be called after 124// a successful token refresh 125type TokenRefreshCallback func(Token) error 126 127// TokenRefresh is a type representing a custom callback to refresh a token 128type TokenRefresh func(ctx context.Context, resource string) (*Token, error) 129 130// Token encapsulates the access token used to authorize Azure requests. 131// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response 132type Token struct { 133 AccessToken string `json:"access_token"` 134 RefreshToken string `json:"refresh_token"` 135 136 ExpiresIn json.Number `json:"expires_in"` 137 ExpiresOn json.Number `json:"expires_on"` 138 NotBefore json.Number `json:"not_before"` 139 140 Resource string `json:"resource"` 141 Type string `json:"token_type"` 142} 143 144func newToken() Token { 145 return Token{ 146 ExpiresIn: "0", 147 ExpiresOn: "0", 148 NotBefore: "0", 149 } 150} 151 152// IsZero returns true if the token object is zero-initialized. 153func (t Token) IsZero() bool { 154 return t == Token{} 155} 156 157// Expires returns the time.Time when the Token expires. 158func (t Token) Expires() time.Time { 159 s, err := t.ExpiresOn.Float64() 160 if err != nil { 161 s = -3600 162 } 163 164 expiration := date.NewUnixTimeFromSeconds(s) 165 166 return time.Time(expiration).UTC() 167} 168 169// IsExpired returns true if the Token is expired, false otherwise. 170func (t Token) IsExpired() bool { 171 return t.WillExpireIn(0) 172} 173 174// WillExpireIn returns true if the Token will expire after the passed time.Duration interval 175// from now, false otherwise. 176func (t Token) WillExpireIn(d time.Duration) bool { 177 return !t.Expires().After(time.Now().Add(d)) 178} 179 180//OAuthToken return the current access token 181func (t *Token) OAuthToken() string { 182 return t.AccessToken 183} 184 185// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form 186// that is submitted when acquiring an oAuth token. 187type ServicePrincipalSecret interface { 188 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error 189} 190 191// ServicePrincipalNoSecret represents a secret type that contains no secret 192// meaning it is not valid for fetching a fresh token. This is used by Manual 193type ServicePrincipalNoSecret struct { 194} 195 196// SetAuthenticationValues is a method of the interface ServicePrincipalSecret 197// It only returns an error for the ServicePrincipalNoSecret type 198func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 199 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") 200} 201 202// MarshalJSON implements the json.Marshaler interface. 203func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) { 204 type tokenType struct { 205 Type string `json:"type"` 206 } 207 return json.Marshal(tokenType{ 208 Type: "ServicePrincipalNoSecret", 209 }) 210} 211 212// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. 213type ServicePrincipalTokenSecret struct { 214 ClientSecret string `json:"value"` 215} 216 217// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 218// It will populate the form submitted during oAuth Token Acquisition using the client_secret. 219func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 220 v.Set("client_secret", tokenSecret.ClientSecret) 221 return nil 222} 223 224// MarshalJSON implements the json.Marshaler interface. 225func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) { 226 type tokenType struct { 227 Type string `json:"type"` 228 Value string `json:"value"` 229 } 230 return json.Marshal(tokenType{ 231 Type: "ServicePrincipalTokenSecret", 232 Value: tokenSecret.ClientSecret, 233 }) 234} 235 236// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. 237type ServicePrincipalCertificateSecret struct { 238 Certificate *x509.Certificate 239 PrivateKey *rsa.PrivateKey 240} 241 242// SignJwt returns the JWT signed with the certificate's private key. 243func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { 244 hasher := sha1.New() 245 _, err := hasher.Write(secret.Certificate.Raw) 246 if err != nil { 247 return "", err 248 } 249 250 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) 251 252 // The jti (JWT ID) claim provides a unique identifier for the JWT. 253 jti := make([]byte, 20) 254 _, err = rand.Read(jti) 255 if err != nil { 256 return "", err 257 } 258 259 token := jwt.New(jwt.SigningMethodRS256) 260 token.Header["x5t"] = thumbprint 261 x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)} 262 token.Header["x5c"] = x5c 263 token.Claims = jwt.MapClaims{ 264 "aud": spt.inner.OauthConfig.TokenEndpoint.String(), 265 "iss": spt.inner.ClientID, 266 "sub": spt.inner.ClientID, 267 "jti": base64.URLEncoding.EncodeToString(jti), 268 "nbf": time.Now().Unix(), 269 "exp": time.Now().Add(24 * time.Hour).Unix(), 270 } 271 272 signedString, err := token.SignedString(secret.PrivateKey) 273 return signedString, err 274} 275 276// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 277// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate. 278func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 279 jwt, err := secret.SignJwt(spt) 280 if err != nil { 281 return err 282 } 283 284 v.Set("client_assertion", jwt) 285 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 286 return nil 287} 288 289// MarshalJSON implements the json.Marshaler interface. 290func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) { 291 return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported") 292} 293 294// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. 295type ServicePrincipalMSISecret struct { 296 msiType msiType 297 clientResourceID string 298} 299 300// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 301func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 302 return nil 303} 304 305// MarshalJSON implements the json.Marshaler interface. 306func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) { 307 return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported") 308} 309 310// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. 311type ServicePrincipalUsernamePasswordSecret struct { 312 Username string `json:"username"` 313 Password string `json:"password"` 314} 315 316// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 317func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 318 v.Set("username", secret.Username) 319 v.Set("password", secret.Password) 320 return nil 321} 322 323// MarshalJSON implements the json.Marshaler interface. 324func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) { 325 type tokenType struct { 326 Type string `json:"type"` 327 Username string `json:"username"` 328 Password string `json:"password"` 329 } 330 return json.Marshal(tokenType{ 331 Type: "ServicePrincipalUsernamePasswordSecret", 332 Username: secret.Username, 333 Password: secret.Password, 334 }) 335} 336 337// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. 338type ServicePrincipalAuthorizationCodeSecret struct { 339 ClientSecret string `json:"value"` 340 AuthorizationCode string `json:"authCode"` 341 RedirectURI string `json:"redirect"` 342} 343 344// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 345func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 346 v.Set("code", secret.AuthorizationCode) 347 v.Set("client_secret", secret.ClientSecret) 348 v.Set("redirect_uri", secret.RedirectURI) 349 return nil 350} 351 352// MarshalJSON implements the json.Marshaler interface. 353func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) { 354 type tokenType struct { 355 Type string `json:"type"` 356 Value string `json:"value"` 357 AuthCode string `json:"authCode"` 358 Redirect string `json:"redirect"` 359 } 360 return json.Marshal(tokenType{ 361 Type: "ServicePrincipalAuthorizationCodeSecret", 362 Value: secret.ClientSecret, 363 AuthCode: secret.AuthorizationCode, 364 Redirect: secret.RedirectURI, 365 }) 366} 367 368// ServicePrincipalToken encapsulates a Token created for a Service Principal. 369type ServicePrincipalToken struct { 370 inner servicePrincipalToken 371 refreshLock *sync.RWMutex 372 sender Sender 373 customRefreshFunc TokenRefresh 374 refreshCallbacks []TokenRefreshCallback 375 // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. 376 // Settings this to a value less than 1 will use the default value. 377 MaxMSIRefreshAttempts int 378} 379 380// MarshalTokenJSON returns the marshalled inner token. 381func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) { 382 return json.Marshal(spt.inner.Token) 383} 384 385// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks. 386func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) { 387 spt.refreshCallbacks = callbacks 388} 389 390// SetCustomRefreshFunc sets a custom refresh function used to refresh the token. 391func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) { 392 spt.customRefreshFunc = customRefreshFunc 393} 394 395// MarshalJSON implements the json.Marshaler interface. 396func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { 397 return json.Marshal(spt.inner) 398} 399 400// UnmarshalJSON implements the json.Unmarshaler interface. 401func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error { 402 // need to determine the token type 403 raw := map[string]interface{}{} 404 err := json.Unmarshal(data, &raw) 405 if err != nil { 406 return err 407 } 408 secret := raw["secret"].(map[string]interface{}) 409 switch secret["type"] { 410 case "ServicePrincipalNoSecret": 411 spt.inner.Secret = &ServicePrincipalNoSecret{} 412 case "ServicePrincipalTokenSecret": 413 spt.inner.Secret = &ServicePrincipalTokenSecret{} 414 case "ServicePrincipalCertificateSecret": 415 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported") 416 case "ServicePrincipalMSISecret": 417 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported") 418 case "ServicePrincipalUsernamePasswordSecret": 419 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{} 420 case "ServicePrincipalAuthorizationCodeSecret": 421 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{} 422 default: 423 return fmt.Errorf("unrecognized token type '%s'", secret["type"]) 424 } 425 err = json.Unmarshal(data, &spt.inner) 426 if err != nil { 427 return err 428 } 429 // Don't override the refreshLock or the sender if those have been already set. 430 if spt.refreshLock == nil { 431 spt.refreshLock = &sync.RWMutex{} 432 } 433 if spt.sender == nil { 434 spt.sender = sender() 435 } 436 return nil 437} 438 439// internal type used for marshalling/unmarshalling 440type servicePrincipalToken struct { 441 Token Token `json:"token"` 442 Secret ServicePrincipalSecret `json:"secret"` 443 OauthConfig OAuthConfig `json:"oauth"` 444 ClientID string `json:"clientID"` 445 Resource string `json:"resource"` 446 AutoRefresh bool `json:"autoRefresh"` 447 RefreshWithin time.Duration `json:"refreshWithin"` 448} 449 450func validateOAuthConfig(oac OAuthConfig) error { 451 if oac.IsZero() { 452 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized") 453 } 454 return nil 455} 456 457// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation. 458func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 459 if err := validateOAuthConfig(oauthConfig); err != nil { 460 return nil, err 461 } 462 if err := validateStringParam(id, "id"); err != nil { 463 return nil, err 464 } 465 if err := validateStringParam(resource, "resource"); err != nil { 466 return nil, err 467 } 468 if secret == nil { 469 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 470 } 471 spt := &ServicePrincipalToken{ 472 inner: servicePrincipalToken{ 473 Token: newToken(), 474 OauthConfig: oauthConfig, 475 Secret: secret, 476 ClientID: id, 477 Resource: resource, 478 AutoRefresh: true, 479 RefreshWithin: defaultRefresh, 480 }, 481 refreshLock: &sync.RWMutex{}, 482 sender: sender(), 483 refreshCallbacks: callbacks, 484 } 485 return spt, nil 486} 487 488// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token 489func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 490 if err := validateOAuthConfig(oauthConfig); err != nil { 491 return nil, err 492 } 493 if err := validateStringParam(clientID, "clientID"); err != nil { 494 return nil, err 495 } 496 if err := validateStringParam(resource, "resource"); err != nil { 497 return nil, err 498 } 499 if token.IsZero() { 500 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 501 } 502 spt, err := NewServicePrincipalTokenWithSecret( 503 oauthConfig, 504 clientID, 505 resource, 506 &ServicePrincipalNoSecret{}, 507 callbacks...) 508 if err != nil { 509 return nil, err 510 } 511 512 spt.inner.Token = token 513 514 return spt, nil 515} 516 517// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret 518func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 519 if err := validateOAuthConfig(oauthConfig); err != nil { 520 return nil, err 521 } 522 if err := validateStringParam(clientID, "clientID"); err != nil { 523 return nil, err 524 } 525 if err := validateStringParam(resource, "resource"); err != nil { 526 return nil, err 527 } 528 if secret == nil { 529 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 530 } 531 if token.IsZero() { 532 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 533 } 534 spt, err := NewServicePrincipalTokenWithSecret( 535 oauthConfig, 536 clientID, 537 resource, 538 secret, 539 callbacks...) 540 if err != nil { 541 return nil, err 542 } 543 544 spt.inner.Token = token 545 546 return spt, nil 547} 548 549// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal 550// credentials scoped to the named resource. 551func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 552 if err := validateOAuthConfig(oauthConfig); err != nil { 553 return nil, err 554 } 555 if err := validateStringParam(clientID, "clientID"); err != nil { 556 return nil, err 557 } 558 if err := validateStringParam(secret, "secret"); err != nil { 559 return nil, err 560 } 561 if err := validateStringParam(resource, "resource"); err != nil { 562 return nil, err 563 } 564 return NewServicePrincipalTokenWithSecret( 565 oauthConfig, 566 clientID, 567 resource, 568 &ServicePrincipalTokenSecret{ 569 ClientSecret: secret, 570 }, 571 callbacks..., 572 ) 573} 574 575// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes. 576func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 577 if err := validateOAuthConfig(oauthConfig); err != nil { 578 return nil, err 579 } 580 if err := validateStringParam(clientID, "clientID"); err != nil { 581 return nil, err 582 } 583 if err := validateStringParam(resource, "resource"); err != nil { 584 return nil, err 585 } 586 if certificate == nil { 587 return nil, fmt.Errorf("parameter 'certificate' cannot be nil") 588 } 589 if privateKey == nil { 590 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil") 591 } 592 return NewServicePrincipalTokenWithSecret( 593 oauthConfig, 594 clientID, 595 resource, 596 &ServicePrincipalCertificateSecret{ 597 PrivateKey: privateKey, 598 Certificate: certificate, 599 }, 600 callbacks..., 601 ) 602} 603 604// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password. 605func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 606 if err := validateOAuthConfig(oauthConfig); err != nil { 607 return nil, err 608 } 609 if err := validateStringParam(clientID, "clientID"); err != nil { 610 return nil, err 611 } 612 if err := validateStringParam(username, "username"); err != nil { 613 return nil, err 614 } 615 if err := validateStringParam(password, "password"); err != nil { 616 return nil, err 617 } 618 if err := validateStringParam(resource, "resource"); err != nil { 619 return nil, err 620 } 621 return NewServicePrincipalTokenWithSecret( 622 oauthConfig, 623 clientID, 624 resource, 625 &ServicePrincipalUsernamePasswordSecret{ 626 Username: username, 627 Password: password, 628 }, 629 callbacks..., 630 ) 631} 632 633// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the 634func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 635 636 if err := validateOAuthConfig(oauthConfig); err != nil { 637 return nil, err 638 } 639 if err := validateStringParam(clientID, "clientID"); err != nil { 640 return nil, err 641 } 642 if err := validateStringParam(clientSecret, "clientSecret"); err != nil { 643 return nil, err 644 } 645 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil { 646 return nil, err 647 } 648 if err := validateStringParam(redirectURI, "redirectURI"); err != nil { 649 return nil, err 650 } 651 if err := validateStringParam(resource, "resource"); err != nil { 652 return nil, err 653 } 654 655 return NewServicePrincipalTokenWithSecret( 656 oauthConfig, 657 clientID, 658 resource, 659 &ServicePrincipalAuthorizationCodeSecret{ 660 ClientSecret: clientSecret, 661 AuthorizationCode: authorizationCode, 662 RedirectURI: redirectURI, 663 }, 664 callbacks..., 665 ) 666} 667 668type msiType int 669 670const ( 671 msiTypeUnavailable msiType = iota 672 msiTypeAppServiceV20170901 673 msiTypeCloudShell 674 msiTypeIMDS 675) 676 677func (m msiType) String() string { 678 switch m { 679 case msiTypeUnavailable: 680 return "unavailable" 681 case msiTypeAppServiceV20170901: 682 return "AppServiceV20170901" 683 case msiTypeCloudShell: 684 return "CloudShell" 685 case msiTypeIMDS: 686 return "IMDS" 687 default: 688 return fmt.Sprintf("unhandled MSI type %d", m) 689 } 690} 691 692// returns the MSI type and endpoint, or an error 693func getMSIType() (msiType, string, error) { 694 if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" { 695 // if the env var MSI_ENDPOINT is set 696 if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" { 697 // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService 698 return msiTypeAppServiceV20170901, endpointEnvVar, nil 699 } 700 // if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell 701 return msiTypeCloudShell, endpointEnvVar, nil 702 } else if msiAvailableHook(context.Background(), sender()) { 703 // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the msiType is IMDS. This will timeout after 500 milliseconds 704 return msiTypeIMDS, msiEndpoint, nil 705 } else { 706 // if MSI_ENDPOINT is NOT set and IMDS endpoint is not available Managed Identity is not available 707 return msiTypeUnavailable, "", errors.New("MSI not available") 708 } 709} 710 711// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. 712// NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell. 713// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint. 714func GetMSIVMEndpoint() (string, error) { 715 return msiEndpoint, nil 716} 717 718// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions. 719// It will return an error when not running in an app service/functions environment. 720// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint. 721func GetMSIAppServiceEndpoint() (string, error) { 722 msiType, endpoint, err := getMSIType() 723 if err != nil { 724 return "", err 725 } 726 switch msiType { 727 case msiTypeAppServiceV20170901: 728 return endpoint, nil 729 default: 730 return "", fmt.Errorf("%s is not app service environment", msiType) 731 } 732} 733 734// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment 735// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint. 736func GetMSIEndpoint() (string, error) { 737 _, endpoint, err := getMSIType() 738 return endpoint, err 739} 740 741// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. 742// It will use the system assigned identity when creating the token. 743// msiEndpoint - empty string, or pass a non-empty string to override the default value. 744// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead. 745func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 746 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...) 747} 748 749// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension. 750// It will use the clientID of specified user assigned identity when creating the token. 751// msiEndpoint - empty string, or pass a non-empty string to override the default value. 752// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead. 753func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 754 if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil { 755 return nil, err 756 } 757 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...) 758} 759 760// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension. 761// It will use the azure resource id of user assigned identity when creating the token. 762// msiEndpoint - empty string, or pass a non-empty string to override the default value. 763// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead. 764func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 765 if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil { 766 return nil, err 767 } 768 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...) 769} 770 771// ManagedIdentityOptions contains optional values for configuring managed identity authentication. 772type ManagedIdentityOptions struct { 773 // ClientID is the user-assigned identity to use during authentication. 774 // It is mutually exclusive with IdentityResourceID. 775 ClientID string 776 777 // IdentityResourceID is the resource ID of the user-assigned identity to use during authentication. 778 // It is mutually exclusive with ClientID. 779 IdentityResourceID string 780} 781 782// NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity. 783// It supports the following managed identity environments. 784// - App Service Environment (API version 2017-09-01 only) 785// - Cloud shell 786// - IMDS with a system or user assigned identity 787func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 788 if options == nil { 789 options = &ManagedIdentityOptions{} 790 } 791 return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...) 792} 793 794func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 795 if err := validateStringParam(resource, "resource"); err != nil { 796 return nil, err 797 } 798 if userAssignedID != "" && identityResourceID != "" { 799 return nil, errors.New("cannot specify userAssignedID and identityResourceID") 800 } 801 msiType, endpoint, err := getMSIType() 802 if err != nil { 803 logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v\n", err) 804 return nil, err 805 } 806 logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s\n", msiType, endpoint) 807 if msiEndpoint != "" { 808 endpoint = msiEndpoint 809 logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s\n", endpoint) 810 } 811 msiEndpointURL, err := url.Parse(endpoint) 812 if err != nil { 813 return nil, err 814 } 815 // cloud shell sends its data in the request body 816 if msiType != msiTypeCloudShell { 817 v := url.Values{} 818 v.Set("resource", resource) 819 clientIDParam := "client_id" 820 switch msiType { 821 case msiTypeAppServiceV20170901: 822 clientIDParam = "clientid" 823 v.Set("api-version", appServiceAPIVersion2017) 824 break 825 case msiTypeIMDS: 826 v.Set("api-version", msiAPIVersion) 827 } 828 if userAssignedID != "" { 829 v.Set(clientIDParam, userAssignedID) 830 } else if identityResourceID != "" { 831 v.Set("mi_res_id", identityResourceID) 832 } 833 msiEndpointURL.RawQuery = v.Encode() 834 } 835 836 spt := &ServicePrincipalToken{ 837 inner: servicePrincipalToken{ 838 Token: newToken(), 839 OauthConfig: OAuthConfig{ 840 TokenEndpoint: *msiEndpointURL, 841 }, 842 Secret: &ServicePrincipalMSISecret{ 843 msiType: msiType, 844 clientResourceID: identityResourceID, 845 }, 846 Resource: resource, 847 AutoRefresh: true, 848 RefreshWithin: defaultRefresh, 849 ClientID: userAssignedID, 850 }, 851 refreshLock: &sync.RWMutex{}, 852 sender: sender(), 853 refreshCallbacks: callbacks, 854 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts, 855 } 856 857 return spt, nil 858} 859 860// internal type that implements TokenRefreshError 861type tokenRefreshError struct { 862 message string 863 resp *http.Response 864} 865 866// Error implements the error interface which is part of the TokenRefreshError interface. 867func (tre tokenRefreshError) Error() string { 868 return tre.message 869} 870 871// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. 872func (tre tokenRefreshError) Response() *http.Response { 873 return tre.resp 874} 875 876func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError { 877 return tokenRefreshError{message: message, resp: resp} 878} 879 880// EnsureFresh will refresh the token if it will expire within the refresh window (as set by 881// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 882func (spt *ServicePrincipalToken) EnsureFresh() error { 883 return spt.EnsureFreshWithContext(context.Background()) 884} 885 886// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by 887// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 888func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { 889 // must take the read lock when initially checking the token's expiration 890 if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) { 891 // take the write lock then check again to see if the token was already refreshed 892 spt.refreshLock.Lock() 893 defer spt.refreshLock.Unlock() 894 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { 895 return spt.refreshInternal(ctx, spt.inner.Resource) 896 } 897 } 898 return nil 899} 900 901// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization 902func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { 903 if spt.refreshCallbacks != nil { 904 for _, callback := range spt.refreshCallbacks { 905 err := callback(spt.inner.Token) 906 if err != nil { 907 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) 908 } 909 } 910 } 911 return nil 912} 913 914// Refresh obtains a fresh token for the Service Principal. 915// This method is safe for concurrent use. 916func (spt *ServicePrincipalToken) Refresh() error { 917 return spt.RefreshWithContext(context.Background()) 918} 919 920// RefreshWithContext obtains a fresh token for the Service Principal. 921// This method is safe for concurrent use. 922func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { 923 spt.refreshLock.Lock() 924 defer spt.refreshLock.Unlock() 925 return spt.refreshInternal(ctx, spt.inner.Resource) 926} 927 928// RefreshExchange refreshes the token, but for a different resource. 929// This method is safe for concurrent use. 930func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { 931 return spt.RefreshExchangeWithContext(context.Background(), resource) 932} 933 934// RefreshExchangeWithContext refreshes the token, but for a different resource. 935// This method is safe for concurrent use. 936func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { 937 spt.refreshLock.Lock() 938 defer spt.refreshLock.Unlock() 939 return spt.refreshInternal(ctx, resource) 940} 941 942func (spt *ServicePrincipalToken) getGrantType() string { 943 switch spt.inner.Secret.(type) { 944 case *ServicePrincipalUsernamePasswordSecret: 945 return OAuthGrantTypeUserPass 946 case *ServicePrincipalAuthorizationCodeSecret: 947 return OAuthGrantTypeAuthorizationCode 948 default: 949 return OAuthGrantTypeClientCredentials 950 } 951} 952 953func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { 954 if spt.customRefreshFunc != nil { 955 token, err := spt.customRefreshFunc(ctx, resource) 956 if err != nil { 957 return err 958 } 959 spt.inner.Token = *token 960 return spt.InvokeRefreshCallbacks(spt.inner.Token) 961 } 962 req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) 963 if err != nil { 964 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) 965 } 966 req.Header.Add("User-Agent", UserAgent()) 967 req = req.WithContext(ctx) 968 var resp *http.Response 969 authBodyFilter := func(b []byte) []byte { 970 if logger.Level() != logger.LogAuth { 971 return []byte("**REDACTED** authentication body") 972 } 973 return b 974 } 975 if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok { 976 switch msiSecret.msiType { 977 case msiTypeAppServiceV20170901: 978 req.Method = http.MethodGet 979 req.Header.Set("secret", os.Getenv(msiSecretEnv)) 980 break 981 case msiTypeCloudShell: 982 req.Header.Set("Metadata", "true") 983 data := url.Values{} 984 data.Set("resource", spt.inner.Resource) 985 if spt.inner.ClientID != "" { 986 data.Set("client_id", spt.inner.ClientID) 987 } else if msiSecret.clientResourceID != "" { 988 data.Set("msi_res_id", msiSecret.clientResourceID) 989 } 990 req.Body = ioutil.NopCloser(strings.NewReader(data.Encode())) 991 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 992 break 993 case msiTypeIMDS: 994 req.Method = http.MethodGet 995 req.Header.Set("Metadata", "true") 996 break 997 } 998 logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter}) 999 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts) 1000 } else { 1001 v := url.Values{} 1002 v.Set("client_id", spt.inner.ClientID) 1003 v.Set("resource", resource) 1004 1005 if spt.inner.Token.RefreshToken != "" { 1006 v.Set("grant_type", OAuthGrantTypeRefreshToken) 1007 v.Set("refresh_token", spt.inner.Token.RefreshToken) 1008 // web apps must specify client_secret when refreshing tokens 1009 // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens 1010 if spt.getGrantType() == OAuthGrantTypeAuthorizationCode { 1011 err := spt.inner.Secret.SetAuthenticationValues(spt, &v) 1012 if err != nil { 1013 return err 1014 } 1015 } 1016 } else { 1017 v.Set("grant_type", spt.getGrantType()) 1018 err := spt.inner.Secret.SetAuthenticationValues(spt, &v) 1019 if err != nil { 1020 return err 1021 } 1022 } 1023 1024 s := v.Encode() 1025 body := ioutil.NopCloser(strings.NewReader(s)) 1026 req.ContentLength = int64(len(s)) 1027 req.Header.Set(contentType, mimeTypeFormPost) 1028 req.Body = body 1029 logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter}) 1030 resp, err = spt.sender.Do(req) 1031 } 1032 1033 // don't return a TokenRefreshError here; this will allow retry logic to apply 1034 if err != nil { 1035 return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err) 1036 } else if resp == nil { 1037 return fmt.Errorf("adal: received nil response and error") 1038 } 1039 1040 logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter}) 1041 defer resp.Body.Close() 1042 rb, err := ioutil.ReadAll(resp.Body) 1043 1044 if resp.StatusCode != http.StatusOK { 1045 if err != nil { 1046 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp) 1047 } 1048 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp) 1049 } 1050 1051 // for the following error cases don't return a TokenRefreshError. the operation succeeded 1052 // but some transient failure happened during deserialization. by returning a generic error 1053 // the retry logic will kick in (we don't retry on TokenRefreshError). 1054 1055 if err != nil { 1056 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) 1057 } 1058 if len(strings.Trim(string(rb), " ")) == 0 { 1059 return fmt.Errorf("adal: Empty service principal token received during refresh") 1060 } 1061 token := struct { 1062 AccessToken string `json:"access_token"` 1063 RefreshToken string `json:"refresh_token"` 1064 1065 // AAD returns expires_in as a string, ADFS returns it as an int 1066 ExpiresIn json.Number `json:"expires_in"` 1067 // expires_on can be in two formats, a UTC time stamp or the number of seconds. 1068 ExpiresOn string `json:"expires_on"` 1069 NotBefore json.Number `json:"not_before"` 1070 1071 Resource string `json:"resource"` 1072 Type string `json:"token_type"` 1073 }{} 1074 // return a TokenRefreshError in the follow error cases as the token is in an unexpected format 1075 err = json.Unmarshal(rb, &token) 1076 if err != nil { 1077 return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp) 1078 } 1079 expiresOn := json.Number("") 1080 // ADFS doesn't include the expires_on field 1081 if token.ExpiresOn != "" { 1082 if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil { 1083 return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp) 1084 } 1085 } 1086 spt.inner.Token.AccessToken = token.AccessToken 1087 spt.inner.Token.RefreshToken = token.RefreshToken 1088 spt.inner.Token.ExpiresIn = token.ExpiresIn 1089 spt.inner.Token.ExpiresOn = expiresOn 1090 spt.inner.Token.NotBefore = token.NotBefore 1091 spt.inner.Token.Resource = token.Resource 1092 spt.inner.Token.Type = token.Type 1093 1094 return spt.InvokeRefreshCallbacks(spt.inner.Token) 1095} 1096 1097// converts expires_on to the number of seconds 1098func parseExpiresOn(s string) (json.Number, error) { 1099 // convert the expiration date to the number of seconds from now 1100 timeToDuration := func(t time.Time) json.Number { 1101 dur := t.Sub(time.Now().UTC()) 1102 return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10)) 1103 } 1104 if _, err := strconv.ParseInt(s, 10, 64); err == nil { 1105 // this is the number of seconds case, no conversion required 1106 return json.Number(s), nil 1107 } else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil { 1108 return timeToDuration(eo), nil 1109 } else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil { 1110 return timeToDuration(eo), nil 1111 } else { 1112 // unknown format 1113 return json.Number(""), err 1114 } 1115} 1116 1117// retry logic specific to retrieving a token from the IMDS endpoint 1118func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) { 1119 // copied from client.go due to circular dependency 1120 retries := []int{ 1121 http.StatusRequestTimeout, // 408 1122 http.StatusTooManyRequests, // 429 1123 http.StatusInternalServerError, // 500 1124 http.StatusBadGateway, // 502 1125 http.StatusServiceUnavailable, // 503 1126 http.StatusGatewayTimeout, // 504 1127 } 1128 // extra retry status codes specific to IMDS 1129 retries = append(retries, 1130 http.StatusNotFound, 1131 http.StatusGone, 1132 // all remaining 5xx 1133 http.StatusNotImplemented, 1134 http.StatusHTTPVersionNotSupported, 1135 http.StatusVariantAlsoNegotiates, 1136 http.StatusInsufficientStorage, 1137 http.StatusLoopDetected, 1138 http.StatusNotExtended, 1139 http.StatusNetworkAuthenticationRequired) 1140 1141 // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance 1142 1143 const maxDelay time.Duration = 60 * time.Second 1144 1145 attempt := 0 1146 delay := time.Duration(0) 1147 1148 // maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made 1149 if maxAttempts < 1 { 1150 maxAttempts = defaultMaxMSIRefreshAttempts 1151 } 1152 1153 for attempt < maxAttempts { 1154 if resp != nil && resp.Body != nil { 1155 io.Copy(ioutil.Discard, resp.Body) 1156 resp.Body.Close() 1157 } 1158 resp, err = sender.Do(req) 1159 // we want to retry if err is not nil or the status code is in the list of retry codes 1160 if err == nil && !responseHasStatusCode(resp, retries...) { 1161 return 1162 } 1163 1164 // perform exponential backoff with a cap. 1165 // must increment attempt before calculating delay. 1166 attempt++ 1167 // the base value of 2 is the "delta backoff" as specified in the guidance doc 1168 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second) 1169 if delay > maxDelay { 1170 delay = maxDelay 1171 } 1172 1173 select { 1174 case <-time.After(delay): 1175 // intentionally left blank 1176 case <-req.Context().Done(): 1177 err = req.Context().Err() 1178 return 1179 } 1180 } 1181 return 1182} 1183 1184func responseHasStatusCode(resp *http.Response, codes ...int) bool { 1185 if resp != nil { 1186 for _, i := range codes { 1187 if i == resp.StatusCode { 1188 return true 1189 } 1190 } 1191 } 1192 return false 1193} 1194 1195// SetAutoRefresh enables or disables automatic refreshing of stale tokens. 1196func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { 1197 spt.inner.AutoRefresh = autoRefresh 1198} 1199 1200// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will 1201// refresh the token. 1202func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { 1203 spt.inner.RefreshWithin = d 1204 return 1205} 1206 1207// SetSender sets the http.Client used when obtaining the Service Principal token. An 1208// undecorated http.Client is used by default. 1209func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } 1210 1211// OAuthToken implements the OAuthTokenProvider interface. It returns the current access token. 1212func (spt *ServicePrincipalToken) OAuthToken() string { 1213 spt.refreshLock.RLock() 1214 defer spt.refreshLock.RUnlock() 1215 return spt.inner.Token.OAuthToken() 1216} 1217 1218// Token returns a copy of the current token. 1219func (spt *ServicePrincipalToken) Token() Token { 1220 spt.refreshLock.RLock() 1221 defer spt.refreshLock.RUnlock() 1222 return spt.inner.Token 1223} 1224 1225// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization. 1226type MultiTenantServicePrincipalToken struct { 1227 PrimaryToken *ServicePrincipalToken 1228 AuxiliaryTokens []*ServicePrincipalToken 1229} 1230 1231// PrimaryOAuthToken returns the primary authorization token. 1232func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string { 1233 return mt.PrimaryToken.OAuthToken() 1234} 1235 1236// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens. 1237func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string { 1238 tokens := make([]string, len(mt.AuxiliaryTokens)) 1239 for i := range mt.AuxiliaryTokens { 1240 tokens[i] = mt.AuxiliaryTokens[i].OAuthToken() 1241 } 1242 return tokens 1243} 1244 1245// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource. 1246func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) { 1247 if err := validateStringParam(clientID, "clientID"); err != nil { 1248 return nil, err 1249 } 1250 if err := validateStringParam(secret, "secret"); err != nil { 1251 return nil, err 1252 } 1253 if err := validateStringParam(resource, "resource"); err != nil { 1254 return nil, err 1255 } 1256 auxTenants := multiTenantCfg.AuxiliaryTenants() 1257 m := MultiTenantServicePrincipalToken{ 1258 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)), 1259 } 1260 primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource) 1261 if err != nil { 1262 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err) 1263 } 1264 m.PrimaryToken = primary 1265 for i := range auxTenants { 1266 aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource) 1267 if err != nil { 1268 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err) 1269 } 1270 m.AuxiliaryTokens[i] = aux 1271 } 1272 return &m, nil 1273} 1274 1275// NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource. 1276func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) { 1277 if err := validateStringParam(clientID, "clientID"); err != nil { 1278 return nil, err 1279 } 1280 if err := validateStringParam(resource, "resource"); err != nil { 1281 return nil, err 1282 } 1283 if certificate == nil { 1284 return nil, fmt.Errorf("parameter 'certificate' cannot be nil") 1285 } 1286 if privateKey == nil { 1287 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil") 1288 } 1289 auxTenants := multiTenantCfg.AuxiliaryTenants() 1290 m := MultiTenantServicePrincipalToken{ 1291 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)), 1292 } 1293 primary, err := NewServicePrincipalTokenWithSecret( 1294 *multiTenantCfg.PrimaryTenant(), 1295 clientID, 1296 resource, 1297 &ServicePrincipalCertificateSecret{ 1298 PrivateKey: privateKey, 1299 Certificate: certificate, 1300 }, 1301 ) 1302 if err != nil { 1303 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err) 1304 } 1305 m.PrimaryToken = primary 1306 for i := range auxTenants { 1307 aux, err := NewServicePrincipalTokenWithSecret( 1308 *auxTenants[i], 1309 clientID, 1310 resource, 1311 &ServicePrincipalCertificateSecret{ 1312 PrivateKey: privateKey, 1313 Certificate: certificate, 1314 }, 1315 ) 1316 if err != nil { 1317 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err) 1318 } 1319 m.AuxiliaryTokens[i] = aux 1320 } 1321 return &m, nil 1322} 1323 1324// MSIAvailable returns true if the MSI endpoint is available for authentication. 1325func MSIAvailable(ctx context.Context, sender Sender) bool { 1326 resp, err := getMSIEndpoint(ctx, sender) 1327 if err == nil { 1328 resp.Body.Close() 1329 } 1330 return err == nil 1331} 1332 1333// used for testing purposes 1334var msiAvailableHook = func(ctx context.Context, sender Sender) bool { 1335 return MSIAvailable(ctx, sender) 1336} 1337