1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/sha1"
22	"crypto/x509"
23	"encoding/base64"
24	"encoding/json"
25	"errors"
26	"fmt"
27	"io/ioutil"
28	"math"
29	"net"
30	"net/http"
31	"net/url"
32	"strings"
33	"sync"
34	"time"
35
36	"github.com/Azure/go-autorest/autorest/date"
37	"github.com/Azure/go-autorest/tracing"
38	"github.com/dgrijalva/jwt-go"
39)
40
41const (
42	defaultRefresh = 5 * time.Minute
43
44	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
45	OAuthGrantTypeDeviceCode = "device_code"
46
47	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
48	OAuthGrantTypeClientCredentials = "client_credentials"
49
50	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
51	OAuthGrantTypeUserPass = "password"
52
53	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
54	OAuthGrantTypeRefreshToken = "refresh_token"
55
56	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
57	OAuthGrantTypeAuthorizationCode = "authorization_code"
58
59	// metadataHeader is the header required by MSI extension
60	metadataHeader = "Metadata"
61
62	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
63	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
64
65	// the default number of attempts to refresh an MSI authentication token
66	defaultMaxMSIRefreshAttempts = 5
67)
68
69// OAuthTokenProvider is an interface which should be implemented by an access token retriever
70type OAuthTokenProvider interface {
71	OAuthToken() string
72}
73
74// TokenRefreshError is an interface used by errors returned during token refresh.
75type TokenRefreshError interface {
76	error
77	Response() *http.Response
78}
79
80// Refresher is an interface for token refresh functionality
81type Refresher interface {
82	Refresh() error
83	RefreshExchange(resource string) error
84	EnsureFresh() error
85}
86
87// RefresherWithContext is an interface for token refresh functionality
88type RefresherWithContext interface {
89	RefreshWithContext(ctx context.Context) error
90	RefreshExchangeWithContext(ctx context.Context, resource string) error
91	EnsureFreshWithContext(ctx context.Context) error
92}
93
94// TokenRefreshCallback is the type representing callbacks that will be called after
95// a successful token refresh
96type TokenRefreshCallback func(Token) error
97
98// Token encapsulates the access token used to authorize Azure requests.
99// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
100type Token struct {
101	AccessToken  string `json:"access_token"`
102	RefreshToken string `json:"refresh_token"`
103
104	ExpiresIn json.Number `json:"expires_in"`
105	ExpiresOn json.Number `json:"expires_on"`
106	NotBefore json.Number `json:"not_before"`
107
108	Resource string `json:"resource"`
109	Type     string `json:"token_type"`
110}
111
112func newToken() Token {
113	return Token{
114		ExpiresIn: "0",
115		ExpiresOn: "0",
116		NotBefore: "0",
117	}
118}
119
120// IsZero returns true if the token object is zero-initialized.
121func (t Token) IsZero() bool {
122	return t == Token{}
123}
124
125// Expires returns the time.Time when the Token expires.
126func (t Token) Expires() time.Time {
127	s, err := t.ExpiresOn.Float64()
128	if err != nil {
129		s = -3600
130	}
131
132	expiration := date.NewUnixTimeFromSeconds(s)
133
134	return time.Time(expiration).UTC()
135}
136
137// IsExpired returns true if the Token is expired, false otherwise.
138func (t Token) IsExpired() bool {
139	return t.WillExpireIn(0)
140}
141
142// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
143// from now, false otherwise.
144func (t Token) WillExpireIn(d time.Duration) bool {
145	return !t.Expires().After(time.Now().Add(d))
146}
147
148//OAuthToken return the current access token
149func (t *Token) OAuthToken() string {
150	return t.AccessToken
151}
152
153// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
154// that is submitted when acquiring an oAuth token.
155type ServicePrincipalSecret interface {
156	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
157}
158
159// ServicePrincipalNoSecret represents a secret type that contains no secret
160// meaning it is not valid for fetching a fresh token. This is used by Manual
161type ServicePrincipalNoSecret struct {
162}
163
164// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
165// It only returns an error for the ServicePrincipalNoSecret type
166func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
167	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
168}
169
170// MarshalJSON implements the json.Marshaler interface.
171func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
172	type tokenType struct {
173		Type string `json:"type"`
174	}
175	return json.Marshal(tokenType{
176		Type: "ServicePrincipalNoSecret",
177	})
178}
179
180// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
181type ServicePrincipalTokenSecret struct {
182	ClientSecret string `json:"value"`
183}
184
185// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
186// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
187func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
188	v.Set("client_secret", tokenSecret.ClientSecret)
189	return nil
190}
191
192// MarshalJSON implements the json.Marshaler interface.
193func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
194	type tokenType struct {
195		Type  string `json:"type"`
196		Value string `json:"value"`
197	}
198	return json.Marshal(tokenType{
199		Type:  "ServicePrincipalTokenSecret",
200		Value: tokenSecret.ClientSecret,
201	})
202}
203
204// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
205type ServicePrincipalCertificateSecret struct {
206	Certificate *x509.Certificate
207	PrivateKey  *rsa.PrivateKey
208}
209
210// SignJwt returns the JWT signed with the certificate's private key.
211func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
212	hasher := sha1.New()
213	_, err := hasher.Write(secret.Certificate.Raw)
214	if err != nil {
215		return "", err
216	}
217
218	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
219
220	// The jti (JWT ID) claim provides a unique identifier for the JWT.
221	jti := make([]byte, 20)
222	_, err = rand.Read(jti)
223	if err != nil {
224		return "", err
225	}
226
227	token := jwt.New(jwt.SigningMethodRS256)
228	token.Header["x5t"] = thumbprint
229	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
230	token.Header["x5c"] = x5c
231	token.Claims = jwt.MapClaims{
232		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
233		"iss": spt.inner.ClientID,
234		"sub": spt.inner.ClientID,
235		"jti": base64.URLEncoding.EncodeToString(jti),
236		"nbf": time.Now().Unix(),
237		"exp": time.Now().Add(time.Hour * 24).Unix(),
238	}
239
240	signedString, err := token.SignedString(secret.PrivateKey)
241	return signedString, err
242}
243
244// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
245// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
246func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
247	jwt, err := secret.SignJwt(spt)
248	if err != nil {
249		return err
250	}
251
252	v.Set("client_assertion", jwt)
253	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
254	return nil
255}
256
257// MarshalJSON implements the json.Marshaler interface.
258func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
259	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
260}
261
262// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
263type ServicePrincipalMSISecret struct {
264}
265
266// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
267func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
268	return nil
269}
270
271// MarshalJSON implements the json.Marshaler interface.
272func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
273	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
274}
275
276// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
277type ServicePrincipalUsernamePasswordSecret struct {
278	Username string `json:"username"`
279	Password string `json:"password"`
280}
281
282// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
283func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
284	v.Set("username", secret.Username)
285	v.Set("password", secret.Password)
286	return nil
287}
288
289// MarshalJSON implements the json.Marshaler interface.
290func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
291	type tokenType struct {
292		Type     string `json:"type"`
293		Username string `json:"username"`
294		Password string `json:"password"`
295	}
296	return json.Marshal(tokenType{
297		Type:     "ServicePrincipalUsernamePasswordSecret",
298		Username: secret.Username,
299		Password: secret.Password,
300	})
301}
302
303// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
304type ServicePrincipalAuthorizationCodeSecret struct {
305	ClientSecret      string `json:"value"`
306	AuthorizationCode string `json:"authCode"`
307	RedirectURI       string `json:"redirect"`
308}
309
310// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
311func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
312	v.Set("code", secret.AuthorizationCode)
313	v.Set("client_secret", secret.ClientSecret)
314	v.Set("redirect_uri", secret.RedirectURI)
315	return nil
316}
317
318// MarshalJSON implements the json.Marshaler interface.
319func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
320	type tokenType struct {
321		Type     string `json:"type"`
322		Value    string `json:"value"`
323		AuthCode string `json:"authCode"`
324		Redirect string `json:"redirect"`
325	}
326	return json.Marshal(tokenType{
327		Type:     "ServicePrincipalAuthorizationCodeSecret",
328		Value:    secret.ClientSecret,
329		AuthCode: secret.AuthorizationCode,
330		Redirect: secret.RedirectURI,
331	})
332}
333
334// ServicePrincipalToken encapsulates a Token created for a Service Principal.
335type ServicePrincipalToken struct {
336	inner            servicePrincipalToken
337	refreshLock      *sync.RWMutex
338	sender           Sender
339	refreshCallbacks []TokenRefreshCallback
340	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
341	MaxMSIRefreshAttempts int
342}
343
344// MarshalTokenJSON returns the marshalled inner token.
345func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
346	return json.Marshal(spt.inner.Token)
347}
348
349// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
350func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
351	spt.refreshCallbacks = callbacks
352}
353
354// MarshalJSON implements the json.Marshaler interface.
355func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
356	return json.Marshal(spt.inner)
357}
358
359// UnmarshalJSON implements the json.Unmarshaler interface.
360func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
361	// need to determine the token type
362	raw := map[string]interface{}{}
363	err := json.Unmarshal(data, &raw)
364	if err != nil {
365		return err
366	}
367	secret := raw["secret"].(map[string]interface{})
368	switch secret["type"] {
369	case "ServicePrincipalNoSecret":
370		spt.inner.Secret = &ServicePrincipalNoSecret{}
371	case "ServicePrincipalTokenSecret":
372		spt.inner.Secret = &ServicePrincipalTokenSecret{}
373	case "ServicePrincipalCertificateSecret":
374		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
375	case "ServicePrincipalMSISecret":
376		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
377	case "ServicePrincipalUsernamePasswordSecret":
378		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
379	case "ServicePrincipalAuthorizationCodeSecret":
380		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
381	default:
382		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
383	}
384	err = json.Unmarshal(data, &spt.inner)
385	if err != nil {
386		return err
387	}
388	// Don't override the refreshLock or the sender if those have been already set.
389	if spt.refreshLock == nil {
390		spt.refreshLock = &sync.RWMutex{}
391	}
392	if spt.sender == nil {
393		spt.sender = &http.Client{Transport: tracing.Transport}
394	}
395	return nil
396}
397
398// internal type used for marshalling/unmarshalling
399type servicePrincipalToken struct {
400	Token         Token                  `json:"token"`
401	Secret        ServicePrincipalSecret `json:"secret"`
402	OauthConfig   OAuthConfig            `json:"oauth"`
403	ClientID      string                 `json:"clientID"`
404	Resource      string                 `json:"resource"`
405	AutoRefresh   bool                   `json:"autoRefresh"`
406	RefreshWithin time.Duration          `json:"refreshWithin"`
407}
408
409func validateOAuthConfig(oac OAuthConfig) error {
410	if oac.IsZero() {
411		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
412	}
413	return nil
414}
415
416// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
417func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
418	if err := validateOAuthConfig(oauthConfig); err != nil {
419		return nil, err
420	}
421	if err := validateStringParam(id, "id"); err != nil {
422		return nil, err
423	}
424	if err := validateStringParam(resource, "resource"); err != nil {
425		return nil, err
426	}
427	if secret == nil {
428		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
429	}
430	spt := &ServicePrincipalToken{
431		inner: servicePrincipalToken{
432			Token:         newToken(),
433			OauthConfig:   oauthConfig,
434			Secret:        secret,
435			ClientID:      id,
436			Resource:      resource,
437			AutoRefresh:   true,
438			RefreshWithin: defaultRefresh,
439		},
440		refreshLock:      &sync.RWMutex{},
441		sender:           &http.Client{Transport: tracing.Transport},
442		refreshCallbacks: callbacks,
443	}
444	return spt, nil
445}
446
447// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
448func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
449	if err := validateOAuthConfig(oauthConfig); err != nil {
450		return nil, err
451	}
452	if err := validateStringParam(clientID, "clientID"); err != nil {
453		return nil, err
454	}
455	if err := validateStringParam(resource, "resource"); err != nil {
456		return nil, err
457	}
458	if token.IsZero() {
459		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
460	}
461	spt, err := NewServicePrincipalTokenWithSecret(
462		oauthConfig,
463		clientID,
464		resource,
465		&ServicePrincipalNoSecret{},
466		callbacks...)
467	if err != nil {
468		return nil, err
469	}
470
471	spt.inner.Token = token
472
473	return spt, nil
474}
475
476// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
477func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
478	if err := validateOAuthConfig(oauthConfig); err != nil {
479		return nil, err
480	}
481	if err := validateStringParam(clientID, "clientID"); err != nil {
482		return nil, err
483	}
484	if err := validateStringParam(resource, "resource"); err != nil {
485		return nil, err
486	}
487	if secret == nil {
488		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
489	}
490	if token.IsZero() {
491		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
492	}
493	spt, err := NewServicePrincipalTokenWithSecret(
494		oauthConfig,
495		clientID,
496		resource,
497		secret,
498		callbacks...)
499	if err != nil {
500		return nil, err
501	}
502
503	spt.inner.Token = token
504
505	return spt, nil
506}
507
508// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
509// credentials scoped to the named resource.
510func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
511	if err := validateOAuthConfig(oauthConfig); err != nil {
512		return nil, err
513	}
514	if err := validateStringParam(clientID, "clientID"); err != nil {
515		return nil, err
516	}
517	if err := validateStringParam(secret, "secret"); err != nil {
518		return nil, err
519	}
520	if err := validateStringParam(resource, "resource"); err != nil {
521		return nil, err
522	}
523	return NewServicePrincipalTokenWithSecret(
524		oauthConfig,
525		clientID,
526		resource,
527		&ServicePrincipalTokenSecret{
528			ClientSecret: secret,
529		},
530		callbacks...,
531	)
532}
533
534// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
535func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
536	if err := validateOAuthConfig(oauthConfig); err != nil {
537		return nil, err
538	}
539	if err := validateStringParam(clientID, "clientID"); err != nil {
540		return nil, err
541	}
542	if err := validateStringParam(resource, "resource"); err != nil {
543		return nil, err
544	}
545	if certificate == nil {
546		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
547	}
548	if privateKey == nil {
549		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
550	}
551	return NewServicePrincipalTokenWithSecret(
552		oauthConfig,
553		clientID,
554		resource,
555		&ServicePrincipalCertificateSecret{
556			PrivateKey:  privateKey,
557			Certificate: certificate,
558		},
559		callbacks...,
560	)
561}
562
563// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
564func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
565	if err := validateOAuthConfig(oauthConfig); err != nil {
566		return nil, err
567	}
568	if err := validateStringParam(clientID, "clientID"); err != nil {
569		return nil, err
570	}
571	if err := validateStringParam(username, "username"); err != nil {
572		return nil, err
573	}
574	if err := validateStringParam(password, "password"); err != nil {
575		return nil, err
576	}
577	if err := validateStringParam(resource, "resource"); err != nil {
578		return nil, err
579	}
580	return NewServicePrincipalTokenWithSecret(
581		oauthConfig,
582		clientID,
583		resource,
584		&ServicePrincipalUsernamePasswordSecret{
585			Username: username,
586			Password: password,
587		},
588		callbacks...,
589	)
590}
591
592// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
593func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
594
595	if err := validateOAuthConfig(oauthConfig); err != nil {
596		return nil, err
597	}
598	if err := validateStringParam(clientID, "clientID"); err != nil {
599		return nil, err
600	}
601	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
602		return nil, err
603	}
604	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
605		return nil, err
606	}
607	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
608		return nil, err
609	}
610	if err := validateStringParam(resource, "resource"); err != nil {
611		return nil, err
612	}
613
614	return NewServicePrincipalTokenWithSecret(
615		oauthConfig,
616		clientID,
617		resource,
618		&ServicePrincipalAuthorizationCodeSecret{
619			ClientSecret:      clientSecret,
620			AuthorizationCode: authorizationCode,
621			RedirectURI:       redirectURI,
622		},
623		callbacks...,
624	)
625}
626
627// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
628func GetMSIVMEndpoint() (string, error) {
629	return msiEndpoint, nil
630}
631
632// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
633// It will use the system assigned identity when creating the token.
634func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
635	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
636}
637
638// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
639// It will use the specified user assigned identity when creating the token.
640func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
641	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
642}
643
644func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
645	if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
646		return nil, err
647	}
648	if err := validateStringParam(resource, "resource"); err != nil {
649		return nil, err
650	}
651	if userAssignedID != nil {
652		if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
653			return nil, err
654		}
655	}
656	// We set the oauth config token endpoint to be MSI's endpoint
657	msiEndpointURL, err := url.Parse(msiEndpoint)
658	if err != nil {
659		return nil, err
660	}
661
662	v := url.Values{}
663	v.Set("resource", resource)
664	v.Set("api-version", "2018-02-01")
665	if userAssignedID != nil {
666		v.Set("client_id", *userAssignedID)
667	}
668	msiEndpointURL.RawQuery = v.Encode()
669
670	spt := &ServicePrincipalToken{
671		inner: servicePrincipalToken{
672			Token: newToken(),
673			OauthConfig: OAuthConfig{
674				TokenEndpoint: *msiEndpointURL,
675			},
676			Secret:        &ServicePrincipalMSISecret{},
677			Resource:      resource,
678			AutoRefresh:   true,
679			RefreshWithin: defaultRefresh,
680		},
681		refreshLock:           &sync.RWMutex{},
682		sender:                &http.Client{Transport: tracing.Transport},
683		refreshCallbacks:      callbacks,
684		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
685	}
686
687	if userAssignedID != nil {
688		spt.inner.ClientID = *userAssignedID
689	}
690
691	return spt, nil
692}
693
694// internal type that implements TokenRefreshError
695type tokenRefreshError struct {
696	message string
697	resp    *http.Response
698}
699
700// Error implements the error interface which is part of the TokenRefreshError interface.
701func (tre tokenRefreshError) Error() string {
702	return tre.message
703}
704
705// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
706func (tre tokenRefreshError) Response() *http.Response {
707	return tre.resp
708}
709
710func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
711	return tokenRefreshError{message: message, resp: resp}
712}
713
714// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
715// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
716func (spt *ServicePrincipalToken) EnsureFresh() error {
717	return spt.EnsureFreshWithContext(context.Background())
718}
719
720// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
721// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
722func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
723	if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
724		// take the write lock then check to see if the token was already refreshed
725		spt.refreshLock.Lock()
726		defer spt.refreshLock.Unlock()
727		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
728			return spt.refreshInternal(ctx, spt.inner.Resource)
729		}
730	}
731	return nil
732}
733
734// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
735func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
736	if spt.refreshCallbacks != nil {
737		for _, callback := range spt.refreshCallbacks {
738			err := callback(spt.inner.Token)
739			if err != nil {
740				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
741			}
742		}
743	}
744	return nil
745}
746
747// Refresh obtains a fresh token for the Service Principal.
748// This method is not safe for concurrent use and should be syncrhonized.
749func (spt *ServicePrincipalToken) Refresh() error {
750	return spt.RefreshWithContext(context.Background())
751}
752
753// RefreshWithContext obtains a fresh token for the Service Principal.
754// This method is not safe for concurrent use and should be syncrhonized.
755func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
756	spt.refreshLock.Lock()
757	defer spt.refreshLock.Unlock()
758	return spt.refreshInternal(ctx, spt.inner.Resource)
759}
760
761// RefreshExchange refreshes the token, but for a different resource.
762// This method is not safe for concurrent use and should be syncrhonized.
763func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
764	return spt.RefreshExchangeWithContext(context.Background(), resource)
765}
766
767// RefreshExchangeWithContext refreshes the token, but for a different resource.
768// This method is not safe for concurrent use and should be syncrhonized.
769func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
770	spt.refreshLock.Lock()
771	defer spt.refreshLock.Unlock()
772	return spt.refreshInternal(ctx, resource)
773}
774
775func (spt *ServicePrincipalToken) getGrantType() string {
776	switch spt.inner.Secret.(type) {
777	case *ServicePrincipalUsernamePasswordSecret:
778		return OAuthGrantTypeUserPass
779	case *ServicePrincipalAuthorizationCodeSecret:
780		return OAuthGrantTypeAuthorizationCode
781	default:
782		return OAuthGrantTypeClientCredentials
783	}
784}
785
786func isIMDS(u url.URL) bool {
787	imds, err := url.Parse(msiEndpoint)
788	if err != nil {
789		return false
790	}
791	return u.Host == imds.Host && u.Path == imds.Path
792}
793
794func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
795	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
796	if err != nil {
797		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
798	}
799	req.Header.Add("User-Agent", UserAgent())
800	req = req.WithContext(ctx)
801	if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
802		v := url.Values{}
803		v.Set("client_id", spt.inner.ClientID)
804		v.Set("resource", resource)
805
806		if spt.inner.Token.RefreshToken != "" {
807			v.Set("grant_type", OAuthGrantTypeRefreshToken)
808			v.Set("refresh_token", spt.inner.Token.RefreshToken)
809			// web apps must specify client_secret when refreshing tokens
810			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
811			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
812				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
813				if err != nil {
814					return err
815				}
816			}
817		} else {
818			v.Set("grant_type", spt.getGrantType())
819			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
820			if err != nil {
821				return err
822			}
823		}
824
825		s := v.Encode()
826		body := ioutil.NopCloser(strings.NewReader(s))
827		req.ContentLength = int64(len(s))
828		req.Header.Set(contentType, mimeTypeFormPost)
829		req.Body = body
830	}
831
832	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
833		req.Method = http.MethodGet
834		req.Header.Set(metadataHeader, "true")
835	}
836
837	var resp *http.Response
838	if isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
839		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
840	} else {
841		resp, err = spt.sender.Do(req)
842	}
843	if err != nil {
844		return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
845	}
846
847	defer resp.Body.Close()
848	rb, err := ioutil.ReadAll(resp.Body)
849
850	if resp.StatusCode != http.StatusOK {
851		if err != nil {
852			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
853		}
854		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
855	}
856
857	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
858	// but some transient failure happened during deserialization.  by returning a generic error
859	// the retry logic will kick in (we don't retry on TokenRefreshError).
860
861	if err != nil {
862		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
863	}
864	if len(strings.Trim(string(rb), " ")) == 0 {
865		return fmt.Errorf("adal: Empty service principal token received during refresh")
866	}
867	var token Token
868	err = json.Unmarshal(rb, &token)
869	if err != nil {
870		return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
871	}
872
873	spt.inner.Token = token
874
875	return spt.InvokeRefreshCallbacks(token)
876}
877
878// retry logic specific to retrieving a token from the IMDS endpoint
879func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
880	// copied from client.go due to circular dependency
881	retries := []int{
882		http.StatusRequestTimeout,      // 408
883		http.StatusTooManyRequests,     // 429
884		http.StatusInternalServerError, // 500
885		http.StatusBadGateway,          // 502
886		http.StatusServiceUnavailable,  // 503
887		http.StatusGatewayTimeout,      // 504
888	}
889	// extra retry status codes specific to IMDS
890	retries = append(retries,
891		http.StatusNotFound,
892		http.StatusGone,
893		// all remaining 5xx
894		http.StatusNotImplemented,
895		http.StatusHTTPVersionNotSupported,
896		http.StatusVariantAlsoNegotiates,
897		http.StatusInsufficientStorage,
898		http.StatusLoopDetected,
899		http.StatusNotExtended,
900		http.StatusNetworkAuthenticationRequired)
901
902	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
903
904	const maxDelay time.Duration = 60 * time.Second
905
906	attempt := 0
907	delay := time.Duration(0)
908
909	for attempt < maxAttempts {
910		resp, err = sender.Do(req)
911		// retry on temporary network errors, e.g. transient network failures.
912		// if we don't receive a response then assume we can't connect to the
913		// endpoint so we're likely not running on an Azure VM so don't retry.
914		if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
915			return
916		}
917
918		// perform exponential backoff with a cap.
919		// must increment attempt before calculating delay.
920		attempt++
921		// the base value of 2 is the "delta backoff" as specified in the guidance doc
922		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
923		if delay > maxDelay {
924			delay = maxDelay
925		}
926
927		select {
928		case <-time.After(delay):
929			// intentionally left blank
930		case <-req.Context().Done():
931			err = req.Context().Err()
932			return
933		}
934	}
935	return
936}
937
938// returns true if the specified error is a temporary network error or false if it's not.
939// if the error doesn't implement the net.Error interface the return value is true.
940func isTemporaryNetworkError(err error) bool {
941	if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
942		return true
943	}
944	return false
945}
946
947// returns true if slice ints contains the value n
948func containsInt(ints []int, n int) bool {
949	for _, i := range ints {
950		if i == n {
951			return true
952		}
953	}
954	return false
955}
956
957// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
958func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
959	spt.inner.AutoRefresh = autoRefresh
960}
961
962// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
963// refresh the token.
964func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
965	spt.inner.RefreshWithin = d
966	return
967}
968
969// SetSender sets the http.Client used when obtaining the Service Principal token. An
970// undecorated http.Client is used by default.
971func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
972
973// OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
974func (spt *ServicePrincipalToken) OAuthToken() string {
975	spt.refreshLock.RLock()
976	defer spt.refreshLock.RUnlock()
977	return spt.inner.Token.OAuthToken()
978}
979
980// Token returns a copy of the current token.
981func (spt *ServicePrincipalToken) Token() Token {
982	spt.refreshLock.RLock()
983	defer spt.refreshLock.RUnlock()
984	return spt.inner.Token
985}
986