1package awsauth
2
3import (
4	"context"
5	"fmt"
6
7	"github.com/aws/aws-sdk-go/aws"
8	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
9	"github.com/aws/aws-sdk-go/aws/session"
10	"github.com/aws/aws-sdk-go/service/ec2"
11	"github.com/aws/aws-sdk-go/service/iam"
12	"github.com/aws/aws-sdk-go/service/sts"
13	"github.com/hashicorp/errwrap"
14	cleanhttp "github.com/hashicorp/go-cleanhttp"
15	"github.com/hashicorp/vault/helper/awsutil"
16	"github.com/hashicorp/vault/sdk/logical"
17)
18
19// getRawClientConfig creates a aws-sdk-go config, which is used to create client
20// that can interact with AWS API. This builds credentials in the following
21// order of preference:
22//
23// * Static credentials from 'config/client'
24// * Environment variables
25// * Instance metadata role
26func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
27	credsConfig := &awsutil.CredentialsConfig{
28		Region: region,
29	}
30
31	// Read the configured secret key and access key
32	config, err := b.nonLockedClientConfigEntry(ctx, s)
33	if err != nil {
34		return nil, err
35	}
36
37	endpoint := aws.String("")
38	var maxRetries int = aws.UseServiceDefaultRetries
39	if config != nil {
40		// Override the default endpoint with the configured endpoint.
41		switch {
42		case clientType == "ec2" && config.Endpoint != "":
43			endpoint = aws.String(config.Endpoint)
44		case clientType == "iam" && config.IAMEndpoint != "":
45			endpoint = aws.String(config.IAMEndpoint)
46		case clientType == "sts" && config.STSEndpoint != "":
47			endpoint = aws.String(config.STSEndpoint)
48		}
49
50		credsConfig.AccessKey = config.AccessKey
51		credsConfig.SecretKey = config.SecretKey
52		maxRetries = config.MaxRetries
53	}
54
55	credsConfig.HTTPClient = cleanhttp.DefaultClient()
56
57	creds, err := credsConfig.GenerateCredentialChain()
58	if err != nil {
59		return nil, err
60	}
61	if creds == nil {
62		return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
63	}
64
65	// Create a config that can be used to make the API calls.
66	return &aws.Config{
67		Credentials: creds,
68		Region:      aws.String(region),
69		HTTPClient:  cleanhttp.DefaultClient(),
70		Endpoint:    endpoint,
71		MaxRetries:  aws.Int(maxRetries),
72	}, nil
73}
74
75// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials
76// It uses getRawClientConfig to obtain config for the runtime environment, and if
77// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
78// credentials. The credentials will expire after 15 minutes but will auto-refresh.
79func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
80
81	config, err := b.getRawClientConfig(ctx, s, region, clientType)
82	if err != nil {
83		return nil, err
84	}
85	if config == nil {
86		return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
87	}
88
89	stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
90	if stsConfig == nil {
91		return nil, fmt.Errorf("could not configure STS client")
92	}
93	if err != nil {
94		return nil, err
95	}
96	if stsRole != "" {
97		assumedCredentials := stscreds.NewCredentials(session.New(stsConfig), stsRole)
98		// Test that we actually have permissions to assume the role
99		if _, err = assumedCredentials.Get(); err != nil {
100			return nil, err
101		}
102		config.Credentials = assumedCredentials
103	} else {
104		if b.defaultAWSAccountID == "" {
105			client := sts.New(session.New(stsConfig))
106			if client == nil {
107				return nil, errwrap.Wrapf("could not obtain sts client: {{err}}", err)
108			}
109			inputParams := &sts.GetCallerIdentityInput{}
110			identity, err := client.GetCallerIdentity(inputParams)
111			if err != nil {
112				return nil, errwrap.Wrapf("unable to fetch current caller: {{err}}", err)
113			}
114			if identity == nil {
115				return nil, fmt.Errorf("got nil result from GetCallerIdentity")
116			}
117			b.defaultAWSAccountID = *identity.Account
118		}
119		if b.defaultAWSAccountID != accountID {
120			return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID)
121		}
122	}
123
124	return config, nil
125}
126
127// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend.
128// If the client credentials configuration is deleted or updated in the backend, all
129// the cached EC2 client objects will be flushed. Config mutex lock should be
130// acquired for write operation before calling this method.
131func (b *backend) flushCachedEC2Clients() {
132	// deleting items in map during iteration is safe
133	for region, _ := range b.EC2ClientsMap {
134		delete(b.EC2ClientsMap, region)
135	}
136}
137
138// flushCachedIAMClients deletes all the cached iam client objects from the
139// backend. If the client credentials configuration is deleted or updated in
140// the backend, all the cached IAM client objects will be flushed. Config mutex
141// lock should be acquired for write operation before calling this method.
142func (b *backend) flushCachedIAMClients() {
143	// deleting items in map during iteration is safe
144	for region, _ := range b.IAMClientsMap {
145		delete(b.IAMClientsMap, region)
146	}
147}
148
149// Gets an entry out of the user ID cache
150func (b *backend) getCachedUserId(userId string) string {
151	if userId == "" {
152		return ""
153	}
154	if entry, ok := b.iamUserIdToArnCache.Get(userId); ok {
155		b.iamUserIdToArnCache.SetDefault(userId, entry)
156		return entry.(string)
157	}
158	return ""
159}
160
161// Sets an entry in the user ID cache
162func (b *backend) setCachedUserId(userId, arn string) {
163	if userId != "" {
164		b.iamUserIdToArnCache.SetDefault(userId, arn)
165	}
166}
167
168func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
169	// Check if an STS configuration exists for the AWS account
170	sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
171	if err != nil {
172		return "", errwrap.Wrapf(fmt.Sprintf("error fetching STS config for account ID %q: {{err}}", accountID), err)
173	}
174	// An empty STS role signifies the master account
175	if sts != nil {
176		return sts.StsRole, nil
177	}
178	return "", nil
179}
180
181// clientEC2 creates a client to interact with AWS EC2 API
182func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
183	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
184	if err != nil {
185		return nil, err
186	}
187	b.configMutex.RLock()
188	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
189		defer b.configMutex.RUnlock()
190		// If the client object was already created, return it
191		return b.EC2ClientsMap[region][stsRole], nil
192	}
193
194	// Release the read lock and acquire the write lock
195	b.configMutex.RUnlock()
196	b.configMutex.Lock()
197	defer b.configMutex.Unlock()
198
199	// If the client gets created while switching the locks, return it
200	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
201		return b.EC2ClientsMap[region][stsRole], nil
202	}
203
204	// Create an AWS config object using a chain of providers
205	var awsConfig *aws.Config
206	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
207
208	if err != nil {
209		return nil, err
210	}
211
212	if awsConfig == nil {
213		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
214	}
215
216	// Create a new EC2 client object, cache it and return the same
217	client := ec2.New(session.New(awsConfig))
218	if client == nil {
219		return nil, fmt.Errorf("could not obtain ec2 client")
220	}
221	if _, ok := b.EC2ClientsMap[region]; !ok {
222		b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
223	} else {
224		b.EC2ClientsMap[region][stsRole] = client
225	}
226
227	return b.EC2ClientsMap[region][stsRole], nil
228}
229
230// clientIAM creates a client to interact with AWS IAM API
231func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
232	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
233	if err != nil {
234		return nil, err
235	}
236	b.configMutex.RLock()
237	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
238		defer b.configMutex.RUnlock()
239		// If the client object was already created, return it
240		return b.IAMClientsMap[region][stsRole], nil
241	}
242
243	// Release the read lock and acquire the write lock
244	b.configMutex.RUnlock()
245	b.configMutex.Lock()
246	defer b.configMutex.Unlock()
247
248	// If the client gets created while switching the locks, return it
249	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
250		return b.IAMClientsMap[region][stsRole], nil
251	}
252
253	// Create an AWS config object using a chain of providers
254	var awsConfig *aws.Config
255	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
256
257	if err != nil {
258		return nil, err
259	}
260
261	if awsConfig == nil {
262		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
263	}
264
265	// Create a new IAM client object, cache it and return the same
266	client := iam.New(session.New(awsConfig))
267	if client == nil {
268		return nil, fmt.Errorf("could not obtain iam client")
269	}
270	if _, ok := b.IAMClientsMap[region]; !ok {
271		b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
272	} else {
273		b.IAMClientsMap[region][stsRole] = client
274	}
275	return b.IAMClientsMap[region][stsRole], nil
276}
277