1package adal
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/rand"
20	"crypto/rsa"
21	"crypto/x509"
22	"crypto/x509/pkix"
23	"encoding/json"
24	"fmt"
25	"io/ioutil"
26	"math/big"
27	"net/http"
28	"net/url"
29	"reflect"
30	"strconv"
31	"strings"
32	"sync"
33	"testing"
34	"time"
35
36	"github.com/Azure/go-autorest/autorest/date"
37	"github.com/Azure/go-autorest/autorest/mocks"
38	jwt "github.com/dgrijalva/jwt-go"
39)
40
41const (
42	defaultFormData       = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
43	defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
44)
45
46func TestTokenExpires(t *testing.T) {
47	tt := time.Now().Add(5 * time.Second)
48	tk := newTokenExpiresAt(tt)
49
50	if tk.Expires().Equal(tt) {
51		t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
52	}
53}
54
55func TestTokenIsExpired(t *testing.T) {
56	tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
57
58	if !tk.IsExpired() {
59		t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
60			time.Now().UTC(), tk.Expires())
61	}
62}
63
64func TestTokenIsExpiredUninitialized(t *testing.T) {
65	tk := &Token{}
66
67	if !tk.IsExpired() {
68		t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
69	}
70}
71
72func TestTokenIsNoExpired(t *testing.T) {
73	tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
74
75	if tk.IsExpired() {
76		t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
77	}
78}
79
80func TestTokenWillExpireIn(t *testing.T) {
81	d := 5 * time.Second
82	tk := newTokenExpiresIn(d)
83
84	if !tk.WillExpireIn(d) {
85		t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
86	}
87}
88
89func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
90	spt := newServicePrincipalToken()
91
92	if !spt.inner.AutoRefresh {
93		t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
94	}
95
96	spt.SetAutoRefresh(false)
97	if spt.inner.AutoRefresh {
98		t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
99	}
100}
101
102func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
103	spt := newServicePrincipalToken()
104
105	if spt.inner.RefreshWithin != defaultRefresh {
106		t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
107	}
108
109	spt.SetRefreshWithin(2 * defaultRefresh)
110	if spt.inner.RefreshWithin != 2*defaultRefresh {
111		t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
112	}
113}
114
115func TestServicePrincipalTokenSetSender(t *testing.T) {
116	spt := newServicePrincipalToken()
117
118	c := &http.Client{}
119	spt.SetSender(c)
120	if !reflect.DeepEqual(c, spt.sender) {
121		t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
122	}
123}
124
125func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
126	spt := newServicePrincipalToken()
127
128	body := mocks.NewBody(newTokenJSON("test", "test"))
129	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
130
131	c := mocks.NewSender()
132	s := DecorateSender(c,
133		(func() SendDecorator {
134			return func(s Sender) Sender {
135				return SenderFunc(func(r *http.Request) (*http.Response, error) {
136					if r.Method != "POST" {
137						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
138					}
139					return resp, nil
140				})
141			}
142		})())
143	spt.SetSender(s)
144	err := spt.Refresh()
145	if err != nil {
146		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
147	}
148
149	if body.IsOpen() {
150		t.Fatalf("the response was not closed!")
151	}
152}
153
154func TestServicePrincipalTokenFromMSIRefreshUsesGET(t *testing.T) {
155	resource := "https://resource"
156	cb := func(token Token) error { return nil }
157
158	spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
159	if err != nil {
160		t.Fatalf("Failed to get MSI SPT: %v", err)
161	}
162
163	body := mocks.NewBody(newTokenJSON("test", "test"))
164	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
165
166	c := mocks.NewSender()
167	s := DecorateSender(c,
168		(func() SendDecorator {
169			return func(s Sender) Sender {
170				return SenderFunc(func(r *http.Request) (*http.Response, error) {
171					if r.Method != "GET" {
172						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
173					}
174					if h := r.Header.Get("Metadata"); h != "true" {
175						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
176					}
177					return resp, nil
178				})
179			}
180		})())
181	spt.SetSender(s)
182	err = spt.Refresh()
183	if err != nil {
184		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
185	}
186
187	if body.IsOpen() {
188		t.Fatalf("the response was not closed!")
189	}
190}
191
192func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
193	ctx, cancel := context.WithCancel(context.Background())
194	endpoint, _ := GetMSIVMEndpoint()
195
196	spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource")
197	if err != nil {
198		t.Fatalf("Failed to get MSI SPT: %v", err)
199	}
200
201	c := mocks.NewSender()
202	c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)
203
204	var wg sync.WaitGroup
205	wg.Add(1)
206	start := time.Now()
207	end := time.Now()
208
209	go func() {
210		spt.SetSender(c)
211		err = spt.RefreshWithContext(ctx)
212		end = time.Now()
213		wg.Done()
214	}()
215
216	cancel()
217	wg.Wait()
218	time.Sleep(5 * time.Millisecond)
219
220	if end.Sub(start) >= time.Second {
221		t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel")
222	}
223}
224
225func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
226	spt := newServicePrincipalToken()
227
228	body := mocks.NewBody(newTokenJSON("test", "test"))
229	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
230
231	c := mocks.NewSender()
232	s := DecorateSender(c,
233		(func() SendDecorator {
234			return func(s Sender) Sender {
235				return SenderFunc(func(r *http.Request) (*http.Response, error) {
236					if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
237						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
238							"application/x-form-urlencoded",
239							r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
240					}
241					return resp, nil
242				})
243			}
244		})())
245	spt.SetSender(s)
246	err := spt.Refresh()
247	if err != nil {
248		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
249	}
250}
251
252func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
253	spt := newServicePrincipalToken()
254
255	body := mocks.NewBody(newTokenJSON("test", "test"))
256	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
257
258	c := mocks.NewSender()
259	s := DecorateSender(c,
260		(func() SendDecorator {
261			return func(s Sender) Sender {
262				return SenderFunc(func(r *http.Request) (*http.Response, error) {
263					if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
264						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
265							TestOAuthConfig.TokenEndpoint, r.URL)
266					}
267					return resp, nil
268				})
269			}
270		})())
271	spt.SetSender(s)
272	err := spt.Refresh()
273	if err != nil {
274		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
275	}
276}
277
278func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
279	body := mocks.NewBody(newTokenJSON("test", "test"))
280	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
281
282	c := mocks.NewSender()
283	s := DecorateSender(c,
284		(func() SendDecorator {
285			return func(s Sender) Sender {
286				return SenderFunc(func(r *http.Request) (*http.Response, error) {
287					b, err := ioutil.ReadAll(r.Body)
288					if err != nil {
289						t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
290					}
291					f(t, b)
292					return resp, nil
293				})
294			}
295		})())
296	spt.SetSender(s)
297	err := spt.Refresh()
298	if err != nil {
299		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
300	}
301}
302
303func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
304	sptManual := newServicePrincipalTokenManual()
305	testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
306		if string(b) != defaultManualFormData {
307			t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
308				defaultManualFormData, string(b))
309		}
310	})
311}
312
313func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
314	sptCert := newServicePrincipalTokenCertificate(t)
315	testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
316		body := string(b)
317
318		values, _ := url.ParseQuery(body)
319		if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
320			values["client_id"][0] != "id" ||
321			values["grant_type"][0] != "client_credentials" ||
322			values["resource"][0] != "resource" {
323			t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
324		}
325
326		tok, _ := jwt.Parse(values["client_assertion"][0], nil)
327		if tok == nil {
328			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
329		}
330		if _, ok := tok.Header["x5t"]; !ok {
331			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5t header")
332		}
333		if _, ok := tok.Header["x5c"]; !ok {
334			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5c header")
335		}
336	})
337}
338
339func TestServicePrincipalTokenUsernamePasswordRefreshSetsBody(t *testing.T) {
340	spt := newServicePrincipalTokenUsernamePassword(t)
341	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
342		body := string(b)
343
344		values, _ := url.ParseQuery(body)
345		if values["client_id"][0] != "id" ||
346			values["grant_type"][0] != "password" ||
347			values["username"][0] != "username" ||
348			values["password"][0] != "password" ||
349			values["resource"][0] != "resource" {
350			t.Fatalf("adal: ServicePrincipalTokenUsernamePassword#Refresh did not correctly set the HTTP Request Body.")
351		}
352	})
353}
354
355func TestServicePrincipalTokenAuthorizationCodeRefreshSetsBody(t *testing.T) {
356	spt := newServicePrincipalTokenAuthorizationCode(t)
357	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
358		body := string(b)
359
360		values, _ := url.ParseQuery(body)
361		if values["client_id"][0] != "id" ||
362			values["grant_type"][0] != OAuthGrantTypeAuthorizationCode ||
363			values["code"][0] != "code" ||
364			values["client_secret"][0] != "clientSecret" ||
365			values["redirect_uri"][0] != "http://redirectUri/getToken" ||
366			values["resource"][0] != "resource" {
367			t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
368		}
369	})
370	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
371		body := string(b)
372
373		values, _ := url.ParseQuery(body)
374		if values["client_id"][0] != "id" ||
375			values["grant_type"][0] != OAuthGrantTypeRefreshToken ||
376			values["code"][0] != "code" ||
377			values["client_secret"][0] != "clientSecret" ||
378			values["redirect_uri"][0] != "http://redirectUri/getToken" ||
379			values["resource"][0] != "resource" {
380			t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
381		}
382	})
383}
384
385func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
386	spt := newServicePrincipalToken()
387	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
388		if string(b) != defaultFormData {
389			t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
390				defaultFormData, string(b))
391		}
392
393	})
394}
395
396func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
397	spt := newServicePrincipalToken()
398
399	body := mocks.NewBody(newTokenJSON("test", "test"))
400	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
401
402	c := mocks.NewSender()
403	s := DecorateSender(c,
404		(func() SendDecorator {
405			return func(s Sender) Sender {
406				return SenderFunc(func(r *http.Request) (*http.Response, error) {
407					return resp, nil
408				})
409			}
410		})())
411	spt.SetSender(s)
412	err := spt.Refresh()
413	if err != nil {
414		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
415	}
416	if resp.Body.(*mocks.Body).IsOpen() {
417		t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
418	}
419}
420
421func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
422	spt := newServicePrincipalToken()
423
424	body := mocks.NewBody(newTokenJSON("test", "test"))
425	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
426
427	c := mocks.NewSender()
428	s := DecorateSender(c,
429		(func() SendDecorator {
430			return func(s Sender) Sender {
431				return SenderFunc(func(r *http.Request) (*http.Response, error) {
432					return resp, nil
433				})
434			}
435		})())
436	spt.SetSender(s)
437	err := spt.Refresh()
438	if err == nil {
439		t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
440	}
441}
442
443func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
444	spt := newServicePrincipalToken()
445
446	c := mocks.NewSender()
447	s := DecorateSender(c,
448		(func() SendDecorator {
449			return func(s Sender) Sender {
450				return SenderFunc(func(r *http.Request) (*http.Response, error) {
451					return mocks.NewResponse(), nil
452				})
453			}
454		})())
455	spt.SetSender(s)
456	err := spt.Refresh()
457	if err == nil {
458		t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
459	}
460}
461
462func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
463	spt := newServicePrincipalToken()
464
465	c := mocks.NewSender()
466	c.SetError(fmt.Errorf("Faux Error"))
467	spt.SetSender(c)
468
469	err := spt.Refresh()
470	if err == nil {
471		t.Fatal("adal: Failed to propagate the request error")
472	}
473}
474
475func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
476	spt := newServicePrincipalToken()
477
478	c := mocks.NewSender()
479	c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
480	spt.SetSender(c)
481
482	err := spt.Refresh()
483	if err == nil {
484		t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
485	}
486}
487
488func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
489	spt := newServicePrincipalToken()
490
491	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
492	j := newTokenJSON(expiresOn, "resource")
493	resp := mocks.NewResponseWithContent(j)
494	c := mocks.NewSender()
495	s := DecorateSender(c,
496		(func() SendDecorator {
497			return func(s Sender) Sender {
498				return SenderFunc(func(r *http.Request) (*http.Response, error) {
499					return resp, nil
500				})
501			}
502		})())
503	spt.SetSender(s)
504
505	err := spt.Refresh()
506	if err != nil {
507		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
508	} else if spt.inner.Token.AccessToken != "accessToken" ||
509		spt.inner.Token.ExpiresIn != "3600" ||
510		spt.inner.Token.ExpiresOn != json.Number(expiresOn) ||
511		spt.inner.Token.NotBefore != json.Number(expiresOn) ||
512		spt.inner.Token.Resource != "resource" ||
513		spt.inner.Token.Type != "Bearer" {
514		t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
515			j, *spt)
516	}
517}
518
519func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
520	spt := newServicePrincipalToken()
521	expireToken(&spt.inner.Token)
522
523	body := mocks.NewBody(newTokenJSON("test", "test"))
524	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
525
526	f := false
527	c := mocks.NewSender()
528	s := DecorateSender(c,
529		(func() SendDecorator {
530			return func(s Sender) Sender {
531				return SenderFunc(func(r *http.Request) (*http.Response, error) {
532					f = true
533					return resp, nil
534				})
535			}
536		})())
537	spt.SetSender(s)
538	err := spt.EnsureFresh()
539	if err != nil {
540		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
541	}
542	if !f {
543		t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
544	}
545}
546
547func TestServicePrincipalTokenEnsureFreshFails(t *testing.T) {
548	spt := newServicePrincipalToken()
549	expireToken(&spt.inner.Token)
550
551	c := mocks.NewSender()
552	c.SetError(fmt.Errorf("some failure"))
553
554	spt.SetSender(c)
555	err := spt.EnsureFresh()
556	if err == nil {
557		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
558	}
559	if _, ok := err.(TokenRefreshError); !ok {
560		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return a TokenRefreshError")
561	}
562}
563
564func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
565	spt := newServicePrincipalToken()
566	setTokenToExpireIn(&spt.inner.Token, 1000*time.Second)
567
568	f := false
569	c := mocks.NewSender()
570	s := DecorateSender(c,
571		(func() SendDecorator {
572			return func(s Sender) Sender {
573				return SenderFunc(func(r *http.Request) (*http.Response, error) {
574					f = true
575					return mocks.NewResponse(), nil
576				})
577			}
578		})())
579	spt.SetSender(s)
580	err := spt.EnsureFresh()
581	if err != nil {
582		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
583	}
584	if f {
585		t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
586	}
587}
588
589func TestRefreshCallback(t *testing.T) {
590	callbackTriggered := false
591	spt := newServicePrincipalToken(func(Token) error {
592		callbackTriggered = true
593		return nil
594	})
595
596	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
597
598	sender := mocks.NewSender()
599	j := newTokenJSON(expiresOn, "resource")
600	sender.AppendResponse(mocks.NewResponseWithContent(j))
601	spt.SetSender(sender)
602	err := spt.Refresh()
603	if err != nil {
604		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
605	}
606	if !callbackTriggered {
607		t.Fatalf("adal: RefreshCallback failed to trigger call callback")
608	}
609}
610
611func TestRefreshCallbackErrorPropagates(t *testing.T) {
612	errorText := "this is an error text"
613	spt := newServicePrincipalToken(func(Token) error {
614		return fmt.Errorf(errorText)
615	})
616
617	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
618
619	sender := mocks.NewSender()
620	j := newTokenJSON(expiresOn, "resource")
621	sender.AppendResponse(mocks.NewResponseWithContent(j))
622	spt.SetSender(sender)
623	err := spt.Refresh()
624
625	if err == nil || !strings.Contains(err.Error(), errorText) {
626		t.Fatalf("adal: RefreshCallback failed to propagate error")
627	}
628}
629
630// This demonstrates the danger of manual token without a refresh token
631func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
632	spt := newServicePrincipalTokenManual()
633	spt.inner.Token.RefreshToken = ""
634	err := spt.Refresh()
635	if err == nil {
636		t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
637	}
638}
639
640func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
641	resource := "https://resource"
642	cb := func(token Token) error { return nil }
643
644	spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
645	if err != nil {
646		t.Fatalf("Failed to get MSI SPT: %v", err)
647	}
648
649	// check some of the SPT fields
650	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
651		t.Fatal("SPT secret was not of MSI type")
652	}
653
654	if spt.inner.Resource != resource {
655		t.Fatal("SPT came back with incorrect resource")
656	}
657
658	if len(spt.refreshCallbacks) != 1 {
659		t.Fatal("SPT had incorrect refresh callbacks.")
660	}
661}
662
663func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
664	resource := "https://resource"
665	userID := "abc123"
666	cb := func(token Token) error { return nil }
667
668	spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
669	if err != nil {
670		t.Fatalf("Failed to get MSI SPT: %v", err)
671	}
672
673	// check some of the SPT fields
674	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
675		t.Fatal("SPT secret was not of MSI type")
676	}
677
678	if spt.inner.Resource != resource {
679		t.Fatal("SPT came back with incorrect resource")
680	}
681
682	if len(spt.refreshCallbacks) != 1 {
683		t.Fatal("SPT had incorrect refresh callbacks.")
684	}
685
686	if spt.inner.ClientID != userID {
687		t.Fatal("SPT had incorrect client ID")
688	}
689}
690
691func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
692	token := newToken()
693	secret := &ServicePrincipalAuthorizationCodeSecret{
694		ClientSecret:      "clientSecret",
695		AuthorizationCode: "code123",
696		RedirectURI:       "redirect",
697	}
698
699	spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil)
700	if err != nil {
701		t.Fatalf("Failed creating new SPT: %s", err)
702	}
703
704	if !reflect.DeepEqual(token, spt.inner.Token) {
705		t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token)
706	}
707
708	if !reflect.DeepEqual(secret, spt.inner.Secret) {
709		t.Fatalf("Secrets do not match: %s, %s", secret, spt.inner.Secret)
710	}
711
712}
713
714func TestGetVMEndpoint(t *testing.T) {
715	endpoint, err := GetMSIVMEndpoint()
716	if err != nil {
717		t.Fatal("Coudn't get VM endpoint")
718	}
719
720	if endpoint != msiEndpoint {
721		t.Fatal("Didn't get correct endpoint")
722	}
723}
724
725func TestMarshalServicePrincipalNoSecret(t *testing.T) {
726	spt := newServicePrincipalTokenManual()
727	b, err := json.Marshal(spt)
728	if err != nil {
729		t.Fatalf("failed to marshal token: %+v", err)
730	}
731	var spt2 *ServicePrincipalToken
732	err = json.Unmarshal(b, &spt2)
733	if err != nil {
734		t.Fatalf("failed to unmarshal token: %+v", err)
735	}
736	if !reflect.DeepEqual(spt, spt2) {
737		t.Fatal("tokens don't match")
738	}
739}
740
741func TestMarshalServicePrincipalTokenSecret(t *testing.T) {
742	spt := newServicePrincipalToken()
743	b, err := json.Marshal(spt)
744	if err != nil {
745		t.Fatalf("failed to marshal token: %+v", err)
746	}
747	var spt2 *ServicePrincipalToken
748	err = json.Unmarshal(b, &spt2)
749	if err != nil {
750		t.Fatalf("failed to unmarshal token: %+v", err)
751	}
752	if !reflect.DeepEqual(spt, spt2) {
753		t.Fatal("tokens don't match")
754	}
755}
756
757func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
758	spt := newServicePrincipalTokenCertificate(t)
759	b, err := json.Marshal(spt)
760	if err == nil {
761		t.Fatal("expected error when marshalling certificate token")
762	}
763	var spt2 *ServicePrincipalToken
764	err = json.Unmarshal(b, &spt2)
765	if err == nil {
766		t.Fatal("expected error when unmarshalling certificate token")
767	}
768}
769
770func TestMarshalServicePrincipalMSISecret(t *testing.T) {
771	spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil)
772	if err != nil {
773		t.Fatalf("failed to get MSI SPT: %+v", err)
774	}
775	b, err := json.Marshal(spt)
776	if err == nil {
777		t.Fatal("expected error when marshalling MSI token")
778	}
779	var spt2 *ServicePrincipalToken
780	err = json.Unmarshal(b, &spt2)
781	if err == nil {
782		t.Fatal("expected error when unmarshalling MSI token")
783	}
784}
785
786func TestMarshalServicePrincipalUsernamePasswordSecret(t *testing.T) {
787	spt := newServicePrincipalTokenUsernamePassword(t)
788	b, err := json.Marshal(spt)
789	if err != nil {
790		t.Fatalf("failed to marshal token: %+v", err)
791	}
792	var spt2 *ServicePrincipalToken
793	err = json.Unmarshal(b, &spt2)
794	if err != nil {
795		t.Fatalf("failed to unmarshal token: %+v", err)
796	}
797	if !reflect.DeepEqual(spt, spt2) {
798		t.Fatal("tokens don't match")
799	}
800}
801
802func TestMarshalServicePrincipalAuthorizationCodeSecret(t *testing.T) {
803	spt := newServicePrincipalTokenAuthorizationCode(t)
804	b, err := json.Marshal(spt)
805	if err != nil {
806		t.Fatalf("failed to marshal token: %+v", err)
807	}
808	var spt2 *ServicePrincipalToken
809	err = json.Unmarshal(b, &spt2)
810	if err != nil {
811		t.Fatalf("failed to unmarshal token: %+v", err)
812	}
813	if !reflect.DeepEqual(spt, spt2) {
814		t.Fatal("tokens don't match")
815	}
816}
817
818func TestMarshalInnerToken(t *testing.T) {
819	spt := newServicePrincipalTokenManual()
820	tokenJSON, err := spt.MarshalTokenJSON()
821	if err != nil {
822		t.Fatalf("failed to marshal token: %+v", err)
823	}
824
825	testToken := newToken()
826	testToken.RefreshToken = "refreshtoken"
827
828	testTokenJSON, err := json.Marshal(testToken)
829	if err != nil {
830		t.Fatalf("failed to marshal test token: %+v", err)
831	}
832
833	if !reflect.DeepEqual(tokenJSON, testTokenJSON) {
834		t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON)
835	}
836
837	var t1 Token
838	err = json.Unmarshal(tokenJSON, &t1)
839	if err != nil {
840		t.Fatalf("failed to unmarshal token: %+v", err)
841	}
842
843	if !reflect.DeepEqual(t1, testToken) {
844		t.Fatalf("tokens don't match: %s, %s", t1, testToken)
845	}
846}
847
848func newTokenJSON(expiresOn string, resource string) string {
849	return fmt.Sprintf(`{
850		"access_token" : "accessToken",
851		"expires_in"   : "3600",
852		"expires_on"   : "%s",
853		"not_before"   : "%s",
854		"resource"     : "%s",
855		"token_type"   : "Bearer",
856		"refresh_token": "ABC123"
857		}`,
858		expiresOn, expiresOn, resource)
859}
860
861func newTokenExpiresIn(expireIn time.Duration) *Token {
862	t := newToken()
863	return setTokenToExpireIn(&t, expireIn)
864}
865
866func newTokenExpiresAt(expireAt time.Time) *Token {
867	t := newToken()
868	return setTokenToExpireAt(&t, expireAt)
869}
870
871func expireToken(t *Token) *Token {
872	return setTokenToExpireIn(t, 0)
873}
874
875func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
876	t.ExpiresIn = "3600"
877	t.ExpiresOn = json.Number(strconv.FormatInt(int64(expireAt.Sub(date.UnixEpoch())/time.Second), 10))
878	t.NotBefore = t.ExpiresOn
879	return t
880}
881
882func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
883	return setTokenToExpireAt(t, time.Now().Add(expireIn))
884}
885
886func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
887	spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
888	return spt
889}
890
891func newServicePrincipalTokenManual() *ServicePrincipalToken {
892	token := newToken()
893	token.RefreshToken = "refreshtoken"
894	spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token)
895	return spt
896}
897
898func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
899	template := x509.Certificate{
900		SerialNumber:          big.NewInt(0),
901		Subject:               pkix.Name{CommonName: "test"},
902		BasicConstraintsValid: true,
903	}
904	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
905	if err != nil {
906		t.Fatal(err)
907	}
908	certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
909	if err != nil {
910		t.Fatal(err)
911	}
912	certificate, err := x509.ParseCertificate(certificateBytes)
913	if err != nil {
914		t.Fatal(err)
915	}
916
917	spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
918	return spt
919}
920
921func newServicePrincipalTokenUsernamePassword(t *testing.T) *ServicePrincipalToken {
922	spt, _ := NewServicePrincipalTokenFromUsernamePassword(TestOAuthConfig, "id", "username", "password", "resource")
923	return spt
924}
925
926func newServicePrincipalTokenAuthorizationCode(t *testing.T) *ServicePrincipalToken {
927	spt, _ := NewServicePrincipalTokenFromAuthorizationCode(TestOAuthConfig, "id", "clientSecret", "code", "http://redirectUri/getToken", "resource")
928	return spt
929}
930