1package ssocreds
2
3import (
4	"context"
5	"fmt"
6	"testing"
7	"time"
8
9	"github.com/aws/aws-sdk-go-v2/aws"
10	"github.com/aws/aws-sdk-go-v2/internal/sdk"
11	"github.com/aws/aws-sdk-go-v2/service/sso"
12	"github.com/aws/aws-sdk-go-v2/service/sso/types"
13	"github.com/google/go-cmp/cmp"
14)
15
16type mockClient struct {
17	t *testing.T
18
19	Output *sso.GetRoleCredentialsOutput
20	Err    error
21
22	ExpectedAccountID    string
23	ExpectedAccessToken  string
24	ExpectedRoleName     string
25
26	Response func(mockClient) (*sso.GetRoleCredentialsOutput, error)
27}
28
29func (m mockClient) GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(options *sso.Options)) (out *sso.GetRoleCredentialsOutput, err error) {
30	m.t.Helper()
31
32	if len(m.ExpectedAccountID) > 0 {
33		if diff := cmp.Diff(m.ExpectedAccountID, aws.ToString(params.AccountId)); len(diff) > 0 {
34			m.t.Error(diff)
35		}
36	}
37
38	if len(m.ExpectedAccessToken) > 0 {
39		if diff := cmp.Diff(m.ExpectedAccessToken, aws.ToString(params.AccessToken)); len(diff) > 0 {
40			m.t.Error(diff)
41		}
42	}
43
44	if len(m.ExpectedRoleName) > 0 {
45		if diff := cmp.Diff(m.ExpectedRoleName, aws.ToString(params.RoleName)); len(diff) > 0 {
46			m.t.Error(diff)
47		}
48	}
49
50	if m.Response == nil {
51		return out, err
52	}
53	return m.Response(m)
54}
55
56func swapCacheLocation(dir string) func() {
57	original := defaultCacheLocation
58	defaultCacheLocation = func() string {
59		return dir
60	}
61	return func() {
62		defaultCacheLocation = original
63	}
64}
65
66func TestProvider(t *testing.T) {
67	restoreCache := swapCacheLocation("testdata")
68	defer restoreCache()
69
70	restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC))
71	defer restoreTime()
72
73	cases := map[string]struct {
74		Client    mockClient
75		AccountID string
76		Region    string
77		RoleName  string
78		StartURL  string
79		Options   []func(*Options)
80
81		ExpectedErr         bool
82		ExpectedCredentials aws.Credentials
83	}{
84		"missing required parameter values": {
85			StartURL:    "https://invalid-required",
86			ExpectedErr: true,
87		},
88		"valid required parameter values": {
89			Client: mockClient{
90				ExpectedAccountID:    "012345678901",
91				ExpectedRoleName:     "TestRole",
92				ExpectedAccessToken:  "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
93				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
94					return &sso.GetRoleCredentialsOutput{
95						RoleCredentials: &types.RoleCredentials{
96							AccessKeyId:     aws.String("AccessKey"),
97							SecretAccessKey: aws.String("SecretKey"),
98							SessionToken:    aws.String("SessionToken"),
99							Expiration:      1611177743123,
100						},
101					}, nil
102				},
103			},
104			AccountID: "012345678901",
105			Region:    "us-west-2",
106			RoleName:  "TestRole",
107			StartURL:  "https://valid-required-only",
108			ExpectedCredentials: aws.Credentials{
109				AccessKeyID:     "AccessKey",
110				SecretAccessKey: "SecretKey",
111				SessionToken:    "SessionToken",
112				CanExpire:       true,
113				Expires:         time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
114				Source:          ProviderName,
115			},
116		},
117		"expired access token": {
118			StartURL:    "https://expired",
119			ExpectedErr: true,
120		},
121		"api error": {
122			Client: mockClient{
123				ExpectedAccountID:    "012345678901",
124				ExpectedRoleName:     "TestRole",
125				ExpectedAccessToken:  "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
126				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
127					return nil, fmt.Errorf("api error")
128				},
129			},
130			AccountID:   "012345678901",
131			Region:      "us-west-2",
132			RoleName:    "TestRole",
133			StartURL:    "https://valid-required-only",
134			ExpectedErr: true,
135		},
136	}
137
138	for name, tt := range cases {
139		t.Run(name, func(t *testing.T) {
140			tt.Client.t = t
141
142			provider := New(tt.Client, tt.AccountID, tt.RoleName, tt.StartURL, tt.Options...)
143
144			credentials, err := provider.Retrieve(context.Background())
145			if (err != nil) != tt.ExpectedErr {
146				t.Errorf("expect error: %v", tt.ExpectedErr)
147			}
148
149			if diff := cmp.Diff(tt.ExpectedCredentials, credentials); len(diff) > 0 {
150				t.Errorf(diff)
151			}
152		})
153	}
154}
155