1package stscreds 2 3import ( 4 "fmt" 5 "io/ioutil" 6 "strconv" 7 "time" 8 9 "github.com/aws/aws-sdk-go/aws" 10 "github.com/aws/aws-sdk-go/aws/awserr" 11 "github.com/aws/aws-sdk-go/aws/client" 12 "github.com/aws/aws-sdk-go/aws/credentials" 13 "github.com/aws/aws-sdk-go/service/sts" 14 "github.com/aws/aws-sdk-go/service/sts/stsiface" 15) 16 17const ( 18 // ErrCodeWebIdentity will be used as an error code when constructing 19 // a new error to be returned during session creation or retrieval. 20 ErrCodeWebIdentity = "WebIdentityErr" 21 22 // WebIdentityProviderName is the web identity provider name 23 WebIdentityProviderName = "WebIdentityCredentials" 24) 25 26// now is used to return a time.Time object representing 27// the current time. This can be used to easily test and 28// compare test values. 29var now = time.Now 30 31// TokenFetcher shuold return WebIdentity token bytes or an error 32type TokenFetcher interface { 33 FetchToken(credentials.Context) ([]byte, error) 34} 35 36// FetchTokenPath is a path to a WebIdentity token file 37type FetchTokenPath string 38 39// FetchToken returns a token by reading from the filesystem 40func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) { 41 data, err := ioutil.ReadFile(string(f)) 42 if err != nil { 43 errMsg := fmt.Sprintf("unable to read file at %s", f) 44 return nil, awserr.New(ErrCodeWebIdentity, errMsg, err) 45 } 46 return data, nil 47} 48 49// WebIdentityRoleProvider is used to retrieve credentials using 50// an OIDC token. 51type WebIdentityRoleProvider struct { 52 credentials.Expiry 53 PolicyArns []*sts.PolicyDescriptorType 54 55 // Duration the STS credentials will be valid for. Truncated to seconds. 56 // If unset, the assumed role will use AssumeRoleWithWebIdentity's default 57 // expiry duration. See 58 // https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity 59 // for more information. 60 Duration time.Duration 61 62 // The amount of time the credentials will be refreshed before they expire. 63 // This is useful refresh credentials before they expire to reduce risk of 64 // using credentials as they expire. If unset, will default to no expiry 65 // window. 66 ExpiryWindow time.Duration 67 68 client stsiface.STSAPI 69 70 tokenFetcher TokenFetcher 71 roleARN string 72 roleSessionName string 73} 74 75// NewWebIdentityCredentials will return a new set of credentials with a given 76// configuration, role arn, and token file path. 77func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials { 78 svc := sts.New(c) 79 p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) 80 return credentials.NewCredentials(p) 81} 82 83// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the 84// provided stsiface.STSAPI 85func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider { 86 return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path)) 87} 88 89// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the 90// provided stsiface.STSAPI and a TokenFetcher 91func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider { 92 return &WebIdentityRoleProvider{ 93 client: svc, 94 tokenFetcher: tokenFetcher, 95 roleARN: roleARN, 96 roleSessionName: roleSessionName, 97 } 98} 99 100// Retrieve will attempt to assume a role from a token which is located at 101// 'WebIdentityTokenFilePath' specified destination and if that is empty an 102// error will be returned. 103func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { 104 return p.RetrieveWithContext(aws.BackgroundContext()) 105} 106 107// RetrieveWithContext will attempt to assume a role from a token which is located at 108// 'WebIdentityTokenFilePath' specified destination and if that is empty an 109// error will be returned. 110func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { 111 b, err := p.tokenFetcher.FetchToken(ctx) 112 if err != nil { 113 return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed fetching WebIdentity token: ", err) 114 } 115 116 sessionName := p.roleSessionName 117 if len(sessionName) == 0 { 118 // session name is used to uniquely identify a session. This simply 119 // uses unix time in nanoseconds to uniquely identify sessions. 120 sessionName = strconv.FormatInt(now().UnixNano(), 10) 121 } 122 123 var duration *int64 124 if p.Duration != 0 { 125 duration = aws.Int64(int64(p.Duration / time.Second)) 126 } 127 128 req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{ 129 PolicyArns: p.PolicyArns, 130 RoleArn: &p.roleARN, 131 RoleSessionName: &sessionName, 132 WebIdentityToken: aws.String(string(b)), 133 DurationSeconds: duration, 134 }) 135 136 req.SetContext(ctx) 137 138 // InvalidIdentityToken error is a temporary error that can occur 139 // when assuming an Role with a JWT web identity token. 140 req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException) 141 if err := req.Send(); err != nil { 142 return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err) 143 } 144 145 p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow) 146 147 value := credentials.Value{ 148 AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId), 149 SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey), 150 SessionToken: aws.StringValue(resp.Credentials.SessionToken), 151 ProviderName: WebIdentityProviderName, 152 } 153 return value, nil 154} 155