1package awsauth
2
3import (
4	"context"
5	"fmt"
6
7	"github.com/aws/aws-sdk-go/aws"
8	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
9	"github.com/aws/aws-sdk-go/aws/session"
10	"github.com/aws/aws-sdk-go/service/ec2"
11	"github.com/aws/aws-sdk-go/service/iam"
12	"github.com/aws/aws-sdk-go/service/sts"
13	cleanhttp "github.com/hashicorp/go-cleanhttp"
14	"github.com/hashicorp/vault/sdk/helper/awsutil"
15	"github.com/hashicorp/vault/sdk/logical"
16)
17
18// getRawClientConfig creates a aws-sdk-go config, which is used to create client
19// that can interact with AWS API. This builds credentials in the following
20// order of preference:
21//
22// * Static credentials from 'config/client'
23// * Environment variables
24// * Instance metadata role
25func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
26	credsConfig := &awsutil.CredentialsConfig{
27		Region: region,
28		Logger: b.Logger(),
29	}
30
31	// Read the configured secret key and access key
32	config, err := b.nonLockedClientConfigEntry(ctx, s)
33	if err != nil {
34		return nil, err
35	}
36
37	endpoint := aws.String("")
38	var maxRetries int = aws.UseServiceDefaultRetries
39	if config != nil {
40		// Override the defaults with configured values.
41		switch {
42		case clientType == "ec2" && config.Endpoint != "":
43			endpoint = aws.String(config.Endpoint)
44		case clientType == "iam" && config.IAMEndpoint != "":
45			endpoint = aws.String(config.IAMEndpoint)
46		case clientType == "sts":
47			if config.STSEndpoint != "" {
48				endpoint = aws.String(config.STSEndpoint)
49			}
50			if config.STSRegion != "" {
51				region = config.STSRegion
52			}
53		}
54
55		credsConfig.AccessKey = config.AccessKey
56		credsConfig.SecretKey = config.SecretKey
57		maxRetries = config.MaxRetries
58	}
59
60	credsConfig.HTTPClient = cleanhttp.DefaultClient()
61
62	creds, err := credsConfig.GenerateCredentialChain()
63	if err != nil {
64		return nil, err
65	}
66	if creds == nil {
67		return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
68	}
69
70	// Create a config that can be used to make the API calls.
71	return &aws.Config{
72		Credentials: creds,
73		Region:      aws.String(region),
74		HTTPClient:  cleanhttp.DefaultClient(),
75		Endpoint:    endpoint,
76		MaxRetries:  aws.Int(maxRetries),
77	}, nil
78}
79
80// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials
81// It uses getRawClientConfig to obtain config for the runtime environment, and if
82// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
83// credentials. The credentials will expire after 15 minutes but will auto-refresh.
84func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
85	config, err := b.getRawClientConfig(ctx, s, region, clientType)
86	if err != nil {
87		return nil, err
88	}
89	if config == nil {
90		return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
91	}
92
93	stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
94	if stsConfig == nil {
95		return nil, fmt.Errorf("could not configure STS client")
96	}
97	if err != nil {
98		return nil, err
99	}
100	if stsRole != "" {
101		sess, err := session.NewSession(stsConfig)
102		if err != nil {
103			return nil, err
104		}
105		assumedCredentials := stscreds.NewCredentials(sess, stsRole)
106		// Test that we actually have permissions to assume the role
107		if _, err = assumedCredentials.Get(); err != nil {
108			return nil, err
109		}
110		config.Credentials = assumedCredentials
111	} else {
112		if b.defaultAWSAccountID == "" {
113			sess, err := session.NewSession(stsConfig)
114			if err != nil {
115				return nil, err
116			}
117			client := sts.New(sess)
118			if client == nil {
119				return nil, fmt.Errorf("could not obtain sts client: %w", err)
120			}
121			inputParams := &sts.GetCallerIdentityInput{}
122			identity, err := client.GetCallerIdentity(inputParams)
123			if err != nil {
124				return nil, fmt.Errorf("unable to fetch current caller: %w", err)
125			}
126			if identity == nil {
127				return nil, fmt.Errorf("got nil result from GetCallerIdentity")
128			}
129			b.defaultAWSAccountID = *identity.Account
130		}
131		if b.defaultAWSAccountID != accountID {
132			return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID)
133		}
134	}
135
136	return config, nil
137}
138
139// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend.
140// If the client credentials configuration is deleted or updated in the backend, all
141// the cached EC2 client objects will be flushed. Config mutex lock should be
142// acquired for write operation before calling this method.
143func (b *backend) flushCachedEC2Clients() {
144	// deleting items in map during iteration is safe
145	for region := range b.EC2ClientsMap {
146		delete(b.EC2ClientsMap, region)
147	}
148}
149
150// flushCachedIAMClients deletes all the cached iam client objects from the
151// backend. If the client credentials configuration is deleted or updated in
152// the backend, all the cached IAM client objects will be flushed. Config mutex
153// lock should be acquired for write operation before calling this method.
154func (b *backend) flushCachedIAMClients() {
155	// deleting items in map during iteration is safe
156	for region := range b.IAMClientsMap {
157		delete(b.IAMClientsMap, region)
158	}
159}
160
161// Gets an entry out of the user ID cache
162func (b *backend) getCachedUserId(userId string) string {
163	if userId == "" {
164		return ""
165	}
166	if entry, ok := b.iamUserIdToArnCache.Get(userId); ok {
167		b.iamUserIdToArnCache.SetDefault(userId, entry)
168		return entry.(string)
169	}
170	return ""
171}
172
173// Sets an entry in the user ID cache
174func (b *backend) setCachedUserId(userId, arn string) {
175	if userId != "" {
176		b.iamUserIdToArnCache.SetDefault(userId, arn)
177	}
178}
179
180func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
181	// Check if an STS configuration exists for the AWS account
182	sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
183	if err != nil {
184		return "", fmt.Errorf("error fetching STS config for account ID %q: %w", accountID, err)
185	}
186	// An empty STS role signifies the master account
187	if sts != nil {
188		return sts.StsRole, nil
189	}
190	return "", nil
191}
192
193// clientEC2 creates a client to interact with AWS EC2 API
194func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
195	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
196	if err != nil {
197		return nil, err
198	}
199	b.configMutex.RLock()
200	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
201		defer b.configMutex.RUnlock()
202		// If the client object was already created, return it
203		return b.EC2ClientsMap[region][stsRole], nil
204	}
205
206	// Release the read lock and acquire the write lock
207	b.configMutex.RUnlock()
208	b.configMutex.Lock()
209	defer b.configMutex.Unlock()
210
211	// If the client gets created while switching the locks, return it
212	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
213		return b.EC2ClientsMap[region][stsRole], nil
214	}
215
216	// Create an AWS config object using a chain of providers
217	var awsConfig *aws.Config
218	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
219
220	if err != nil {
221		return nil, err
222	}
223
224	if awsConfig == nil {
225		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
226	}
227
228	// Create a new EC2 client object, cache it and return the same
229	sess, err := session.NewSession(awsConfig)
230	if err != nil {
231		return nil, err
232	}
233	client := ec2.New(sess)
234	if client == nil {
235		return nil, fmt.Errorf("could not obtain ec2 client")
236	}
237	if _, ok := b.EC2ClientsMap[region]; !ok {
238		b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
239	} else {
240		b.EC2ClientsMap[region][stsRole] = client
241	}
242
243	return b.EC2ClientsMap[region][stsRole], nil
244}
245
246// clientIAM creates a client to interact with AWS IAM API
247func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
248	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
249	if err != nil {
250		return nil, err
251	}
252	if stsRole == "" {
253		b.Logger().Debug(fmt.Sprintf("no stsRole found for %s", accountID))
254	} else {
255		b.Logger().Debug(fmt.Sprintf("found stsRole %s for account %s", stsRole, accountID))
256	}
257	b.configMutex.RLock()
258	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
259		defer b.configMutex.RUnlock()
260		// If the client object was already created, return it
261		b.Logger().Debug(fmt.Sprintf("returning cached client for region %s and stsRole %s", region, stsRole))
262		return b.IAMClientsMap[region][stsRole], nil
263	}
264	b.Logger().Debug(fmt.Sprintf("no cached client for region %s and stsRole %s", region, stsRole))
265
266	// Release the read lock and acquire the write lock
267	b.configMutex.RUnlock()
268	b.configMutex.Lock()
269	defer b.configMutex.Unlock()
270
271	// If the client gets created while switching the locks, return it
272	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
273		return b.IAMClientsMap[region][stsRole], nil
274	}
275
276	// Create an AWS config object using a chain of providers
277	var awsConfig *aws.Config
278	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
279
280	if err != nil {
281		return nil, err
282	}
283
284	if awsConfig == nil {
285		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
286	}
287
288	// Create a new IAM client object, cache it and return the same
289	sess, err := session.NewSession(awsConfig)
290	if err != nil {
291		return nil, err
292	}
293	client := iam.New(sess)
294	if client == nil {
295		return nil, fmt.Errorf("could not obtain iam client")
296	}
297	if _, ok := b.IAMClientsMap[region]; !ok {
298		b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
299	} else {
300		b.IAMClientsMap[region][stsRole] = client
301	}
302	return b.IAMClientsMap[region][stsRole], nil
303}
304