1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/x509"
22	"crypto/x509/pkix"
23	"encoding/json"
24	"fmt"
25	"io/ioutil"
26	"math/big"
27	"net/http"
28	"net/url"
29	"os"
30	"reflect"
31	"strconv"
32	"strings"
33	"sync"
34	"testing"
35	"time"
36
37	"github.com/Azure/go-autorest/autorest/date"
38	"github.com/Azure/go-autorest/autorest/mocks"
39	jwt "github.com/dgrijalva/jwt-go"
40)
41
42const (
43	defaultFormData       = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
44	defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
45)
46
47func TestTokenExpires(t *testing.T) {
48	tt := time.Now().Add(5 * time.Second)
49	tk := newTokenExpiresAt(tt)
50
51	if tk.Expires().Equal(tt) {
52		t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
53	}
54}
55
56func TestTokenIsExpired(t *testing.T) {
57	tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
58
59	if !tk.IsExpired() {
60		t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
61			time.Now().UTC(), tk.Expires())
62	}
63}
64
65func TestTokenIsExpiredUninitialized(t *testing.T) {
66	tk := &Token{}
67
68	if !tk.IsExpired() {
69		t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
70	}
71}
72
73func TestTokenIsNoExpired(t *testing.T) {
74	tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
75
76	if tk.IsExpired() {
77		t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
78	}
79}
80
81func TestTokenWillExpireIn(t *testing.T) {
82	d := 5 * time.Second
83	tk := newTokenExpiresIn(d)
84
85	if !tk.WillExpireIn(d) {
86		t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
87	}
88}
89
90func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
91	spt := newServicePrincipalToken()
92
93	if !spt.inner.AutoRefresh {
94		t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
95	}
96
97	spt.SetAutoRefresh(false)
98	if spt.inner.AutoRefresh {
99		t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
100	}
101}
102
103func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) {
104	spt := newServicePrincipalToken()
105
106	var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
107		return nil, nil
108	}
109
110	if spt.customRefreshFunc != nil {
111		t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't")
112	}
113
114	spt.SetCustomRefreshFunc(refreshFunc)
115
116	if spt.customRefreshFunc == nil {
117		t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func")
118	}
119}
120
121func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
122	spt := newServicePrincipalToken()
123
124	if spt.inner.RefreshWithin != defaultRefresh {
125		t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
126	}
127
128	spt.SetRefreshWithin(2 * defaultRefresh)
129	if spt.inner.RefreshWithin != 2*defaultRefresh {
130		t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
131	}
132}
133
134func TestServicePrincipalTokenSetSender(t *testing.T) {
135	spt := newServicePrincipalToken()
136
137	c := &http.Client{}
138	spt.SetSender(c)
139	if !reflect.DeepEqual(c, spt.sender) {
140		t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
141	}
142}
143
144func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) {
145	spt := newServicePrincipalToken()
146
147	called := false
148	var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
149		called = true
150		return &Token{}, nil
151	}
152	spt.SetCustomRefreshFunc(refreshFunc)
153	if called {
154		t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing")
155	}
156
157	spt.refreshInternal(context.Background(), "https://example.com")
158
159	if !called {
160		t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function")
161	}
162}
163
164func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
165	spt := newServicePrincipalToken()
166
167	body := mocks.NewBody(newTokenJSON("12345", "test"))
168	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
169
170	c := mocks.NewSender()
171	s := DecorateSender(c,
172		(func() SendDecorator {
173			return func(s Sender) Sender {
174				return SenderFunc(func(r *http.Request) (*http.Response, error) {
175					if r.Method != "POST" {
176						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
177					}
178					return resp, nil
179				})
180			}
181		})())
182	spt.SetSender(s)
183	err := spt.Refresh()
184	if err != nil {
185		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
186	}
187
188	if body.IsOpen() {
189		t.Fatalf("the response was not closed!")
190	}
191}
192
193func TestServicePrincipalTokenFromMSIRefreshUsesGET(t *testing.T) {
194	resource := "https://resource"
195	cb := func(token Token) error { return nil }
196
197	spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
198	if err != nil {
199		t.Fatalf("Failed to get MSI SPT: %v", err)
200	}
201
202	body := mocks.NewBody(newTokenJSON("12345", "test"))
203	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
204
205	c := mocks.NewSender()
206	s := DecorateSender(c,
207		(func() SendDecorator {
208			return func(s Sender) Sender {
209				return SenderFunc(func(r *http.Request) (*http.Response, error) {
210					if r.Method != "GET" {
211						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
212					}
213					if h := r.Header.Get("Metadata"); h != "true" {
214						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
215					}
216					return resp, nil
217				})
218			}
219		})())
220	spt.SetSender(s)
221	err = spt.Refresh()
222	if err != nil {
223		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
224	}
225
226	if body.IsOpen() {
227		t.Fatalf("the response was not closed!")
228	}
229}
230
231func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
232	ctx, cancel := context.WithCancel(context.Background())
233	endpoint, _ := GetMSIVMEndpoint()
234
235	spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource")
236	if err != nil {
237		t.Fatalf("Failed to get MSI SPT: %v", err)
238	}
239
240	c := mocks.NewSender()
241	c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)
242
243	var wg sync.WaitGroup
244	wg.Add(1)
245	start := time.Now()
246	end := time.Now()
247
248	go func() {
249		spt.SetSender(c)
250		err = spt.RefreshWithContext(ctx)
251		end = time.Now()
252		wg.Done()
253	}()
254
255	cancel()
256	wg.Wait()
257	time.Sleep(5 * time.Millisecond)
258
259	if end.Sub(start) >= time.Second {
260		t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel")
261	}
262}
263
264func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
265	spt := newServicePrincipalToken()
266
267	body := mocks.NewBody(newTokenJSON("12345", "test"))
268	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
269
270	c := mocks.NewSender()
271	s := DecorateSender(c,
272		(func() SendDecorator {
273			return func(s Sender) Sender {
274				return SenderFunc(func(r *http.Request) (*http.Response, error) {
275					if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
276						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
277							"application/x-form-urlencoded",
278							r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
279					}
280					return resp, nil
281				})
282			}
283		})())
284	spt.SetSender(s)
285	err := spt.Refresh()
286	if err != nil {
287		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
288	}
289}
290
291func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
292	spt := newServicePrincipalToken()
293
294	body := mocks.NewBody(newTokenJSON("12345", "test"))
295	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
296
297	c := mocks.NewSender()
298	s := DecorateSender(c,
299		(func() SendDecorator {
300			return func(s Sender) Sender {
301				return SenderFunc(func(r *http.Request) (*http.Response, error) {
302					if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
303						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
304							TestOAuthConfig.TokenEndpoint, r.URL)
305					}
306					return resp, nil
307				})
308			}
309		})())
310	spt.SetSender(s)
311	err := spt.Refresh()
312	if err != nil {
313		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
314	}
315}
316
317func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
318	body := mocks.NewBody(newTokenJSON("12345", "test"))
319	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
320
321	c := mocks.NewSender()
322	s := DecorateSender(c,
323		(func() SendDecorator {
324			return func(s Sender) Sender {
325				return SenderFunc(func(r *http.Request) (*http.Response, error) {
326					b, err := ioutil.ReadAll(r.Body)
327					if err != nil {
328						t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
329					}
330					f(t, b)
331					return resp, nil
332				})
333			}
334		})())
335	spt.SetSender(s)
336	err := spt.Refresh()
337	if err != nil {
338		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
339	}
340}
341
342func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
343	sptManual := newServicePrincipalTokenManual()
344	testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
345		if string(b) != defaultManualFormData {
346			t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
347				defaultManualFormData, string(b))
348		}
349	})
350}
351
352func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
353	sptCert := newServicePrincipalTokenCertificate(t)
354	testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
355		body := string(b)
356
357		values, _ := url.ParseQuery(body)
358		if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
359			values["client_id"][0] != "id" ||
360			values["grant_type"][0] != "client_credentials" ||
361			values["resource"][0] != "resource" {
362			t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
363		}
364
365		tok, _ := jwt.Parse(values["client_assertion"][0], nil)
366		if tok == nil {
367			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
368		}
369		if _, ok := tok.Header["x5t"]; !ok {
370			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5t header")
371		}
372		if _, ok := tok.Header["x5c"]; !ok {
373			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5c header")
374		}
375	})
376}
377
378func TestServicePrincipalTokenUsernamePasswordRefreshSetsBody(t *testing.T) {
379	spt := newServicePrincipalTokenUsernamePassword(t)
380	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
381		body := string(b)
382
383		values, _ := url.ParseQuery(body)
384		if values["client_id"][0] != "id" ||
385			values["grant_type"][0] != "password" ||
386			values["username"][0] != "username" ||
387			values["password"][0] != "password" ||
388			values["resource"][0] != "resource" {
389			t.Fatalf("adal: ServicePrincipalTokenUsernamePassword#Refresh did not correctly set the HTTP Request Body.")
390		}
391	})
392}
393
394func TestServicePrincipalTokenAuthorizationCodeRefreshSetsBody(t *testing.T) {
395	spt := newServicePrincipalTokenAuthorizationCode(t)
396	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
397		body := string(b)
398
399		values, _ := url.ParseQuery(body)
400		if values["client_id"][0] != "id" ||
401			values["grant_type"][0] != OAuthGrantTypeAuthorizationCode ||
402			values["code"][0] != "code" ||
403			values["client_secret"][0] != "clientSecret" ||
404			values["redirect_uri"][0] != "http://redirectUri/getToken" ||
405			values["resource"][0] != "resource" {
406			t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
407		}
408	})
409	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
410		body := string(b)
411
412		values, _ := url.ParseQuery(body)
413		if values["client_id"][0] != "id" ||
414			values["grant_type"][0] != OAuthGrantTypeRefreshToken ||
415			values["code"][0] != "code" ||
416			values["client_secret"][0] != "clientSecret" ||
417			values["redirect_uri"][0] != "http://redirectUri/getToken" ||
418			values["resource"][0] != "resource" {
419			t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
420		}
421	})
422}
423
424func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
425	spt := newServicePrincipalToken()
426	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
427		if string(b) != defaultFormData {
428			t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
429				defaultFormData, string(b))
430		}
431
432	})
433}
434
435func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
436	spt := newServicePrincipalToken()
437
438	body := mocks.NewBody(newTokenJSON("12345", "test"))
439	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
440
441	c := mocks.NewSender()
442	s := DecorateSender(c,
443		(func() SendDecorator {
444			return func(s Sender) Sender {
445				return SenderFunc(func(r *http.Request) (*http.Response, error) {
446					return resp, nil
447				})
448			}
449		})())
450	spt.SetSender(s)
451	err := spt.Refresh()
452	if err != nil {
453		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
454	}
455	if resp.Body.(*mocks.Body).IsOpen() {
456		t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
457	}
458}
459
460func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
461	spt := newServicePrincipalToken()
462
463	body := mocks.NewBody(newTokenJSON("12345", "test"))
464	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
465
466	c := mocks.NewSender()
467	s := DecorateSender(c,
468		(func() SendDecorator {
469			return func(s Sender) Sender {
470				return SenderFunc(func(r *http.Request) (*http.Response, error) {
471					return resp, nil
472				})
473			}
474		})())
475	spt.SetSender(s)
476	err := spt.Refresh()
477	if err == nil {
478		t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
479	}
480}
481
482func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
483	spt := newServicePrincipalToken()
484
485	c := mocks.NewSender()
486	s := DecorateSender(c,
487		(func() SendDecorator {
488			return func(s Sender) Sender {
489				return SenderFunc(func(r *http.Request) (*http.Response, error) {
490					return mocks.NewResponse(), nil
491				})
492			}
493		})())
494	spt.SetSender(s)
495	err := spt.Refresh()
496	if err == nil {
497		t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
498	}
499}
500
501func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
502	spt := newServicePrincipalToken()
503
504	c := mocks.NewSender()
505	c.SetError(fmt.Errorf("Faux Error"))
506	spt.SetSender(c)
507
508	err := spt.Refresh()
509	if err == nil {
510		t.Fatal("adal: Failed to propagate the request error")
511	}
512}
513
514func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
515	spt := newServicePrincipalToken()
516
517	c := mocks.NewSender()
518	c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
519	spt.SetSender(c)
520
521	err := spt.Refresh()
522	if err == nil {
523		t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
524	}
525}
526
527func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
528	spt := newServicePrincipalToken()
529
530	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
531	j := newTokenJSON(expiresOn, "resource")
532	resp := mocks.NewResponseWithContent(j)
533	c := mocks.NewSender()
534	s := DecorateSender(c,
535		(func() SendDecorator {
536			return func(s Sender) Sender {
537				return SenderFunc(func(r *http.Request) (*http.Response, error) {
538					return resp, nil
539				})
540			}
541		})())
542	spt.SetSender(s)
543
544	err := spt.Refresh()
545	if err != nil {
546		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
547	} else if spt.inner.Token.AccessToken != "accessToken" ||
548		spt.inner.Token.ExpiresIn != "3600" ||
549		spt.inner.Token.ExpiresOn != json.Number(expiresOn) ||
550		spt.inner.Token.NotBefore != json.Number(expiresOn) ||
551		spt.inner.Token.Resource != "resource" ||
552		spt.inner.Token.Type != "Bearer" {
553		t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
554			j, *spt)
555	}
556}
557
558func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
559	spt := newServicePrincipalToken()
560	expireToken(&spt.inner.Token)
561
562	body := mocks.NewBody(newTokenJSON("12345", "test"))
563	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
564
565	f := false
566	c := mocks.NewSender()
567	s := DecorateSender(c,
568		(func() SendDecorator {
569			return func(s Sender) Sender {
570				return SenderFunc(func(r *http.Request) (*http.Response, error) {
571					f = true
572					return resp, nil
573				})
574			}
575		})())
576	spt.SetSender(s)
577	err := spt.EnsureFresh()
578	if err != nil {
579		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
580	}
581	if !f {
582		t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
583	}
584}
585
586func TestServicePrincipalTokenEnsureFreshFails1(t *testing.T) {
587	spt := newServicePrincipalToken()
588	expireToken(&spt.inner.Token)
589
590	c := mocks.NewSender()
591	c.SetError(fmt.Errorf("some failure"))
592
593	spt.SetSender(c)
594	err := spt.EnsureFresh()
595	if err == nil {
596		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
597	}
598	if _, ok := err.(TokenRefreshError); ok {
599		t.Fatal("adal: ServicePrincipalToken#EnsureFresh unexpected TokenRefreshError")
600	}
601}
602
603func TestServicePrincipalTokenEnsureFreshFails2(t *testing.T) {
604	spt := newServicePrincipalToken()
605	expireToken(&spt.inner.Token)
606
607	c := mocks.NewSender()
608	c.AppendResponse(mocks.NewResponseWithStatus("bad request", http.StatusBadRequest))
609
610	spt.SetSender(c)
611	err := spt.EnsureFresh()
612	if err == nil {
613		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
614	}
615	if _, ok := err.(TokenRefreshError); !ok {
616		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return a TokenRefreshError")
617	}
618}
619
620func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
621	spt := newServicePrincipalToken()
622	setTokenToExpireIn(&spt.inner.Token, 1000*time.Second)
623
624	f := false
625	c := mocks.NewSender()
626	s := DecorateSender(c,
627		(func() SendDecorator {
628			return func(s Sender) Sender {
629				return SenderFunc(func(r *http.Request) (*http.Response, error) {
630					f = true
631					return mocks.NewResponse(), nil
632				})
633			}
634		})())
635	spt.SetSender(s)
636	err := spt.EnsureFresh()
637	if err != nil {
638		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
639	}
640	if f {
641		t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
642	}
643}
644
645func TestRefreshCallback(t *testing.T) {
646	callbackTriggered := false
647	spt := newServicePrincipalToken(func(Token) error {
648		callbackTriggered = true
649		return nil
650	})
651
652	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
653
654	sender := mocks.NewSender()
655	j := newTokenJSON(expiresOn, "resource")
656	sender.AppendResponse(mocks.NewResponseWithContent(j))
657	spt.SetSender(sender)
658	err := spt.Refresh()
659	if err != nil {
660		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
661	}
662	if !callbackTriggered {
663		t.Fatalf("adal: RefreshCallback failed to trigger call callback")
664	}
665}
666
667func TestRefreshCallbackErrorPropagates(t *testing.T) {
668	errorText := "this is an error text"
669	spt := newServicePrincipalToken(func(Token) error {
670		return fmt.Errorf(errorText)
671	})
672
673	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
674
675	sender := mocks.NewSender()
676	j := newTokenJSON(expiresOn, "resource")
677	sender.AppendResponse(mocks.NewResponseWithContent(j))
678	spt.SetSender(sender)
679	err := spt.Refresh()
680
681	if err == nil || !strings.Contains(err.Error(), errorText) {
682		t.Fatalf("adal: RefreshCallback failed to propagate error")
683	}
684}
685
686// This demonstrates the danger of manual token without a refresh token
687func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
688	spt := newServicePrincipalTokenManual()
689	spt.inner.Token.RefreshToken = ""
690	err := spt.Refresh()
691	if err == nil {
692		t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
693	}
694}
695
696func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
697	resource := "https://resource"
698	cb := func(token Token) error { return nil }
699
700	spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
701	if err != nil {
702		t.Fatalf("Failed to get MSI SPT: %v", err)
703	}
704
705	// check some of the SPT fields
706	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
707		t.Fatal("SPT secret was not of MSI type")
708	}
709
710	if spt.inner.Resource != resource {
711		t.Fatal("SPT came back with incorrect resource")
712	}
713
714	if len(spt.refreshCallbacks) != 1 {
715		t.Fatal("SPT had incorrect refresh callbacks.")
716	}
717}
718
719func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
720	resource := "https://resource"
721	userID := "abc123"
722	cb := func(token Token) error { return nil }
723
724	spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
725	if err != nil {
726		t.Fatalf("Failed to get MSI SPT: %v", err)
727	}
728
729	// check some of the SPT fields
730	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
731		t.Fatal("SPT secret was not of MSI type")
732	}
733
734	if spt.inner.Resource != resource {
735		t.Fatal("SPT came back with incorrect resource")
736	}
737
738	if len(spt.refreshCallbacks) != 1 {
739		t.Fatal("SPT had incorrect refresh callbacks.")
740	}
741
742	if spt.inner.ClientID != userID {
743		t.Fatal("SPT had incorrect client ID")
744	}
745}
746
747func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
748	token := newToken()
749	secret := &ServicePrincipalAuthorizationCodeSecret{
750		ClientSecret:      "clientSecret",
751		AuthorizationCode: "code123",
752		RedirectURI:       "redirect",
753	}
754
755	spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil)
756	if err != nil {
757		t.Fatalf("Failed creating new SPT: %s", err)
758	}
759
760	if !reflect.DeepEqual(token, spt.inner.Token) {
761		t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token)
762	}
763
764	if !reflect.DeepEqual(secret, spt.inner.Secret) {
765		t.Fatalf("Secrets do not match: %s, %s", secret, spt.inner.Secret)
766	}
767
768}
769
770func TestGetVMEndpoint(t *testing.T) {
771	endpoint, err := GetMSIVMEndpoint()
772	if err != nil {
773		t.Fatal("Coudn't get VM endpoint")
774	}
775
776	if endpoint != msiEndpoint {
777		t.Fatal("Didn't get correct endpoint")
778	}
779}
780
781func TestGetAppServiceEndpoint(t *testing.T) {
782	const testEndpoint = "http://172.16.1.2:8081/msi/token"
783	if err := os.Setenv(asMSIEndpointEnv, testEndpoint); err != nil {
784		t.Fatalf("os.Setenv: %v", err)
785	}
786
787	endpoint, err := GetMSIAppServiceEndpoint()
788	if err != nil {
789		t.Fatal("Coudn't get App Service endpoint")
790	}
791
792	if endpoint != testEndpoint {
793		t.Fatal("Didn't get correct endpoint")
794	}
795
796	if err := os.Unsetenv(asMSIEndpointEnv); err != nil {
797		t.Fatalf("os.Unsetenv: %v", err)
798	}
799}
800
801func TestGetMSIEndpoint(t *testing.T) {
802	const (
803		testEndpoint = "http://172.16.1.2:8081/msi/token"
804		testSecret   = "DEADBEEF-BBBB-AAAA-DDDD-DDD000000DDD"
805	)
806
807	// Test VM well-known endpoint is returned
808	if err := os.Unsetenv(asMSIEndpointEnv); err != nil {
809		t.Fatalf("os.Unsetenv: %v", err)
810	}
811
812	if err := os.Unsetenv(asMSISecretEnv); err != nil {
813		t.Fatalf("os.Unsetenv: %v", err)
814	}
815
816	vmEndpoint, err := GetMSIEndpoint()
817	if err != nil {
818		t.Fatal("Coudn't get VM endpoint")
819	}
820
821	if vmEndpoint != msiEndpoint {
822		t.Fatal("Didn't get correct endpoint")
823	}
824
825	// Test App Service endpoint is returned
826	if err := os.Setenv(asMSIEndpointEnv, testEndpoint); err != nil {
827		t.Fatalf("os.Setenv: %v", err)
828	}
829
830	if err := os.Setenv(asMSISecretEnv, testSecret); err != nil {
831		t.Fatalf("os.Setenv: %v", err)
832	}
833
834	asEndpoint, err := GetMSIEndpoint()
835	if err != nil {
836		t.Fatal("Coudn't get App Service endpoint")
837	}
838
839	if asEndpoint != testEndpoint {
840		t.Fatal("Didn't get correct endpoint")
841	}
842
843	if err := os.Unsetenv(asMSIEndpointEnv); err != nil {
844		t.Fatalf("os.Unsetenv: %v", err)
845	}
846
847	if err := os.Unsetenv(asMSISecretEnv); err != nil {
848		t.Fatalf("os.Unsetenv: %v", err)
849	}
850}
851
852func TestMarshalServicePrincipalNoSecret(t *testing.T) {
853	spt := newServicePrincipalTokenManual()
854	b, err := json.Marshal(spt)
855	if err != nil {
856		t.Fatalf("failed to marshal token: %+v", err)
857	}
858	var spt2 *ServicePrincipalToken
859	err = json.Unmarshal(b, &spt2)
860	if err != nil {
861		t.Fatalf("failed to unmarshal token: %+v", err)
862	}
863	if !reflect.DeepEqual(spt, spt2) {
864		t.Fatal("tokens don't match")
865	}
866}
867
868func TestMarshalServicePrincipalTokenSecret(t *testing.T) {
869	spt := newServicePrincipalToken()
870	b, err := json.Marshal(spt)
871	if err != nil {
872		t.Fatalf("failed to marshal token: %+v", err)
873	}
874	var spt2 *ServicePrincipalToken
875	err = json.Unmarshal(b, &spt2)
876	if err != nil {
877		t.Fatalf("failed to unmarshal token: %+v", err)
878	}
879	if !reflect.DeepEqual(spt, spt2) {
880		t.Fatal("tokens don't match")
881	}
882}
883
884func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
885	spt := newServicePrincipalTokenCertificate(t)
886	b, err := json.Marshal(spt)
887	if err == nil {
888		t.Fatal("expected error when marshalling certificate token")
889	}
890	var spt2 *ServicePrincipalToken
891	err = json.Unmarshal(b, &spt2)
892	if err == nil {
893		t.Fatal("expected error when unmarshalling certificate token")
894	}
895}
896
897func TestMarshalServicePrincipalMSISecret(t *testing.T) {
898	spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil)
899	if err != nil {
900		t.Fatalf("failed to get MSI SPT: %+v", err)
901	}
902	b, err := json.Marshal(spt)
903	if err == nil {
904		t.Fatal("expected error when marshalling MSI token")
905	}
906	var spt2 *ServicePrincipalToken
907	err = json.Unmarshal(b, &spt2)
908	if err == nil {
909		t.Fatal("expected error when unmarshalling MSI token")
910	}
911}
912
913func TestMarshalServicePrincipalUsernamePasswordSecret(t *testing.T) {
914	spt := newServicePrincipalTokenUsernamePassword(t)
915	b, err := json.Marshal(spt)
916	if err != nil {
917		t.Fatalf("failed to marshal token: %+v", err)
918	}
919	var spt2 *ServicePrincipalToken
920	err = json.Unmarshal(b, &spt2)
921	if err != nil {
922		t.Fatalf("failed to unmarshal token: %+v", err)
923	}
924	if !reflect.DeepEqual(spt, spt2) {
925		t.Fatal("tokens don't match")
926	}
927}
928
929func TestMarshalServicePrincipalAuthorizationCodeSecret(t *testing.T) {
930	spt := newServicePrincipalTokenAuthorizationCode(t)
931	b, err := json.Marshal(spt)
932	if err != nil {
933		t.Fatalf("failed to marshal token: %+v", err)
934	}
935	var spt2 *ServicePrincipalToken
936	err = json.Unmarshal(b, &spt2)
937	if err != nil {
938		t.Fatalf("failed to unmarshal token: %+v", err)
939	}
940	if !reflect.DeepEqual(spt, spt2) {
941		t.Fatal("tokens don't match")
942	}
943}
944
945func TestMarshalInnerToken(t *testing.T) {
946	spt := newServicePrincipalTokenManual()
947	tokenJSON, err := spt.MarshalTokenJSON()
948	if err != nil {
949		t.Fatalf("failed to marshal token: %+v", err)
950	}
951
952	testToken := newToken()
953	testToken.RefreshToken = "refreshtoken"
954
955	testTokenJSON, err := json.Marshal(testToken)
956	if err != nil {
957		t.Fatalf("failed to marshal test token: %+v", err)
958	}
959
960	if !reflect.DeepEqual(tokenJSON, testTokenJSON) {
961		t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON)
962	}
963
964	var t1 Token
965	err = json.Unmarshal(tokenJSON, &t1)
966	if err != nil {
967		t.Fatalf("failed to unmarshal token: %+v", err)
968	}
969
970	if !reflect.DeepEqual(t1, testToken) {
971		t.Fatalf("tokens don't match: %s, %s", t1, testToken)
972	}
973}
974
975func TestNewMultiTenantServicePrincipalToken(t *testing.T) {
976	cfg, err := NewMultiTenantOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID, TestAuxTenantIDs, OAuthOptions{})
977	if err != nil {
978		t.Fatalf("autorest/adal: unexpected error while creating multitenant config: %v", err)
979	}
980	mt, err := NewMultiTenantServicePrincipalToken(cfg, "clientID", "superSecret", "resource")
981	if err != nil {
982		t.Fatalf("autorest/adal: unexpected error while creating multitenant service principal token: %v", err)
983	}
984	if !strings.Contains(mt.PrimaryToken.inner.OauthConfig.AuthorizeEndpoint.String(), TestTenantID) {
985		t.Fatal("didn't find primary tenant ID in primary SPT")
986	}
987	for i := range mt.AuxiliaryTokens {
988		if ep := mt.AuxiliaryTokens[i].inner.OauthConfig.AuthorizeEndpoint.String(); !strings.Contains(ep, fmt.Sprintf("%s%d", TestAuxTenantPrefix, i)) {
989			t.Fatalf("didn't find auxiliary tenant ID in token %s", ep)
990		}
991	}
992}
993
994func newTokenJSON(expiresOn string, resource string) string {
995	return fmt.Sprintf(`{
996		"access_token" : "accessToken",
997		"expires_in"   : "3600",
998		"expires_on"   : "%s",
999		"not_before"   : "%s",
1000		"resource"     : "%s",
1001		"token_type"   : "Bearer",
1002		"refresh_token": "ABC123"
1003		}`,
1004		expiresOn, expiresOn, resource)
1005}
1006
1007func newTokenExpiresIn(expireIn time.Duration) *Token {
1008	t := newToken()
1009	return setTokenToExpireIn(&t, expireIn)
1010}
1011
1012func newTokenExpiresAt(expireAt time.Time) *Token {
1013	t := newToken()
1014	return setTokenToExpireAt(&t, expireAt)
1015}
1016
1017func expireToken(t *Token) *Token {
1018	return setTokenToExpireIn(t, 0)
1019}
1020
1021func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
1022	t.ExpiresIn = "3600"
1023	t.ExpiresOn = json.Number(strconv.FormatInt(int64(expireAt.Sub(date.UnixEpoch())/time.Second), 10))
1024	t.NotBefore = t.ExpiresOn
1025	return t
1026}
1027
1028func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
1029	return setTokenToExpireAt(t, time.Now().Add(expireIn))
1030}
1031
1032func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
1033	spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
1034	return spt
1035}
1036
1037func newServicePrincipalTokenManual() *ServicePrincipalToken {
1038	token := newToken()
1039	token.RefreshToken = "refreshtoken"
1040	spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token)
1041	return spt
1042}
1043
1044func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
1045	template := x509.Certificate{
1046		SerialNumber:          big.NewInt(0),
1047		Subject:               pkix.Name{CommonName: "test"},
1048		BasicConstraintsValid: true,
1049	}
1050	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
1051	if err != nil {
1052		t.Fatal(err)
1053	}
1054	certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
1055	if err != nil {
1056		t.Fatal(err)
1057	}
1058	certificate, err := x509.ParseCertificate(certificateBytes)
1059	if err != nil {
1060		t.Fatal(err)
1061	}
1062
1063	spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
1064	return spt
1065}
1066
1067func newServicePrincipalTokenUsernamePassword(t *testing.T) *ServicePrincipalToken {
1068	spt, _ := NewServicePrincipalTokenFromUsernamePassword(TestOAuthConfig, "id", "username", "password", "resource")
1069	return spt
1070}
1071
1072func newServicePrincipalTokenAuthorizationCode(t *testing.T) *ServicePrincipalToken {
1073	spt, _ := NewServicePrincipalTokenFromAuthorizationCode(TestOAuthConfig, "id", "clientSecret", "code", "http://redirectUri/getToken", "resource")
1074	return spt
1075}
1076