1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/sha1"
22	"crypto/x509"
23	"encoding/base64"
24	"encoding/json"
25	"errors"
26	"fmt"
27	"io"
28	"io/ioutil"
29	"math"
30	"net/http"
31	"net/url"
32	"os"
33	"strings"
34	"sync"
35	"time"
36
37	"github.com/Azure/go-autorest/autorest/date"
38	"github.com/dgrijalva/jwt-go"
39)
40
41const (
42	defaultRefresh = 5 * time.Minute
43
44	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
45	OAuthGrantTypeDeviceCode = "device_code"
46
47	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
48	OAuthGrantTypeClientCredentials = "client_credentials"
49
50	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
51	OAuthGrantTypeUserPass = "password"
52
53	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
54	OAuthGrantTypeRefreshToken = "refresh_token"
55
56	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
57	OAuthGrantTypeAuthorizationCode = "authorization_code"
58
59	// metadataHeader is the header required by MSI extension
60	metadataHeader = "Metadata"
61
62	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
63	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
64
65	// the default number of attempts to refresh an MSI authentication token
66	defaultMaxMSIRefreshAttempts = 5
67
68	// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
69	asMSIEndpointEnv = "MSI_ENDPOINT"
70
71	// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
72	asMSISecretEnv = "MSI_SECRET"
73)
74
75// OAuthTokenProvider is an interface which should be implemented by an access token retriever
76type OAuthTokenProvider interface {
77	OAuthToken() string
78}
79
80// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
81type MultitenantOAuthTokenProvider interface {
82	PrimaryOAuthToken() string
83	AuxiliaryOAuthTokens() []string
84}
85
86// TokenRefreshError is an interface used by errors returned during token refresh.
87type TokenRefreshError interface {
88	error
89	Response() *http.Response
90}
91
92// Refresher is an interface for token refresh functionality
93type Refresher interface {
94	Refresh() error
95	RefreshExchange(resource string) error
96	EnsureFresh() error
97}
98
99// RefresherWithContext is an interface for token refresh functionality
100type RefresherWithContext interface {
101	RefreshWithContext(ctx context.Context) error
102	RefreshExchangeWithContext(ctx context.Context, resource string) error
103	EnsureFreshWithContext(ctx context.Context) error
104}
105
106// TokenRefreshCallback is the type representing callbacks that will be called after
107// a successful token refresh
108type TokenRefreshCallback func(Token) error
109
110// TokenRefresh is a type representing a custom callback to refresh a token
111type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
112
113// Token encapsulates the access token used to authorize Azure requests.
114// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
115type Token struct {
116	AccessToken  string `json:"access_token"`
117	RefreshToken string `json:"refresh_token"`
118
119	ExpiresIn json.Number `json:"expires_in"`
120	ExpiresOn json.Number `json:"expires_on"`
121	NotBefore json.Number `json:"not_before"`
122
123	Resource string `json:"resource"`
124	Type     string `json:"token_type"`
125}
126
127func newToken() Token {
128	return Token{
129		ExpiresIn: "0",
130		ExpiresOn: "0",
131		NotBefore: "0",
132	}
133}
134
135// IsZero returns true if the token object is zero-initialized.
136func (t Token) IsZero() bool {
137	return t == Token{}
138}
139
140// Expires returns the time.Time when the Token expires.
141func (t Token) Expires() time.Time {
142	s, err := t.ExpiresOn.Float64()
143	if err != nil {
144		s = -3600
145	}
146
147	expiration := date.NewUnixTimeFromSeconds(s)
148
149	return time.Time(expiration).UTC()
150}
151
152// IsExpired returns true if the Token is expired, false otherwise.
153func (t Token) IsExpired() bool {
154	return t.WillExpireIn(0)
155}
156
157// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
158// from now, false otherwise.
159func (t Token) WillExpireIn(d time.Duration) bool {
160	return !t.Expires().After(time.Now().Add(d))
161}
162
163//OAuthToken return the current access token
164func (t *Token) OAuthToken() string {
165	return t.AccessToken
166}
167
168// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
169// that is submitted when acquiring an oAuth token.
170type ServicePrincipalSecret interface {
171	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
172}
173
174// ServicePrincipalNoSecret represents a secret type that contains no secret
175// meaning it is not valid for fetching a fresh token. This is used by Manual
176type ServicePrincipalNoSecret struct {
177}
178
179// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
180// It only returns an error for the ServicePrincipalNoSecret type
181func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
182	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
183}
184
185// MarshalJSON implements the json.Marshaler interface.
186func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
187	type tokenType struct {
188		Type string `json:"type"`
189	}
190	return json.Marshal(tokenType{
191		Type: "ServicePrincipalNoSecret",
192	})
193}
194
195// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
196type ServicePrincipalTokenSecret struct {
197	ClientSecret string `json:"value"`
198}
199
200// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
201// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
202func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
203	v.Set("client_secret", tokenSecret.ClientSecret)
204	return nil
205}
206
207// MarshalJSON implements the json.Marshaler interface.
208func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
209	type tokenType struct {
210		Type  string `json:"type"`
211		Value string `json:"value"`
212	}
213	return json.Marshal(tokenType{
214		Type:  "ServicePrincipalTokenSecret",
215		Value: tokenSecret.ClientSecret,
216	})
217}
218
219// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
220type ServicePrincipalCertificateSecret struct {
221	Certificate *x509.Certificate
222	PrivateKey  *rsa.PrivateKey
223}
224
225// SignJwt returns the JWT signed with the certificate's private key.
226func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
227	hasher := sha1.New()
228	_, err := hasher.Write(secret.Certificate.Raw)
229	if err != nil {
230		return "", err
231	}
232
233	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
234
235	// The jti (JWT ID) claim provides a unique identifier for the JWT.
236	jti := make([]byte, 20)
237	_, err = rand.Read(jti)
238	if err != nil {
239		return "", err
240	}
241
242	token := jwt.New(jwt.SigningMethodRS256)
243	token.Header["x5t"] = thumbprint
244	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
245	token.Header["x5c"] = x5c
246	token.Claims = jwt.MapClaims{
247		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
248		"iss": spt.inner.ClientID,
249		"sub": spt.inner.ClientID,
250		"jti": base64.URLEncoding.EncodeToString(jti),
251		"nbf": time.Now().Unix(),
252		"exp": time.Now().Add(24 * time.Hour).Unix(),
253	}
254
255	signedString, err := token.SignedString(secret.PrivateKey)
256	return signedString, err
257}
258
259// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
260// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
261func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
262	jwt, err := secret.SignJwt(spt)
263	if err != nil {
264		return err
265	}
266
267	v.Set("client_assertion", jwt)
268	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
269	return nil
270}
271
272// MarshalJSON implements the json.Marshaler interface.
273func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
274	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
275}
276
277// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
278type ServicePrincipalMSISecret struct {
279}
280
281// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
282func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
283	return nil
284}
285
286// MarshalJSON implements the json.Marshaler interface.
287func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
288	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
289}
290
291// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
292type ServicePrincipalUsernamePasswordSecret struct {
293	Username string `json:"username"`
294	Password string `json:"password"`
295}
296
297// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
298func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
299	v.Set("username", secret.Username)
300	v.Set("password", secret.Password)
301	return nil
302}
303
304// MarshalJSON implements the json.Marshaler interface.
305func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
306	type tokenType struct {
307		Type     string `json:"type"`
308		Username string `json:"username"`
309		Password string `json:"password"`
310	}
311	return json.Marshal(tokenType{
312		Type:     "ServicePrincipalUsernamePasswordSecret",
313		Username: secret.Username,
314		Password: secret.Password,
315	})
316}
317
318// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
319type ServicePrincipalAuthorizationCodeSecret struct {
320	ClientSecret      string `json:"value"`
321	AuthorizationCode string `json:"authCode"`
322	RedirectURI       string `json:"redirect"`
323}
324
325// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
326func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
327	v.Set("code", secret.AuthorizationCode)
328	v.Set("client_secret", secret.ClientSecret)
329	v.Set("redirect_uri", secret.RedirectURI)
330	return nil
331}
332
333// MarshalJSON implements the json.Marshaler interface.
334func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
335	type tokenType struct {
336		Type     string `json:"type"`
337		Value    string `json:"value"`
338		AuthCode string `json:"authCode"`
339		Redirect string `json:"redirect"`
340	}
341	return json.Marshal(tokenType{
342		Type:     "ServicePrincipalAuthorizationCodeSecret",
343		Value:    secret.ClientSecret,
344		AuthCode: secret.AuthorizationCode,
345		Redirect: secret.RedirectURI,
346	})
347}
348
349// ServicePrincipalToken encapsulates a Token created for a Service Principal.
350type ServicePrincipalToken struct {
351	inner             servicePrincipalToken
352	refreshLock       *sync.RWMutex
353	sender            Sender
354	customRefreshFunc TokenRefresh
355	refreshCallbacks  []TokenRefreshCallback
356	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
357	MaxMSIRefreshAttempts int
358}
359
360// MarshalTokenJSON returns the marshalled inner token.
361func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
362	return json.Marshal(spt.inner.Token)
363}
364
365// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
366func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
367	spt.refreshCallbacks = callbacks
368}
369
370// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
371func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
372	spt.customRefreshFunc = customRefreshFunc
373}
374
375// MarshalJSON implements the json.Marshaler interface.
376func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
377	return json.Marshal(spt.inner)
378}
379
380// UnmarshalJSON implements the json.Unmarshaler interface.
381func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
382	// need to determine the token type
383	raw := map[string]interface{}{}
384	err := json.Unmarshal(data, &raw)
385	if err != nil {
386		return err
387	}
388	secret := raw["secret"].(map[string]interface{})
389	switch secret["type"] {
390	case "ServicePrincipalNoSecret":
391		spt.inner.Secret = &ServicePrincipalNoSecret{}
392	case "ServicePrincipalTokenSecret":
393		spt.inner.Secret = &ServicePrincipalTokenSecret{}
394	case "ServicePrincipalCertificateSecret":
395		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
396	case "ServicePrincipalMSISecret":
397		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
398	case "ServicePrincipalUsernamePasswordSecret":
399		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
400	case "ServicePrincipalAuthorizationCodeSecret":
401		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
402	default:
403		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
404	}
405	err = json.Unmarshal(data, &spt.inner)
406	if err != nil {
407		return err
408	}
409	// Don't override the refreshLock or the sender if those have been already set.
410	if spt.refreshLock == nil {
411		spt.refreshLock = &sync.RWMutex{}
412	}
413	if spt.sender == nil {
414		spt.sender = sender()
415	}
416	return nil
417}
418
419// internal type used for marshalling/unmarshalling
420type servicePrincipalToken struct {
421	Token         Token                  `json:"token"`
422	Secret        ServicePrincipalSecret `json:"secret"`
423	OauthConfig   OAuthConfig            `json:"oauth"`
424	ClientID      string                 `json:"clientID"`
425	Resource      string                 `json:"resource"`
426	AutoRefresh   bool                   `json:"autoRefresh"`
427	RefreshWithin time.Duration          `json:"refreshWithin"`
428}
429
430func validateOAuthConfig(oac OAuthConfig) error {
431	if oac.IsZero() {
432		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
433	}
434	return nil
435}
436
437// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
438func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
439	if err := validateOAuthConfig(oauthConfig); err != nil {
440		return nil, err
441	}
442	if err := validateStringParam(id, "id"); err != nil {
443		return nil, err
444	}
445	if err := validateStringParam(resource, "resource"); err != nil {
446		return nil, err
447	}
448	if secret == nil {
449		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
450	}
451	spt := &ServicePrincipalToken{
452		inner: servicePrincipalToken{
453			Token:         newToken(),
454			OauthConfig:   oauthConfig,
455			Secret:        secret,
456			ClientID:      id,
457			Resource:      resource,
458			AutoRefresh:   true,
459			RefreshWithin: defaultRefresh,
460		},
461		refreshLock:      &sync.RWMutex{},
462		sender:           sender(),
463		refreshCallbacks: callbacks,
464	}
465	return spt, nil
466}
467
468// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
469func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
470	if err := validateOAuthConfig(oauthConfig); err != nil {
471		return nil, err
472	}
473	if err := validateStringParam(clientID, "clientID"); err != nil {
474		return nil, err
475	}
476	if err := validateStringParam(resource, "resource"); err != nil {
477		return nil, err
478	}
479	if token.IsZero() {
480		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
481	}
482	spt, err := NewServicePrincipalTokenWithSecret(
483		oauthConfig,
484		clientID,
485		resource,
486		&ServicePrincipalNoSecret{},
487		callbacks...)
488	if err != nil {
489		return nil, err
490	}
491
492	spt.inner.Token = token
493
494	return spt, nil
495}
496
497// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
498func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
499	if err := validateOAuthConfig(oauthConfig); err != nil {
500		return nil, err
501	}
502	if err := validateStringParam(clientID, "clientID"); err != nil {
503		return nil, err
504	}
505	if err := validateStringParam(resource, "resource"); err != nil {
506		return nil, err
507	}
508	if secret == nil {
509		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
510	}
511	if token.IsZero() {
512		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
513	}
514	spt, err := NewServicePrincipalTokenWithSecret(
515		oauthConfig,
516		clientID,
517		resource,
518		secret,
519		callbacks...)
520	if err != nil {
521		return nil, err
522	}
523
524	spt.inner.Token = token
525
526	return spt, nil
527}
528
529// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
530// credentials scoped to the named resource.
531func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
532	if err := validateOAuthConfig(oauthConfig); err != nil {
533		return nil, err
534	}
535	if err := validateStringParam(clientID, "clientID"); err != nil {
536		return nil, err
537	}
538	if err := validateStringParam(secret, "secret"); err != nil {
539		return nil, err
540	}
541	if err := validateStringParam(resource, "resource"); err != nil {
542		return nil, err
543	}
544	return NewServicePrincipalTokenWithSecret(
545		oauthConfig,
546		clientID,
547		resource,
548		&ServicePrincipalTokenSecret{
549			ClientSecret: secret,
550		},
551		callbacks...,
552	)
553}
554
555// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
556func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
557	if err := validateOAuthConfig(oauthConfig); err != nil {
558		return nil, err
559	}
560	if err := validateStringParam(clientID, "clientID"); err != nil {
561		return nil, err
562	}
563	if err := validateStringParam(resource, "resource"); err != nil {
564		return nil, err
565	}
566	if certificate == nil {
567		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
568	}
569	if privateKey == nil {
570		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
571	}
572	return NewServicePrincipalTokenWithSecret(
573		oauthConfig,
574		clientID,
575		resource,
576		&ServicePrincipalCertificateSecret{
577			PrivateKey:  privateKey,
578			Certificate: certificate,
579		},
580		callbacks...,
581	)
582}
583
584// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
585func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
586	if err := validateOAuthConfig(oauthConfig); err != nil {
587		return nil, err
588	}
589	if err := validateStringParam(clientID, "clientID"); err != nil {
590		return nil, err
591	}
592	if err := validateStringParam(username, "username"); err != nil {
593		return nil, err
594	}
595	if err := validateStringParam(password, "password"); err != nil {
596		return nil, err
597	}
598	if err := validateStringParam(resource, "resource"); err != nil {
599		return nil, err
600	}
601	return NewServicePrincipalTokenWithSecret(
602		oauthConfig,
603		clientID,
604		resource,
605		&ServicePrincipalUsernamePasswordSecret{
606			Username: username,
607			Password: password,
608		},
609		callbacks...,
610	)
611}
612
613// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
614func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
615
616	if err := validateOAuthConfig(oauthConfig); err != nil {
617		return nil, err
618	}
619	if err := validateStringParam(clientID, "clientID"); err != nil {
620		return nil, err
621	}
622	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
623		return nil, err
624	}
625	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
626		return nil, err
627	}
628	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
629		return nil, err
630	}
631	if err := validateStringParam(resource, "resource"); err != nil {
632		return nil, err
633	}
634
635	return NewServicePrincipalTokenWithSecret(
636		oauthConfig,
637		clientID,
638		resource,
639		&ServicePrincipalAuthorizationCodeSecret{
640			ClientSecret:      clientSecret,
641			AuthorizationCode: authorizationCode,
642			RedirectURI:       redirectURI,
643		},
644		callbacks...,
645	)
646}
647
648// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
649func GetMSIVMEndpoint() (string, error) {
650	return msiEndpoint, nil
651}
652
653func isAppService() bool {
654	_, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
655	_, asMSISecretEnvExists := os.LookupEnv(asMSISecretEnv)
656
657	return asMSIEndpointEnvExists && asMSISecretEnvExists
658}
659
660// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions
661func GetMSIAppServiceEndpoint() (string, error) {
662	asMSIEndpoint, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
663
664	if asMSIEndpointEnvExists {
665		return asMSIEndpoint, nil
666	}
667	return "", errors.New("MSI endpoint not found")
668}
669
670// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
671func GetMSIEndpoint() (string, error) {
672	if isAppService() {
673		return GetMSIAppServiceEndpoint()
674	}
675	return GetMSIVMEndpoint()
676}
677
678// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
679// It will use the system assigned identity when creating the token.
680func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
681	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
682}
683
684// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
685// It will use the specified user assigned identity when creating the token.
686func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
687	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
688}
689
690func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
691	if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
692		return nil, err
693	}
694	if err := validateStringParam(resource, "resource"); err != nil {
695		return nil, err
696	}
697	if userAssignedID != nil {
698		if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
699			return nil, err
700		}
701	}
702	// We set the oauth config token endpoint to be MSI's endpoint
703	msiEndpointURL, err := url.Parse(msiEndpoint)
704	if err != nil {
705		return nil, err
706	}
707
708	v := url.Values{}
709	v.Set("resource", resource)
710	// App Service MSI currently only supports token API version 2017-09-01
711	if isAppService() {
712		v.Set("api-version", "2017-09-01")
713	} else {
714		v.Set("api-version", "2018-02-01")
715	}
716	if userAssignedID != nil {
717		v.Set("client_id", *userAssignedID)
718	}
719	msiEndpointURL.RawQuery = v.Encode()
720
721	spt := &ServicePrincipalToken{
722		inner: servicePrincipalToken{
723			Token: newToken(),
724			OauthConfig: OAuthConfig{
725				TokenEndpoint: *msiEndpointURL,
726			},
727			Secret:        &ServicePrincipalMSISecret{},
728			Resource:      resource,
729			AutoRefresh:   true,
730			RefreshWithin: defaultRefresh,
731		},
732		refreshLock:           &sync.RWMutex{},
733		sender:                sender(),
734		refreshCallbacks:      callbacks,
735		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
736	}
737
738	if userAssignedID != nil {
739		spt.inner.ClientID = *userAssignedID
740	}
741
742	return spt, nil
743}
744
745// internal type that implements TokenRefreshError
746type tokenRefreshError struct {
747	message string
748	resp    *http.Response
749}
750
751// Error implements the error interface which is part of the TokenRefreshError interface.
752func (tre tokenRefreshError) Error() string {
753	return tre.message
754}
755
756// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
757func (tre tokenRefreshError) Response() *http.Response {
758	return tre.resp
759}
760
761func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
762	return tokenRefreshError{message: message, resp: resp}
763}
764
765// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
766// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
767func (spt *ServicePrincipalToken) EnsureFresh() error {
768	return spt.EnsureFreshWithContext(context.Background())
769}
770
771// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
772// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
773func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
774	// must take the read lock when initially checking the token's expiration
775	if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
776		// take the write lock then check again to see if the token was already refreshed
777		spt.refreshLock.Lock()
778		defer spt.refreshLock.Unlock()
779		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
780			return spt.refreshInternal(ctx, spt.inner.Resource)
781		}
782	}
783	return nil
784}
785
786// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
787func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
788	if spt.refreshCallbacks != nil {
789		for _, callback := range spt.refreshCallbacks {
790			err := callback(spt.inner.Token)
791			if err != nil {
792				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
793			}
794		}
795	}
796	return nil
797}
798
799// Refresh obtains a fresh token for the Service Principal.
800// This method is safe for concurrent use.
801func (spt *ServicePrincipalToken) Refresh() error {
802	return spt.RefreshWithContext(context.Background())
803}
804
805// RefreshWithContext obtains a fresh token for the Service Principal.
806// This method is safe for concurrent use.
807func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
808	spt.refreshLock.Lock()
809	defer spt.refreshLock.Unlock()
810	return spt.refreshInternal(ctx, spt.inner.Resource)
811}
812
813// RefreshExchange refreshes the token, but for a different resource.
814// This method is safe for concurrent use.
815func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
816	return spt.RefreshExchangeWithContext(context.Background(), resource)
817}
818
819// RefreshExchangeWithContext refreshes the token, but for a different resource.
820// This method is safe for concurrent use.
821func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
822	spt.refreshLock.Lock()
823	defer spt.refreshLock.Unlock()
824	return spt.refreshInternal(ctx, resource)
825}
826
827func (spt *ServicePrincipalToken) getGrantType() string {
828	switch spt.inner.Secret.(type) {
829	case *ServicePrincipalUsernamePasswordSecret:
830		return OAuthGrantTypeUserPass
831	case *ServicePrincipalAuthorizationCodeSecret:
832		return OAuthGrantTypeAuthorizationCode
833	default:
834		return OAuthGrantTypeClientCredentials
835	}
836}
837
838func isIMDS(u url.URL) bool {
839	imds, err := url.Parse(msiEndpoint)
840	if err != nil {
841		return false
842	}
843	return (u.Host == imds.Host && u.Path == imds.Path) || isAppService()
844}
845
846func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
847	if spt.customRefreshFunc != nil {
848		token, err := spt.customRefreshFunc(ctx, resource)
849		if err != nil {
850			return err
851		}
852		spt.inner.Token = *token
853		return spt.InvokeRefreshCallbacks(spt.inner.Token)
854	}
855
856	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
857	if err != nil {
858		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
859	}
860	req.Header.Add("User-Agent", UserAgent())
861	// Add header when runtime is on App Service or Functions
862	if isAppService() {
863		asMSISecret, _ := os.LookupEnv(asMSISecretEnv)
864		req.Header.Add("Secret", asMSISecret)
865	}
866	req = req.WithContext(ctx)
867	if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
868		v := url.Values{}
869		v.Set("client_id", spt.inner.ClientID)
870		v.Set("resource", resource)
871
872		if spt.inner.Token.RefreshToken != "" {
873			v.Set("grant_type", OAuthGrantTypeRefreshToken)
874			v.Set("refresh_token", spt.inner.Token.RefreshToken)
875			// web apps must specify client_secret when refreshing tokens
876			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
877			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
878				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
879				if err != nil {
880					return err
881				}
882			}
883		} else {
884			v.Set("grant_type", spt.getGrantType())
885			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
886			if err != nil {
887				return err
888			}
889		}
890
891		s := v.Encode()
892		body := ioutil.NopCloser(strings.NewReader(s))
893		req.ContentLength = int64(len(s))
894		req.Header.Set(contentType, mimeTypeFormPost)
895		req.Body = body
896	}
897
898	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
899		req.Method = http.MethodGet
900		req.Header.Set(metadataHeader, "true")
901	}
902
903	var resp *http.Response
904	if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
905		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
906	} else {
907		resp, err = spt.sender.Do(req)
908	}
909	if err != nil {
910		// don't return a TokenRefreshError here; this will allow retry logic to apply
911		return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
912	}
913
914	defer resp.Body.Close()
915	rb, err := ioutil.ReadAll(resp.Body)
916
917	if resp.StatusCode != http.StatusOK {
918		if err != nil {
919			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
920		}
921		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
922	}
923
924	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
925	// but some transient failure happened during deserialization.  by returning a generic error
926	// the retry logic will kick in (we don't retry on TokenRefreshError).
927
928	if err != nil {
929		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
930	}
931	if len(strings.Trim(string(rb), " ")) == 0 {
932		return fmt.Errorf("adal: Empty service principal token received during refresh")
933	}
934	var token Token
935	err = json.Unmarshal(rb, &token)
936	if err != nil {
937		return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
938	}
939
940	spt.inner.Token = token
941
942	return spt.InvokeRefreshCallbacks(token)
943}
944
945// retry logic specific to retrieving a token from the IMDS endpoint
946func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
947	// copied from client.go due to circular dependency
948	retries := []int{
949		http.StatusRequestTimeout,      // 408
950		http.StatusTooManyRequests,     // 429
951		http.StatusInternalServerError, // 500
952		http.StatusBadGateway,          // 502
953		http.StatusServiceUnavailable,  // 503
954		http.StatusGatewayTimeout,      // 504
955	}
956	// extra retry status codes specific to IMDS
957	retries = append(retries,
958		http.StatusNotFound,
959		http.StatusGone,
960		// all remaining 5xx
961		http.StatusNotImplemented,
962		http.StatusHTTPVersionNotSupported,
963		http.StatusVariantAlsoNegotiates,
964		http.StatusInsufficientStorage,
965		http.StatusLoopDetected,
966		http.StatusNotExtended,
967		http.StatusNetworkAuthenticationRequired)
968
969	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
970
971	const maxDelay time.Duration = 60 * time.Second
972
973	attempt := 0
974	delay := time.Duration(0)
975
976	for attempt < maxAttempts {
977		if resp != nil && resp.Body != nil {
978			io.Copy(ioutil.Discard, resp.Body)
979			resp.Body.Close()
980		}
981		resp, err = sender.Do(req)
982		// we want to retry if err is not nil or the status code is in the list of retry codes
983		if err == nil && !responseHasStatusCode(resp, retries...) {
984			return
985		}
986
987		// perform exponential backoff with a cap.
988		// must increment attempt before calculating delay.
989		attempt++
990		// the base value of 2 is the "delta backoff" as specified in the guidance doc
991		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
992		if delay > maxDelay {
993			delay = maxDelay
994		}
995
996		select {
997		case <-time.After(delay):
998			// intentionally left blank
999		case <-req.Context().Done():
1000			err = req.Context().Err()
1001			return
1002		}
1003	}
1004	return
1005}
1006
1007func responseHasStatusCode(resp *http.Response, codes ...int) bool {
1008	if resp != nil {
1009		for _, i := range codes {
1010			if i == resp.StatusCode {
1011				return true
1012			}
1013		}
1014	}
1015	return false
1016}
1017
1018// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
1019func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
1020	spt.inner.AutoRefresh = autoRefresh
1021}
1022
1023// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
1024// refresh the token.
1025func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
1026	spt.inner.RefreshWithin = d
1027	return
1028}
1029
1030// SetSender sets the http.Client used when obtaining the Service Principal token. An
1031// undecorated http.Client is used by default.
1032func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
1033
1034// OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
1035func (spt *ServicePrincipalToken) OAuthToken() string {
1036	spt.refreshLock.RLock()
1037	defer spt.refreshLock.RUnlock()
1038	return spt.inner.Token.OAuthToken()
1039}
1040
1041// Token returns a copy of the current token.
1042func (spt *ServicePrincipalToken) Token() Token {
1043	spt.refreshLock.RLock()
1044	defer spt.refreshLock.RUnlock()
1045	return spt.inner.Token
1046}
1047
1048// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
1049type MultiTenantServicePrincipalToken struct {
1050	PrimaryToken    *ServicePrincipalToken
1051	AuxiliaryTokens []*ServicePrincipalToken
1052}
1053
1054// PrimaryOAuthToken returns the primary authorization token.
1055func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
1056	return mt.PrimaryToken.OAuthToken()
1057}
1058
1059// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
1060func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
1061	tokens := make([]string, len(mt.AuxiliaryTokens))
1062	for i := range mt.AuxiliaryTokens {
1063		tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
1064	}
1065	return tokens
1066}
1067
1068// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
1069// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
1070func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
1071	if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil {
1072		return fmt.Errorf("failed to refresh primary token: %v", err)
1073	}
1074	for _, aux := range mt.AuxiliaryTokens {
1075		if err := aux.EnsureFreshWithContext(ctx); err != nil {
1076			return fmt.Errorf("failed to refresh auxiliary token: %v", err)
1077		}
1078	}
1079	return nil
1080}
1081
1082// RefreshWithContext obtains a fresh token for the Service Principal.
1083func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
1084	if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil {
1085		return fmt.Errorf("failed to refresh primary token: %v", err)
1086	}
1087	for _, aux := range mt.AuxiliaryTokens {
1088		if err := aux.RefreshWithContext(ctx); err != nil {
1089			return fmt.Errorf("failed to refresh auxiliary token: %v", err)
1090		}
1091	}
1092	return nil
1093}
1094
1095// RefreshExchangeWithContext refreshes the token, but for a different resource.
1096func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
1097	if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil {
1098		return fmt.Errorf("failed to refresh primary token: %v", err)
1099	}
1100	for _, aux := range mt.AuxiliaryTokens {
1101		if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil {
1102			return fmt.Errorf("failed to refresh auxiliary token: %v", err)
1103		}
1104	}
1105	return nil
1106}
1107
1108// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
1109func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
1110	if err := validateStringParam(clientID, "clientID"); err != nil {
1111		return nil, err
1112	}
1113	if err := validateStringParam(secret, "secret"); err != nil {
1114		return nil, err
1115	}
1116	if err := validateStringParam(resource, "resource"); err != nil {
1117		return nil, err
1118	}
1119	auxTenants := multiTenantCfg.AuxiliaryTenants()
1120	m := MultiTenantServicePrincipalToken{
1121		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1122	}
1123	primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
1124	if err != nil {
1125		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1126	}
1127	m.PrimaryToken = primary
1128	for i := range auxTenants {
1129		aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
1130		if err != nil {
1131			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1132		}
1133		m.AuxiliaryTokens[i] = aux
1134	}
1135	return &m, nil
1136}
1137