1package ssocreds 2 3import ( 4 "crypto/sha1" 5 "encoding/hex" 6 "encoding/json" 7 "fmt" 8 "io/ioutil" 9 "path/filepath" 10 "strings" 11 "time" 12 13 "github.com/aws/aws-sdk-go/aws" 14 "github.com/aws/aws-sdk-go/aws/awserr" 15 "github.com/aws/aws-sdk-go/aws/client" 16 "github.com/aws/aws-sdk-go/aws/credentials" 17 "github.com/aws/aws-sdk-go/service/sso" 18 "github.com/aws/aws-sdk-go/service/sso/ssoiface" 19) 20 21// ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid. 22// To refresh the SSO session run aws sso login with the corresponding profile. 23const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken" 24 25const invalidTokenMessage = "the SSO session has expired or is invalid" 26 27func init() { 28 nowTime = time.Now 29 defaultCacheLocation = defaultCacheLocationImpl 30} 31 32var nowTime func() time.Time 33 34// ProviderName is the name of the provider used to specify the source of credentials. 35const ProviderName = "SSOProvider" 36 37var defaultCacheLocation func() string 38 39func defaultCacheLocationImpl() string { 40 return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") 41} 42 43// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token. 44type Provider struct { 45 credentials.Expiry 46 47 // The Client which is configured for the AWS Region where the AWS SSO user portal is located. 48 Client ssoiface.SSOAPI 49 50 // The AWS account that is assigned to the user. 51 AccountID string 52 53 // The role name that is assigned to the user. 54 RoleName string 55 56 // The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal. 57 StartURL string 58} 59 60// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured 61// for the AWS Region where the AWS SSO user portal is located. 62func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials { 63 return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...) 64} 65 66// NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured 67// for the AWS Region where the AWS SSO user portal is located. 68func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials { 69 p := &Provider{ 70 Client: client, 71 AccountID: accountID, 72 RoleName: roleName, 73 StartURL: startURL, 74 } 75 76 for _, fn := range optFns { 77 fn(p) 78 } 79 80 return credentials.NewCredentials(p) 81} 82 83// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal 84// by exchanging the accessToken present in ~/.aws/sso/cache. 85func (p *Provider) Retrieve() (credentials.Value, error) { 86 return p.RetrieveWithContext(aws.BackgroundContext()) 87} 88 89// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal 90// by exchanging the accessToken present in ~/.aws/sso/cache. 91func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { 92 tokenFile, err := loadTokenFile(p.StartURL) 93 if err != nil { 94 return credentials.Value{}, err 95 } 96 97 output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{ 98 AccessToken: &tokenFile.AccessToken, 99 AccountId: &p.AccountID, 100 RoleName: &p.RoleName, 101 }) 102 if err != nil { 103 return credentials.Value{}, err 104 } 105 106 expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC() 107 p.SetExpiration(expireTime, 0) 108 109 return credentials.Value{ 110 AccessKeyID: aws.StringValue(output.RoleCredentials.AccessKeyId), 111 SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey), 112 SessionToken: aws.StringValue(output.RoleCredentials.SessionToken), 113 ProviderName: ProviderName, 114 }, nil 115} 116 117func getCacheFileName(url string) (string, error) { 118 hash := sha1.New() 119 _, err := hash.Write([]byte(url)) 120 if err != nil { 121 return "", err 122 } 123 return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil 124} 125 126type rfc3339 time.Time 127 128func (r *rfc3339) UnmarshalJSON(bytes []byte) error { 129 var value string 130 131 if err := json.Unmarshal(bytes, &value); err != nil { 132 return err 133 } 134 135 parse, err := time.Parse(time.RFC3339, value) 136 if err != nil { 137 return fmt.Errorf("expected RFC3339 timestamp: %v", err) 138 } 139 140 *r = rfc3339(parse) 141 142 return nil 143} 144 145type token struct { 146 AccessToken string `json:"accessToken"` 147 ExpiresAt rfc3339 `json:"expiresAt"` 148 Region string `json:"region,omitempty"` 149 StartURL string `json:"startUrl,omitempty"` 150} 151 152func (t token) Expired() bool { 153 return nowTime().Round(0).After(time.Time(t.ExpiresAt)) 154} 155 156func loadTokenFile(startURL string) (t token, err error) { 157 key, err := getCacheFileName(startURL) 158 if err != nil { 159 return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) 160 } 161 162 fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key)) 163 if err != nil { 164 return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) 165 } 166 167 if err := json.Unmarshal(fileBytes, &t); err != nil { 168 return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) 169 } 170 171 if len(t.AccessToken) == 0 { 172 return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil) 173 } 174 175 if t.Expired() { 176 return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil) 177 } 178 179 return t, nil 180} 181