1package awsauth 2 3import ( 4 "context" 5 "fmt" 6 7 "github.com/aws/aws-sdk-go/aws" 8 "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 9 "github.com/aws/aws-sdk-go/aws/session" 10 "github.com/aws/aws-sdk-go/service/ec2" 11 "github.com/aws/aws-sdk-go/service/iam" 12 "github.com/aws/aws-sdk-go/service/sts" 13 "github.com/hashicorp/errwrap" 14 cleanhttp "github.com/hashicorp/go-cleanhttp" 15 "github.com/hashicorp/vault/helper/awsutil" 16 "github.com/hashicorp/vault/sdk/logical" 17) 18 19// getRawClientConfig creates a aws-sdk-go config, which is used to create client 20// that can interact with AWS API. This builds credentials in the following 21// order of preference: 22// 23// * Static credentials from 'config/client' 24// * Environment variables 25// * Instance metadata role 26func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) { 27 credsConfig := &awsutil.CredentialsConfig{ 28 Region: region, 29 } 30 31 // Read the configured secret key and access key 32 config, err := b.nonLockedClientConfigEntry(ctx, s) 33 if err != nil { 34 return nil, err 35 } 36 37 endpoint := aws.String("") 38 var maxRetries int = aws.UseServiceDefaultRetries 39 if config != nil { 40 // Override the default endpoint with the configured endpoint. 41 switch { 42 case clientType == "ec2" && config.Endpoint != "": 43 endpoint = aws.String(config.Endpoint) 44 case clientType == "iam" && config.IAMEndpoint != "": 45 endpoint = aws.String(config.IAMEndpoint) 46 case clientType == "sts" && config.STSEndpoint != "": 47 endpoint = aws.String(config.STSEndpoint) 48 } 49 50 credsConfig.AccessKey = config.AccessKey 51 credsConfig.SecretKey = config.SecretKey 52 maxRetries = config.MaxRetries 53 } 54 55 credsConfig.HTTPClient = cleanhttp.DefaultClient() 56 57 creds, err := credsConfig.GenerateCredentialChain() 58 if err != nil { 59 return nil, err 60 } 61 if creds == nil { 62 return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") 63 } 64 65 // Create a config that can be used to make the API calls. 66 return &aws.Config{ 67 Credentials: creds, 68 Region: aws.String(region), 69 HTTPClient: cleanhttp.DefaultClient(), 70 Endpoint: endpoint, 71 MaxRetries: aws.Int(maxRetries), 72 }, nil 73} 74 75// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials 76// It uses getRawClientConfig to obtain config for the runtime environment, and if 77// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed 78// credentials. The credentials will expire after 15 minutes but will auto-refresh. 79func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) { 80 81 config, err := b.getRawClientConfig(ctx, s, region, clientType) 82 if err != nil { 83 return nil, err 84 } 85 if config == nil { 86 return nil, fmt.Errorf("could not compile valid credentials through the default provider chain") 87 } 88 89 stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts") 90 if stsConfig == nil { 91 return nil, fmt.Errorf("could not configure STS client") 92 } 93 if err != nil { 94 return nil, err 95 } 96 if stsRole != "" { 97 assumedCredentials := stscreds.NewCredentials(session.New(stsConfig), stsRole) 98 // Test that we actually have permissions to assume the role 99 if _, err = assumedCredentials.Get(); err != nil { 100 return nil, err 101 } 102 config.Credentials = assumedCredentials 103 } else { 104 if b.defaultAWSAccountID == "" { 105 client := sts.New(session.New(stsConfig)) 106 if client == nil { 107 return nil, errwrap.Wrapf("could not obtain sts client: {{err}}", err) 108 } 109 inputParams := &sts.GetCallerIdentityInput{} 110 identity, err := client.GetCallerIdentity(inputParams) 111 if err != nil { 112 return nil, errwrap.Wrapf("unable to fetch current caller: {{err}}", err) 113 } 114 if identity == nil { 115 return nil, fmt.Errorf("got nil result from GetCallerIdentity") 116 } 117 b.defaultAWSAccountID = *identity.Account 118 } 119 if b.defaultAWSAccountID != accountID { 120 return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID) 121 } 122 } 123 124 return config, nil 125} 126 127// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend. 128// If the client credentials configuration is deleted or updated in the backend, all 129// the cached EC2 client objects will be flushed. Config mutex lock should be 130// acquired for write operation before calling this method. 131func (b *backend) flushCachedEC2Clients() { 132 // deleting items in map during iteration is safe 133 for region, _ := range b.EC2ClientsMap { 134 delete(b.EC2ClientsMap, region) 135 } 136} 137 138// flushCachedIAMClients deletes all the cached iam client objects from the 139// backend. If the client credentials configuration is deleted or updated in 140// the backend, all the cached IAM client objects will be flushed. Config mutex 141// lock should be acquired for write operation before calling this method. 142func (b *backend) flushCachedIAMClients() { 143 // deleting items in map during iteration is safe 144 for region, _ := range b.IAMClientsMap { 145 delete(b.IAMClientsMap, region) 146 } 147} 148 149// Gets an entry out of the user ID cache 150func (b *backend) getCachedUserId(userId string) string { 151 if userId == "" { 152 return "" 153 } 154 if entry, ok := b.iamUserIdToArnCache.Get(userId); ok { 155 b.iamUserIdToArnCache.SetDefault(userId, entry) 156 return entry.(string) 157 } 158 return "" 159} 160 161// Sets an entry in the user ID cache 162func (b *backend) setCachedUserId(userId, arn string) { 163 if userId != "" { 164 b.iamUserIdToArnCache.SetDefault(userId, arn) 165 } 166} 167 168func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) { 169 // Check if an STS configuration exists for the AWS account 170 sts, err := b.lockedAwsStsEntry(ctx, s, accountID) 171 if err != nil { 172 return "", errwrap.Wrapf(fmt.Sprintf("error fetching STS config for account ID %q: {{err}}", accountID), err) 173 } 174 // An empty STS role signifies the master account 175 if sts != nil { 176 return sts.StsRole, nil 177 } 178 return "", nil 179} 180 181// clientEC2 creates a client to interact with AWS EC2 API 182func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) { 183 stsRole, err := b.stsRoleForAccount(ctx, s, accountID) 184 if err != nil { 185 return nil, err 186 } 187 b.configMutex.RLock() 188 if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { 189 defer b.configMutex.RUnlock() 190 // If the client object was already created, return it 191 return b.EC2ClientsMap[region][stsRole], nil 192 } 193 194 // Release the read lock and acquire the write lock 195 b.configMutex.RUnlock() 196 b.configMutex.Lock() 197 defer b.configMutex.Unlock() 198 199 // If the client gets created while switching the locks, return it 200 if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { 201 return b.EC2ClientsMap[region][stsRole], nil 202 } 203 204 // Create an AWS config object using a chain of providers 205 var awsConfig *aws.Config 206 awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2") 207 208 if err != nil { 209 return nil, err 210 } 211 212 if awsConfig == nil { 213 return nil, fmt.Errorf("could not retrieve valid assumed credentials") 214 } 215 216 // Create a new EC2 client object, cache it and return the same 217 client := ec2.New(session.New(awsConfig)) 218 if client == nil { 219 return nil, fmt.Errorf("could not obtain ec2 client") 220 } 221 if _, ok := b.EC2ClientsMap[region]; !ok { 222 b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client} 223 } else { 224 b.EC2ClientsMap[region][stsRole] = client 225 } 226 227 return b.EC2ClientsMap[region][stsRole], nil 228} 229 230// clientIAM creates a client to interact with AWS IAM API 231func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) { 232 stsRole, err := b.stsRoleForAccount(ctx, s, accountID) 233 if err != nil { 234 return nil, err 235 } 236 b.configMutex.RLock() 237 if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { 238 defer b.configMutex.RUnlock() 239 // If the client object was already created, return it 240 return b.IAMClientsMap[region][stsRole], nil 241 } 242 243 // Release the read lock and acquire the write lock 244 b.configMutex.RUnlock() 245 b.configMutex.Lock() 246 defer b.configMutex.Unlock() 247 248 // If the client gets created while switching the locks, return it 249 if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { 250 return b.IAMClientsMap[region][stsRole], nil 251 } 252 253 // Create an AWS config object using a chain of providers 254 var awsConfig *aws.Config 255 awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam") 256 257 if err != nil { 258 return nil, err 259 } 260 261 if awsConfig == nil { 262 return nil, fmt.Errorf("could not retrieve valid assumed credentials") 263 } 264 265 // Create a new IAM client object, cache it and return the same 266 client := iam.New(session.New(awsConfig)) 267 if client == nil { 268 return nil, fmt.Errorf("could not obtain iam client") 269 } 270 if _, ok := b.IAMClientsMap[region]; !ok { 271 b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client} 272 } else { 273 b.IAMClientsMap[region][stsRole] = client 274 } 275 return b.IAMClientsMap[region][stsRole], nil 276} 277