1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/sha1"
22	"crypto/x509"
23	"encoding/base64"
24	"encoding/json"
25	"fmt"
26	"io/ioutil"
27	"net"
28	"net/http"
29	"net/url"
30	"strconv"
31	"strings"
32	"sync"
33	"time"
34
35	"github.com/Azure/go-autorest/autorest/date"
36	"github.com/dgrijalva/jwt-go"
37)
38
39const (
40	defaultRefresh = 5 * time.Minute
41
42	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
43	OAuthGrantTypeDeviceCode = "device_code"
44
45	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
46	OAuthGrantTypeClientCredentials = "client_credentials"
47
48	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
49	OAuthGrantTypeUserPass = "password"
50
51	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
52	OAuthGrantTypeRefreshToken = "refresh_token"
53
54	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
55	OAuthGrantTypeAuthorizationCode = "authorization_code"
56
57	// metadataHeader is the header required by MSI extension
58	metadataHeader = "Metadata"
59
60	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
61	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
62)
63
64// OAuthTokenProvider is an interface which should be implemented by an access token retriever
65type OAuthTokenProvider interface {
66	OAuthToken() string
67}
68
69// TokenRefreshError is an interface used by errors returned during token refresh.
70type TokenRefreshError interface {
71	error
72	Response() *http.Response
73}
74
75// Refresher is an interface for token refresh functionality
76type Refresher interface {
77	Refresh() error
78	RefreshExchange(resource string) error
79	EnsureFresh() error
80}
81
82// RefresherWithContext is an interface for token refresh functionality
83type RefresherWithContext interface {
84	RefreshWithContext(ctx context.Context) error
85	RefreshExchangeWithContext(ctx context.Context, resource string) error
86	EnsureFreshWithContext(ctx context.Context) error
87}
88
89// TokenRefreshCallback is the type representing callbacks that will be called after
90// a successful token refresh
91type TokenRefreshCallback func(Token) error
92
93// Token encapsulates the access token used to authorize Azure requests.
94type Token struct {
95	AccessToken  string `json:"access_token"`
96	RefreshToken string `json:"refresh_token"`
97
98	ExpiresIn string `json:"expires_in"`
99	ExpiresOn string `json:"expires_on"`
100	NotBefore string `json:"not_before"`
101
102	Resource string `json:"resource"`
103	Type     string `json:"token_type"`
104}
105
106// IsZero returns true if the token object is zero-initialized.
107func (t Token) IsZero() bool {
108	return t == Token{}
109}
110
111// Expires returns the time.Time when the Token expires.
112func (t Token) Expires() time.Time {
113	s, err := strconv.Atoi(t.ExpiresOn)
114	if err != nil {
115		s = -3600
116	}
117
118	expiration := date.NewUnixTimeFromSeconds(float64(s))
119
120	return time.Time(expiration).UTC()
121}
122
123// IsExpired returns true if the Token is expired, false otherwise.
124func (t Token) IsExpired() bool {
125	return t.WillExpireIn(0)
126}
127
128// WillExpireIn returns true if the Token will expire after the passed time.Duration interval
129// from now, false otherwise.
130func (t Token) WillExpireIn(d time.Duration) bool {
131	return !t.Expires().After(time.Now().Add(d))
132}
133
134//OAuthToken return the current access token
135func (t *Token) OAuthToken() string {
136	return t.AccessToken
137}
138
139// ServicePrincipalNoSecret represents a secret type that contains no secret
140// meaning it is not valid for fetching a fresh token. This is used by Manual
141type ServicePrincipalNoSecret struct {
142}
143
144// SetAuthenticationValues is a method of the interface ServicePrincipalSecret
145// It only returns an error for the ServicePrincipalNoSecret type
146func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
147	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
148}
149
150// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
151// that is submitted when acquiring an oAuth token.
152type ServicePrincipalSecret interface {
153	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
154}
155
156// ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
157type ServicePrincipalTokenSecret struct {
158	ClientSecret string
159}
160
161// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
162// It will populate the form submitted during oAuth Token Acquisition using the client_secret.
163func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
164	v.Set("client_secret", tokenSecret.ClientSecret)
165	return nil
166}
167
168// ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
169type ServicePrincipalCertificateSecret struct {
170	Certificate *x509.Certificate
171	PrivateKey  *rsa.PrivateKey
172}
173
174// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
175type ServicePrincipalMSISecret struct {
176}
177
178// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
179type ServicePrincipalUsernamePasswordSecret struct {
180	Username string
181	Password string
182}
183
184// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
185type ServicePrincipalAuthorizationCodeSecret struct {
186	ClientSecret      string
187	AuthorizationCode string
188	RedirectURI       string
189}
190
191// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
192func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
193	v.Set("code", secret.AuthorizationCode)
194	v.Set("client_secret", secret.ClientSecret)
195	v.Set("redirect_uri", secret.RedirectURI)
196	return nil
197}
198
199// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
200func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
201	v.Set("username", secret.Username)
202	v.Set("password", secret.Password)
203	return nil
204}
205
206// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
207func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
208	return nil
209}
210
211// SignJwt returns the JWT signed with the certificate's private key.
212func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
213	hasher := sha1.New()
214	_, err := hasher.Write(secret.Certificate.Raw)
215	if err != nil {
216		return "", err
217	}
218
219	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
220
221	// The jti (JWT ID) claim provides a unique identifier for the JWT.
222	jti := make([]byte, 20)
223	_, err = rand.Read(jti)
224	if err != nil {
225		return "", err
226	}
227
228	token := jwt.New(jwt.SigningMethodRS256)
229	token.Header["x5t"] = thumbprint
230	token.Claims = jwt.MapClaims{
231		"aud": spt.oauthConfig.TokenEndpoint.String(),
232		"iss": spt.clientID,
233		"sub": spt.clientID,
234		"jti": base64.URLEncoding.EncodeToString(jti),
235		"nbf": time.Now().Unix(),
236		"exp": time.Now().Add(time.Hour * 24).Unix(),
237	}
238
239	signedString, err := token.SignedString(secret.PrivateKey)
240	return signedString, err
241}
242
243// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
244// It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
245func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
246	jwt, err := secret.SignJwt(spt)
247	if err != nil {
248		return err
249	}
250
251	v.Set("client_assertion", jwt)
252	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
253	return nil
254}
255
256// ServicePrincipalToken encapsulates a Token created for a Service Principal.
257type ServicePrincipalToken struct {
258	token         Token
259	secret        ServicePrincipalSecret
260	oauthConfig   OAuthConfig
261	clientID      string
262	resource      string
263	autoRefresh   bool
264	refreshLock   *sync.RWMutex
265	refreshWithin time.Duration
266	sender        Sender
267
268	refreshCallbacks []TokenRefreshCallback
269}
270
271func validateOAuthConfig(oac OAuthConfig) error {
272	if oac.IsZero() {
273		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
274	}
275	return nil
276}
277
278// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
279func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
280	if err := validateOAuthConfig(oauthConfig); err != nil {
281		return nil, err
282	}
283	if err := validateStringParam(id, "id"); err != nil {
284		return nil, err
285	}
286	if err := validateStringParam(resource, "resource"); err != nil {
287		return nil, err
288	}
289	if secret == nil {
290		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
291	}
292	spt := &ServicePrincipalToken{
293		oauthConfig:      oauthConfig,
294		secret:           secret,
295		clientID:         id,
296		resource:         resource,
297		autoRefresh:      true,
298		refreshLock:      &sync.RWMutex{},
299		refreshWithin:    defaultRefresh,
300		sender:           &http.Client{},
301		refreshCallbacks: callbacks,
302	}
303	return spt, nil
304}
305
306// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
307func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
308	if err := validateOAuthConfig(oauthConfig); err != nil {
309		return nil, err
310	}
311	if err := validateStringParam(clientID, "clientID"); err != nil {
312		return nil, err
313	}
314	if err := validateStringParam(resource, "resource"); err != nil {
315		return nil, err
316	}
317	if token.IsZero() {
318		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
319	}
320	spt, err := NewServicePrincipalTokenWithSecret(
321		oauthConfig,
322		clientID,
323		resource,
324		&ServicePrincipalNoSecret{},
325		callbacks...)
326	if err != nil {
327		return nil, err
328	}
329
330	spt.token = token
331
332	return spt, nil
333}
334
335// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
336// credentials scoped to the named resource.
337func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
338	if err := validateOAuthConfig(oauthConfig); err != nil {
339		return nil, err
340	}
341	if err := validateStringParam(clientID, "clientID"); err != nil {
342		return nil, err
343	}
344	if err := validateStringParam(secret, "secret"); err != nil {
345		return nil, err
346	}
347	if err := validateStringParam(resource, "resource"); err != nil {
348		return nil, err
349	}
350	return NewServicePrincipalTokenWithSecret(
351		oauthConfig,
352		clientID,
353		resource,
354		&ServicePrincipalTokenSecret{
355			ClientSecret: secret,
356		},
357		callbacks...,
358	)
359}
360
361// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
362func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
363	if err := validateOAuthConfig(oauthConfig); err != nil {
364		return nil, err
365	}
366	if err := validateStringParam(clientID, "clientID"); err != nil {
367		return nil, err
368	}
369	if err := validateStringParam(resource, "resource"); err != nil {
370		return nil, err
371	}
372	if certificate == nil {
373		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
374	}
375	if privateKey == nil {
376		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
377	}
378	return NewServicePrincipalTokenWithSecret(
379		oauthConfig,
380		clientID,
381		resource,
382		&ServicePrincipalCertificateSecret{
383			PrivateKey:  privateKey,
384			Certificate: certificate,
385		},
386		callbacks...,
387	)
388}
389
390// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
391func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
392	if err := validateOAuthConfig(oauthConfig); err != nil {
393		return nil, err
394	}
395	if err := validateStringParam(clientID, "clientID"); err != nil {
396		return nil, err
397	}
398	if err := validateStringParam(username, "username"); err != nil {
399		return nil, err
400	}
401	if err := validateStringParam(password, "password"); err != nil {
402		return nil, err
403	}
404	if err := validateStringParam(resource, "resource"); err != nil {
405		return nil, err
406	}
407	return NewServicePrincipalTokenWithSecret(
408		oauthConfig,
409		clientID,
410		resource,
411		&ServicePrincipalUsernamePasswordSecret{
412			Username: username,
413			Password: password,
414		},
415		callbacks...,
416	)
417}
418
419// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
420func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
421
422	if err := validateOAuthConfig(oauthConfig); err != nil {
423		return nil, err
424	}
425	if err := validateStringParam(clientID, "clientID"); err != nil {
426		return nil, err
427	}
428	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
429		return nil, err
430	}
431	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
432		return nil, err
433	}
434	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
435		return nil, err
436	}
437	if err := validateStringParam(resource, "resource"); err != nil {
438		return nil, err
439	}
440
441	return NewServicePrincipalTokenWithSecret(
442		oauthConfig,
443		clientID,
444		resource,
445		&ServicePrincipalAuthorizationCodeSecret{
446			ClientSecret:      clientSecret,
447			AuthorizationCode: authorizationCode,
448			RedirectURI:       redirectURI,
449		},
450		callbacks...,
451	)
452}
453
454// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
455func GetMSIVMEndpoint() (string, error) {
456	return msiEndpoint, nil
457}
458
459// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
460// It will use the system assigned identity when creating the token.
461func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
462	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
463}
464
465// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
466// It will use the specified user assigned identity when creating the token.
467func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
468	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
469}
470
471func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
472	if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
473		return nil, err
474	}
475	if err := validateStringParam(resource, "resource"); err != nil {
476		return nil, err
477	}
478	if userAssignedID != nil {
479		if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
480			return nil, err
481		}
482	}
483	// We set the oauth config token endpoint to be MSI's endpoint
484	msiEndpointURL, err := url.Parse(msiEndpoint)
485	if err != nil {
486		return nil, err
487	}
488
489	v := url.Values{}
490	v.Set("resource", resource)
491	v.Set("api-version", "2018-02-01")
492	if userAssignedID != nil {
493		v.Set("client_id", *userAssignedID)
494	}
495	msiEndpointURL.RawQuery = v.Encode()
496
497	spt := &ServicePrincipalToken{
498		oauthConfig: OAuthConfig{
499			TokenEndpoint: *msiEndpointURL,
500		},
501		secret:           &ServicePrincipalMSISecret{},
502		resource:         resource,
503		autoRefresh:      true,
504		refreshLock:      &sync.RWMutex{},
505		refreshWithin:    defaultRefresh,
506		sender:           &http.Client{},
507		refreshCallbacks: callbacks,
508	}
509
510	if userAssignedID != nil {
511		spt.clientID = *userAssignedID
512	}
513
514	return spt, nil
515}
516
517// internal type that implements TokenRefreshError
518type tokenRefreshError struct {
519	message string
520	resp    *http.Response
521}
522
523// Error implements the error interface which is part of the TokenRefreshError interface.
524func (tre tokenRefreshError) Error() string {
525	return tre.message
526}
527
528// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
529func (tre tokenRefreshError) Response() *http.Response {
530	return tre.resp
531}
532
533func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
534	return tokenRefreshError{message: message, resp: resp}
535}
536
537// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
538// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
539func (spt *ServicePrincipalToken) EnsureFresh() error {
540	return spt.EnsureFreshWithContext(context.Background())
541}
542
543// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
544// RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
545func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
546	if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) {
547		// take the write lock then check to see if the token was already refreshed
548		spt.refreshLock.Lock()
549		defer spt.refreshLock.Unlock()
550		if spt.token.WillExpireIn(spt.refreshWithin) {
551			return spt.refreshInternal(ctx, spt.resource)
552		}
553	}
554	return nil
555}
556
557// InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
558func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
559	if spt.refreshCallbacks != nil {
560		for _, callback := range spt.refreshCallbacks {
561			err := callback(spt.token)
562			if err != nil {
563				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
564			}
565		}
566	}
567	return nil
568}
569
570// Refresh obtains a fresh token for the Service Principal.
571// This method is not safe for concurrent use and should be syncrhonized.
572func (spt *ServicePrincipalToken) Refresh() error {
573	return spt.RefreshWithContext(context.Background())
574}
575
576// RefreshWithContext obtains a fresh token for the Service Principal.
577// This method is not safe for concurrent use and should be syncrhonized.
578func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
579	spt.refreshLock.Lock()
580	defer spt.refreshLock.Unlock()
581	return spt.refreshInternal(ctx, spt.resource)
582}
583
584// RefreshExchange refreshes the token, but for a different resource.
585// This method is not safe for concurrent use and should be syncrhonized.
586func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
587	return spt.RefreshExchangeWithContext(context.Background(), resource)
588}
589
590// RefreshExchangeWithContext refreshes the token, but for a different resource.
591// This method is not safe for concurrent use and should be syncrhonized.
592func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
593	spt.refreshLock.Lock()
594	defer spt.refreshLock.Unlock()
595	return spt.refreshInternal(ctx, resource)
596}
597
598func (spt *ServicePrincipalToken) getGrantType() string {
599	switch spt.secret.(type) {
600	case *ServicePrincipalUsernamePasswordSecret:
601		return OAuthGrantTypeUserPass
602	case *ServicePrincipalAuthorizationCodeSecret:
603		return OAuthGrantTypeAuthorizationCode
604	default:
605		return OAuthGrantTypeClientCredentials
606	}
607}
608
609func isIMDS(u url.URL) bool {
610	imds, err := url.Parse(msiEndpoint)
611	if err != nil {
612		return false
613	}
614	return u.Host == imds.Host && u.Path == imds.Path
615}
616
617func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
618	req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil)
619	if err != nil {
620		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
621	}
622	req = req.WithContext(ctx)
623	if !isIMDS(spt.oauthConfig.TokenEndpoint) {
624		v := url.Values{}
625		v.Set("client_id", spt.clientID)
626		v.Set("resource", resource)
627
628		if spt.token.RefreshToken != "" {
629			v.Set("grant_type", OAuthGrantTypeRefreshToken)
630			v.Set("refresh_token", spt.token.RefreshToken)
631		} else {
632			v.Set("grant_type", spt.getGrantType())
633			err := spt.secret.SetAuthenticationValues(spt, &v)
634			if err != nil {
635				return err
636			}
637		}
638
639		s := v.Encode()
640		body := ioutil.NopCloser(strings.NewReader(s))
641		req.ContentLength = int64(len(s))
642		req.Header.Set(contentType, mimeTypeFormPost)
643		req.Body = body
644	}
645
646	if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok {
647		req.Method = http.MethodGet
648		req.Header.Set(metadataHeader, "true")
649	}
650
651	var resp *http.Response
652	if isIMDS(spt.oauthConfig.TokenEndpoint) {
653		resp, err = retry(spt.sender, req)
654	} else {
655		resp, err = spt.sender.Do(req)
656	}
657	if err != nil {
658		return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
659	}
660
661	defer resp.Body.Close()
662	rb, err := ioutil.ReadAll(resp.Body)
663
664	if resp.StatusCode != http.StatusOK {
665		if err != nil {
666			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
667		}
668		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
669	}
670
671	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
672	// but some transient failure happened during deserialization.  by returning a generic error
673	// the retry logic will kick in (we don't retry on TokenRefreshError).
674
675	if err != nil {
676		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
677	}
678	if len(strings.Trim(string(rb), " ")) == 0 {
679		return fmt.Errorf("adal: Empty service principal token received during refresh")
680	}
681	var token Token
682	err = json.Unmarshal(rb, &token)
683	if err != nil {
684		return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
685	}
686
687	spt.token = token
688
689	return spt.InvokeRefreshCallbacks(token)
690}
691
692func retry(sender Sender, req *http.Request) (resp *http.Response, err error) {
693	retries := []int{
694		http.StatusRequestTimeout,      // 408
695		http.StatusTooManyRequests,     // 429
696		http.StatusInternalServerError, // 500
697		http.StatusBadGateway,          // 502
698		http.StatusServiceUnavailable,  // 503
699		http.StatusGatewayTimeout,      // 504
700	}
701	// Extra retry status codes requered
702	retries = append(retries, http.StatusNotFound,
703		// all remaining 5xx
704		http.StatusNotImplemented,
705		http.StatusHTTPVersionNotSupported,
706		http.StatusVariantAlsoNegotiates,
707		http.StatusInsufficientStorage,
708		http.StatusLoopDetected,
709		http.StatusNotExtended,
710		http.StatusNetworkAuthenticationRequired)
711
712	attempt := 0
713	maxAttempts := 5
714
715	for attempt < maxAttempts {
716		resp, err = sender.Do(req)
717		// retry on temporary network errors, e.g. transient network failures.
718		if (err != nil && !isTemporaryNetworkError(err)) || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
719			return
720		}
721
722		if !delay(resp, req.Context().Done()) {
723			select {
724			case <-time.After(time.Second):
725				attempt++
726			case <-req.Context().Done():
727				err = req.Context().Err()
728				return
729			}
730		}
731	}
732	return
733}
734
735func isTemporaryNetworkError(err error) bool {
736	if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
737		return true
738	}
739	return false
740}
741
742func containsInt(ints []int, n int) bool {
743	for _, i := range ints {
744		if i == n {
745			return true
746		}
747	}
748	return false
749}
750
751func delay(resp *http.Response, cancel <-chan struct{}) bool {
752	if resp == nil {
753		return false
754	}
755	retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
756	if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
757		select {
758		case <-time.After(time.Duration(retryAfter) * time.Second):
759			return true
760		case <-cancel:
761			return false
762		}
763	}
764	return false
765}
766
767// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
768func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
769	spt.autoRefresh = autoRefresh
770}
771
772// SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
773// refresh the token.
774func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
775	spt.refreshWithin = d
776	return
777}
778
779// SetSender sets the http.Client used when obtaining the Service Principal token. An
780// undecorated http.Client is used by default.
781func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
782
783// OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
784func (spt *ServicePrincipalToken) OAuthToken() string {
785	spt.refreshLock.RLock()
786	defer spt.refreshLock.RUnlock()
787	return spt.token.OAuthToken()
788}
789
790// Token returns a copy of the current token.
791func (spt *ServicePrincipalToken) Token() Token {
792	spt.refreshLock.RLock()
793	defer spt.refreshLock.RUnlock()
794	return spt.token
795}
796