1package ssocreds
2
3import (
4	"crypto/sha1"
5	"encoding/hex"
6	"encoding/json"
7	"fmt"
8	"io/ioutil"
9	"path/filepath"
10	"strings"
11	"time"
12
13	"github.com/aws/aws-sdk-go/aws"
14	"github.com/aws/aws-sdk-go/aws/awserr"
15	"github.com/aws/aws-sdk-go/aws/client"
16	"github.com/aws/aws-sdk-go/aws/credentials"
17	"github.com/aws/aws-sdk-go/service/sso"
18	"github.com/aws/aws-sdk-go/service/sso/ssoiface"
19)
20
21// ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid.
22// To refresh the SSO session run aws sso login with the corresponding profile.
23const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken"
24
25const invalidTokenMessage = "the SSO session has expired or is invalid"
26
27func init() {
28	nowTime = time.Now
29	defaultCacheLocation = defaultCacheLocationImpl
30}
31
32var nowTime func() time.Time
33
34// ProviderName is the name of the provider used to specify the source of credentials.
35const ProviderName = "SSOProvider"
36
37var defaultCacheLocation func() string
38
39func defaultCacheLocationImpl() string {
40	return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")
41}
42
43// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
44type Provider struct {
45	credentials.Expiry
46
47	// The Client which is configured for the AWS Region where the AWS SSO user portal is located.
48	Client ssoiface.SSOAPI
49
50	// The AWS account that is assigned to the user.
51	AccountID string
52
53	// The role name that is assigned to the user.
54	RoleName string
55
56	// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
57	StartURL string
58}
59
60// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
61// for the AWS Region where the AWS SSO user portal is located.
62func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
63	return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...)
64}
65
66// NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured
67// for the AWS Region where the AWS SSO user portal is located.
68func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
69	p := &Provider{
70		Client:    client,
71		AccountID: accountID,
72		RoleName:  roleName,
73		StartURL:  startURL,
74	}
75
76	for _, fn := range optFns {
77		fn(p)
78	}
79
80	return credentials.NewCredentials(p)
81}
82
83// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
84// by exchanging the accessToken present in ~/.aws/sso/cache.
85func (p *Provider) Retrieve() (credentials.Value, error) {
86	return p.RetrieveWithContext(aws.BackgroundContext())
87}
88
89// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
90// by exchanging the accessToken present in ~/.aws/sso/cache.
91func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
92	tokenFile, err := loadTokenFile(p.StartURL)
93	if err != nil {
94		return credentials.Value{}, err
95	}
96
97	output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
98		AccessToken: &tokenFile.AccessToken,
99		AccountId:   &p.AccountID,
100		RoleName:    &p.RoleName,
101	})
102	if err != nil {
103		return credentials.Value{}, err
104	}
105
106	expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC()
107	p.SetExpiration(expireTime, 0)
108
109	return credentials.Value{
110		AccessKeyID:     aws.StringValue(output.RoleCredentials.AccessKeyId),
111		SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey),
112		SessionToken:    aws.StringValue(output.RoleCredentials.SessionToken),
113		ProviderName:    ProviderName,
114	}, nil
115}
116
117func getCacheFileName(url string) (string, error) {
118	hash := sha1.New()
119	_, err := hash.Write([]byte(url))
120	if err != nil {
121		return "", err
122	}
123	return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
124}
125
126type rfc3339 time.Time
127
128func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
129	var value string
130
131	if err := json.Unmarshal(bytes, &value); err != nil {
132		return err
133	}
134
135	parse, err := time.Parse(time.RFC3339, value)
136	if err != nil {
137		return fmt.Errorf("expected RFC3339 timestamp: %v", err)
138	}
139
140	*r = rfc3339(parse)
141
142	return nil
143}
144
145type token struct {
146	AccessToken string  `json:"accessToken"`
147	ExpiresAt   rfc3339 `json:"expiresAt"`
148	Region      string  `json:"region,omitempty"`
149	StartURL    string  `json:"startUrl,omitempty"`
150}
151
152func (t token) Expired() bool {
153	return nowTime().Round(0).After(time.Time(t.ExpiresAt))
154}
155
156func loadTokenFile(startURL string) (t token, err error) {
157	key, err := getCacheFileName(startURL)
158	if err != nil {
159		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
160	}
161
162	fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
163	if err != nil {
164		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
165	}
166
167	if err := json.Unmarshal(fileBytes, &t); err != nil {
168		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
169	}
170
171	if len(t.AccessToken) == 0 {
172		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
173	}
174
175	if t.Expired() {
176		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
177	}
178
179	return t, nil
180}
181