1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/sha1"
22	"crypto/x509"
23	"encoding/base64"
24	"encoding/json"
25	"errors"
26	"fmt"
27	"io"
28	"io/ioutil"
29	"math"
30	"net/http"
31	"net/url"
32	"os"
33	"strconv"
34	"strings"
35	"sync"
36	"time"
37
38	"github.com/Azure/go-autorest/autorest/date"
39	"github.com/Azure/go-autorest/logger"
40	"github.com/form3tech-oss/jwt-go"
41)
42
43const (
44	defaultRefresh = 5 * time.Minute
45
46	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
47	OAuthGrantTypeDeviceCode = "device_code"
48
49	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
50	OAuthGrantTypeClientCredentials = "client_credentials"
51
52	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
53	OAuthGrantTypeUserPass = "password"
54
55	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
56	OAuthGrantTypeRefreshToken = "refresh_token"
57
58	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
59	OAuthGrantTypeAuthorizationCode = "authorization_code"
60
61	// metadataHeader is the header required by MSI extension
62	metadataHeader = "Metadata"
63
64	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
65	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
66
67	// the API version to use for the MSI endpoint
68	msiAPIVersion = "2018-02-01"
69
70	// the default number of attempts to refresh an MSI authentication token
71	defaultMaxMSIRefreshAttempts = 5
72
73	// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
74	msiEndpointEnv = "MSI_ENDPOINT"
75
76	// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
77	msiSecretEnv = "MSI_SECRET"
78
79	// the API version to use for the legacy App Service MSI endpoint
80	appServiceAPIVersion2017 = "2017-09-01"
81
82	// secret header used when authenticating against app service MSI endpoint
83	secretHeader = "Secret"
84
85	// the format for expires_on in UTC with AM/PM
86	expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00"
87
88	// the format for expires_on in UTC without AM/PM
89	expiresOnDateFormat = "1/2/2006 15:04:05 +00:00"
90)
91
92// OAuthTokenProvider is an interface which should be implemented by an access token retriever
93type OAuthTokenProvider interface {
94	OAuthToken() string
95}
96
97// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
98type MultitenantOAuthTokenProvider interface {
99	PrimaryOAuthToken() string
100	AuxiliaryOAuthTokens() []string
101}
102
103// TokenRefreshError is an interface used by errors returned during token refresh.
104type TokenRefreshError interface {
105	error
106	Response() *http.Response
107}
108
109// Refresher is an interface for token refresh functionality
110type Refresher interface {
111	Refresh() error
112	RefreshExchange(resource string) error
113	EnsureFresh() error
114}
115
116// RefresherWithContext is an interface for token refresh functionality
117type RefresherWithContext interface {
118	RefreshWithContext(ctx context.Context) error
119	RefreshExchangeWithContext(ctx context.Context, resource string) error
120	EnsureFreshWithContext(ctx context.Context) error
121}
122
123// TokenRefreshCallback is the type representing callbacks that will be called after
124// a successful token refresh
125type TokenRefreshCallback func(Token) error
126
127// TokenRefresh is a type representing a custom callback to refresh a token
128type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
129
130// Token encapsulates the access token used to authorize Azure requests.
131// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
132type Token struct {
133	AccessToken  string `json:"access_token"`
134	RefreshToken string `json:"refresh_token"`
135
136	ExpiresIn json.Number `json:"expires_in"`
137	ExpiresOn json.Number `json:"expires_on"`
138	NotBefore json.Number `json:"not_before"`
139
140	Resource string `json:"resource"`
141	Type     string `json:"token_type"`
142}
143
144func newToken() Token {
145	return Token{
146		ExpiresIn: "0",
147		ExpiresOn: "0",
148		NotBefore: "0",
149	}
150}
151
152// IsZero returns true if the token object is zero-initialized.
153func (t Token) IsZero() bool {
154	return t == Token{}
155}
156
157// Expires returns the time.Time when the Token expires.
158func (t Token) Expires() time.Time {
159	s, err := t.ExpiresOn.Float64()
160	if err != nil {
161		s = -3600
162	}
163
164	expiration := date.NewUnixTimeFromSeconds(s)
165
166	return time.Time(expiration).UTC()
167}
168
169// IsExpired returns true if the Token is expired, false otherwise.
170func (t Token) IsExpired() bool {
171	return t.WillExpireIn(0)
172}
173
174// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
175// from now, false otherwise.
176func (t Token) WillExpireIn(d time.Duration) bool {
177	return !t.Expires().After(time.Now().Add(d))
178}
179
180//OAuthToken return the current access token
181func (t *Token) OAuthToken() string {
182	return t.AccessToken
183}
184
185// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
186// that is submitted when acquiring an oAuth token.
187type ServicePrincipalSecret interface {
188	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
189}
190
191// ServicePrincipalNoSecret represents a secret type that contains no secret
192// meaning it is not valid for fetching a fresh token. This is used by Manual
193type ServicePrincipalNoSecret struct {
194}
195
196// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
197// It only returns an error for the ServicePrincipalNoSecret type
198func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
199	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
200}
201
202// MarshalJSON implements the json.Marshaler interface.
203func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
204	type tokenType struct {
205		Type string `json:"type"`
206	}
207	return json.Marshal(tokenType{
208		Type: "ServicePrincipalNoSecret",
209	})
210}
211
212// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
213type ServicePrincipalTokenSecret struct {
214	ClientSecret string `json:"value"`
215}
216
217// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
218// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
219func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
220	v.Set("client_secret", tokenSecret.ClientSecret)
221	return nil
222}
223
224// MarshalJSON implements the json.Marshaler interface.
225func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
226	type tokenType struct {
227		Type  string `json:"type"`
228		Value string `json:"value"`
229	}
230	return json.Marshal(tokenType{
231		Type:  "ServicePrincipalTokenSecret",
232		Value: tokenSecret.ClientSecret,
233	})
234}
235
236// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
237type ServicePrincipalCertificateSecret struct {
238	Certificate *x509.Certificate
239	PrivateKey  *rsa.PrivateKey
240}
241
242// SignJwt returns the JWT signed with the certificate's private key.
243func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
244	hasher := sha1.New()
245	_, err := hasher.Write(secret.Certificate.Raw)
246	if err != nil {
247		return "", err
248	}
249
250	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
251
252	// The jti (JWT ID) claim provides a unique identifier for the JWT.
253	jti := make([]byte, 20)
254	_, err = rand.Read(jti)
255	if err != nil {
256		return "", err
257	}
258
259	token := jwt.New(jwt.SigningMethodRS256)
260	token.Header["x5t"] = thumbprint
261	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
262	token.Header["x5c"] = x5c
263	token.Claims = jwt.MapClaims{
264		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
265		"iss": spt.inner.ClientID,
266		"sub": spt.inner.ClientID,
267		"jti": base64.URLEncoding.EncodeToString(jti),
268		"nbf": time.Now().Unix(),
269		"exp": time.Now().Add(24 * time.Hour).Unix(),
270	}
271
272	signedString, err := token.SignedString(secret.PrivateKey)
273	return signedString, err
274}
275
276// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
277// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
278func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
279	jwt, err := secret.SignJwt(spt)
280	if err != nil {
281		return err
282	}
283
284	v.Set("client_assertion", jwt)
285	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
286	return nil
287}
288
289// MarshalJSON implements the json.Marshaler interface.
290func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
291	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
292}
293
294// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
295type ServicePrincipalMSISecret struct {
296	msiType          msiType
297	clientResourceID string
298}
299
300// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
301func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
302	return nil
303}
304
305// MarshalJSON implements the json.Marshaler interface.
306func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
307	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
308}
309
310// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
311type ServicePrincipalUsernamePasswordSecret struct {
312	Username string `json:"username"`
313	Password string `json:"password"`
314}
315
316// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
317func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
318	v.Set("username", secret.Username)
319	v.Set("password", secret.Password)
320	return nil
321}
322
323// MarshalJSON implements the json.Marshaler interface.
324func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
325	type tokenType struct {
326		Type     string `json:"type"`
327		Username string `json:"username"`
328		Password string `json:"password"`
329	}
330	return json.Marshal(tokenType{
331		Type:     "ServicePrincipalUsernamePasswordSecret",
332		Username: secret.Username,
333		Password: secret.Password,
334	})
335}
336
337// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
338type ServicePrincipalAuthorizationCodeSecret struct {
339	ClientSecret      string `json:"value"`
340	AuthorizationCode string `json:"authCode"`
341	RedirectURI       string `json:"redirect"`
342}
343
344// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
345func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
346	v.Set("code", secret.AuthorizationCode)
347	v.Set("client_secret", secret.ClientSecret)
348	v.Set("redirect_uri", secret.RedirectURI)
349	return nil
350}
351
352// MarshalJSON implements the json.Marshaler interface.
353func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
354	type tokenType struct {
355		Type     string `json:"type"`
356		Value    string `json:"value"`
357		AuthCode string `json:"authCode"`
358		Redirect string `json:"redirect"`
359	}
360	return json.Marshal(tokenType{
361		Type:     "ServicePrincipalAuthorizationCodeSecret",
362		Value:    secret.ClientSecret,
363		AuthCode: secret.AuthorizationCode,
364		Redirect: secret.RedirectURI,
365	})
366}
367
368// ServicePrincipalToken encapsulates a Token created for a Service Principal.
369type ServicePrincipalToken struct {
370	inner             servicePrincipalToken
371	refreshLock       *sync.RWMutex
372	sender            Sender
373	customRefreshFunc TokenRefresh
374	refreshCallbacks  []TokenRefreshCallback
375	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
376	// Settings this to a value less than 1 will use the default value.
377	MaxMSIRefreshAttempts int
378}
379
380// MarshalTokenJSON returns the marshalled inner token.
381func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
382	return json.Marshal(spt.inner.Token)
383}
384
385// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
386func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
387	spt.refreshCallbacks = callbacks
388}
389
390// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
391func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
392	spt.customRefreshFunc = customRefreshFunc
393}
394
395// MarshalJSON implements the json.Marshaler interface.
396func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
397	return json.Marshal(spt.inner)
398}
399
400// UnmarshalJSON implements the json.Unmarshaler interface.
401func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
402	// need to determine the token type
403	raw := map[string]interface{}{}
404	err := json.Unmarshal(data, &raw)
405	if err != nil {
406		return err
407	}
408	secret := raw["secret"].(map[string]interface{})
409	switch secret["type"] {
410	case "ServicePrincipalNoSecret":
411		spt.inner.Secret = &ServicePrincipalNoSecret{}
412	case "ServicePrincipalTokenSecret":
413		spt.inner.Secret = &ServicePrincipalTokenSecret{}
414	case "ServicePrincipalCertificateSecret":
415		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
416	case "ServicePrincipalMSISecret":
417		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
418	case "ServicePrincipalUsernamePasswordSecret":
419		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
420	case "ServicePrincipalAuthorizationCodeSecret":
421		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
422	default:
423		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
424	}
425	err = json.Unmarshal(data, &spt.inner)
426	if err != nil {
427		return err
428	}
429	// Don't override the refreshLock or the sender if those have been already set.
430	if spt.refreshLock == nil {
431		spt.refreshLock = &sync.RWMutex{}
432	}
433	if spt.sender == nil {
434		spt.sender = sender()
435	}
436	return nil
437}
438
439// internal type used for marshalling/unmarshalling
440type servicePrincipalToken struct {
441	Token         Token                  `json:"token"`
442	Secret        ServicePrincipalSecret `json:"secret"`
443	OauthConfig   OAuthConfig            `json:"oauth"`
444	ClientID      string                 `json:"clientID"`
445	Resource      string                 `json:"resource"`
446	AutoRefresh   bool                   `json:"autoRefresh"`
447	RefreshWithin time.Duration          `json:"refreshWithin"`
448}
449
450func validateOAuthConfig(oac OAuthConfig) error {
451	if oac.IsZero() {
452		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
453	}
454	return nil
455}
456
457// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
458func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
459	if err := validateOAuthConfig(oauthConfig); err != nil {
460		return nil, err
461	}
462	if err := validateStringParam(id, "id"); err != nil {
463		return nil, err
464	}
465	if err := validateStringParam(resource, "resource"); err != nil {
466		return nil, err
467	}
468	if secret == nil {
469		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
470	}
471	spt := &ServicePrincipalToken{
472		inner: servicePrincipalToken{
473			Token:         newToken(),
474			OauthConfig:   oauthConfig,
475			Secret:        secret,
476			ClientID:      id,
477			Resource:      resource,
478			AutoRefresh:   true,
479			RefreshWithin: defaultRefresh,
480		},
481		refreshLock:      &sync.RWMutex{},
482		sender:           sender(),
483		refreshCallbacks: callbacks,
484	}
485	return spt, nil
486}
487
488// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
489func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
490	if err := validateOAuthConfig(oauthConfig); err != nil {
491		return nil, err
492	}
493	if err := validateStringParam(clientID, "clientID"); err != nil {
494		return nil, err
495	}
496	if err := validateStringParam(resource, "resource"); err != nil {
497		return nil, err
498	}
499	if token.IsZero() {
500		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
501	}
502	spt, err := NewServicePrincipalTokenWithSecret(
503		oauthConfig,
504		clientID,
505		resource,
506		&ServicePrincipalNoSecret{},
507		callbacks...)
508	if err != nil {
509		return nil, err
510	}
511
512	spt.inner.Token = token
513
514	return spt, nil
515}
516
517// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
518func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
519	if err := validateOAuthConfig(oauthConfig); err != nil {
520		return nil, err
521	}
522	if err := validateStringParam(clientID, "clientID"); err != nil {
523		return nil, err
524	}
525	if err := validateStringParam(resource, "resource"); err != nil {
526		return nil, err
527	}
528	if secret == nil {
529		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
530	}
531	if token.IsZero() {
532		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
533	}
534	spt, err := NewServicePrincipalTokenWithSecret(
535		oauthConfig,
536		clientID,
537		resource,
538		secret,
539		callbacks...)
540	if err != nil {
541		return nil, err
542	}
543
544	spt.inner.Token = token
545
546	return spt, nil
547}
548
549// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
550// credentials scoped to the named resource.
551func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
552	if err := validateOAuthConfig(oauthConfig); err != nil {
553		return nil, err
554	}
555	if err := validateStringParam(clientID, "clientID"); err != nil {
556		return nil, err
557	}
558	if err := validateStringParam(secret, "secret"); err != nil {
559		return nil, err
560	}
561	if err := validateStringParam(resource, "resource"); err != nil {
562		return nil, err
563	}
564	return NewServicePrincipalTokenWithSecret(
565		oauthConfig,
566		clientID,
567		resource,
568		&ServicePrincipalTokenSecret{
569			ClientSecret: secret,
570		},
571		callbacks...,
572	)
573}
574
575// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
576func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
577	if err := validateOAuthConfig(oauthConfig); err != nil {
578		return nil, err
579	}
580	if err := validateStringParam(clientID, "clientID"); err != nil {
581		return nil, err
582	}
583	if err := validateStringParam(resource, "resource"); err != nil {
584		return nil, err
585	}
586	if certificate == nil {
587		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
588	}
589	if privateKey == nil {
590		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
591	}
592	return NewServicePrincipalTokenWithSecret(
593		oauthConfig,
594		clientID,
595		resource,
596		&ServicePrincipalCertificateSecret{
597			PrivateKey:  privateKey,
598			Certificate: certificate,
599		},
600		callbacks...,
601	)
602}
603
604// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
605func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
606	if err := validateOAuthConfig(oauthConfig); err != nil {
607		return nil, err
608	}
609	if err := validateStringParam(clientID, "clientID"); err != nil {
610		return nil, err
611	}
612	if err := validateStringParam(username, "username"); err != nil {
613		return nil, err
614	}
615	if err := validateStringParam(password, "password"); err != nil {
616		return nil, err
617	}
618	if err := validateStringParam(resource, "resource"); err != nil {
619		return nil, err
620	}
621	return NewServicePrincipalTokenWithSecret(
622		oauthConfig,
623		clientID,
624		resource,
625		&ServicePrincipalUsernamePasswordSecret{
626			Username: username,
627			Password: password,
628		},
629		callbacks...,
630	)
631}
632
633// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
634func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
635
636	if err := validateOAuthConfig(oauthConfig); err != nil {
637		return nil, err
638	}
639	if err := validateStringParam(clientID, "clientID"); err != nil {
640		return nil, err
641	}
642	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
643		return nil, err
644	}
645	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
646		return nil, err
647	}
648	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
649		return nil, err
650	}
651	if err := validateStringParam(resource, "resource"); err != nil {
652		return nil, err
653	}
654
655	return NewServicePrincipalTokenWithSecret(
656		oauthConfig,
657		clientID,
658		resource,
659		&ServicePrincipalAuthorizationCodeSecret{
660			ClientSecret:      clientSecret,
661			AuthorizationCode: authorizationCode,
662			RedirectURI:       redirectURI,
663		},
664		callbacks...,
665	)
666}
667
668type msiType int
669
670const (
671	msiTypeUnavailable msiType = iota
672	msiTypeAppServiceV20170901
673	msiTypeCloudShell
674	msiTypeIMDS
675)
676
677func (m msiType) String() string {
678	switch m {
679	case msiTypeUnavailable:
680		return "unavailable"
681	case msiTypeAppServiceV20170901:
682		return "AppServiceV20170901"
683	case msiTypeCloudShell:
684		return "CloudShell"
685	case msiTypeIMDS:
686		return "IMDS"
687	default:
688		return fmt.Sprintf("unhandled MSI type %d", m)
689	}
690}
691
692// returns the MSI type and endpoint, or an error
693func getMSIType() (msiType, string, error) {
694	if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" {
695		// if the env var MSI_ENDPOINT is set
696		if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" {
697			// if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService
698			return msiTypeAppServiceV20170901, endpointEnvVar, nil
699		}
700		// if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell
701		return msiTypeCloudShell, endpointEnvVar, nil
702	} else if msiAvailableHook(context.Background(), sender()) {
703		// if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the msiType is IMDS. This will timeout after 500 milliseconds
704		return msiTypeIMDS, msiEndpoint, nil
705	} else {
706		// if MSI_ENDPOINT is NOT set and IMDS endpoint is not available Managed Identity is not available
707		return msiTypeUnavailable, "", errors.New("MSI not available")
708	}
709}
710
711// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
712// NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell.
713// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
714func GetMSIVMEndpoint() (string, error) {
715	return msiEndpoint, nil
716}
717
718// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions.
719// It will return an error when not running in an app service/functions environment.
720// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
721func GetMSIAppServiceEndpoint() (string, error) {
722	msiType, endpoint, err := getMSIType()
723	if err != nil {
724		return "", err
725	}
726	switch msiType {
727	case msiTypeAppServiceV20170901:
728		return endpoint, nil
729	default:
730		return "", fmt.Errorf("%s is not app service environment", msiType)
731	}
732}
733
734// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
735// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
736func GetMSIEndpoint() (string, error) {
737	_, endpoint, err := getMSIType()
738	return endpoint, err
739}
740
741// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
742// It will use the system assigned identity when creating the token.
743// msiEndpoint - empty string, or pass a non-empty string to override the default value.
744// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
745func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
746	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...)
747}
748
749// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
750// It will use the clientID of specified user assigned identity when creating the token.
751// msiEndpoint - empty string, or pass a non-empty string to override the default value.
752// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
753func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
754	if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil {
755		return nil, err
756	}
757	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...)
758}
759
760// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
761// It will use the azure resource id of user assigned identity when creating the token.
762// msiEndpoint - empty string, or pass a non-empty string to override the default value.
763// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
764func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
765	if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil {
766		return nil, err
767	}
768	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...)
769}
770
771// ManagedIdentityOptions contains optional values for configuring managed identity authentication.
772type ManagedIdentityOptions struct {
773	// ClientID is the user-assigned identity to use during authentication.
774	// It is mutually exclusive with IdentityResourceID.
775	ClientID string
776
777	// IdentityResourceID is the resource ID of the user-assigned identity to use during authentication.
778	// It is mutually exclusive with ClientID.
779	IdentityResourceID string
780}
781
782// NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity.
783// It supports the following managed identity environments.
784// - App Service Environment (API version 2017-09-01 only)
785// - Cloud shell
786// - IMDS with a system or user assigned identity
787func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
788	if options == nil {
789		options = &ManagedIdentityOptions{}
790	}
791	return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...)
792}
793
794func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
795	if err := validateStringParam(resource, "resource"); err != nil {
796		return nil, err
797	}
798	if userAssignedID != "" && identityResourceID != "" {
799		return nil, errors.New("cannot specify userAssignedID and identityResourceID")
800	}
801	msiType, endpoint, err := getMSIType()
802	if err != nil {
803		logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v", err)
804		return nil, err
805	}
806	logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s", msiType, endpoint)
807	if msiEndpoint != "" {
808		endpoint = msiEndpoint
809		logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s", endpoint)
810	}
811	msiEndpointURL, err := url.Parse(endpoint)
812	if err != nil {
813		return nil, err
814	}
815	// cloud shell sends its data in the request body
816	if msiType != msiTypeCloudShell {
817		v := url.Values{}
818		v.Set("resource", resource)
819		clientIDParam := "client_id"
820		switch msiType {
821		case msiTypeAppServiceV20170901:
822			clientIDParam = "clientid"
823			v.Set("api-version", appServiceAPIVersion2017)
824			break
825		case msiTypeIMDS:
826			v.Set("api-version", msiAPIVersion)
827		}
828		if userAssignedID != "" {
829			v.Set(clientIDParam, userAssignedID)
830		} else if identityResourceID != "" {
831			v.Set("mi_res_id", identityResourceID)
832		}
833		msiEndpointURL.RawQuery = v.Encode()
834	}
835
836	spt := &ServicePrincipalToken{
837		inner: servicePrincipalToken{
838			Token: newToken(),
839			OauthConfig: OAuthConfig{
840				TokenEndpoint: *msiEndpointURL,
841			},
842			Secret: &ServicePrincipalMSISecret{
843				msiType:          msiType,
844				clientResourceID: identityResourceID,
845			},
846			Resource:      resource,
847			AutoRefresh:   true,
848			RefreshWithin: defaultRefresh,
849			ClientID:      userAssignedID,
850		},
851		refreshLock:           &sync.RWMutex{},
852		sender:                sender(),
853		refreshCallbacks:      callbacks,
854		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
855	}
856
857	return spt, nil
858}
859
860// internal type that implements TokenRefreshError
861type tokenRefreshError struct {
862	message string
863	resp    *http.Response
864}
865
866// Error implements the error interface which is part of the TokenRefreshError interface.
867func (tre tokenRefreshError) Error() string {
868	return tre.message
869}
870
871// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
872func (tre tokenRefreshError) Response() *http.Response {
873	return tre.resp
874}
875
876func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
877	return tokenRefreshError{message: message, resp: resp}
878}
879
880// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
881// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
882func (spt *ServicePrincipalToken) EnsureFresh() error {
883	return spt.EnsureFreshWithContext(context.Background())
884}
885
886// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
887// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
888func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
889	// must take the read lock when initially checking the token's expiration
890	if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
891		// take the write lock then check again to see if the token was already refreshed
892		spt.refreshLock.Lock()
893		defer spt.refreshLock.Unlock()
894		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
895			return spt.refreshInternal(ctx, spt.inner.Resource)
896		}
897	}
898	return nil
899}
900
901// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
902func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
903	if spt.refreshCallbacks != nil {
904		for _, callback := range spt.refreshCallbacks {
905			err := callback(spt.inner.Token)
906			if err != nil {
907				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
908			}
909		}
910	}
911	return nil
912}
913
914// Refresh obtains a fresh token for the Service Principal.
915// This method is safe for concurrent use.
916func (spt *ServicePrincipalToken) Refresh() error {
917	return spt.RefreshWithContext(context.Background())
918}
919
920// RefreshWithContext obtains a fresh token for the Service Principal.
921// This method is safe for concurrent use.
922func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
923	spt.refreshLock.Lock()
924	defer spt.refreshLock.Unlock()
925	return spt.refreshInternal(ctx, spt.inner.Resource)
926}
927
928// RefreshExchange refreshes the token, but for a different resource.
929// This method is safe for concurrent use.
930func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
931	return spt.RefreshExchangeWithContext(context.Background(), resource)
932}
933
934// RefreshExchangeWithContext refreshes the token, but for a different resource.
935// This method is safe for concurrent use.
936func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
937	spt.refreshLock.Lock()
938	defer spt.refreshLock.Unlock()
939	return spt.refreshInternal(ctx, resource)
940}
941
942func (spt *ServicePrincipalToken) getGrantType() string {
943	switch spt.inner.Secret.(type) {
944	case *ServicePrincipalUsernamePasswordSecret:
945		return OAuthGrantTypeUserPass
946	case *ServicePrincipalAuthorizationCodeSecret:
947		return OAuthGrantTypeAuthorizationCode
948	default:
949		return OAuthGrantTypeClientCredentials
950	}
951}
952
953func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
954	if spt.customRefreshFunc != nil {
955		token, err := spt.customRefreshFunc(ctx, resource)
956		if err != nil {
957			return err
958		}
959		spt.inner.Token = *token
960		return spt.InvokeRefreshCallbacks(spt.inner.Token)
961	}
962	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
963	if err != nil {
964		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
965	}
966	req.Header.Add("User-Agent", UserAgent())
967	req = req.WithContext(ctx)
968	var resp *http.Response
969	if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
970		switch msiSecret.msiType {
971		case msiTypeAppServiceV20170901:
972			req.Method = http.MethodGet
973			req.Header.Set("secret", os.Getenv(msiSecretEnv))
974			break
975		case msiTypeCloudShell:
976			req.Header.Set("Metadata", "true")
977			data := url.Values{}
978			data.Set("resource", spt.inner.Resource)
979			if spt.inner.ClientID != "" {
980				data.Set("client_id", spt.inner.ClientID)
981			} else if msiSecret.clientResourceID != "" {
982				data.Set("msi_res_id", msiSecret.clientResourceID)
983			}
984			req.Body = ioutil.NopCloser(strings.NewReader(data.Encode()))
985			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
986			break
987		case msiTypeIMDS:
988			req.Method = http.MethodGet
989			req.Header.Set("Metadata", "true")
990			break
991		}
992		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
993	} else {
994		v := url.Values{}
995		v.Set("client_id", spt.inner.ClientID)
996		v.Set("resource", resource)
997
998		if spt.inner.Token.RefreshToken != "" {
999			v.Set("grant_type", OAuthGrantTypeRefreshToken)
1000			v.Set("refresh_token", spt.inner.Token.RefreshToken)
1001			// web apps must specify client_secret when refreshing tokens
1002			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
1003			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
1004				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1005				if err != nil {
1006					return err
1007				}
1008			}
1009		} else {
1010			v.Set("grant_type", spt.getGrantType())
1011			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1012			if err != nil {
1013				return err
1014			}
1015		}
1016
1017		s := v.Encode()
1018		body := ioutil.NopCloser(strings.NewReader(s))
1019		req.ContentLength = int64(len(s))
1020		req.Header.Set(contentType, mimeTypeFormPost)
1021		req.Body = body
1022		resp, err = spt.sender.Do(req)
1023	}
1024
1025	if err != nil {
1026		// don't return a TokenRefreshError here; this will allow retry logic to apply
1027		return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
1028	}
1029
1030	defer resp.Body.Close()
1031	rb, err := ioutil.ReadAll(resp.Body)
1032
1033	if resp.StatusCode != http.StatusOK {
1034		if err != nil {
1035			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
1036		}
1037		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
1038	}
1039
1040	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
1041	// but some transient failure happened during deserialization.  by returning a generic error
1042	// the retry logic will kick in (we don't retry on TokenRefreshError).
1043
1044	if err != nil {
1045		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
1046	}
1047	if len(strings.Trim(string(rb), " ")) == 0 {
1048		return fmt.Errorf("adal: Empty service principal token received during refresh")
1049	}
1050	token := struct {
1051		AccessToken  string `json:"access_token"`
1052		RefreshToken string `json:"refresh_token"`
1053
1054		// AAD returns expires_in as a string, ADFS returns it as an int
1055		ExpiresIn json.Number `json:"expires_in"`
1056		// expires_on can be in two formats, a UTC time stamp or the number of seconds.
1057		ExpiresOn string      `json:"expires_on"`
1058		NotBefore json.Number `json:"not_before"`
1059
1060		Resource string `json:"resource"`
1061		Type     string `json:"token_type"`
1062	}{}
1063	// return a TokenRefreshError in the follow error cases as the token is in an unexpected format
1064	err = json.Unmarshal(rb, &token)
1065	if err != nil {
1066		return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
1067	}
1068	expiresOn := json.Number("")
1069	// ADFS doesn't include the expires_on field
1070	if token.ExpiresOn != "" {
1071		if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
1072			return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
1073		}
1074	}
1075	spt.inner.Token.AccessToken = token.AccessToken
1076	spt.inner.Token.RefreshToken = token.RefreshToken
1077	spt.inner.Token.ExpiresIn = token.ExpiresIn
1078	spt.inner.Token.ExpiresOn = expiresOn
1079	spt.inner.Token.NotBefore = token.NotBefore
1080	spt.inner.Token.Resource = token.Resource
1081	spt.inner.Token.Type = token.Type
1082
1083	return spt.InvokeRefreshCallbacks(spt.inner.Token)
1084}
1085
1086// converts expires_on to the number of seconds
1087func parseExpiresOn(s string) (json.Number, error) {
1088	// convert the expiration date to the number of seconds from now
1089	timeToDuration := func(t time.Time) json.Number {
1090		dur := t.Sub(time.Now().UTC())
1091		return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10))
1092	}
1093	if _, err := strconv.ParseInt(s, 10, 64); err == nil {
1094		// this is the number of seconds case, no conversion required
1095		return json.Number(s), nil
1096	} else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil {
1097		return timeToDuration(eo), nil
1098	} else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil {
1099		return timeToDuration(eo), nil
1100	} else {
1101		// unknown format
1102		return json.Number(""), err
1103	}
1104}
1105
1106// retry logic specific to retrieving a token from the IMDS endpoint
1107func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
1108	// copied from client.go due to circular dependency
1109	retries := []int{
1110		http.StatusRequestTimeout,      // 408
1111		http.StatusTooManyRequests,     // 429
1112		http.StatusInternalServerError, // 500
1113		http.StatusBadGateway,          // 502
1114		http.StatusServiceUnavailable,  // 503
1115		http.StatusGatewayTimeout,      // 504
1116	}
1117	// extra retry status codes specific to IMDS
1118	retries = append(retries,
1119		http.StatusNotFound,
1120		http.StatusGone,
1121		// all remaining 5xx
1122		http.StatusNotImplemented,
1123		http.StatusHTTPVersionNotSupported,
1124		http.StatusVariantAlsoNegotiates,
1125		http.StatusInsufficientStorage,
1126		http.StatusLoopDetected,
1127		http.StatusNotExtended,
1128		http.StatusNetworkAuthenticationRequired)
1129
1130	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
1131
1132	const maxDelay time.Duration = 60 * time.Second
1133
1134	attempt := 0
1135	delay := time.Duration(0)
1136
1137	// maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made
1138	if maxAttempts < 1 {
1139		maxAttempts = defaultMaxMSIRefreshAttempts
1140	}
1141
1142	for attempt < maxAttempts {
1143		if resp != nil && resp.Body != nil {
1144			io.Copy(ioutil.Discard, resp.Body)
1145			resp.Body.Close()
1146		}
1147		resp, err = sender.Do(req)
1148		// we want to retry if err is not nil or the status code is in the list of retry codes
1149		if err == nil && !responseHasStatusCode(resp, retries...) {
1150			return
1151		}
1152
1153		// perform exponential backoff with a cap.
1154		// must increment attempt before calculating delay.
1155		attempt++
1156		// the base value of 2 is the "delta backoff" as specified in the guidance doc
1157		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
1158		if delay > maxDelay {
1159			delay = maxDelay
1160		}
1161
1162		select {
1163		case <-time.After(delay):
1164			// intentionally left blank
1165		case <-req.Context().Done():
1166			err = req.Context().Err()
1167			return
1168		}
1169	}
1170	return
1171}
1172
1173func responseHasStatusCode(resp *http.Response, codes ...int) bool {
1174	if resp != nil {
1175		for _, i := range codes {
1176			if i == resp.StatusCode {
1177				return true
1178			}
1179		}
1180	}
1181	return false
1182}
1183
1184// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
1185func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
1186	spt.inner.AutoRefresh = autoRefresh
1187}
1188
1189// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
1190// refresh the token.
1191func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
1192	spt.inner.RefreshWithin = d
1193	return
1194}
1195
1196// SetSender sets the http.Client used when obtaining the Service Principal token. An
1197// undecorated http.Client is used by default.
1198func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
1199
1200// OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
1201func (spt *ServicePrincipalToken) OAuthToken() string {
1202	spt.refreshLock.RLock()
1203	defer spt.refreshLock.RUnlock()
1204	return spt.inner.Token.OAuthToken()
1205}
1206
1207// Token returns a copy of the current token.
1208func (spt *ServicePrincipalToken) Token() Token {
1209	spt.refreshLock.RLock()
1210	defer spt.refreshLock.RUnlock()
1211	return spt.inner.Token
1212}
1213
1214// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
1215type MultiTenantServicePrincipalToken struct {
1216	PrimaryToken    *ServicePrincipalToken
1217	AuxiliaryTokens []*ServicePrincipalToken
1218}
1219
1220// PrimaryOAuthToken returns the primary authorization token.
1221func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
1222	return mt.PrimaryToken.OAuthToken()
1223}
1224
1225// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
1226func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
1227	tokens := make([]string, len(mt.AuxiliaryTokens))
1228	for i := range mt.AuxiliaryTokens {
1229		tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
1230	}
1231	return tokens
1232}
1233
1234// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
1235func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
1236	if err := validateStringParam(clientID, "clientID"); err != nil {
1237		return nil, err
1238	}
1239	if err := validateStringParam(secret, "secret"); err != nil {
1240		return nil, err
1241	}
1242	if err := validateStringParam(resource, "resource"); err != nil {
1243		return nil, err
1244	}
1245	auxTenants := multiTenantCfg.AuxiliaryTenants()
1246	m := MultiTenantServicePrincipalToken{
1247		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1248	}
1249	primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
1250	if err != nil {
1251		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1252	}
1253	m.PrimaryToken = primary
1254	for i := range auxTenants {
1255		aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
1256		if err != nil {
1257			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1258		}
1259		m.AuxiliaryTokens[i] = aux
1260	}
1261	return &m, nil
1262}
1263
1264// NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource.
1265func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
1266	if err := validateStringParam(clientID, "clientID"); err != nil {
1267		return nil, err
1268	}
1269	if err := validateStringParam(resource, "resource"); err != nil {
1270		return nil, err
1271	}
1272	if certificate == nil {
1273		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
1274	}
1275	if privateKey == nil {
1276		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
1277	}
1278	auxTenants := multiTenantCfg.AuxiliaryTenants()
1279	m := MultiTenantServicePrincipalToken{
1280		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1281	}
1282	primary, err := NewServicePrincipalTokenWithSecret(
1283		*multiTenantCfg.PrimaryTenant(),
1284		clientID,
1285		resource,
1286		&ServicePrincipalCertificateSecret{
1287			PrivateKey:  privateKey,
1288			Certificate: certificate,
1289		},
1290	)
1291	if err != nil {
1292		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1293	}
1294	m.PrimaryToken = primary
1295	for i := range auxTenants {
1296		aux, err := NewServicePrincipalTokenWithSecret(
1297			*auxTenants[i],
1298			clientID,
1299			resource,
1300			&ServicePrincipalCertificateSecret{
1301				PrivateKey:  privateKey,
1302				Certificate: certificate,
1303			},
1304		)
1305		if err != nil {
1306			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1307		}
1308		m.AuxiliaryTokens[i] = aux
1309	}
1310	return &m, nil
1311}
1312
1313// MSIAvailable returns true if the MSI endpoint is available for authentication.
1314func MSIAvailable(ctx context.Context, sender Sender) bool {
1315	resp, err := getMSIEndpoint(ctx, sender)
1316	if err == nil {
1317		resp.Body.Close()
1318	}
1319	return err == nil
1320}
1321
1322// used for testing purposes
1323var msiAvailableHook = func(ctx context.Context, sender Sender) bool {
1324	return MSIAvailable(ctx, sender)
1325}
1326