1package adal 2 3// Copyright 2017 Microsoft Corporation 4// 5// Licensed under the Apache License, Version 2.0 (the "License"); 6// you may not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/x509" 22 "crypto/x509/pkix" 23 "encoding/json" 24 "fmt" 25 "io/ioutil" 26 "math/big" 27 "net/http" 28 "net/url" 29 "os" 30 "reflect" 31 "strconv" 32 "strings" 33 "sync" 34 "testing" 35 "time" 36 37 "github.com/Azure/go-autorest/autorest/date" 38 "github.com/Azure/go-autorest/autorest/mocks" 39 jwt "github.com/dgrijalva/jwt-go" 40) 41 42const ( 43 defaultFormData = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource" 44 defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource" 45) 46 47func TestTokenExpires(t *testing.T) { 48 tt := time.Now().Add(5 * time.Second) 49 tk := newTokenExpiresAt(tt) 50 51 if tk.Expires().Equal(tt) { 52 t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt) 53 } 54} 55 56func TestTokenIsExpired(t *testing.T) { 57 tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second)) 58 59 if !tk.IsExpired() { 60 t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v", 61 time.Now().UTC(), tk.Expires()) 62 } 63} 64 65func TestTokenIsExpiredUninitialized(t *testing.T) { 66 tk := &Token{} 67 68 if !tk.IsExpired() { 69 t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires()) 70 } 71} 72 73func TestTokenIsNoExpired(t *testing.T) { 74 tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second)) 75 76 if tk.IsExpired() { 77 t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires()) 78 } 79} 80 81func TestTokenWillExpireIn(t *testing.T) { 82 d := 5 * time.Second 83 tk := newTokenExpiresIn(d) 84 85 if !tk.WillExpireIn(d) { 86 t.Fatal("adal: Token#WillExpireIn mismeasured expiration time") 87 } 88} 89 90func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) { 91 spt := newServicePrincipalToken() 92 93 if !spt.inner.AutoRefresh { 94 t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing") 95 } 96 97 spt.SetAutoRefresh(false) 98 if spt.inner.AutoRefresh { 99 t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing") 100 } 101} 102 103func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) { 104 spt := newServicePrincipalToken() 105 106 var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) { 107 return nil, nil 108 } 109 110 if spt.customRefreshFunc != nil { 111 t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't") 112 } 113 114 spt.SetCustomRefreshFunc(refreshFunc) 115 116 if spt.customRefreshFunc == nil { 117 t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func") 118 } 119} 120 121func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) { 122 spt := newServicePrincipalToken() 123 124 if spt.inner.RefreshWithin != defaultRefresh { 125 t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval") 126 } 127 128 spt.SetRefreshWithin(2 * defaultRefresh) 129 if spt.inner.RefreshWithin != 2*defaultRefresh { 130 t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval") 131 } 132} 133 134func TestServicePrincipalTokenSetSender(t *testing.T) { 135 spt := newServicePrincipalToken() 136 137 c := &http.Client{} 138 spt.SetSender(c) 139 if !reflect.DeepEqual(c, spt.sender) { 140 t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender") 141 } 142} 143 144func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) { 145 spt := newServicePrincipalToken() 146 147 called := false 148 var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) { 149 called = true 150 return &Token{}, nil 151 } 152 spt.SetCustomRefreshFunc(refreshFunc) 153 if called { 154 t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing") 155 } 156 157 spt.refreshInternal(context.Background(), "https://example.com") 158 159 if !called { 160 t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function") 161 } 162} 163 164func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) { 165 spt := newServicePrincipalToken() 166 167 body := mocks.NewBody(newTokenJSON("12345", "test")) 168 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 169 170 c := mocks.NewSender() 171 s := DecorateSender(c, 172 (func() SendDecorator { 173 return func(s Sender) Sender { 174 return SenderFunc(func(r *http.Request) (*http.Response, error) { 175 if r.Method != "POST" { 176 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method) 177 } 178 return resp, nil 179 }) 180 } 181 })()) 182 spt.SetSender(s) 183 err := spt.Refresh() 184 if err != nil { 185 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 186 } 187 188 if body.IsOpen() { 189 t.Fatalf("the response was not closed!") 190 } 191} 192 193func TestServicePrincipalTokenFromMSIRefreshUsesGET(t *testing.T) { 194 resource := "https://resource" 195 cb := func(token Token) error { return nil } 196 197 spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb) 198 if err != nil { 199 t.Fatalf("Failed to get MSI SPT: %v", err) 200 } 201 202 body := mocks.NewBody(newTokenJSON("12345", "test")) 203 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 204 205 c := mocks.NewSender() 206 s := DecorateSender(c, 207 (func() SendDecorator { 208 return func(s Sender) Sender { 209 return SenderFunc(func(r *http.Request) (*http.Response, error) { 210 if r.Method != "GET" { 211 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method) 212 } 213 if h := r.Header.Get("Metadata"); h != "true" { 214 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI") 215 } 216 return resp, nil 217 }) 218 } 219 })()) 220 spt.SetSender(s) 221 err = spt.Refresh() 222 if err != nil { 223 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 224 } 225 226 if body.IsOpen() { 227 t.Fatalf("the response was not closed!") 228 } 229} 230 231func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) { 232 ctx, cancel := context.WithCancel(context.Background()) 233 endpoint, _ := GetMSIVMEndpoint() 234 235 spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource") 236 if err != nil { 237 t.Fatalf("Failed to get MSI SPT: %v", err) 238 } 239 240 c := mocks.NewSender() 241 c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5) 242 243 var wg sync.WaitGroup 244 wg.Add(1) 245 start := time.Now() 246 end := time.Now() 247 248 go func() { 249 spt.SetSender(c) 250 err = spt.RefreshWithContext(ctx) 251 end = time.Now() 252 wg.Done() 253 }() 254 255 cancel() 256 wg.Wait() 257 time.Sleep(5 * time.Millisecond) 258 259 if end.Sub(start) >= time.Second { 260 t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel") 261 } 262} 263 264func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) { 265 spt := newServicePrincipalToken() 266 267 body := mocks.NewBody(newTokenJSON("12345", "test")) 268 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 269 270 c := mocks.NewSender() 271 s := DecorateSender(c, 272 (func() SendDecorator { 273 return func(s Sender) Sender { 274 return SenderFunc(func(r *http.Request) (*http.Response, error) { 275 if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" { 276 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v", 277 "application/x-form-urlencoded", 278 r.Header.Get(http.CanonicalHeaderKey("Content-Type"))) 279 } 280 return resp, nil 281 }) 282 } 283 })()) 284 spt.SetSender(s) 285 err := spt.Refresh() 286 if err != nil { 287 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 288 } 289} 290 291func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) { 292 spt := newServicePrincipalToken() 293 294 body := mocks.NewBody(newTokenJSON("12345", "test")) 295 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 296 297 c := mocks.NewSender() 298 s := DecorateSender(c, 299 (func() SendDecorator { 300 return func(s Sender) Sender { 301 return SenderFunc(func(r *http.Request) (*http.Response, error) { 302 if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() { 303 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v", 304 TestOAuthConfig.TokenEndpoint, r.URL) 305 } 306 return resp, nil 307 }) 308 } 309 })()) 310 spt.SetSender(s) 311 err := spt.Refresh() 312 if err != nil { 313 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 314 } 315} 316 317func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) { 318 body := mocks.NewBody(newTokenJSON("12345", "test")) 319 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 320 321 c := mocks.NewSender() 322 s := DecorateSender(c, 323 (func() SendDecorator { 324 return func(s Sender) Sender { 325 return SenderFunc(func(r *http.Request) (*http.Response, error) { 326 b, err := ioutil.ReadAll(r.Body) 327 if err != nil { 328 t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err) 329 } 330 f(t, b) 331 return resp, nil 332 }) 333 } 334 })()) 335 spt.SetSender(s) 336 err := spt.Refresh() 337 if err != nil { 338 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 339 } 340} 341 342func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) { 343 sptManual := newServicePrincipalTokenManual() 344 testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) { 345 if string(b) != defaultManualFormData { 346 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v", 347 defaultManualFormData, string(b)) 348 } 349 }) 350} 351 352func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) { 353 sptCert := newServicePrincipalTokenCertificate(t) 354 testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) { 355 body := string(b) 356 357 values, _ := url.ParseQuery(body) 358 if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" || 359 values["client_id"][0] != "id" || 360 values["grant_type"][0] != "client_credentials" || 361 values["resource"][0] != "resource" { 362 t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.") 363 } 364 365 tok, _ := jwt.Parse(values["client_assertion"][0], nil) 366 if tok == nil { 367 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT") 368 } 369 if _, ok := tok.Header["x5t"]; !ok { 370 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5t header") 371 } 372 if _, ok := tok.Header["x5c"]; !ok { 373 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5c header") 374 } 375 }) 376} 377 378func TestServicePrincipalTokenUsernamePasswordRefreshSetsBody(t *testing.T) { 379 spt := newServicePrincipalTokenUsernamePassword(t) 380 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) { 381 body := string(b) 382 383 values, _ := url.ParseQuery(body) 384 if values["client_id"][0] != "id" || 385 values["grant_type"][0] != "password" || 386 values["username"][0] != "username" || 387 values["password"][0] != "password" || 388 values["resource"][0] != "resource" { 389 t.Fatalf("adal: ServicePrincipalTokenUsernamePassword#Refresh did not correctly set the HTTP Request Body.") 390 } 391 }) 392} 393 394func TestServicePrincipalTokenAuthorizationCodeRefreshSetsBody(t *testing.T) { 395 spt := newServicePrincipalTokenAuthorizationCode(t) 396 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) { 397 body := string(b) 398 399 values, _ := url.ParseQuery(body) 400 if values["client_id"][0] != "id" || 401 values["grant_type"][0] != OAuthGrantTypeAuthorizationCode || 402 values["code"][0] != "code" || 403 values["client_secret"][0] != "clientSecret" || 404 values["redirect_uri"][0] != "http://redirectUri/getToken" || 405 values["resource"][0] != "resource" { 406 t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.") 407 } 408 }) 409 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) { 410 body := string(b) 411 412 values, _ := url.ParseQuery(body) 413 if values["client_id"][0] != "id" || 414 values["grant_type"][0] != OAuthGrantTypeRefreshToken || 415 values["code"][0] != "code" || 416 values["client_secret"][0] != "clientSecret" || 417 values["redirect_uri"][0] != "http://redirectUri/getToken" || 418 values["resource"][0] != "resource" { 419 t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.") 420 } 421 }) 422} 423 424func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) { 425 spt := newServicePrincipalToken() 426 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) { 427 if string(b) != defaultFormData { 428 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v", 429 defaultFormData, string(b)) 430 } 431 432 }) 433} 434 435func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) { 436 spt := newServicePrincipalToken() 437 438 body := mocks.NewBody(newTokenJSON("12345", "test")) 439 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 440 441 c := mocks.NewSender() 442 s := DecorateSender(c, 443 (func() SendDecorator { 444 return func(s Sender) Sender { 445 return SenderFunc(func(r *http.Request) (*http.Response, error) { 446 return resp, nil 447 }) 448 } 449 })()) 450 spt.SetSender(s) 451 err := spt.Refresh() 452 if err != nil { 453 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 454 } 455 if resp.Body.(*mocks.Body).IsOpen() { 456 t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body") 457 } 458} 459 460func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) { 461 spt := newServicePrincipalToken() 462 463 body := mocks.NewBody(newTokenJSON("12345", "test")) 464 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized") 465 466 c := mocks.NewSender() 467 s := DecorateSender(c, 468 (func() SendDecorator { 469 return func(s Sender) Sender { 470 return SenderFunc(func(r *http.Request) (*http.Response, error) { 471 return resp, nil 472 }) 473 } 474 })()) 475 spt.SetSender(s) 476 err := spt.Refresh() 477 if err == nil { 478 t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK) 479 } 480} 481 482func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) { 483 spt := newServicePrincipalToken() 484 485 c := mocks.NewSender() 486 s := DecorateSender(c, 487 (func() SendDecorator { 488 return func(s Sender) Sender { 489 return SenderFunc(func(r *http.Request) (*http.Response, error) { 490 return mocks.NewResponse(), nil 491 }) 492 } 493 })()) 494 spt.SetSender(s) 495 err := spt.Refresh() 496 if err == nil { 497 t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token") 498 } 499} 500 501func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) { 502 spt := newServicePrincipalToken() 503 504 c := mocks.NewSender() 505 c.SetError(fmt.Errorf("Faux Error")) 506 spt.SetSender(c) 507 508 err := spt.Refresh() 509 if err == nil { 510 t.Fatal("adal: Failed to propagate the request error") 511 } 512} 513 514func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) { 515 spt := newServicePrincipalToken() 516 517 c := mocks.NewSender() 518 c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized)) 519 spt.SetSender(c) 520 521 err := spt.Refresh() 522 if err == nil { 523 t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK) 524 } 525} 526 527func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) { 528 spt := newServicePrincipalToken() 529 530 expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds())) 531 j := newTokenJSON(expiresOn, "resource") 532 resp := mocks.NewResponseWithContent(j) 533 c := mocks.NewSender() 534 s := DecorateSender(c, 535 (func() SendDecorator { 536 return func(s Sender) Sender { 537 return SenderFunc(func(r *http.Request) (*http.Response, error) { 538 return resp, nil 539 }) 540 } 541 })()) 542 spt.SetSender(s) 543 544 err := spt.Refresh() 545 if err != nil { 546 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 547 } else if spt.inner.Token.AccessToken != "accessToken" || 548 spt.inner.Token.ExpiresIn != "3600" || 549 spt.inner.Token.ExpiresOn != json.Number(expiresOn) || 550 spt.inner.Token.NotBefore != json.Number(expiresOn) || 551 spt.inner.Token.Resource != "resource" || 552 spt.inner.Token.Type != "Bearer" { 553 t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v", 554 j, *spt) 555 } 556} 557 558func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) { 559 spt := newServicePrincipalToken() 560 expireToken(&spt.inner.Token) 561 562 body := mocks.NewBody(newTokenJSON("12345", "test")) 563 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") 564 565 f := false 566 c := mocks.NewSender() 567 s := DecorateSender(c, 568 (func() SendDecorator { 569 return func(s Sender) Sender { 570 return SenderFunc(func(r *http.Request) (*http.Response, error) { 571 f = true 572 return resp, nil 573 }) 574 } 575 })()) 576 spt.SetSender(s) 577 err := spt.EnsureFresh() 578 if err != nil { 579 t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err) 580 } 581 if !f { 582 t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token") 583 } 584} 585 586func TestServicePrincipalTokenEnsureFreshFails1(t *testing.T) { 587 spt := newServicePrincipalToken() 588 expireToken(&spt.inner.Token) 589 590 c := mocks.NewSender() 591 c.SetError(fmt.Errorf("some failure")) 592 593 spt.SetSender(c) 594 err := spt.EnsureFresh() 595 if err == nil { 596 t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error") 597 } 598 if _, ok := err.(TokenRefreshError); ok { 599 t.Fatal("adal: ServicePrincipalToken#EnsureFresh unexpected TokenRefreshError") 600 } 601} 602 603func TestServicePrincipalTokenEnsureFreshFails2(t *testing.T) { 604 spt := newServicePrincipalToken() 605 expireToken(&spt.inner.Token) 606 607 c := mocks.NewSender() 608 c.AppendResponse(mocks.NewResponseWithStatus("bad request", http.StatusBadRequest)) 609 610 spt.SetSender(c) 611 err := spt.EnsureFresh() 612 if err == nil { 613 t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error") 614 } 615 if _, ok := err.(TokenRefreshError); !ok { 616 t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return a TokenRefreshError") 617 } 618} 619 620func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) { 621 spt := newServicePrincipalToken() 622 setTokenToExpireIn(&spt.inner.Token, 1000*time.Second) 623 624 f := false 625 c := mocks.NewSender() 626 s := DecorateSender(c, 627 (func() SendDecorator { 628 return func(s Sender) Sender { 629 return SenderFunc(func(r *http.Request) (*http.Response, error) { 630 f = true 631 return mocks.NewResponse(), nil 632 }) 633 } 634 })()) 635 spt.SetSender(s) 636 err := spt.EnsureFresh() 637 if err != nil { 638 t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err) 639 } 640 if f { 641 t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token") 642 } 643} 644 645func TestRefreshCallback(t *testing.T) { 646 callbackTriggered := false 647 spt := newServicePrincipalToken(func(Token) error { 648 callbackTriggered = true 649 return nil 650 }) 651 652 expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds())) 653 654 sender := mocks.NewSender() 655 j := newTokenJSON(expiresOn, "resource") 656 sender.AppendResponse(mocks.NewResponseWithContent(j)) 657 spt.SetSender(sender) 658 err := spt.Refresh() 659 if err != nil { 660 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) 661 } 662 if !callbackTriggered { 663 t.Fatalf("adal: RefreshCallback failed to trigger call callback") 664 } 665} 666 667func TestRefreshCallbackErrorPropagates(t *testing.T) { 668 errorText := "this is an error text" 669 spt := newServicePrincipalToken(func(Token) error { 670 return fmt.Errorf(errorText) 671 }) 672 673 expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds())) 674 675 sender := mocks.NewSender() 676 j := newTokenJSON(expiresOn, "resource") 677 sender.AppendResponse(mocks.NewResponseWithContent(j)) 678 spt.SetSender(sender) 679 err := spt.Refresh() 680 681 if err == nil || !strings.Contains(err.Error(), errorText) { 682 t.Fatalf("adal: RefreshCallback failed to propagate error") 683 } 684} 685 686// This demonstrates the danger of manual token without a refresh token 687func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) { 688 spt := newServicePrincipalTokenManual() 689 spt.inner.Token.RefreshToken = "" 690 err := spt.Refresh() 691 if err == nil { 692 t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token") 693 } 694} 695 696func TestNewServicePrincipalTokenFromMSI(t *testing.T) { 697 resource := "https://resource" 698 cb := func(token Token) error { return nil } 699 700 spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb) 701 if err != nil { 702 t.Fatalf("Failed to get MSI SPT: %v", err) 703 } 704 705 // check some of the SPT fields 706 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok { 707 t.Fatal("SPT secret was not of MSI type") 708 } 709 710 if spt.inner.Resource != resource { 711 t.Fatal("SPT came back with incorrect resource") 712 } 713 714 if len(spt.refreshCallbacks) != 1 { 715 t.Fatal("SPT had incorrect refresh callbacks.") 716 } 717} 718 719func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { 720 resource := "https://resource" 721 userID := "abc123" 722 cb := func(token Token) error { return nil } 723 724 spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb) 725 if err != nil { 726 t.Fatalf("Failed to get MSI SPT: %v", err) 727 } 728 729 // check some of the SPT fields 730 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok { 731 t.Fatal("SPT secret was not of MSI type") 732 } 733 734 if spt.inner.Resource != resource { 735 t.Fatal("SPT came back with incorrect resource") 736 } 737 738 if len(spt.refreshCallbacks) != 1 { 739 t.Fatal("SPT had incorrect refresh callbacks.") 740 } 741 742 if spt.inner.ClientID != userID { 743 t.Fatal("SPT had incorrect client ID") 744 } 745} 746 747func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) { 748 token := newToken() 749 secret := &ServicePrincipalAuthorizationCodeSecret{ 750 ClientSecret: "clientSecret", 751 AuthorizationCode: "code123", 752 RedirectURI: "redirect", 753 } 754 755 spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil) 756 if err != nil { 757 t.Fatalf("Failed creating new SPT: %s", err) 758 } 759 760 if !reflect.DeepEqual(token, spt.inner.Token) { 761 t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token) 762 } 763 764 if !reflect.DeepEqual(secret, spt.inner.Secret) { 765 t.Fatalf("Secrets do not match: %s, %s", secret, spt.inner.Secret) 766 } 767 768} 769 770func TestGetVMEndpoint(t *testing.T) { 771 endpoint, err := GetMSIVMEndpoint() 772 if err != nil { 773 t.Fatal("Coudn't get VM endpoint") 774 } 775 776 if endpoint != msiEndpoint { 777 t.Fatal("Didn't get correct endpoint") 778 } 779} 780 781func TestGetAppServiceEndpoint(t *testing.T) { 782 const testEndpoint = "http://172.16.1.2:8081/msi/token" 783 if err := os.Setenv(asMSIEndpointEnv, testEndpoint); err != nil { 784 t.Fatalf("os.Setenv: %v", err) 785 } 786 787 endpoint, err := GetMSIAppServiceEndpoint() 788 if err != nil { 789 t.Fatal("Coudn't get App Service endpoint") 790 } 791 792 if endpoint != testEndpoint { 793 t.Fatal("Didn't get correct endpoint") 794 } 795 796 if err := os.Unsetenv(asMSIEndpointEnv); err != nil { 797 t.Fatalf("os.Unsetenv: %v", err) 798 } 799} 800 801func TestGetMSIEndpoint(t *testing.T) { 802 const ( 803 testEndpoint = "http://172.16.1.2:8081/msi/token" 804 testSecret = "DEADBEEF-BBBB-AAAA-DDDD-DDD000000DDD" 805 ) 806 807 // Test VM well-known endpoint is returned 808 if err := os.Unsetenv(asMSIEndpointEnv); err != nil { 809 t.Fatalf("os.Unsetenv: %v", err) 810 } 811 812 if err := os.Unsetenv(asMSISecretEnv); err != nil { 813 t.Fatalf("os.Unsetenv: %v", err) 814 } 815 816 vmEndpoint, err := GetMSIEndpoint() 817 if err != nil { 818 t.Fatal("Coudn't get VM endpoint") 819 } 820 821 if vmEndpoint != msiEndpoint { 822 t.Fatal("Didn't get correct endpoint") 823 } 824 825 // Test App Service endpoint is returned 826 if err := os.Setenv(asMSIEndpointEnv, testEndpoint); err != nil { 827 t.Fatalf("os.Setenv: %v", err) 828 } 829 830 if err := os.Setenv(asMSISecretEnv, testSecret); err != nil { 831 t.Fatalf("os.Setenv: %v", err) 832 } 833 834 asEndpoint, err := GetMSIEndpoint() 835 if err != nil { 836 t.Fatal("Coudn't get App Service endpoint") 837 } 838 839 if asEndpoint != testEndpoint { 840 t.Fatal("Didn't get correct endpoint") 841 } 842 843 if err := os.Unsetenv(asMSIEndpointEnv); err != nil { 844 t.Fatalf("os.Unsetenv: %v", err) 845 } 846 847 if err := os.Unsetenv(asMSISecretEnv); err != nil { 848 t.Fatalf("os.Unsetenv: %v", err) 849 } 850} 851 852func TestMarshalServicePrincipalNoSecret(t *testing.T) { 853 spt := newServicePrincipalTokenManual() 854 b, err := json.Marshal(spt) 855 if err != nil { 856 t.Fatalf("failed to marshal token: %+v", err) 857 } 858 var spt2 *ServicePrincipalToken 859 err = json.Unmarshal(b, &spt2) 860 if err != nil { 861 t.Fatalf("failed to unmarshal token: %+v", err) 862 } 863 if !reflect.DeepEqual(spt, spt2) { 864 t.Fatal("tokens don't match") 865 } 866} 867 868func TestMarshalServicePrincipalTokenSecret(t *testing.T) { 869 spt := newServicePrincipalToken() 870 b, err := json.Marshal(spt) 871 if err != nil { 872 t.Fatalf("failed to marshal token: %+v", err) 873 } 874 var spt2 *ServicePrincipalToken 875 err = json.Unmarshal(b, &spt2) 876 if err != nil { 877 t.Fatalf("failed to unmarshal token: %+v", err) 878 } 879 if !reflect.DeepEqual(spt, spt2) { 880 t.Fatal("tokens don't match") 881 } 882} 883 884func TestMarshalServicePrincipalCertificateSecret(t *testing.T) { 885 spt := newServicePrincipalTokenCertificate(t) 886 b, err := json.Marshal(spt) 887 if err == nil { 888 t.Fatal("expected error when marshalling certificate token") 889 } 890 var spt2 *ServicePrincipalToken 891 err = json.Unmarshal(b, &spt2) 892 if err == nil { 893 t.Fatal("expected error when unmarshalling certificate token") 894 } 895} 896 897func TestMarshalServicePrincipalMSISecret(t *testing.T) { 898 spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil) 899 if err != nil { 900 t.Fatalf("failed to get MSI SPT: %+v", err) 901 } 902 b, err := json.Marshal(spt) 903 if err == nil { 904 t.Fatal("expected error when marshalling MSI token") 905 } 906 var spt2 *ServicePrincipalToken 907 err = json.Unmarshal(b, &spt2) 908 if err == nil { 909 t.Fatal("expected error when unmarshalling MSI token") 910 } 911} 912 913func TestMarshalServicePrincipalUsernamePasswordSecret(t *testing.T) { 914 spt := newServicePrincipalTokenUsernamePassword(t) 915 b, err := json.Marshal(spt) 916 if err != nil { 917 t.Fatalf("failed to marshal token: %+v", err) 918 } 919 var spt2 *ServicePrincipalToken 920 err = json.Unmarshal(b, &spt2) 921 if err != nil { 922 t.Fatalf("failed to unmarshal token: %+v", err) 923 } 924 if !reflect.DeepEqual(spt, spt2) { 925 t.Fatal("tokens don't match") 926 } 927} 928 929func TestMarshalServicePrincipalAuthorizationCodeSecret(t *testing.T) { 930 spt := newServicePrincipalTokenAuthorizationCode(t) 931 b, err := json.Marshal(spt) 932 if err != nil { 933 t.Fatalf("failed to marshal token: %+v", err) 934 } 935 var spt2 *ServicePrincipalToken 936 err = json.Unmarshal(b, &spt2) 937 if err != nil { 938 t.Fatalf("failed to unmarshal token: %+v", err) 939 } 940 if !reflect.DeepEqual(spt, spt2) { 941 t.Fatal("tokens don't match") 942 } 943} 944 945func TestMarshalInnerToken(t *testing.T) { 946 spt := newServicePrincipalTokenManual() 947 tokenJSON, err := spt.MarshalTokenJSON() 948 if err != nil { 949 t.Fatalf("failed to marshal token: %+v", err) 950 } 951 952 testToken := newToken() 953 testToken.RefreshToken = "refreshtoken" 954 955 testTokenJSON, err := json.Marshal(testToken) 956 if err != nil { 957 t.Fatalf("failed to marshal test token: %+v", err) 958 } 959 960 if !reflect.DeepEqual(tokenJSON, testTokenJSON) { 961 t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON) 962 } 963 964 var t1 Token 965 err = json.Unmarshal(tokenJSON, &t1) 966 if err != nil { 967 t.Fatalf("failed to unmarshal token: %+v", err) 968 } 969 970 if !reflect.DeepEqual(t1, testToken) { 971 t.Fatalf("tokens don't match: %s, %s", t1, testToken) 972 } 973} 974 975func TestNewMultiTenantServicePrincipalToken(t *testing.T) { 976 cfg, err := NewMultiTenantOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID, TestAuxTenantIDs, OAuthOptions{}) 977 if err != nil { 978 t.Fatalf("autorest/adal: unexpected error while creating multitenant config: %v", err) 979 } 980 mt, err := NewMultiTenantServicePrincipalToken(cfg, "clientID", "superSecret", "resource") 981 if err != nil { 982 t.Fatalf("autorest/adal: unexpected error while creating multitenant service principal token: %v", err) 983 } 984 if !strings.Contains(mt.PrimaryToken.inner.OauthConfig.AuthorizeEndpoint.String(), TestTenantID) { 985 t.Fatal("didn't find primary tenant ID in primary SPT") 986 } 987 for i := range mt.AuxiliaryTokens { 988 if ep := mt.AuxiliaryTokens[i].inner.OauthConfig.AuthorizeEndpoint.String(); !strings.Contains(ep, fmt.Sprintf("%s%d", TestAuxTenantPrefix, i)) { 989 t.Fatalf("didn't find auxiliary tenant ID in token %s", ep) 990 } 991 } 992} 993 994func newTokenJSON(expiresOn string, resource string) string { 995 return fmt.Sprintf(`{ 996 "access_token" : "accessToken", 997 "expires_in" : "3600", 998 "expires_on" : "%s", 999 "not_before" : "%s", 1000 "resource" : "%s", 1001 "token_type" : "Bearer", 1002 "refresh_token": "ABC123" 1003 }`, 1004 expiresOn, expiresOn, resource) 1005} 1006 1007func newTokenExpiresIn(expireIn time.Duration) *Token { 1008 t := newToken() 1009 return setTokenToExpireIn(&t, expireIn) 1010} 1011 1012func newTokenExpiresAt(expireAt time.Time) *Token { 1013 t := newToken() 1014 return setTokenToExpireAt(&t, expireAt) 1015} 1016 1017func expireToken(t *Token) *Token { 1018 return setTokenToExpireIn(t, 0) 1019} 1020 1021func setTokenToExpireAt(t *Token, expireAt time.Time) *Token { 1022 t.ExpiresIn = "3600" 1023 t.ExpiresOn = json.Number(strconv.FormatInt(int64(expireAt.Sub(date.UnixEpoch())/time.Second), 10)) 1024 t.NotBefore = t.ExpiresOn 1025 return t 1026} 1027 1028func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token { 1029 return setTokenToExpireAt(t, time.Now().Add(expireIn)) 1030} 1031 1032func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken { 1033 spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...) 1034 return spt 1035} 1036 1037func newServicePrincipalTokenManual() *ServicePrincipalToken { 1038 token := newToken() 1039 token.RefreshToken = "refreshtoken" 1040 spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token) 1041 return spt 1042} 1043 1044func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken { 1045 template := x509.Certificate{ 1046 SerialNumber: big.NewInt(0), 1047 Subject: pkix.Name{CommonName: "test"}, 1048 BasicConstraintsValid: true, 1049 } 1050 privateKey, err := rsa.GenerateKey(rand.Reader, 2048) 1051 if err != nil { 1052 t.Fatal(err) 1053 } 1054 certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) 1055 if err != nil { 1056 t.Fatal(err) 1057 } 1058 certificate, err := x509.ParseCertificate(certificateBytes) 1059 if err != nil { 1060 t.Fatal(err) 1061 } 1062 1063 spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource") 1064 return spt 1065} 1066 1067func newServicePrincipalTokenUsernamePassword(t *testing.T) *ServicePrincipalToken { 1068 spt, _ := NewServicePrincipalTokenFromUsernamePassword(TestOAuthConfig, "id", "username", "password", "resource") 1069 return spt 1070} 1071 1072func newServicePrincipalTokenAuthorizationCode(t *testing.T) *ServicePrincipalToken { 1073 spt, _ := NewServicePrincipalTokenFromAuthorizationCode(TestOAuthConfig, "id", "clientSecret", "code", "http://redirectUri/getToken", "resource") 1074 return spt 1075} 1076