1package awsauth 2 3import ( 4 "context" 5 "fmt" 6 7 "github.com/aws/aws-sdk-go/aws" 8 "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 9 "github.com/aws/aws-sdk-go/aws/session" 10 "github.com/aws/aws-sdk-go/service/ec2" 11 "github.com/aws/aws-sdk-go/service/iam" 12 "github.com/aws/aws-sdk-go/service/sts" 13 cleanhttp "github.com/hashicorp/go-cleanhttp" 14 "github.com/hashicorp/vault/sdk/helper/awsutil" 15 "github.com/hashicorp/vault/sdk/logical" 16) 17 18// getRawClientConfig creates a aws-sdk-go config, which is used to create client 19// that can interact with AWS API. This builds credentials in the following 20// order of preference: 21// 22// * Static credentials from 'config/client' 23// * Environment variables 24// * Instance metadata role 25func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) { 26 credsConfig := &awsutil.CredentialsConfig{ 27 Region: region, 28 Logger: b.Logger(), 29 } 30 31 // Read the configured secret key and access key 32 config, err := b.nonLockedClientConfigEntry(ctx, s) 33 if err != nil { 34 return nil, err 35 } 36 37 endpoint := aws.String("") 38 var maxRetries int = aws.UseServiceDefaultRetries 39 if config != nil { 40 // Override the defaults with configured values. 41 switch { 42 case clientType == "ec2" && config.Endpoint != "": 43 endpoint = aws.String(config.Endpoint) 44 case clientType == "iam" && config.IAMEndpoint != "": 45 endpoint = aws.String(config.IAMEndpoint) 46 case clientType == "sts": 47 if config.STSEndpoint != "" { 48 endpoint = aws.String(config.STSEndpoint) 49 } 50 if config.STSRegion != "" { 51 region = config.STSRegion 52 } 53 } 54 55 credsConfig.AccessKey = config.AccessKey 56 credsConfig.SecretKey = config.SecretKey 57 maxRetries = config.MaxRetries 58 } 59 60 credsConfig.HTTPClient = cleanhttp.DefaultClient() 61 62 creds, err := credsConfig.GenerateCredentialChain() 63 if err != nil { 64 return nil, err 65 } 66 if creds == nil { 67 return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") 68 } 69 70 // Create a config that can be used to make the API calls. 71 return &aws.Config{ 72 Credentials: creds, 73 Region: aws.String(region), 74 HTTPClient: cleanhttp.DefaultClient(), 75 Endpoint: endpoint, 76 MaxRetries: aws.Int(maxRetries), 77 }, nil 78} 79 80// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials 81// It uses getRawClientConfig to obtain config for the runtime environment, and if 82// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed 83// credentials. The credentials will expire after 15 minutes but will auto-refresh. 84func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) { 85 config, err := b.getRawClientConfig(ctx, s, region, clientType) 86 if err != nil { 87 return nil, err 88 } 89 if config == nil { 90 return nil, fmt.Errorf("could not compile valid credentials through the default provider chain") 91 } 92 93 stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts") 94 if stsConfig == nil { 95 return nil, fmt.Errorf("could not configure STS client") 96 } 97 if err != nil { 98 return nil, err 99 } 100 if stsRole != "" { 101 sess, err := session.NewSession(stsConfig) 102 if err != nil { 103 return nil, err 104 } 105 assumedCredentials := stscreds.NewCredentials(sess, stsRole) 106 // Test that we actually have permissions to assume the role 107 if _, err = assumedCredentials.Get(); err != nil { 108 return nil, err 109 } 110 config.Credentials = assumedCredentials 111 } else { 112 if b.defaultAWSAccountID == "" { 113 sess, err := session.NewSession(stsConfig) 114 if err != nil { 115 return nil, err 116 } 117 client := sts.New(sess) 118 if client == nil { 119 return nil, fmt.Errorf("could not obtain sts client: %w", err) 120 } 121 inputParams := &sts.GetCallerIdentityInput{} 122 identity, err := client.GetCallerIdentity(inputParams) 123 if err != nil { 124 return nil, fmt.Errorf("unable to fetch current caller: %w", err) 125 } 126 if identity == nil { 127 return nil, fmt.Errorf("got nil result from GetCallerIdentity") 128 } 129 b.defaultAWSAccountID = *identity.Account 130 } 131 if b.defaultAWSAccountID != accountID { 132 return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID) 133 } 134 } 135 136 return config, nil 137} 138 139// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend. 140// If the client credentials configuration is deleted or updated in the backend, all 141// the cached EC2 client objects will be flushed. Config mutex lock should be 142// acquired for write operation before calling this method. 143func (b *backend) flushCachedEC2Clients() { 144 // deleting items in map during iteration is safe 145 for region := range b.EC2ClientsMap { 146 delete(b.EC2ClientsMap, region) 147 } 148} 149 150// flushCachedIAMClients deletes all the cached iam client objects from the 151// backend. If the client credentials configuration is deleted or updated in 152// the backend, all the cached IAM client objects will be flushed. Config mutex 153// lock should be acquired for write operation before calling this method. 154func (b *backend) flushCachedIAMClients() { 155 // deleting items in map during iteration is safe 156 for region := range b.IAMClientsMap { 157 delete(b.IAMClientsMap, region) 158 } 159} 160 161// Gets an entry out of the user ID cache 162func (b *backend) getCachedUserId(userId string) string { 163 if userId == "" { 164 return "" 165 } 166 if entry, ok := b.iamUserIdToArnCache.Get(userId); ok { 167 b.iamUserIdToArnCache.SetDefault(userId, entry) 168 return entry.(string) 169 } 170 return "" 171} 172 173// Sets an entry in the user ID cache 174func (b *backend) setCachedUserId(userId, arn string) { 175 if userId != "" { 176 b.iamUserIdToArnCache.SetDefault(userId, arn) 177 } 178} 179 180func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) { 181 // Check if an STS configuration exists for the AWS account 182 sts, err := b.lockedAwsStsEntry(ctx, s, accountID) 183 if err != nil { 184 return "", fmt.Errorf("error fetching STS config for account ID %q: %w", accountID, err) 185 } 186 // An empty STS role signifies the master account 187 if sts != nil { 188 return sts.StsRole, nil 189 } 190 return "", nil 191} 192 193// clientEC2 creates a client to interact with AWS EC2 API 194func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) { 195 stsRole, err := b.stsRoleForAccount(ctx, s, accountID) 196 if err != nil { 197 return nil, err 198 } 199 b.configMutex.RLock() 200 if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { 201 defer b.configMutex.RUnlock() 202 // If the client object was already created, return it 203 return b.EC2ClientsMap[region][stsRole], nil 204 } 205 206 // Release the read lock and acquire the write lock 207 b.configMutex.RUnlock() 208 b.configMutex.Lock() 209 defer b.configMutex.Unlock() 210 211 // If the client gets created while switching the locks, return it 212 if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { 213 return b.EC2ClientsMap[region][stsRole], nil 214 } 215 216 // Create an AWS config object using a chain of providers 217 var awsConfig *aws.Config 218 awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2") 219 220 if err != nil { 221 return nil, err 222 } 223 224 if awsConfig == nil { 225 return nil, fmt.Errorf("could not retrieve valid assumed credentials") 226 } 227 228 // Create a new EC2 client object, cache it and return the same 229 sess, err := session.NewSession(awsConfig) 230 if err != nil { 231 return nil, err 232 } 233 client := ec2.New(sess) 234 if client == nil { 235 return nil, fmt.Errorf("could not obtain ec2 client") 236 } 237 if _, ok := b.EC2ClientsMap[region]; !ok { 238 b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client} 239 } else { 240 b.EC2ClientsMap[region][stsRole] = client 241 } 242 243 return b.EC2ClientsMap[region][stsRole], nil 244} 245 246// clientIAM creates a client to interact with AWS IAM API 247func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) { 248 stsRole, err := b.stsRoleForAccount(ctx, s, accountID) 249 if err != nil { 250 return nil, err 251 } 252 if stsRole == "" { 253 b.Logger().Debug(fmt.Sprintf("no stsRole found for %s", accountID)) 254 } else { 255 b.Logger().Debug(fmt.Sprintf("found stsRole %s for account %s", stsRole, accountID)) 256 } 257 b.configMutex.RLock() 258 if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { 259 defer b.configMutex.RUnlock() 260 // If the client object was already created, return it 261 b.Logger().Debug(fmt.Sprintf("returning cached client for region %s and stsRole %s", region, stsRole)) 262 return b.IAMClientsMap[region][stsRole], nil 263 } 264 b.Logger().Debug(fmt.Sprintf("no cached client for region %s and stsRole %s", region, stsRole)) 265 266 // Release the read lock and acquire the write lock 267 b.configMutex.RUnlock() 268 b.configMutex.Lock() 269 defer b.configMutex.Unlock() 270 271 // If the client gets created while switching the locks, return it 272 if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { 273 return b.IAMClientsMap[region][stsRole], nil 274 } 275 276 // Create an AWS config object using a chain of providers 277 var awsConfig *aws.Config 278 awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam") 279 280 if err != nil { 281 return nil, err 282 } 283 284 if awsConfig == nil { 285 return nil, fmt.Errorf("could not retrieve valid assumed credentials") 286 } 287 288 // Create a new IAM client object, cache it and return the same 289 sess, err := session.NewSession(awsConfig) 290 if err != nil { 291 return nil, err 292 } 293 client := iam.New(sess) 294 if client == nil { 295 return nil, fmt.Errorf("could not obtain iam client") 296 } 297 if _, ok := b.IAMClientsMap[region]; !ok { 298 b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client} 299 } else { 300 b.IAMClientsMap[region][stsRole] = client 301 } 302 return b.IAMClientsMap[region][stsRole], nil 303} 304