1package adal 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/sha1" 22 "crypto/x509" 23 "encoding/base64" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "io/ioutil" 29 "math" 30 "net/http" 31 "net/url" 32 "os" 33 "strings" 34 "sync" 35 "time" 36 37 "github.com/Azure/go-autorest/autorest/date" 38 "github.com/dgrijalva/jwt-go" 39) 40 41const ( 42 defaultRefresh = 5 * time.Minute 43 44 // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow 45 OAuthGrantTypeDeviceCode = "device_code" 46 47 // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows 48 OAuthGrantTypeClientCredentials = "client_credentials" 49 50 // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows 51 OAuthGrantTypeUserPass = "password" 52 53 // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows 54 OAuthGrantTypeRefreshToken = "refresh_token" 55 56 // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows 57 OAuthGrantTypeAuthorizationCode = "authorization_code" 58 59 // metadataHeader is the header required by MSI extension 60 metadataHeader = "Metadata" 61 62 // msiEndpoint is the well known endpoint for getting MSI authentications tokens 63 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" 64 65 // the default number of attempts to refresh an MSI authentication token 66 defaultMaxMSIRefreshAttempts = 5 67 68 // asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions 69 asMSIEndpointEnv = "MSI_ENDPOINT" 70 71 // asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions 72 asMSISecretEnv = "MSI_SECRET" 73) 74 75// OAuthTokenProvider is an interface which should be implemented by an access token retriever 76type OAuthTokenProvider interface { 77 OAuthToken() string 78} 79 80// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization. 81type MultitenantOAuthTokenProvider interface { 82 PrimaryOAuthToken() string 83 AuxiliaryOAuthTokens() []string 84} 85 86// TokenRefreshError is an interface used by errors returned during token refresh. 87type TokenRefreshError interface { 88 error 89 Response() *http.Response 90} 91 92// Refresher is an interface for token refresh functionality 93type Refresher interface { 94 Refresh() error 95 RefreshExchange(resource string) error 96 EnsureFresh() error 97} 98 99// RefresherWithContext is an interface for token refresh functionality 100type RefresherWithContext interface { 101 RefreshWithContext(ctx context.Context) error 102 RefreshExchangeWithContext(ctx context.Context, resource string) error 103 EnsureFreshWithContext(ctx context.Context) error 104} 105 106// TokenRefreshCallback is the type representing callbacks that will be called after 107// a successful token refresh 108type TokenRefreshCallback func(Token) error 109 110// TokenRefresh is a type representing a custom callback to refresh a token 111type TokenRefresh func(ctx context.Context, resource string) (*Token, error) 112 113// Token encapsulates the access token used to authorize Azure requests. 114// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response 115type Token struct { 116 AccessToken string `json:"access_token"` 117 RefreshToken string `json:"refresh_token"` 118 119 ExpiresIn json.Number `json:"expires_in"` 120 ExpiresOn json.Number `json:"expires_on"` 121 NotBefore json.Number `json:"not_before"` 122 123 Resource string `json:"resource"` 124 Type string `json:"token_type"` 125} 126 127func newToken() Token { 128 return Token{ 129 ExpiresIn: "0", 130 ExpiresOn: "0", 131 NotBefore: "0", 132 } 133} 134 135// IsZero returns true if the token object is zero-initialized. 136func (t Token) IsZero() bool { 137 return t == Token{} 138} 139 140// Expires returns the time.Time when the Token expires. 141func (t Token) Expires() time.Time { 142 s, err := t.ExpiresOn.Float64() 143 if err != nil { 144 s = -3600 145 } 146 147 expiration := date.NewUnixTimeFromSeconds(s) 148 149 return time.Time(expiration).UTC() 150} 151 152// IsExpired returns true if the Token is expired, false otherwise. 153func (t Token) IsExpired() bool { 154 return t.WillExpireIn(0) 155} 156 157// WillExpireIn returns true if the Token will expire after the passed time.Duration interval 158// from now, false otherwise. 159func (t Token) WillExpireIn(d time.Duration) bool { 160 return !t.Expires().After(time.Now().Add(d)) 161} 162 163//OAuthToken return the current access token 164func (t *Token) OAuthToken() string { 165 return t.AccessToken 166} 167 168// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form 169// that is submitted when acquiring an oAuth token. 170type ServicePrincipalSecret interface { 171 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error 172} 173 174// ServicePrincipalNoSecret represents a secret type that contains no secret 175// meaning it is not valid for fetching a fresh token. This is used by Manual 176type ServicePrincipalNoSecret struct { 177} 178 179// SetAuthenticationValues is a method of the interface ServicePrincipalSecret 180// It only returns an error for the ServicePrincipalNoSecret type 181func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 182 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") 183} 184 185// MarshalJSON implements the json.Marshaler interface. 186func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) { 187 type tokenType struct { 188 Type string `json:"type"` 189 } 190 return json.Marshal(tokenType{ 191 Type: "ServicePrincipalNoSecret", 192 }) 193} 194 195// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. 196type ServicePrincipalTokenSecret struct { 197 ClientSecret string `json:"value"` 198} 199 200// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 201// It will populate the form submitted during oAuth Token Acquisition using the client_secret. 202func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 203 v.Set("client_secret", tokenSecret.ClientSecret) 204 return nil 205} 206 207// MarshalJSON implements the json.Marshaler interface. 208func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) { 209 type tokenType struct { 210 Type string `json:"type"` 211 Value string `json:"value"` 212 } 213 return json.Marshal(tokenType{ 214 Type: "ServicePrincipalTokenSecret", 215 Value: tokenSecret.ClientSecret, 216 }) 217} 218 219// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. 220type ServicePrincipalCertificateSecret struct { 221 Certificate *x509.Certificate 222 PrivateKey *rsa.PrivateKey 223} 224 225// SignJwt returns the JWT signed with the certificate's private key. 226func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { 227 hasher := sha1.New() 228 _, err := hasher.Write(secret.Certificate.Raw) 229 if err != nil { 230 return "", err 231 } 232 233 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) 234 235 // The jti (JWT ID) claim provides a unique identifier for the JWT. 236 jti := make([]byte, 20) 237 _, err = rand.Read(jti) 238 if err != nil { 239 return "", err 240 } 241 242 token := jwt.New(jwt.SigningMethodRS256) 243 token.Header["x5t"] = thumbprint 244 x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)} 245 token.Header["x5c"] = x5c 246 token.Claims = jwt.MapClaims{ 247 "aud": spt.inner.OauthConfig.TokenEndpoint.String(), 248 "iss": spt.inner.ClientID, 249 "sub": spt.inner.ClientID, 250 "jti": base64.URLEncoding.EncodeToString(jti), 251 "nbf": time.Now().Unix(), 252 "exp": time.Now().Add(24 * time.Hour).Unix(), 253 } 254 255 signedString, err := token.SignedString(secret.PrivateKey) 256 return signedString, err 257} 258 259// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 260// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate. 261func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 262 jwt, err := secret.SignJwt(spt) 263 if err != nil { 264 return err 265 } 266 267 v.Set("client_assertion", jwt) 268 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 269 return nil 270} 271 272// MarshalJSON implements the json.Marshaler interface. 273func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) { 274 return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported") 275} 276 277// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. 278type ServicePrincipalMSISecret struct { 279} 280 281// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 282func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 283 return nil 284} 285 286// MarshalJSON implements the json.Marshaler interface. 287func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) { 288 return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported") 289} 290 291// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. 292type ServicePrincipalUsernamePasswordSecret struct { 293 Username string `json:"username"` 294 Password string `json:"password"` 295} 296 297// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 298func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 299 v.Set("username", secret.Username) 300 v.Set("password", secret.Password) 301 return nil 302} 303 304// MarshalJSON implements the json.Marshaler interface. 305func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) { 306 type tokenType struct { 307 Type string `json:"type"` 308 Username string `json:"username"` 309 Password string `json:"password"` 310 } 311 return json.Marshal(tokenType{ 312 Type: "ServicePrincipalUsernamePasswordSecret", 313 Username: secret.Username, 314 Password: secret.Password, 315 }) 316} 317 318// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. 319type ServicePrincipalAuthorizationCodeSecret struct { 320 ClientSecret string `json:"value"` 321 AuthorizationCode string `json:"authCode"` 322 RedirectURI string `json:"redirect"` 323} 324 325// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. 326func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { 327 v.Set("code", secret.AuthorizationCode) 328 v.Set("client_secret", secret.ClientSecret) 329 v.Set("redirect_uri", secret.RedirectURI) 330 return nil 331} 332 333// MarshalJSON implements the json.Marshaler interface. 334func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) { 335 type tokenType struct { 336 Type string `json:"type"` 337 Value string `json:"value"` 338 AuthCode string `json:"authCode"` 339 Redirect string `json:"redirect"` 340 } 341 return json.Marshal(tokenType{ 342 Type: "ServicePrincipalAuthorizationCodeSecret", 343 Value: secret.ClientSecret, 344 AuthCode: secret.AuthorizationCode, 345 Redirect: secret.RedirectURI, 346 }) 347} 348 349// ServicePrincipalToken encapsulates a Token created for a Service Principal. 350type ServicePrincipalToken struct { 351 inner servicePrincipalToken 352 refreshLock *sync.RWMutex 353 sender Sender 354 customRefreshFunc TokenRefresh 355 refreshCallbacks []TokenRefreshCallback 356 // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. 357 MaxMSIRefreshAttempts int 358} 359 360// MarshalTokenJSON returns the marshalled inner token. 361func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) { 362 return json.Marshal(spt.inner.Token) 363} 364 365// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks. 366func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) { 367 spt.refreshCallbacks = callbacks 368} 369 370// SetCustomRefreshFunc sets a custom refresh function used to refresh the token. 371func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) { 372 spt.customRefreshFunc = customRefreshFunc 373} 374 375// MarshalJSON implements the json.Marshaler interface. 376func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { 377 return json.Marshal(spt.inner) 378} 379 380// UnmarshalJSON implements the json.Unmarshaler interface. 381func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error { 382 // need to determine the token type 383 raw := map[string]interface{}{} 384 err := json.Unmarshal(data, &raw) 385 if err != nil { 386 return err 387 } 388 secret := raw["secret"].(map[string]interface{}) 389 switch secret["type"] { 390 case "ServicePrincipalNoSecret": 391 spt.inner.Secret = &ServicePrincipalNoSecret{} 392 case "ServicePrincipalTokenSecret": 393 spt.inner.Secret = &ServicePrincipalTokenSecret{} 394 case "ServicePrincipalCertificateSecret": 395 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported") 396 case "ServicePrincipalMSISecret": 397 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported") 398 case "ServicePrincipalUsernamePasswordSecret": 399 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{} 400 case "ServicePrincipalAuthorizationCodeSecret": 401 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{} 402 default: 403 return fmt.Errorf("unrecognized token type '%s'", secret["type"]) 404 } 405 err = json.Unmarshal(data, &spt.inner) 406 if err != nil { 407 return err 408 } 409 // Don't override the refreshLock or the sender if those have been already set. 410 if spt.refreshLock == nil { 411 spt.refreshLock = &sync.RWMutex{} 412 } 413 if spt.sender == nil { 414 spt.sender = sender() 415 } 416 return nil 417} 418 419// internal type used for marshalling/unmarshalling 420type servicePrincipalToken struct { 421 Token Token `json:"token"` 422 Secret ServicePrincipalSecret `json:"secret"` 423 OauthConfig OAuthConfig `json:"oauth"` 424 ClientID string `json:"clientID"` 425 Resource string `json:"resource"` 426 AutoRefresh bool `json:"autoRefresh"` 427 RefreshWithin time.Duration `json:"refreshWithin"` 428} 429 430func validateOAuthConfig(oac OAuthConfig) error { 431 if oac.IsZero() { 432 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized") 433 } 434 return nil 435} 436 437// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation. 438func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 439 if err := validateOAuthConfig(oauthConfig); err != nil { 440 return nil, err 441 } 442 if err := validateStringParam(id, "id"); err != nil { 443 return nil, err 444 } 445 if err := validateStringParam(resource, "resource"); err != nil { 446 return nil, err 447 } 448 if secret == nil { 449 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 450 } 451 spt := &ServicePrincipalToken{ 452 inner: servicePrincipalToken{ 453 Token: newToken(), 454 OauthConfig: oauthConfig, 455 Secret: secret, 456 ClientID: id, 457 Resource: resource, 458 AutoRefresh: true, 459 RefreshWithin: defaultRefresh, 460 }, 461 refreshLock: &sync.RWMutex{}, 462 sender: sender(), 463 refreshCallbacks: callbacks, 464 } 465 return spt, nil 466} 467 468// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token 469func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 470 if err := validateOAuthConfig(oauthConfig); err != nil { 471 return nil, err 472 } 473 if err := validateStringParam(clientID, "clientID"); err != nil { 474 return nil, err 475 } 476 if err := validateStringParam(resource, "resource"); err != nil { 477 return nil, err 478 } 479 if token.IsZero() { 480 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 481 } 482 spt, err := NewServicePrincipalTokenWithSecret( 483 oauthConfig, 484 clientID, 485 resource, 486 &ServicePrincipalNoSecret{}, 487 callbacks...) 488 if err != nil { 489 return nil, err 490 } 491 492 spt.inner.Token = token 493 494 return spt, nil 495} 496 497// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret 498func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 499 if err := validateOAuthConfig(oauthConfig); err != nil { 500 return nil, err 501 } 502 if err := validateStringParam(clientID, "clientID"); err != nil { 503 return nil, err 504 } 505 if err := validateStringParam(resource, "resource"); err != nil { 506 return nil, err 507 } 508 if secret == nil { 509 return nil, fmt.Errorf("parameter 'secret' cannot be nil") 510 } 511 if token.IsZero() { 512 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") 513 } 514 spt, err := NewServicePrincipalTokenWithSecret( 515 oauthConfig, 516 clientID, 517 resource, 518 secret, 519 callbacks...) 520 if err != nil { 521 return nil, err 522 } 523 524 spt.inner.Token = token 525 526 return spt, nil 527} 528 529// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal 530// credentials scoped to the named resource. 531func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 532 if err := validateOAuthConfig(oauthConfig); err != nil { 533 return nil, err 534 } 535 if err := validateStringParam(clientID, "clientID"); err != nil { 536 return nil, err 537 } 538 if err := validateStringParam(secret, "secret"); err != nil { 539 return nil, err 540 } 541 if err := validateStringParam(resource, "resource"); err != nil { 542 return nil, err 543 } 544 return NewServicePrincipalTokenWithSecret( 545 oauthConfig, 546 clientID, 547 resource, 548 &ServicePrincipalTokenSecret{ 549 ClientSecret: secret, 550 }, 551 callbacks..., 552 ) 553} 554 555// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes. 556func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 557 if err := validateOAuthConfig(oauthConfig); err != nil { 558 return nil, err 559 } 560 if err := validateStringParam(clientID, "clientID"); err != nil { 561 return nil, err 562 } 563 if err := validateStringParam(resource, "resource"); err != nil { 564 return nil, err 565 } 566 if certificate == nil { 567 return nil, fmt.Errorf("parameter 'certificate' cannot be nil") 568 } 569 if privateKey == nil { 570 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil") 571 } 572 return NewServicePrincipalTokenWithSecret( 573 oauthConfig, 574 clientID, 575 resource, 576 &ServicePrincipalCertificateSecret{ 577 PrivateKey: privateKey, 578 Certificate: certificate, 579 }, 580 callbacks..., 581 ) 582} 583 584// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password. 585func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 586 if err := validateOAuthConfig(oauthConfig); err != nil { 587 return nil, err 588 } 589 if err := validateStringParam(clientID, "clientID"); err != nil { 590 return nil, err 591 } 592 if err := validateStringParam(username, "username"); err != nil { 593 return nil, err 594 } 595 if err := validateStringParam(password, "password"); err != nil { 596 return nil, err 597 } 598 if err := validateStringParam(resource, "resource"); err != nil { 599 return nil, err 600 } 601 return NewServicePrincipalTokenWithSecret( 602 oauthConfig, 603 clientID, 604 resource, 605 &ServicePrincipalUsernamePasswordSecret{ 606 Username: username, 607 Password: password, 608 }, 609 callbacks..., 610 ) 611} 612 613// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the 614func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 615 616 if err := validateOAuthConfig(oauthConfig); err != nil { 617 return nil, err 618 } 619 if err := validateStringParam(clientID, "clientID"); err != nil { 620 return nil, err 621 } 622 if err := validateStringParam(clientSecret, "clientSecret"); err != nil { 623 return nil, err 624 } 625 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil { 626 return nil, err 627 } 628 if err := validateStringParam(redirectURI, "redirectURI"); err != nil { 629 return nil, err 630 } 631 if err := validateStringParam(resource, "resource"); err != nil { 632 return nil, err 633 } 634 635 return NewServicePrincipalTokenWithSecret( 636 oauthConfig, 637 clientID, 638 resource, 639 &ServicePrincipalAuthorizationCodeSecret{ 640 ClientSecret: clientSecret, 641 AuthorizationCode: authorizationCode, 642 RedirectURI: redirectURI, 643 }, 644 callbacks..., 645 ) 646} 647 648// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. 649func GetMSIVMEndpoint() (string, error) { 650 return msiEndpoint, nil 651} 652 653func isAppService() bool { 654 _, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv) 655 _, asMSISecretEnvExists := os.LookupEnv(asMSISecretEnv) 656 657 return asMSIEndpointEnvExists && asMSISecretEnvExists 658} 659 660// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions 661func GetMSIAppServiceEndpoint() (string, error) { 662 asMSIEndpoint, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv) 663 664 if asMSIEndpointEnvExists { 665 return asMSIEndpoint, nil 666 } 667 return "", errors.New("MSI endpoint not found") 668} 669 670// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment 671func GetMSIEndpoint() (string, error) { 672 if isAppService() { 673 return GetMSIAppServiceEndpoint() 674 } 675 return GetMSIVMEndpoint() 676} 677 678// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. 679// It will use the system assigned identity when creating the token. 680func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 681 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...) 682} 683 684// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension. 685// It will use the specified user assigned identity when creating the token. 686func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 687 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...) 688} 689 690func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { 691 if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil { 692 return nil, err 693 } 694 if err := validateStringParam(resource, "resource"); err != nil { 695 return nil, err 696 } 697 if userAssignedID != nil { 698 if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil { 699 return nil, err 700 } 701 } 702 // We set the oauth config token endpoint to be MSI's endpoint 703 msiEndpointURL, err := url.Parse(msiEndpoint) 704 if err != nil { 705 return nil, err 706 } 707 708 v := url.Values{} 709 v.Set("resource", resource) 710 // App Service MSI currently only supports token API version 2017-09-01 711 if isAppService() { 712 v.Set("api-version", "2017-09-01") 713 } else { 714 v.Set("api-version", "2018-02-01") 715 } 716 if userAssignedID != nil { 717 v.Set("client_id", *userAssignedID) 718 } 719 msiEndpointURL.RawQuery = v.Encode() 720 721 spt := &ServicePrincipalToken{ 722 inner: servicePrincipalToken{ 723 Token: newToken(), 724 OauthConfig: OAuthConfig{ 725 TokenEndpoint: *msiEndpointURL, 726 }, 727 Secret: &ServicePrincipalMSISecret{}, 728 Resource: resource, 729 AutoRefresh: true, 730 RefreshWithin: defaultRefresh, 731 }, 732 refreshLock: &sync.RWMutex{}, 733 sender: sender(), 734 refreshCallbacks: callbacks, 735 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts, 736 } 737 738 if userAssignedID != nil { 739 spt.inner.ClientID = *userAssignedID 740 } 741 742 return spt, nil 743} 744 745// internal type that implements TokenRefreshError 746type tokenRefreshError struct { 747 message string 748 resp *http.Response 749} 750 751// Error implements the error interface which is part of the TokenRefreshError interface. 752func (tre tokenRefreshError) Error() string { 753 return tre.message 754} 755 756// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. 757func (tre tokenRefreshError) Response() *http.Response { 758 return tre.resp 759} 760 761func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError { 762 return tokenRefreshError{message: message, resp: resp} 763} 764 765// EnsureFresh will refresh the token if it will expire within the refresh window (as set by 766// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 767func (spt *ServicePrincipalToken) EnsureFresh() error { 768 return spt.EnsureFreshWithContext(context.Background()) 769} 770 771// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by 772// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 773func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { 774 // must take the read lock when initially checking the token's expiration 775 if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) { 776 // take the write lock then check again to see if the token was already refreshed 777 spt.refreshLock.Lock() 778 defer spt.refreshLock.Unlock() 779 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { 780 return spt.refreshInternal(ctx, spt.inner.Resource) 781 } 782 } 783 return nil 784} 785 786// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization 787func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { 788 if spt.refreshCallbacks != nil { 789 for _, callback := range spt.refreshCallbacks { 790 err := callback(spt.inner.Token) 791 if err != nil { 792 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) 793 } 794 } 795 } 796 return nil 797} 798 799// Refresh obtains a fresh token for the Service Principal. 800// This method is safe for concurrent use. 801func (spt *ServicePrincipalToken) Refresh() error { 802 return spt.RefreshWithContext(context.Background()) 803} 804 805// RefreshWithContext obtains a fresh token for the Service Principal. 806// This method is safe for concurrent use. 807func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { 808 spt.refreshLock.Lock() 809 defer spt.refreshLock.Unlock() 810 return spt.refreshInternal(ctx, spt.inner.Resource) 811} 812 813// RefreshExchange refreshes the token, but for a different resource. 814// This method is safe for concurrent use. 815func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { 816 return spt.RefreshExchangeWithContext(context.Background(), resource) 817} 818 819// RefreshExchangeWithContext refreshes the token, but for a different resource. 820// This method is safe for concurrent use. 821func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { 822 spt.refreshLock.Lock() 823 defer spt.refreshLock.Unlock() 824 return spt.refreshInternal(ctx, resource) 825} 826 827func (spt *ServicePrincipalToken) getGrantType() string { 828 switch spt.inner.Secret.(type) { 829 case *ServicePrincipalUsernamePasswordSecret: 830 return OAuthGrantTypeUserPass 831 case *ServicePrincipalAuthorizationCodeSecret: 832 return OAuthGrantTypeAuthorizationCode 833 default: 834 return OAuthGrantTypeClientCredentials 835 } 836} 837 838func isIMDS(u url.URL) bool { 839 imds, err := url.Parse(msiEndpoint) 840 if err != nil { 841 return false 842 } 843 return (u.Host == imds.Host && u.Path == imds.Path) || isAppService() 844} 845 846func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { 847 if spt.customRefreshFunc != nil { 848 token, err := spt.customRefreshFunc(ctx, resource) 849 if err != nil { 850 return err 851 } 852 spt.inner.Token = *token 853 return spt.InvokeRefreshCallbacks(spt.inner.Token) 854 } 855 856 req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) 857 if err != nil { 858 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) 859 } 860 req.Header.Add("User-Agent", UserAgent()) 861 // Add header when runtime is on App Service or Functions 862 if isAppService() { 863 asMSISecret, _ := os.LookupEnv(asMSISecretEnv) 864 req.Header.Add("Secret", asMSISecret) 865 } 866 req = req.WithContext(ctx) 867 if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) { 868 v := url.Values{} 869 v.Set("client_id", spt.inner.ClientID) 870 v.Set("resource", resource) 871 872 if spt.inner.Token.RefreshToken != "" { 873 v.Set("grant_type", OAuthGrantTypeRefreshToken) 874 v.Set("refresh_token", spt.inner.Token.RefreshToken) 875 // web apps must specify client_secret when refreshing tokens 876 // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens 877 if spt.getGrantType() == OAuthGrantTypeAuthorizationCode { 878 err := spt.inner.Secret.SetAuthenticationValues(spt, &v) 879 if err != nil { 880 return err 881 } 882 } 883 } else { 884 v.Set("grant_type", spt.getGrantType()) 885 err := spt.inner.Secret.SetAuthenticationValues(spt, &v) 886 if err != nil { 887 return err 888 } 889 } 890 891 s := v.Encode() 892 body := ioutil.NopCloser(strings.NewReader(s)) 893 req.ContentLength = int64(len(s)) 894 req.Header.Set(contentType, mimeTypeFormPost) 895 req.Body = body 896 } 897 898 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok { 899 req.Method = http.MethodGet 900 req.Header.Set(metadataHeader, "true") 901 } 902 903 var resp *http.Response 904 if isIMDS(spt.inner.OauthConfig.TokenEndpoint) { 905 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts) 906 } else { 907 resp, err = spt.sender.Do(req) 908 } 909 if err != nil { 910 // don't return a TokenRefreshError here; this will allow retry logic to apply 911 return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err) 912 } 913 914 defer resp.Body.Close() 915 rb, err := ioutil.ReadAll(resp.Body) 916 917 if resp.StatusCode != http.StatusOK { 918 if err != nil { 919 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp) 920 } 921 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) 922 } 923 924 // for the following error cases don't return a TokenRefreshError. the operation succeeded 925 // but some transient failure happened during deserialization. by returning a generic error 926 // the retry logic will kick in (we don't retry on TokenRefreshError). 927 928 if err != nil { 929 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) 930 } 931 if len(strings.Trim(string(rb), " ")) == 0 { 932 return fmt.Errorf("adal: Empty service principal token received during refresh") 933 } 934 var token Token 935 err = json.Unmarshal(rb, &token) 936 if err != nil { 937 return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) 938 } 939 940 spt.inner.Token = token 941 942 return spt.InvokeRefreshCallbacks(token) 943} 944 945// retry logic specific to retrieving a token from the IMDS endpoint 946func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) { 947 // copied from client.go due to circular dependency 948 retries := []int{ 949 http.StatusRequestTimeout, // 408 950 http.StatusTooManyRequests, // 429 951 http.StatusInternalServerError, // 500 952 http.StatusBadGateway, // 502 953 http.StatusServiceUnavailable, // 503 954 http.StatusGatewayTimeout, // 504 955 } 956 // extra retry status codes specific to IMDS 957 retries = append(retries, 958 http.StatusNotFound, 959 http.StatusGone, 960 // all remaining 5xx 961 http.StatusNotImplemented, 962 http.StatusHTTPVersionNotSupported, 963 http.StatusVariantAlsoNegotiates, 964 http.StatusInsufficientStorage, 965 http.StatusLoopDetected, 966 http.StatusNotExtended, 967 http.StatusNetworkAuthenticationRequired) 968 969 // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance 970 971 const maxDelay time.Duration = 60 * time.Second 972 973 attempt := 0 974 delay := time.Duration(0) 975 976 for attempt < maxAttempts { 977 if resp != nil && resp.Body != nil { 978 io.Copy(ioutil.Discard, resp.Body) 979 resp.Body.Close() 980 } 981 resp, err = sender.Do(req) 982 // we want to retry if err is not nil or the status code is in the list of retry codes 983 if err == nil && !responseHasStatusCode(resp, retries...) { 984 return 985 } 986 987 // perform exponential backoff with a cap. 988 // must increment attempt before calculating delay. 989 attempt++ 990 // the base value of 2 is the "delta backoff" as specified in the guidance doc 991 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second) 992 if delay > maxDelay { 993 delay = maxDelay 994 } 995 996 select { 997 case <-time.After(delay): 998 // intentionally left blank 999 case <-req.Context().Done(): 1000 err = req.Context().Err() 1001 return 1002 } 1003 } 1004 return 1005} 1006 1007func responseHasStatusCode(resp *http.Response, codes ...int) bool { 1008 if resp != nil { 1009 for _, i := range codes { 1010 if i == resp.StatusCode { 1011 return true 1012 } 1013 } 1014 } 1015 return false 1016} 1017 1018// SetAutoRefresh enables or disables automatic refreshing of stale tokens. 1019func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { 1020 spt.inner.AutoRefresh = autoRefresh 1021} 1022 1023// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will 1024// refresh the token. 1025func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { 1026 spt.inner.RefreshWithin = d 1027 return 1028} 1029 1030// SetSender sets the http.Client used when obtaining the Service Principal token. An 1031// undecorated http.Client is used by default. 1032func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } 1033 1034// OAuthToken implements the OAuthTokenProvider interface. It returns the current access token. 1035func (spt *ServicePrincipalToken) OAuthToken() string { 1036 spt.refreshLock.RLock() 1037 defer spt.refreshLock.RUnlock() 1038 return spt.inner.Token.OAuthToken() 1039} 1040 1041// Token returns a copy of the current token. 1042func (spt *ServicePrincipalToken) Token() Token { 1043 spt.refreshLock.RLock() 1044 defer spt.refreshLock.RUnlock() 1045 return spt.inner.Token 1046} 1047 1048// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization. 1049type MultiTenantServicePrincipalToken struct { 1050 PrimaryToken *ServicePrincipalToken 1051 AuxiliaryTokens []*ServicePrincipalToken 1052} 1053 1054// PrimaryOAuthToken returns the primary authorization token. 1055func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string { 1056 return mt.PrimaryToken.OAuthToken() 1057} 1058 1059// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens. 1060func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string { 1061 tokens := make([]string, len(mt.AuxiliaryTokens)) 1062 for i := range mt.AuxiliaryTokens { 1063 tokens[i] = mt.AuxiliaryTokens[i].OAuthToken() 1064 } 1065 return tokens 1066} 1067 1068// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by 1069// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. 1070func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { 1071 if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil { 1072 return fmt.Errorf("failed to refresh primary token: %v", err) 1073 } 1074 for _, aux := range mt.AuxiliaryTokens { 1075 if err := aux.EnsureFreshWithContext(ctx); err != nil { 1076 return fmt.Errorf("failed to refresh auxiliary token: %v", err) 1077 } 1078 } 1079 return nil 1080} 1081 1082// RefreshWithContext obtains a fresh token for the Service Principal. 1083func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error { 1084 if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil { 1085 return fmt.Errorf("failed to refresh primary token: %v", err) 1086 } 1087 for _, aux := range mt.AuxiliaryTokens { 1088 if err := aux.RefreshWithContext(ctx); err != nil { 1089 return fmt.Errorf("failed to refresh auxiliary token: %v", err) 1090 } 1091 } 1092 return nil 1093} 1094 1095// RefreshExchangeWithContext refreshes the token, but for a different resource. 1096func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { 1097 if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil { 1098 return fmt.Errorf("failed to refresh primary token: %v", err) 1099 } 1100 for _, aux := range mt.AuxiliaryTokens { 1101 if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil { 1102 return fmt.Errorf("failed to refresh auxiliary token: %v", err) 1103 } 1104 } 1105 return nil 1106} 1107 1108// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource. 1109func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) { 1110 if err := validateStringParam(clientID, "clientID"); err != nil { 1111 return nil, err 1112 } 1113 if err := validateStringParam(secret, "secret"); err != nil { 1114 return nil, err 1115 } 1116 if err := validateStringParam(resource, "resource"); err != nil { 1117 return nil, err 1118 } 1119 auxTenants := multiTenantCfg.AuxiliaryTenants() 1120 m := MultiTenantServicePrincipalToken{ 1121 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)), 1122 } 1123 primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource) 1124 if err != nil { 1125 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err) 1126 } 1127 m.PrimaryToken = primary 1128 for i := range auxTenants { 1129 aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource) 1130 if err != nil { 1131 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err) 1132 } 1133 m.AuxiliaryTokens[i] = aux 1134 } 1135 return &m, nil 1136} 1137