1package awsauth 2 3import ( 4 "context" 5 "fmt" 6 7 "github.com/aws/aws-sdk-go/aws" 8 "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 9 "github.com/aws/aws-sdk-go/aws/session" 10 "github.com/aws/aws-sdk-go/service/ec2" 11 "github.com/aws/aws-sdk-go/service/iam" 12 "github.com/aws/aws-sdk-go/service/sts" 13 "github.com/hashicorp/errwrap" 14 cleanhttp "github.com/hashicorp/go-cleanhttp" 15 "github.com/hashicorp/vault/sdk/helper/awsutil" 16 "github.com/hashicorp/vault/sdk/logical" 17) 18 19// getRawClientConfig creates a aws-sdk-go config, which is used to create client 20// that can interact with AWS API. This builds credentials in the following 21// order of preference: 22// 23// * Static credentials from 'config/client' 24// * Environment variables 25// * Instance metadata role 26func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) { 27 credsConfig := &awsutil.CredentialsConfig{ 28 Region: region, 29 } 30 31 // Read the configured secret key and access key 32 config, err := b.nonLockedClientConfigEntry(ctx, s) 33 if err != nil { 34 return nil, err 35 } 36 37 endpoint := aws.String("") 38 var maxRetries int = aws.UseServiceDefaultRetries 39 if config != nil { 40 // Override the defaults with configured values. 41 switch { 42 case clientType == "ec2" && config.Endpoint != "": 43 endpoint = aws.String(config.Endpoint) 44 case clientType == "iam" && config.IAMEndpoint != "": 45 endpoint = aws.String(config.IAMEndpoint) 46 case clientType == "sts": 47 if config.STSEndpoint != "" { 48 endpoint = aws.String(config.STSEndpoint) 49 } 50 if config.STSRegion != "" { 51 region = config.STSRegion 52 } 53 } 54 55 credsConfig.AccessKey = config.AccessKey 56 credsConfig.SecretKey = config.SecretKey 57 maxRetries = config.MaxRetries 58 } 59 60 credsConfig.HTTPClient = cleanhttp.DefaultClient() 61 62 creds, err := credsConfig.GenerateCredentialChain() 63 if err != nil { 64 return nil, err 65 } 66 if creds == nil { 67 return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") 68 } 69 70 // Create a config that can be used to make the API calls. 71 return &aws.Config{ 72 Credentials: creds, 73 Region: aws.String(region), 74 HTTPClient: cleanhttp.DefaultClient(), 75 Endpoint: endpoint, 76 MaxRetries: aws.Int(maxRetries), 77 }, nil 78} 79 80// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials 81// It uses getRawClientConfig to obtain config for the runtime environment, and if 82// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed 83// credentials. The credentials will expire after 15 minutes but will auto-refresh. 84func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) { 85 86 config, err := b.getRawClientConfig(ctx, s, region, clientType) 87 if err != nil { 88 return nil, err 89 } 90 if config == nil { 91 return nil, fmt.Errorf("could not compile valid credentials through the default provider chain") 92 } 93 94 stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts") 95 if stsConfig == nil { 96 return nil, fmt.Errorf("could not configure STS client") 97 } 98 if err != nil { 99 return nil, err 100 } 101 if stsRole != "" { 102 sess, err := session.NewSession(stsConfig) 103 if err != nil { 104 return nil, err 105 } 106 assumedCredentials := stscreds.NewCredentials(sess, stsRole) 107 // Test that we actually have permissions to assume the role 108 if _, err = assumedCredentials.Get(); err != nil { 109 return nil, err 110 } 111 config.Credentials = assumedCredentials 112 } else { 113 if b.defaultAWSAccountID == "" { 114 sess, err := session.NewSession(stsConfig) 115 if err != nil { 116 return nil, err 117 } 118 client := sts.New(sess) 119 if client == nil { 120 return nil, errwrap.Wrapf("could not obtain sts client: {{err}}", err) 121 } 122 inputParams := &sts.GetCallerIdentityInput{} 123 identity, err := client.GetCallerIdentity(inputParams) 124 if err != nil { 125 return nil, errwrap.Wrapf("unable to fetch current caller: {{err}}", err) 126 } 127 if identity == nil { 128 return nil, fmt.Errorf("got nil result from GetCallerIdentity") 129 } 130 b.defaultAWSAccountID = *identity.Account 131 } 132 if b.defaultAWSAccountID != accountID { 133 return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID) 134 } 135 } 136 137 return config, nil 138} 139 140// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend. 141// If the client credentials configuration is deleted or updated in the backend, all 142// the cached EC2 client objects will be flushed. Config mutex lock should be 143// acquired for write operation before calling this method. 144func (b *backend) flushCachedEC2Clients() { 145 // deleting items in map during iteration is safe 146 for region, _ := range b.EC2ClientsMap { 147 delete(b.EC2ClientsMap, region) 148 } 149} 150 151// flushCachedIAMClients deletes all the cached iam client objects from the 152// backend. If the client credentials configuration is deleted or updated in 153// the backend, all the cached IAM client objects will be flushed. Config mutex 154// lock should be acquired for write operation before calling this method. 155func (b *backend) flushCachedIAMClients() { 156 // deleting items in map during iteration is safe 157 for region, _ := range b.IAMClientsMap { 158 delete(b.IAMClientsMap, region) 159 } 160} 161 162// Gets an entry out of the user ID cache 163func (b *backend) getCachedUserId(userId string) string { 164 if userId == "" { 165 return "" 166 } 167 if entry, ok := b.iamUserIdToArnCache.Get(userId); ok { 168 b.iamUserIdToArnCache.SetDefault(userId, entry) 169 return entry.(string) 170 } 171 return "" 172} 173 174// Sets an entry in the user ID cache 175func (b *backend) setCachedUserId(userId, arn string) { 176 if userId != "" { 177 b.iamUserIdToArnCache.SetDefault(userId, arn) 178 } 179} 180 181func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) { 182 // Check if an STS configuration exists for the AWS account 183 sts, err := b.lockedAwsStsEntry(ctx, s, accountID) 184 if err != nil { 185 return "", errwrap.Wrapf(fmt.Sprintf("error fetching STS config for account ID %q: {{err}}", accountID), err) 186 } 187 // An empty STS role signifies the master account 188 if sts != nil { 189 return sts.StsRole, nil 190 } 191 return "", nil 192} 193 194// clientEC2 creates a client to interact with AWS EC2 API 195func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) { 196 stsRole, err := b.stsRoleForAccount(ctx, s, accountID) 197 if err != nil { 198 return nil, err 199 } 200 b.configMutex.RLock() 201 if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { 202 defer b.configMutex.RUnlock() 203 // If the client object was already created, return it 204 return b.EC2ClientsMap[region][stsRole], nil 205 } 206 207 // Release the read lock and acquire the write lock 208 b.configMutex.RUnlock() 209 b.configMutex.Lock() 210 defer b.configMutex.Unlock() 211 212 // If the client gets created while switching the locks, return it 213 if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { 214 return b.EC2ClientsMap[region][stsRole], nil 215 } 216 217 // Create an AWS config object using a chain of providers 218 var awsConfig *aws.Config 219 awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2") 220 221 if err != nil { 222 return nil, err 223 } 224 225 if awsConfig == nil { 226 return nil, fmt.Errorf("could not retrieve valid assumed credentials") 227 } 228 229 // Create a new EC2 client object, cache it and return the same 230 sess, err := session.NewSession(awsConfig) 231 if err != nil { 232 return nil, err 233 } 234 client := ec2.New(sess) 235 if client == nil { 236 return nil, fmt.Errorf("could not obtain ec2 client") 237 } 238 if _, ok := b.EC2ClientsMap[region]; !ok { 239 b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client} 240 } else { 241 b.EC2ClientsMap[region][stsRole] = client 242 } 243 244 return b.EC2ClientsMap[region][stsRole], nil 245} 246 247// clientIAM creates a client to interact with AWS IAM API 248func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) { 249 stsRole, err := b.stsRoleForAccount(ctx, s, accountID) 250 if err != nil { 251 return nil, err 252 } 253 b.configMutex.RLock() 254 if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { 255 defer b.configMutex.RUnlock() 256 // If the client object was already created, return it 257 return b.IAMClientsMap[region][stsRole], nil 258 } 259 260 // Release the read lock and acquire the write lock 261 b.configMutex.RUnlock() 262 b.configMutex.Lock() 263 defer b.configMutex.Unlock() 264 265 // If the client gets created while switching the locks, return it 266 if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { 267 return b.IAMClientsMap[region][stsRole], nil 268 } 269 270 // Create an AWS config object using a chain of providers 271 var awsConfig *aws.Config 272 awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam") 273 274 if err != nil { 275 return nil, err 276 } 277 278 if awsConfig == nil { 279 return nil, fmt.Errorf("could not retrieve valid assumed credentials") 280 } 281 282 // Create a new IAM client object, cache it and return the same 283 sess, err := session.NewSession(awsConfig) 284 if err != nil { 285 return nil, err 286 } 287 client := iam.New(sess) 288 if client == nil { 289 return nil, fmt.Errorf("could not obtain iam client") 290 } 291 if _, ok := b.IAMClientsMap[region]; !ok { 292 b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client} 293 } else { 294 b.IAMClientsMap[region][stsRole] = client 295 } 296 return b.IAMClientsMap[region][stsRole], nil 297} 298