1package awsauth
2
3import (
4	"context"
5	"fmt"
6
7	"github.com/aws/aws-sdk-go/aws"
8	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
9	"github.com/aws/aws-sdk-go/aws/session"
10	"github.com/aws/aws-sdk-go/service/ec2"
11	"github.com/aws/aws-sdk-go/service/iam"
12	"github.com/aws/aws-sdk-go/service/sts"
13	"github.com/hashicorp/errwrap"
14	cleanhttp "github.com/hashicorp/go-cleanhttp"
15	"github.com/hashicorp/vault/sdk/helper/awsutil"
16	"github.com/hashicorp/vault/sdk/logical"
17)
18
19// getRawClientConfig creates a aws-sdk-go config, which is used to create client
20// that can interact with AWS API. This builds credentials in the following
21// order of preference:
22//
23// * Static credentials from 'config/client'
24// * Environment variables
25// * Instance metadata role
26func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
27	credsConfig := &awsutil.CredentialsConfig{
28		Region: region,
29	}
30
31	// Read the configured secret key and access key
32	config, err := b.nonLockedClientConfigEntry(ctx, s)
33	if err != nil {
34		return nil, err
35	}
36
37	endpoint := aws.String("")
38	var maxRetries int = aws.UseServiceDefaultRetries
39	if config != nil {
40		// Override the defaults with configured values.
41		switch {
42		case clientType == "ec2" && config.Endpoint != "":
43			endpoint = aws.String(config.Endpoint)
44		case clientType == "iam" && config.IAMEndpoint != "":
45			endpoint = aws.String(config.IAMEndpoint)
46		case clientType == "sts":
47			if config.STSEndpoint != "" {
48				endpoint = aws.String(config.STSEndpoint)
49			}
50			if config.STSRegion != "" {
51				region = config.STSRegion
52			}
53		}
54
55		credsConfig.AccessKey = config.AccessKey
56		credsConfig.SecretKey = config.SecretKey
57		maxRetries = config.MaxRetries
58	}
59
60	credsConfig.HTTPClient = cleanhttp.DefaultClient()
61
62	creds, err := credsConfig.GenerateCredentialChain()
63	if err != nil {
64		return nil, err
65	}
66	if creds == nil {
67		return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
68	}
69
70	// Create a config that can be used to make the API calls.
71	return &aws.Config{
72		Credentials: creds,
73		Region:      aws.String(region),
74		HTTPClient:  cleanhttp.DefaultClient(),
75		Endpoint:    endpoint,
76		MaxRetries:  aws.Int(maxRetries),
77	}, nil
78}
79
80// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials
81// It uses getRawClientConfig to obtain config for the runtime environment, and if
82// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
83// credentials. The credentials will expire after 15 minutes but will auto-refresh.
84func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
85
86	config, err := b.getRawClientConfig(ctx, s, region, clientType)
87	if err != nil {
88		return nil, err
89	}
90	if config == nil {
91		return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
92	}
93
94	stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
95	if stsConfig == nil {
96		return nil, fmt.Errorf("could not configure STS client")
97	}
98	if err != nil {
99		return nil, err
100	}
101	if stsRole != "" {
102		sess, err := session.NewSession(stsConfig)
103		if err != nil {
104			return nil, err
105		}
106		assumedCredentials := stscreds.NewCredentials(sess, stsRole)
107		// Test that we actually have permissions to assume the role
108		if _, err = assumedCredentials.Get(); err != nil {
109			return nil, err
110		}
111		config.Credentials = assumedCredentials
112	} else {
113		if b.defaultAWSAccountID == "" {
114			sess, err := session.NewSession(stsConfig)
115			if err != nil {
116				return nil, err
117			}
118			client := sts.New(sess)
119			if client == nil {
120				return nil, errwrap.Wrapf("could not obtain sts client: {{err}}", err)
121			}
122			inputParams := &sts.GetCallerIdentityInput{}
123			identity, err := client.GetCallerIdentity(inputParams)
124			if err != nil {
125				return nil, errwrap.Wrapf("unable to fetch current caller: {{err}}", err)
126			}
127			if identity == nil {
128				return nil, fmt.Errorf("got nil result from GetCallerIdentity")
129			}
130			b.defaultAWSAccountID = *identity.Account
131		}
132		if b.defaultAWSAccountID != accountID {
133			return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID)
134		}
135	}
136
137	return config, nil
138}
139
140// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend.
141// If the client credentials configuration is deleted or updated in the backend, all
142// the cached EC2 client objects will be flushed. Config mutex lock should be
143// acquired for write operation before calling this method.
144func (b *backend) flushCachedEC2Clients() {
145	// deleting items in map during iteration is safe
146	for region, _ := range b.EC2ClientsMap {
147		delete(b.EC2ClientsMap, region)
148	}
149}
150
151// flushCachedIAMClients deletes all the cached iam client objects from the
152// backend. If the client credentials configuration is deleted or updated in
153// the backend, all the cached IAM client objects will be flushed. Config mutex
154// lock should be acquired for write operation before calling this method.
155func (b *backend) flushCachedIAMClients() {
156	// deleting items in map during iteration is safe
157	for region, _ := range b.IAMClientsMap {
158		delete(b.IAMClientsMap, region)
159	}
160}
161
162// Gets an entry out of the user ID cache
163func (b *backend) getCachedUserId(userId string) string {
164	if userId == "" {
165		return ""
166	}
167	if entry, ok := b.iamUserIdToArnCache.Get(userId); ok {
168		b.iamUserIdToArnCache.SetDefault(userId, entry)
169		return entry.(string)
170	}
171	return ""
172}
173
174// Sets an entry in the user ID cache
175func (b *backend) setCachedUserId(userId, arn string) {
176	if userId != "" {
177		b.iamUserIdToArnCache.SetDefault(userId, arn)
178	}
179}
180
181func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
182	// Check if an STS configuration exists for the AWS account
183	sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
184	if err != nil {
185		return "", errwrap.Wrapf(fmt.Sprintf("error fetching STS config for account ID %q: {{err}}", accountID), err)
186	}
187	// An empty STS role signifies the master account
188	if sts != nil {
189		return sts.StsRole, nil
190	}
191	return "", nil
192}
193
194// clientEC2 creates a client to interact with AWS EC2 API
195func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
196	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
197	if err != nil {
198		return nil, err
199	}
200	b.configMutex.RLock()
201	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
202		defer b.configMutex.RUnlock()
203		// If the client object was already created, return it
204		return b.EC2ClientsMap[region][stsRole], nil
205	}
206
207	// Release the read lock and acquire the write lock
208	b.configMutex.RUnlock()
209	b.configMutex.Lock()
210	defer b.configMutex.Unlock()
211
212	// If the client gets created while switching the locks, return it
213	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
214		return b.EC2ClientsMap[region][stsRole], nil
215	}
216
217	// Create an AWS config object using a chain of providers
218	var awsConfig *aws.Config
219	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
220
221	if err != nil {
222		return nil, err
223	}
224
225	if awsConfig == nil {
226		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
227	}
228
229	// Create a new EC2 client object, cache it and return the same
230	sess, err := session.NewSession(awsConfig)
231	if err != nil {
232		return nil, err
233	}
234	client := ec2.New(sess)
235	if client == nil {
236		return nil, fmt.Errorf("could not obtain ec2 client")
237	}
238	if _, ok := b.EC2ClientsMap[region]; !ok {
239		b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
240	} else {
241		b.EC2ClientsMap[region][stsRole] = client
242	}
243
244	return b.EC2ClientsMap[region][stsRole], nil
245}
246
247// clientIAM creates a client to interact with AWS IAM API
248func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
249	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
250	if err != nil {
251		return nil, err
252	}
253	b.configMutex.RLock()
254	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
255		defer b.configMutex.RUnlock()
256		// If the client object was already created, return it
257		return b.IAMClientsMap[region][stsRole], nil
258	}
259
260	// Release the read lock and acquire the write lock
261	b.configMutex.RUnlock()
262	b.configMutex.Lock()
263	defer b.configMutex.Unlock()
264
265	// If the client gets created while switching the locks, return it
266	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
267		return b.IAMClientsMap[region][stsRole], nil
268	}
269
270	// Create an AWS config object using a chain of providers
271	var awsConfig *aws.Config
272	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
273
274	if err != nil {
275		return nil, err
276	}
277
278	if awsConfig == nil {
279		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
280	}
281
282	// Create a new IAM client object, cache it and return the same
283	sess, err := session.NewSession(awsConfig)
284	if err != nil {
285		return nil, err
286	}
287	client := iam.New(sess)
288	if client == nil {
289		return nil, fmt.Errorf("could not obtain iam client")
290	}
291	if _, ok := b.IAMClientsMap[region]; !ok {
292		b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
293	} else {
294		b.IAMClientsMap[region][stsRole] = client
295	}
296	return b.IAMClientsMap[region][stsRole], nil
297}
298