1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/sha1"
22	"crypto/x509"
23	"encoding/base64"
24	"encoding/json"
25	"errors"
26	"fmt"
27	"io"
28	"io/ioutil"
29	"math"
30	"net/http"
31	"net/url"
32	"os"
33	"strconv"
34	"strings"
35	"sync"
36	"time"
37
38	"github.com/Azure/go-autorest/autorest/date"
39	"github.com/form3tech-oss/jwt-go"
40)
41
42const (
43	defaultRefresh = 5 * time.Minute
44
45	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
46	OAuthGrantTypeDeviceCode = "device_code"
47
48	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
49	OAuthGrantTypeClientCredentials = "client_credentials"
50
51	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
52	OAuthGrantTypeUserPass = "password"
53
54	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
55	OAuthGrantTypeRefreshToken = "refresh_token"
56
57	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
58	OAuthGrantTypeAuthorizationCode = "authorization_code"
59
60	// metadataHeader is the header required by MSI extension
61	metadataHeader = "Metadata"
62
63	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
64	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
65
66	// the API version to use for the MSI endpoint
67	msiAPIVersion = "2018-02-01"
68
69	// the default number of attempts to refresh an MSI authentication token
70	defaultMaxMSIRefreshAttempts = 5
71
72	// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
73	asMSIEndpointEnv = "MSI_ENDPOINT"
74
75	// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
76	asMSISecretEnv = "MSI_SECRET"
77
78	// the API version to use for the App Service MSI endpoint
79	appServiceAPIVersion = "2017-09-01"
80
81	// secret header used when authenticating against app service MSI endpoint
82	secretHeader = "Secret"
83
84	// the format for expires_on in UTC (applicable to MSI via legacy ASE only)
85	expiresOnDateFormat = "1/2/2006 15:04:05 PM +00:00"
86)
87
88// OAuthTokenProvider is an interface which should be implemented by an access token retriever
89type OAuthTokenProvider interface {
90	OAuthToken() string
91}
92
93// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
94type MultitenantOAuthTokenProvider interface {
95	PrimaryOAuthToken() string
96	AuxiliaryOAuthTokens() []string
97}
98
99// TokenRefreshError is an interface used by errors returned during token refresh.
100type TokenRefreshError interface {
101	error
102	Response() *http.Response
103}
104
105// Refresher is an interface for token refresh functionality
106type Refresher interface {
107	Refresh() error
108	RefreshExchange(resource string) error
109	EnsureFresh() error
110}
111
112// RefresherWithContext is an interface for token refresh functionality
113type RefresherWithContext interface {
114	RefreshWithContext(ctx context.Context) error
115	RefreshExchangeWithContext(ctx context.Context, resource string) error
116	EnsureFreshWithContext(ctx context.Context) error
117}
118
119// TokenRefreshCallback is the type representing callbacks that will be called after
120// a successful token refresh
121type TokenRefreshCallback func(Token) error
122
123// TokenRefresh is a type representing a custom callback to refresh a token
124type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
125
126// Token encapsulates the access token used to authorize Azure requests.
127// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
128type Token struct {
129	AccessToken  string `json:"access_token"`
130	RefreshToken string `json:"refresh_token"`
131
132	ExpiresIn json.Number `json:"expires_in"`
133	ExpiresOn json.Number `json:"expires_on"`
134	NotBefore json.Number `json:"not_before"`
135
136	Resource string `json:"resource"`
137	Type     string `json:"token_type"`
138}
139
140func newToken() Token {
141	return Token{
142		ExpiresIn: "0",
143		ExpiresOn: "0",
144		NotBefore: "0",
145	}
146}
147
148// IsZero returns true if the token object is zero-initialized.
149func (t Token) IsZero() bool {
150	return t == Token{}
151}
152
153// Expires returns the time.Time when the Token expires.
154func (t Token) Expires() time.Time {
155	s, err := t.ExpiresOn.Float64()
156	if err != nil {
157		s = -3600
158	}
159
160	expiration := date.NewUnixTimeFromSeconds(s)
161
162	return time.Time(expiration).UTC()
163}
164
165// IsExpired returns true if the Token is expired, false otherwise.
166func (t Token) IsExpired() bool {
167	return t.WillExpireIn(0)
168}
169
170// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
171// from now, false otherwise.
172func (t Token) WillExpireIn(d time.Duration) bool {
173	return !t.Expires().After(time.Now().Add(d))
174}
175
176//OAuthToken return the current access token
177func (t *Token) OAuthToken() string {
178	return t.AccessToken
179}
180
181// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
182// that is submitted when acquiring an oAuth token.
183type ServicePrincipalSecret interface {
184	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
185}
186
187// ServicePrincipalNoSecret represents a secret type that contains no secret
188// meaning it is not valid for fetching a fresh token. This is used by Manual
189type ServicePrincipalNoSecret struct {
190}
191
192// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
193// It only returns an error for the ServicePrincipalNoSecret type
194func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
195	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
196}
197
198// MarshalJSON implements the json.Marshaler interface.
199func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
200	type tokenType struct {
201		Type string `json:"type"`
202	}
203	return json.Marshal(tokenType{
204		Type: "ServicePrincipalNoSecret",
205	})
206}
207
208// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
209type ServicePrincipalTokenSecret struct {
210	ClientSecret string `json:"value"`
211}
212
213// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
214// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
215func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
216	v.Set("client_secret", tokenSecret.ClientSecret)
217	return nil
218}
219
220// MarshalJSON implements the json.Marshaler interface.
221func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
222	type tokenType struct {
223		Type  string `json:"type"`
224		Value string `json:"value"`
225	}
226	return json.Marshal(tokenType{
227		Type:  "ServicePrincipalTokenSecret",
228		Value: tokenSecret.ClientSecret,
229	})
230}
231
232// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
233type ServicePrincipalCertificateSecret struct {
234	Certificate *x509.Certificate
235	PrivateKey  *rsa.PrivateKey
236}
237
238// SignJwt returns the JWT signed with the certificate's private key.
239func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
240	hasher := sha1.New()
241	_, err := hasher.Write(secret.Certificate.Raw)
242	if err != nil {
243		return "", err
244	}
245
246	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
247
248	// The jti (JWT ID) claim provides a unique identifier for the JWT.
249	jti := make([]byte, 20)
250	_, err = rand.Read(jti)
251	if err != nil {
252		return "", err
253	}
254
255	token := jwt.New(jwt.SigningMethodRS256)
256	token.Header["x5t"] = thumbprint
257	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
258	token.Header["x5c"] = x5c
259	token.Claims = jwt.MapClaims{
260		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
261		"iss": spt.inner.ClientID,
262		"sub": spt.inner.ClientID,
263		"jti": base64.URLEncoding.EncodeToString(jti),
264		"nbf": time.Now().Unix(),
265		"exp": time.Now().Add(24 * time.Hour).Unix(),
266	}
267
268	signedString, err := token.SignedString(secret.PrivateKey)
269	return signedString, err
270}
271
272// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
273// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
274func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
275	jwt, err := secret.SignJwt(spt)
276	if err != nil {
277		return err
278	}
279
280	v.Set("client_assertion", jwt)
281	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
282	return nil
283}
284
285// MarshalJSON implements the json.Marshaler interface.
286func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
287	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
288}
289
290// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
291type ServicePrincipalMSISecret struct {
292}
293
294// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
295func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
296	return nil
297}
298
299// MarshalJSON implements the json.Marshaler interface.
300func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
301	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
302}
303
304// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
305type ServicePrincipalUsernamePasswordSecret struct {
306	Username string `json:"username"`
307	Password string `json:"password"`
308}
309
310// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
311func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
312	v.Set("username", secret.Username)
313	v.Set("password", secret.Password)
314	return nil
315}
316
317// MarshalJSON implements the json.Marshaler interface.
318func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
319	type tokenType struct {
320		Type     string `json:"type"`
321		Username string `json:"username"`
322		Password string `json:"password"`
323	}
324	return json.Marshal(tokenType{
325		Type:     "ServicePrincipalUsernamePasswordSecret",
326		Username: secret.Username,
327		Password: secret.Password,
328	})
329}
330
331// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
332type ServicePrincipalAuthorizationCodeSecret struct {
333	ClientSecret      string `json:"value"`
334	AuthorizationCode string `json:"authCode"`
335	RedirectURI       string `json:"redirect"`
336}
337
338// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
339func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
340	v.Set("code", secret.AuthorizationCode)
341	v.Set("client_secret", secret.ClientSecret)
342	v.Set("redirect_uri", secret.RedirectURI)
343	return nil
344}
345
346// MarshalJSON implements the json.Marshaler interface.
347func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
348	type tokenType struct {
349		Type     string `json:"type"`
350		Value    string `json:"value"`
351		AuthCode string `json:"authCode"`
352		Redirect string `json:"redirect"`
353	}
354	return json.Marshal(tokenType{
355		Type:     "ServicePrincipalAuthorizationCodeSecret",
356		Value:    secret.ClientSecret,
357		AuthCode: secret.AuthorizationCode,
358		Redirect: secret.RedirectURI,
359	})
360}
361
362// ServicePrincipalToken encapsulates a Token created for a Service Principal.
363type ServicePrincipalToken struct {
364	inner             servicePrincipalToken
365	refreshLock       *sync.RWMutex
366	sender            Sender
367	customRefreshFunc TokenRefresh
368	refreshCallbacks  []TokenRefreshCallback
369	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
370	// Settings this to a value less than 1 will use the default value.
371	MaxMSIRefreshAttempts int
372}
373
374// MarshalTokenJSON returns the marshalled inner token.
375func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
376	return json.Marshal(spt.inner.Token)
377}
378
379// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
380func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
381	spt.refreshCallbacks = callbacks
382}
383
384// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
385func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
386	spt.customRefreshFunc = customRefreshFunc
387}
388
389// MarshalJSON implements the json.Marshaler interface.
390func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
391	return json.Marshal(spt.inner)
392}
393
394// UnmarshalJSON implements the json.Unmarshaler interface.
395func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
396	// need to determine the token type
397	raw := map[string]interface{}{}
398	err := json.Unmarshal(data, &raw)
399	if err != nil {
400		return err
401	}
402	secret := raw["secret"].(map[string]interface{})
403	switch secret["type"] {
404	case "ServicePrincipalNoSecret":
405		spt.inner.Secret = &ServicePrincipalNoSecret{}
406	case "ServicePrincipalTokenSecret":
407		spt.inner.Secret = &ServicePrincipalTokenSecret{}
408	case "ServicePrincipalCertificateSecret":
409		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
410	case "ServicePrincipalMSISecret":
411		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
412	case "ServicePrincipalUsernamePasswordSecret":
413		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
414	case "ServicePrincipalAuthorizationCodeSecret":
415		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
416	default:
417		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
418	}
419	err = json.Unmarshal(data, &spt.inner)
420	if err != nil {
421		return err
422	}
423	// Don't override the refreshLock or the sender if those have been already set.
424	if spt.refreshLock == nil {
425		spt.refreshLock = &sync.RWMutex{}
426	}
427	if spt.sender == nil {
428		spt.sender = sender()
429	}
430	return nil
431}
432
433// internal type used for marshalling/unmarshalling
434type servicePrincipalToken struct {
435	Token         Token                  `json:"token"`
436	Secret        ServicePrincipalSecret `json:"secret"`
437	OauthConfig   OAuthConfig            `json:"oauth"`
438	ClientID      string                 `json:"clientID"`
439	Resource      string                 `json:"resource"`
440	AutoRefresh   bool                   `json:"autoRefresh"`
441	RefreshWithin time.Duration          `json:"refreshWithin"`
442}
443
444func validateOAuthConfig(oac OAuthConfig) error {
445	if oac.IsZero() {
446		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
447	}
448	return nil
449}
450
451// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
452func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
453	if err := validateOAuthConfig(oauthConfig); err != nil {
454		return nil, err
455	}
456	if err := validateStringParam(id, "id"); err != nil {
457		return nil, err
458	}
459	if err := validateStringParam(resource, "resource"); err != nil {
460		return nil, err
461	}
462	if secret == nil {
463		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
464	}
465	spt := &ServicePrincipalToken{
466		inner: servicePrincipalToken{
467			Token:         newToken(),
468			OauthConfig:   oauthConfig,
469			Secret:        secret,
470			ClientID:      id,
471			Resource:      resource,
472			AutoRefresh:   true,
473			RefreshWithin: defaultRefresh,
474		},
475		refreshLock:      &sync.RWMutex{},
476		sender:           sender(),
477		refreshCallbacks: callbacks,
478	}
479	return spt, nil
480}
481
482// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
483func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
484	if err := validateOAuthConfig(oauthConfig); err != nil {
485		return nil, err
486	}
487	if err := validateStringParam(clientID, "clientID"); err != nil {
488		return nil, err
489	}
490	if err := validateStringParam(resource, "resource"); err != nil {
491		return nil, err
492	}
493	if token.IsZero() {
494		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
495	}
496	spt, err := NewServicePrincipalTokenWithSecret(
497		oauthConfig,
498		clientID,
499		resource,
500		&ServicePrincipalNoSecret{},
501		callbacks...)
502	if err != nil {
503		return nil, err
504	}
505
506	spt.inner.Token = token
507
508	return spt, nil
509}
510
511// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
512func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
513	if err := validateOAuthConfig(oauthConfig); err != nil {
514		return nil, err
515	}
516	if err := validateStringParam(clientID, "clientID"); err != nil {
517		return nil, err
518	}
519	if err := validateStringParam(resource, "resource"); err != nil {
520		return nil, err
521	}
522	if secret == nil {
523		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
524	}
525	if token.IsZero() {
526		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
527	}
528	spt, err := NewServicePrincipalTokenWithSecret(
529		oauthConfig,
530		clientID,
531		resource,
532		secret,
533		callbacks...)
534	if err != nil {
535		return nil, err
536	}
537
538	spt.inner.Token = token
539
540	return spt, nil
541}
542
543// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
544// credentials scoped to the named resource.
545func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
546	if err := validateOAuthConfig(oauthConfig); err != nil {
547		return nil, err
548	}
549	if err := validateStringParam(clientID, "clientID"); err != nil {
550		return nil, err
551	}
552	if err := validateStringParam(secret, "secret"); err != nil {
553		return nil, err
554	}
555	if err := validateStringParam(resource, "resource"); err != nil {
556		return nil, err
557	}
558	return NewServicePrincipalTokenWithSecret(
559		oauthConfig,
560		clientID,
561		resource,
562		&ServicePrincipalTokenSecret{
563			ClientSecret: secret,
564		},
565		callbacks...,
566	)
567}
568
569// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
570func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
571	if err := validateOAuthConfig(oauthConfig); err != nil {
572		return nil, err
573	}
574	if err := validateStringParam(clientID, "clientID"); err != nil {
575		return nil, err
576	}
577	if err := validateStringParam(resource, "resource"); err != nil {
578		return nil, err
579	}
580	if certificate == nil {
581		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
582	}
583	if privateKey == nil {
584		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
585	}
586	return NewServicePrincipalTokenWithSecret(
587		oauthConfig,
588		clientID,
589		resource,
590		&ServicePrincipalCertificateSecret{
591			PrivateKey:  privateKey,
592			Certificate: certificate,
593		},
594		callbacks...,
595	)
596}
597
598// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
599func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
600	if err := validateOAuthConfig(oauthConfig); err != nil {
601		return nil, err
602	}
603	if err := validateStringParam(clientID, "clientID"); err != nil {
604		return nil, err
605	}
606	if err := validateStringParam(username, "username"); err != nil {
607		return nil, err
608	}
609	if err := validateStringParam(password, "password"); err != nil {
610		return nil, err
611	}
612	if err := validateStringParam(resource, "resource"); err != nil {
613		return nil, err
614	}
615	return NewServicePrincipalTokenWithSecret(
616		oauthConfig,
617		clientID,
618		resource,
619		&ServicePrincipalUsernamePasswordSecret{
620			Username: username,
621			Password: password,
622		},
623		callbacks...,
624	)
625}
626
627// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
628func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
629
630	if err := validateOAuthConfig(oauthConfig); err != nil {
631		return nil, err
632	}
633	if err := validateStringParam(clientID, "clientID"); err != nil {
634		return nil, err
635	}
636	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
637		return nil, err
638	}
639	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
640		return nil, err
641	}
642	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
643		return nil, err
644	}
645	if err := validateStringParam(resource, "resource"); err != nil {
646		return nil, err
647	}
648
649	return NewServicePrincipalTokenWithSecret(
650		oauthConfig,
651		clientID,
652		resource,
653		&ServicePrincipalAuthorizationCodeSecret{
654			ClientSecret:      clientSecret,
655			AuthorizationCode: authorizationCode,
656			RedirectURI:       redirectURI,
657		},
658		callbacks...,
659	)
660}
661
662// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
663func GetMSIVMEndpoint() (string, error) {
664	return msiEndpoint, nil
665}
666
667// NOTE: this only indicates if the ASE environment credentials have been set
668// which does not necessarily mean that the caller is authenticating via ASE!
669func isAppService() bool {
670	_, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
671	_, asMSISecretEnvExists := os.LookupEnv(asMSISecretEnv)
672
673	return asMSIEndpointEnvExists && asMSISecretEnvExists
674}
675
676// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions
677func GetMSIAppServiceEndpoint() (string, error) {
678	asMSIEndpoint, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
679
680	if asMSIEndpointEnvExists {
681		return asMSIEndpoint, nil
682	}
683	return "", errors.New("MSI endpoint not found")
684}
685
686// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
687func GetMSIEndpoint() (string, error) {
688	if isAppService() {
689		return GetMSIAppServiceEndpoint()
690	}
691	return GetMSIVMEndpoint()
692}
693
694// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
695// It will use the system assigned identity when creating the token.
696func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
697	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, nil, callbacks...)
698}
699
700// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
701// It will use the clientID of specified user assigned identity when creating the token.
702func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
703	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, nil, callbacks...)
704}
705
706// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
707// It will use the azure resource id of user assigned identity when creating the token.
708func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
709	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, &identityResourceID, callbacks...)
710}
711
712func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, identityResourceID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
713	if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
714		return nil, err
715	}
716	if err := validateStringParam(resource, "resource"); err != nil {
717		return nil, err
718	}
719	if userAssignedID != nil {
720		if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
721			return nil, err
722		}
723	}
724	if identityResourceID != nil {
725		if err := validateStringParam(*identityResourceID, "identityResourceID"); err != nil {
726			return nil, err
727		}
728	}
729	// We set the oauth config token endpoint to be MSI's endpoint
730	msiEndpointURL, err := url.Parse(msiEndpoint)
731	if err != nil {
732		return nil, err
733	}
734
735	v := url.Values{}
736	v.Set("resource", resource)
737	// we only support token API version 2017-09-01 for app services
738	clientIDParam := "client_id"
739	if isASEEndpoint(*msiEndpointURL) {
740		v.Set("api-version", appServiceAPIVersion)
741		clientIDParam = "clientid"
742	} else {
743		v.Set("api-version", msiAPIVersion)
744	}
745	if userAssignedID != nil {
746		v.Set(clientIDParam, *userAssignedID)
747	}
748	if identityResourceID != nil {
749		v.Set("mi_res_id", *identityResourceID)
750	}
751	msiEndpointURL.RawQuery = v.Encode()
752
753	spt := &ServicePrincipalToken{
754		inner: servicePrincipalToken{
755			Token: newToken(),
756			OauthConfig: OAuthConfig{
757				TokenEndpoint: *msiEndpointURL,
758			},
759			Secret:        &ServicePrincipalMSISecret{},
760			Resource:      resource,
761			AutoRefresh:   true,
762			RefreshWithin: defaultRefresh,
763		},
764		refreshLock:           &sync.RWMutex{},
765		sender:                sender(),
766		refreshCallbacks:      callbacks,
767		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
768	}
769
770	if userAssignedID != nil {
771		spt.inner.ClientID = *userAssignedID
772	}
773
774	return spt, nil
775}
776
777// internal type that implements TokenRefreshError
778type tokenRefreshError struct {
779	message string
780	resp    *http.Response
781}
782
783// Error implements the error interface which is part of the TokenRefreshError interface.
784func (tre tokenRefreshError) Error() string {
785	return tre.message
786}
787
788// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
789func (tre tokenRefreshError) Response() *http.Response {
790	return tre.resp
791}
792
793func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
794	return tokenRefreshError{message: message, resp: resp}
795}
796
797// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
798// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
799func (spt *ServicePrincipalToken) EnsureFresh() error {
800	return spt.EnsureFreshWithContext(context.Background())
801}
802
803// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
804// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
805func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
806	// must take the read lock when initially checking the token's expiration
807	if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
808		// take the write lock then check again to see if the token was already refreshed
809		spt.refreshLock.Lock()
810		defer spt.refreshLock.Unlock()
811		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
812			return spt.refreshInternal(ctx, spt.inner.Resource)
813		}
814	}
815	return nil
816}
817
818// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
819func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
820	if spt.refreshCallbacks != nil {
821		for _, callback := range spt.refreshCallbacks {
822			err := callback(spt.inner.Token)
823			if err != nil {
824				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
825			}
826		}
827	}
828	return nil
829}
830
831// Refresh obtains a fresh token for the Service Principal.
832// This method is safe for concurrent use.
833func (spt *ServicePrincipalToken) Refresh() error {
834	return spt.RefreshWithContext(context.Background())
835}
836
837// RefreshWithContext obtains a fresh token for the Service Principal.
838// This method is safe for concurrent use.
839func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
840	spt.refreshLock.Lock()
841	defer spt.refreshLock.Unlock()
842	return spt.refreshInternal(ctx, spt.inner.Resource)
843}
844
845// RefreshExchange refreshes the token, but for a different resource.
846// This method is safe for concurrent use.
847func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
848	return spt.RefreshExchangeWithContext(context.Background(), resource)
849}
850
851// RefreshExchangeWithContext refreshes the token, but for a different resource.
852// This method is safe for concurrent use.
853func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
854	spt.refreshLock.Lock()
855	defer spt.refreshLock.Unlock()
856	return spt.refreshInternal(ctx, resource)
857}
858
859func (spt *ServicePrincipalToken) getGrantType() string {
860	switch spt.inner.Secret.(type) {
861	case *ServicePrincipalUsernamePasswordSecret:
862		return OAuthGrantTypeUserPass
863	case *ServicePrincipalAuthorizationCodeSecret:
864		return OAuthGrantTypeAuthorizationCode
865	default:
866		return OAuthGrantTypeClientCredentials
867	}
868}
869
870func isIMDS(u url.URL) bool {
871	return isMSIEndpoint(u) == true || isASEEndpoint(u) == true
872}
873
874func isMSIEndpoint(endpoint url.URL) bool {
875	msi, err := url.Parse(msiEndpoint)
876	if err != nil {
877		return false
878	}
879	return endpoint.Host == msi.Host && endpoint.Path == msi.Path
880}
881
882func isASEEndpoint(endpoint url.URL) bool {
883	aseEndpoint, err := GetMSIAppServiceEndpoint()
884	if err != nil {
885		// app service environment isn't enabled
886		return false
887	}
888	ase, err := url.Parse(aseEndpoint)
889	if err != nil {
890		return false
891	}
892	return endpoint.Host == ase.Host && endpoint.Path == ase.Path
893}
894
895func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
896	if spt.customRefreshFunc != nil {
897		token, err := spt.customRefreshFunc(ctx, resource)
898		if err != nil {
899			return err
900		}
901		spt.inner.Token = *token
902		return spt.InvokeRefreshCallbacks(spt.inner.Token)
903	}
904	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
905	if err != nil {
906		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
907	}
908	req.Header.Add("User-Agent", UserAgent())
909	// Add header when runtime is on App Service or Functions
910	if isASEEndpoint(spt.inner.OauthConfig.TokenEndpoint) {
911		asMSISecret, _ := os.LookupEnv(asMSISecretEnv)
912		req.Header.Add(secretHeader, asMSISecret)
913	}
914	req = req.WithContext(ctx)
915	if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
916		v := url.Values{}
917		v.Set("client_id", spt.inner.ClientID)
918		v.Set("resource", resource)
919
920		if spt.inner.Token.RefreshToken != "" {
921			v.Set("grant_type", OAuthGrantTypeRefreshToken)
922			v.Set("refresh_token", spt.inner.Token.RefreshToken)
923			// web apps must specify client_secret when refreshing tokens
924			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
925			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
926				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
927				if err != nil {
928					return err
929				}
930			}
931		} else {
932			v.Set("grant_type", spt.getGrantType())
933			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
934			if err != nil {
935				return err
936			}
937		}
938
939		s := v.Encode()
940		body := ioutil.NopCloser(strings.NewReader(s))
941		req.ContentLength = int64(len(s))
942		req.Header.Set(contentType, mimeTypeFormPost)
943		req.Body = body
944	}
945
946	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
947		req.Method = http.MethodGet
948		// the metadata header isn't applicable for ASE
949		if !isASEEndpoint(spt.inner.OauthConfig.TokenEndpoint) {
950			req.Header.Set(metadataHeader, "true")
951		}
952	}
953
954	var resp *http.Response
955	if isMSIEndpoint(spt.inner.OauthConfig.TokenEndpoint) {
956		resp, err = getMSIEndpoint(ctx, spt.sender)
957		if err != nil {
958			// return a TokenRefreshError here so that we don't keep retrying
959			return newTokenRefreshError(fmt.Sprintf("the MSI endpoint is not available. Failed HTTP request to MSI endpoint: %v", err), nil)
960		}
961		resp.Body.Close()
962	}
963	if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
964		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
965	} else {
966		resp, err = spt.sender.Do(req)
967	}
968	if err != nil {
969		// don't return a TokenRefreshError here; this will allow retry logic to apply
970		return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
971	}
972
973	defer resp.Body.Close()
974	rb, err := ioutil.ReadAll(resp.Body)
975
976	if resp.StatusCode != http.StatusOK {
977		if err != nil {
978			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
979		}
980		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
981	}
982
983	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
984	// but some transient failure happened during deserialization.  by returning a generic error
985	// the retry logic will kick in (we don't retry on TokenRefreshError).
986
987	if err != nil {
988		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
989	}
990	if len(strings.Trim(string(rb), " ")) == 0 {
991		return fmt.Errorf("adal: Empty service principal token received during refresh")
992	}
993	token := struct {
994		AccessToken  string `json:"access_token"`
995		RefreshToken string `json:"refresh_token"`
996
997		// AAD returns expires_in as a string, ADFS returns it as an int
998		ExpiresIn json.Number `json:"expires_in"`
999		// expires_on can be in two formats, a UTC time stamp or the number of seconds.
1000		ExpiresOn string      `json:"expires_on"`
1001		NotBefore json.Number `json:"not_before"`
1002
1003		Resource string `json:"resource"`
1004		Type     string `json:"token_type"`
1005	}{}
1006	err = json.Unmarshal(rb, &token)
1007	if err != nil {
1008		return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
1009	}
1010	expiresOn, err := parseExpiresOn(token.ExpiresOn)
1011	if err != nil {
1012		return err
1013	}
1014	spt.inner.Token.AccessToken = token.AccessToken
1015	spt.inner.Token.RefreshToken = token.RefreshToken
1016	spt.inner.Token.ExpiresIn = token.ExpiresIn
1017	spt.inner.Token.ExpiresOn = expiresOn
1018	spt.inner.Token.NotBefore = token.NotBefore
1019	spt.inner.Token.Resource = token.Resource
1020	spt.inner.Token.Type = token.Type
1021
1022	return spt.InvokeRefreshCallbacks(spt.inner.Token)
1023}
1024
1025// converts expires_on to the number of seconds
1026func parseExpiresOn(s string) (json.Number, error) {
1027	if _, err := strconv.ParseInt(s, 10, 64); err == nil {
1028		// this is the number of seconds case, no conversion required
1029		return json.Number(s), nil
1030	} else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil {
1031		// convert the expiration date to the number of seconds from now
1032		dur := eo.Sub(time.Now().UTC())
1033		return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10)), nil
1034	} else {
1035		// unknown format
1036		return json.Number(""), err
1037	}
1038}
1039
1040// retry logic specific to retrieving a token from the IMDS endpoint
1041func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
1042	// copied from client.go due to circular dependency
1043	retries := []int{
1044		http.StatusRequestTimeout,      // 408
1045		http.StatusTooManyRequests,     // 429
1046		http.StatusInternalServerError, // 500
1047		http.StatusBadGateway,          // 502
1048		http.StatusServiceUnavailable,  // 503
1049		http.StatusGatewayTimeout,      // 504
1050	}
1051	// extra retry status codes specific to IMDS
1052	retries = append(retries,
1053		http.StatusNotFound,
1054		http.StatusGone,
1055		// all remaining 5xx
1056		http.StatusNotImplemented,
1057		http.StatusHTTPVersionNotSupported,
1058		http.StatusVariantAlsoNegotiates,
1059		http.StatusInsufficientStorage,
1060		http.StatusLoopDetected,
1061		http.StatusNotExtended,
1062		http.StatusNetworkAuthenticationRequired)
1063
1064	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
1065
1066	const maxDelay time.Duration = 60 * time.Second
1067
1068	attempt := 0
1069	delay := time.Duration(0)
1070
1071	// maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made
1072	if maxAttempts < 1 {
1073		maxAttempts = defaultMaxMSIRefreshAttempts
1074	}
1075
1076	for attempt < maxAttempts {
1077		if resp != nil && resp.Body != nil {
1078			io.Copy(ioutil.Discard, resp.Body)
1079			resp.Body.Close()
1080		}
1081		resp, err = sender.Do(req)
1082		// we want to retry if err is not nil or the status code is in the list of retry codes
1083		if err == nil && !responseHasStatusCode(resp, retries...) {
1084			return
1085		}
1086
1087		// perform exponential backoff with a cap.
1088		// must increment attempt before calculating delay.
1089		attempt++
1090		// the base value of 2 is the "delta backoff" as specified in the guidance doc
1091		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
1092		if delay > maxDelay {
1093			delay = maxDelay
1094		}
1095
1096		select {
1097		case <-time.After(delay):
1098			// intentionally left blank
1099		case <-req.Context().Done():
1100			err = req.Context().Err()
1101			return
1102		}
1103	}
1104	return
1105}
1106
1107func responseHasStatusCode(resp *http.Response, codes ...int) bool {
1108	if resp != nil {
1109		for _, i := range codes {
1110			if i == resp.StatusCode {
1111				return true
1112			}
1113		}
1114	}
1115	return false
1116}
1117
1118// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
1119func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
1120	spt.inner.AutoRefresh = autoRefresh
1121}
1122
1123// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
1124// refresh the token.
1125func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
1126	spt.inner.RefreshWithin = d
1127	return
1128}
1129
1130// SetSender sets the http.Client used when obtaining the Service Principal token. An
1131// undecorated http.Client is used by default.
1132func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
1133
1134// OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
1135func (spt *ServicePrincipalToken) OAuthToken() string {
1136	spt.refreshLock.RLock()
1137	defer spt.refreshLock.RUnlock()
1138	return spt.inner.Token.OAuthToken()
1139}
1140
1141// Token returns a copy of the current token.
1142func (spt *ServicePrincipalToken) Token() Token {
1143	spt.refreshLock.RLock()
1144	defer spt.refreshLock.RUnlock()
1145	return spt.inner.Token
1146}
1147
1148// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
1149type MultiTenantServicePrincipalToken struct {
1150	PrimaryToken    *ServicePrincipalToken
1151	AuxiliaryTokens []*ServicePrincipalToken
1152}
1153
1154// PrimaryOAuthToken returns the primary authorization token.
1155func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
1156	return mt.PrimaryToken.OAuthToken()
1157}
1158
1159// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
1160func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
1161	tokens := make([]string, len(mt.AuxiliaryTokens))
1162	for i := range mt.AuxiliaryTokens {
1163		tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
1164	}
1165	return tokens
1166}
1167
1168// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
1169func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
1170	if err := validateStringParam(clientID, "clientID"); err != nil {
1171		return nil, err
1172	}
1173	if err := validateStringParam(secret, "secret"); err != nil {
1174		return nil, err
1175	}
1176	if err := validateStringParam(resource, "resource"); err != nil {
1177		return nil, err
1178	}
1179	auxTenants := multiTenantCfg.AuxiliaryTenants()
1180	m := MultiTenantServicePrincipalToken{
1181		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1182	}
1183	primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
1184	if err != nil {
1185		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1186	}
1187	m.PrimaryToken = primary
1188	for i := range auxTenants {
1189		aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
1190		if err != nil {
1191			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1192		}
1193		m.AuxiliaryTokens[i] = aux
1194	}
1195	return &m, nil
1196}
1197
1198// NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource.
1199func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
1200	if err := validateStringParam(clientID, "clientID"); err != nil {
1201		return nil, err
1202	}
1203	if err := validateStringParam(resource, "resource"); err != nil {
1204		return nil, err
1205	}
1206	if certificate == nil {
1207		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
1208	}
1209	if privateKey == nil {
1210		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
1211	}
1212	auxTenants := multiTenantCfg.AuxiliaryTenants()
1213	m := MultiTenantServicePrincipalToken{
1214		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1215	}
1216	primary, err := NewServicePrincipalTokenWithSecret(
1217		*multiTenantCfg.PrimaryTenant(),
1218		clientID,
1219		resource,
1220		&ServicePrincipalCertificateSecret{
1221			PrivateKey:  privateKey,
1222			Certificate: certificate,
1223		},
1224	)
1225	if err != nil {
1226		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1227	}
1228	m.PrimaryToken = primary
1229	for i := range auxTenants {
1230		aux, err := NewServicePrincipalTokenWithSecret(
1231			*auxTenants[i],
1232			clientID,
1233			resource,
1234			&ServicePrincipalCertificateSecret{
1235				PrivateKey:  privateKey,
1236				Certificate: certificate,
1237			},
1238		)
1239		if err != nil {
1240			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1241		}
1242		m.AuxiliaryTokens[i] = aux
1243	}
1244	return &m, nil
1245}
1246
1247// MSIAvailable returns true if the MSI endpoint is available for authentication.
1248func MSIAvailable(ctx context.Context, sender Sender) bool {
1249	resp, err := getMSIEndpoint(ctx, sender)
1250	if err == nil {
1251		resp.Body.Close()
1252	}
1253	return err == nil
1254}
1255