1package processcreds_test 2 3import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "os" 10 "os/exec" 11 "runtime" 12 "strings" 13 "testing" 14 "time" 15 16 "github.com/aws/aws-sdk-go/aws" 17 "github.com/aws/aws-sdk-go/aws/awserr" 18 "github.com/aws/aws-sdk-go/aws/credentials/processcreds" 19 "github.com/aws/aws-sdk-go/aws/session" 20 "github.com/aws/aws-sdk-go/internal/sdktesting" 21) 22 23func TestProcessProviderFromSessionCfg(t *testing.T) { 24 restoreEnvFn := sdktesting.StashEnv() 25 defer restoreEnvFn() 26 27 os.Setenv("AWS_SDK_LOAD_CONFIG", "1") 28 if runtime.GOOS == "windows" { 29 os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini") 30 } else { 31 os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini") 32 } 33 34 sess, err := session.NewSession(&aws.Config{ 35 Region: aws.String("region")}, 36 ) 37 38 if err != nil { 39 t.Errorf("error getting session: %v", err) 40 } 41 42 creds, err := sess.Config.Credentials.Get() 43 if err != nil { 44 t.Errorf("error getting credentials: %v", err) 45 } 46 47 if e, a := "accessKey", creds.AccessKeyID; e != a { 48 t.Errorf("expected %v, got %v", e, a) 49 } 50 51 if e, a := "secret", creds.SecretAccessKey; e != a { 52 t.Errorf("expected %v, got %v", e, a) 53 } 54 55 if e, a := "tokenDefault", creds.SessionToken; e != a { 56 t.Errorf("expected %v, got %v", e, a) 57 } 58 59} 60 61func TestProcessProviderFromSessionWithProfileCfg(t *testing.T) { 62 restoreEnvFn := sdktesting.StashEnv() 63 defer restoreEnvFn() 64 65 os.Setenv("AWS_SDK_LOAD_CONFIG", "1") 66 os.Setenv("AWS_PROFILE", "non_expire") 67 if runtime.GOOS == "windows" { 68 os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini") 69 } else { 70 os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini") 71 } 72 73 sess, err := session.NewSession(&aws.Config{ 74 Region: aws.String("region")}, 75 ) 76 77 if err != nil { 78 t.Errorf("error getting session: %v", err) 79 } 80 81 creds, err := sess.Config.Credentials.Get() 82 if err != nil { 83 t.Errorf("error getting credentials: %v", err) 84 } 85 86 if e, a := "nonDefaultToken", creds.SessionToken; e != a { 87 t.Errorf("expected %v, got %v", e, a) 88 } 89 90} 91 92func TestProcessProviderNotFromCredProcCfg(t *testing.T) { 93 restoreEnvFn := sdktesting.StashEnv() 94 defer restoreEnvFn() 95 96 os.Setenv("AWS_SDK_LOAD_CONFIG", "1") 97 os.Setenv("AWS_PROFILE", "not_alone") 98 if runtime.GOOS == "windows" { 99 os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini") 100 } else { 101 os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini") 102 } 103 104 sess, err := session.NewSession(&aws.Config{ 105 Region: aws.String("region")}, 106 ) 107 108 if err != nil { 109 t.Errorf("error getting session: %v", err) 110 } 111 112 creds, err := sess.Config.Credentials.Get() 113 if err != nil { 114 t.Errorf("error getting credentials: %v", err) 115 } 116 117 if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a { 118 t.Errorf("expected %v, got %v", e, a) 119 } 120 121 if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a { 122 t.Errorf("expected %v, got %v", e, a) 123 } 124 125} 126 127func TestProcessProviderFromSessionCrd(t *testing.T) { 128 restoreEnvFn := sdktesting.StashEnv() 129 defer restoreEnvFn() 130 131 if runtime.GOOS == "windows" { 132 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini") 133 } else { 134 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini") 135 } 136 137 sess, err := session.NewSession(&aws.Config{ 138 Region: aws.String("region")}, 139 ) 140 141 if err != nil { 142 t.Errorf("error getting session: %v", err) 143 } 144 145 creds, err := sess.Config.Credentials.Get() 146 if err != nil { 147 t.Errorf("error getting credentials: %v", err) 148 } 149 150 if e, a := "accessKey", creds.AccessKeyID; e != a { 151 t.Errorf("expected %v, got %v", e, a) 152 } 153 154 if e, a := "secret", creds.SecretAccessKey; e != a { 155 t.Errorf("expected %v, got %v", e, a) 156 } 157 158 if e, a := "tokenDefault", creds.SessionToken; e != a { 159 t.Errorf("expected %v, got %v", e, a) 160 } 161 162} 163 164func TestProcessProviderFromSessionWithProfileCrd(t *testing.T) { 165 restoreEnvFn := sdktesting.StashEnv() 166 defer restoreEnvFn() 167 168 os.Setenv("AWS_PROFILE", "non_expire") 169 if runtime.GOOS == "windows" { 170 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini") 171 } else { 172 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini") 173 } 174 175 sess, err := session.NewSession(&aws.Config{ 176 Region: aws.String("region")}, 177 ) 178 179 if err != nil { 180 t.Errorf("error getting session: %v", err) 181 } 182 183 creds, err := sess.Config.Credentials.Get() 184 if err != nil { 185 t.Errorf("error getting credentials: %v", err) 186 } 187 188 if e, a := "nonDefaultToken", creds.SessionToken; e != a { 189 t.Errorf("expected %v, got %v", e, a) 190 } 191 192} 193 194func TestProcessProviderNotFromCredProcCrd(t *testing.T) { 195 restoreEnvFn := sdktesting.StashEnv() 196 defer restoreEnvFn() 197 198 os.Setenv("AWS_PROFILE", "not_alone") 199 if runtime.GOOS == "windows" { 200 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini") 201 } else { 202 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini") 203 } 204 205 sess, err := session.NewSession(&aws.Config{ 206 Region: aws.String("region")}, 207 ) 208 209 if err != nil { 210 t.Errorf("error getting session: %v", err) 211 } 212 213 creds, err := sess.Config.Credentials.Get() 214 if err != nil { 215 t.Errorf("error getting credentials: %v", err) 216 } 217 218 if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a { 219 t.Errorf("expected %v, got %v", e, a) 220 } 221 222 if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a { 223 t.Errorf("expected %v, got %v", e, a) 224 } 225 226} 227 228func TestProcessProviderBadCommand(t *testing.T) { 229 restoreEnvFn := sdktesting.StashEnv() 230 defer restoreEnvFn() 231 232 creds := processcreds.NewCredentials("/bad/process") 233 _, err := creds.Get() 234 if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution { 235 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err) 236 } 237} 238 239func TestProcessProviderMoreEmptyCommands(t *testing.T) { 240 restoreEnvFn := sdktesting.StashEnv() 241 defer restoreEnvFn() 242 243 creds := processcreds.NewCredentials("") 244 _, err := creds.Get() 245 if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution { 246 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err) 247 } 248 249} 250 251func TestProcessProviderExpectErrors(t *testing.T) { 252 restoreEnvFn := sdktesting.StashEnv() 253 defer restoreEnvFn() 254 255 creds := processcreds.NewCredentials( 256 fmt.Sprintf( 257 "%s %s", 258 getOSCat(), 259 strings.Join( 260 []string{"testdata", "malformed.json"}, 261 string(os.PathSeparator)))) 262 _, err := creds.Get() 263 if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderParse { 264 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderParse, err) 265 } 266 267 creds = processcreds.NewCredentials( 268 fmt.Sprintf("%s %s", 269 getOSCat(), 270 strings.Join( 271 []string{"testdata", "wrongversion.json"}, 272 string(os.PathSeparator)))) 273 _, err = creds.Get() 274 if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderVersion { 275 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderVersion, err) 276 } 277 278 creds = processcreds.NewCredentials( 279 fmt.Sprintf( 280 "%s %s", 281 getOSCat(), 282 strings.Join( 283 []string{"testdata", "missingkey.json"}, 284 string(os.PathSeparator)))) 285 _, err = creds.Get() 286 if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired { 287 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err) 288 } 289 290 creds = processcreds.NewCredentials( 291 fmt.Sprintf( 292 "%s %s", 293 getOSCat(), 294 strings.Join( 295 []string{"testdata", "missingsecret.json"}, 296 string(os.PathSeparator)))) 297 _, err = creds.Get() 298 if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired { 299 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err) 300 } 301 302} 303 304func TestProcessProviderTimeout(t *testing.T) { 305 restoreEnvFn := sdktesting.StashEnv() 306 defer restoreEnvFn() 307 308 command := "/bin/sleep 2" 309 if runtime.GOOS == "windows" { 310 // "timeout" command does not work due to pipe redirection 311 command = "ping -n 2 127.0.0.1>nul" 312 } 313 314 creds := processcreds.NewCredentialsTimeout( 315 command, 316 time.Duration(1)*time.Second) 317 if _, err := creds.Get(); err == nil || err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution || err.(awserr.Error).Message() != "credential process timed out" { 318 t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err) 319 } 320 321} 322 323func TestProcessProviderWithLongSessionToken(t *testing.T) { 324 restoreEnvFn := sdktesting.StashEnv() 325 defer restoreEnvFn() 326 327 creds := processcreds.NewCredentials( 328 fmt.Sprintf( 329 "%s %s", 330 getOSCat(), 331 strings.Join( 332 []string{"testdata", "longsessiontoken.json"}, 333 string(os.PathSeparator)))) 334 v, err := creds.Get() 335 if err != nil { 336 t.Errorf("expected %v, got %v", "no error", err) 337 } 338 339 // Text string same length as session token returned by AWS for AssumeRoleWithWebIdentity 340 e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" 341 if a := v.SessionToken; e != a { 342 t.Errorf("expected %v, got %v", e, a) 343 } 344} 345 346type credentialTest struct { 347 Version int 348 AccessKeyID string `json:"AccessKeyId"` 349 SecretAccessKey string 350 Expiration string 351} 352 353func TestProcessProviderStatic(t *testing.T) { 354 restoreEnvFn := sdktesting.StashEnv() 355 defer restoreEnvFn() 356 357 // static 358 creds := processcreds.NewCredentials( 359 fmt.Sprintf( 360 "%s %s", 361 getOSCat(), 362 strings.Join( 363 []string{"testdata", "static.json"}, 364 string(os.PathSeparator)))) 365 _, err := creds.Get() 366 if err != nil { 367 t.Errorf("expected %v, got %v", "no error", err) 368 } 369 if creds.IsExpired() { 370 t.Errorf("expected %v, got %v", "static credentials/not expired", "expired") 371 } 372 373} 374 375func TestProcessProviderNotExpired(t *testing.T) { 376 restoreEnvFn := sdktesting.StashEnv() 377 defer restoreEnvFn() 378 379 // non-static, not expired 380 exp := &credentialTest{} 381 exp.Version = 1 382 exp.AccessKeyID = "accesskey" 383 exp.SecretAccessKey = "secretkey" 384 exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339) 385 b, err := json.Marshal(exp) 386 if err != nil { 387 t.Errorf("expected %v, got %v", "no error", err) 388 } 389 390 tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expiring") 391 if err != nil { 392 t.Errorf("expected %v, got %v", "no error", err) 393 } 394 if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil { 395 t.Errorf("expected %v, got %v", "no error", err) 396 } 397 defer func() { 398 if err = tmpFile.Close(); err != nil { 399 t.Errorf("expected %v, got %v", "no error", err) 400 } 401 if err = os.Remove(tmpFile.Name()); err != nil { 402 t.Errorf("expected %v, got %v", "no error", err) 403 } 404 }() 405 creds := processcreds.NewCredentials( 406 fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name())) 407 _, err = creds.Get() 408 if err != nil { 409 t.Errorf("expected %v, got %v", "no error", err) 410 } 411 if creds.IsExpired() { 412 t.Errorf("expected %v, got %v", "not expired", "expired") 413 } 414} 415 416func TestProcessProviderExpired(t *testing.T) { 417 restoreEnvFn := sdktesting.StashEnv() 418 defer restoreEnvFn() 419 420 // non-static, expired 421 exp := &credentialTest{} 422 exp.Version = 1 423 exp.AccessKeyID = "accesskey" 424 exp.SecretAccessKey = "secretkey" 425 exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339) 426 b, err := json.Marshal(exp) 427 if err != nil { 428 t.Errorf("expected %v, got %v", "no error", err) 429 } 430 431 tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expired") 432 if err != nil { 433 t.Errorf("expected %v, got %v", "no error", err) 434 } 435 if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil { 436 t.Errorf("expected %v, got %v", "no error", err) 437 } 438 defer func() { 439 if err = tmpFile.Close(); err != nil { 440 t.Errorf("expected %v, got %v", "no error", err) 441 } 442 if err = os.Remove(tmpFile.Name()); err != nil { 443 t.Errorf("expected %v, got %v", "no error", err) 444 } 445 }() 446 creds := processcreds.NewCredentials( 447 fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name())) 448 _, err = creds.Get() 449 if err != nil { 450 t.Errorf("expected %v, got %v", "no error", err) 451 } 452 if !creds.IsExpired() { 453 t.Errorf("expected %v, got %v", "expired", "not expired") 454 } 455} 456 457func TestProcessProviderForceExpire(t *testing.T) { 458 restoreEnvFn := sdktesting.StashEnv() 459 defer restoreEnvFn() 460 461 // non-static, not expired 462 463 // setup test credentials file 464 exp := &credentialTest{} 465 exp.Version = 1 466 exp.AccessKeyID = "accesskey" 467 exp.SecretAccessKey = "secretkey" 468 exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339) 469 b, err := json.Marshal(exp) 470 if err != nil { 471 t.Errorf("expected %v, got %v", "no error", err) 472 } 473 tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_force_expire") 474 if err != nil { 475 t.Errorf("expected %v, got %v", "no error", err) 476 } 477 if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil { 478 t.Errorf("expected %v, got %v", "no error", err) 479 } 480 defer func() { 481 if err = tmpFile.Close(); err != nil { 482 t.Errorf("expected %v, got %v", "no error", err) 483 } 484 if err = os.Remove(tmpFile.Name()); err != nil { 485 t.Errorf("expected %v, got %v", "no error", err) 486 } 487 }() 488 489 // get credentials from file 490 creds := processcreds.NewCredentials( 491 fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name())) 492 if _, err = creds.Get(); err != nil { 493 t.Errorf("expected %v, got %v", "no error", err) 494 } 495 if creds.IsExpired() { 496 t.Errorf("expected %v, got %v", "not expired", "expired") 497 } 498 499 // force expire creds 500 creds.Expire() 501 if !creds.IsExpired() { 502 t.Errorf("expected %v, got %v", "expired", "not expired") 503 } 504 505 // renew creds 506 if _, err = creds.Get(); err != nil { 507 t.Errorf("expected %v, got %v", "no error", err) 508 } 509 if creds.IsExpired() { 510 t.Errorf("expected %v, got %v", "not expired", "expired") 511 } 512 513} 514 515func TestProcessProviderAltConstruct(t *testing.T) { 516 restoreEnvFn := sdktesting.StashEnv() 517 defer restoreEnvFn() 518 519 // constructing with exec.Cmd instead of string 520 myCommand := exec.Command( 521 fmt.Sprintf( 522 "%s %s", 523 getOSCat(), 524 strings.Join( 525 []string{"testdata", "static.json"}, 526 string(os.PathSeparator)))) 527 creds := processcreds.NewCredentialsCommand(myCommand, func(opt *processcreds.ProcessProvider) { 528 opt.Timeout = time.Duration(1) * time.Second 529 }) 530 _, err := creds.Get() 531 if err != nil { 532 t.Errorf("expected %v, got %v", "no error", err) 533 } 534 if creds.IsExpired() { 535 t.Errorf("expected %v, got %v", "static credentials/not expired", "expired") 536 } 537} 538 539func BenchmarkProcessProvider(b *testing.B) { 540 restoreEnvFn := sdktesting.StashEnv() 541 defer restoreEnvFn() 542 543 creds := processcreds.NewCredentials( 544 fmt.Sprintf( 545 "%s %s", 546 getOSCat(), 547 strings.Join( 548 []string{"testdata", "static.json"}, 549 string(os.PathSeparator)))) 550 _, err := creds.Get() 551 if err != nil { 552 b.Fatal(err) 553 } 554 555 b.ResetTimer() 556 for i := 0; i < b.N; i++ { 557 _, err := creds.Get() 558 if err != nil { 559 b.Fatal(err) 560 } 561 } 562} 563 564func getOSCat() string { 565 if runtime.GOOS == "windows" { 566 return "type" 567 } 568 return "cat" 569} 570