1package stscreds
2
3import (
4	"fmt"
5	"io/ioutil"
6	"strconv"
7	"time"
8
9	"github.com/aws/aws-sdk-go/aws"
10	"github.com/aws/aws-sdk-go/aws/awserr"
11	"github.com/aws/aws-sdk-go/aws/client"
12	"github.com/aws/aws-sdk-go/aws/credentials"
13	"github.com/aws/aws-sdk-go/service/sts"
14	"github.com/aws/aws-sdk-go/service/sts/stsiface"
15)
16
17const (
18	// ErrCodeWebIdentity will be used as an error code when constructing
19	// a new error to be returned during session creation or retrieval.
20	ErrCodeWebIdentity = "WebIdentityErr"
21
22	// WebIdentityProviderName is the web identity provider name
23	WebIdentityProviderName = "WebIdentityCredentials"
24)
25
26// now is used to return a time.Time object representing
27// the current time. This can be used to easily test and
28// compare test values.
29var now = time.Now
30
31// TokenFetcher shuold return WebIdentity token bytes or an error
32type TokenFetcher interface {
33	FetchToken(credentials.Context) ([]byte, error)
34}
35
36// FetchTokenPath is a path to a WebIdentity token file
37type FetchTokenPath string
38
39// FetchToken returns a token by reading from the filesystem
40func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
41	data, err := ioutil.ReadFile(string(f))
42	if err != nil {
43		errMsg := fmt.Sprintf("unable to read file at %s", f)
44		return nil, awserr.New(ErrCodeWebIdentity, errMsg, err)
45	}
46	return data, nil
47}
48
49// WebIdentityRoleProvider is used to retrieve credentials using
50// an OIDC token.
51type WebIdentityRoleProvider struct {
52	credentials.Expiry
53	PolicyArns []*sts.PolicyDescriptorType
54
55	// Duration the STS credentials will be valid for. Truncated to seconds.
56	// If unset, the assumed role will use AssumeRoleWithWebIdentity's default
57	// expiry duration. See
58	// https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity
59	// for more information.
60	Duration time.Duration
61
62	// The amount of time the credentials will be refreshed before they expire.
63	// This is useful refresh credentials before they expire to reduce risk of
64	// using credentials as they expire. If unset, will default to no expiry
65	// window.
66	ExpiryWindow time.Duration
67
68	client stsiface.STSAPI
69
70	tokenFetcher    TokenFetcher
71	roleARN         string
72	roleSessionName string
73}
74
75// NewWebIdentityCredentials will return a new set of credentials with a given
76// configuration, role arn, and token file path.
77func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
78	svc := sts.New(c)
79	p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
80	return credentials.NewCredentials(p)
81}
82
83// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
84// provided stsiface.STSAPI
85func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
86	return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path))
87}
88
89// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
90// provided stsiface.STSAPI and a TokenFetcher
91func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
92	return &WebIdentityRoleProvider{
93		client:          svc,
94		tokenFetcher:    tokenFetcher,
95		roleARN:         roleARN,
96		roleSessionName: roleSessionName,
97	}
98}
99
100// Retrieve will attempt to assume a role from a token which is located at
101// 'WebIdentityTokenFilePath' specified destination and if that is empty an
102// error will be returned.
103func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
104	return p.RetrieveWithContext(aws.BackgroundContext())
105}
106
107// RetrieveWithContext will attempt to assume a role from a token which is located at
108// 'WebIdentityTokenFilePath' specified destination and if that is empty an
109// error will be returned.
110func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
111	b, err := p.tokenFetcher.FetchToken(ctx)
112	if err != nil {
113		return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed fetching WebIdentity token: ", err)
114	}
115
116	sessionName := p.roleSessionName
117	if len(sessionName) == 0 {
118		// session name is used to uniquely identify a session. This simply
119		// uses unix time in nanoseconds to uniquely identify sessions.
120		sessionName = strconv.FormatInt(now().UnixNano(), 10)
121	}
122
123	var duration *int64
124	if p.Duration != 0 {
125		duration = aws.Int64(int64(p.Duration / time.Second))
126	}
127
128	req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
129		PolicyArns:       p.PolicyArns,
130		RoleArn:          &p.roleARN,
131		RoleSessionName:  &sessionName,
132		WebIdentityToken: aws.String(string(b)),
133		DurationSeconds:  duration,
134	})
135
136	req.SetContext(ctx)
137
138	// InvalidIdentityToken error is a temporary error that can occur
139	// when assuming an Role with a JWT web identity token.
140	req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
141	if err := req.Send(); err != nil {
142		return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
143	}
144
145	p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
146
147	value := credentials.Value{
148		AccessKeyID:     aws.StringValue(resp.Credentials.AccessKeyId),
149		SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
150		SessionToken:    aws.StringValue(resp.Credentials.SessionToken),
151		ProviderName:    WebIdentityProviderName,
152	}
153	return value, nil
154}
155