1// Package endpointcreds provides support for retrieving credentials from an
2// arbitrary HTTP endpoint.
3//
4// The credentials endpoint Provider can receive both static and refreshable
5// credentials that will expire. Credentials are static when an "Expiration"
6// value is not provided in the endpoint's response.
7//
8// Static credentials will never expire once they have been retrieved. The format
9// of the static credentials response:
10//    {
11//        "AccessKeyId" : "MUA...",
12//        "SecretAccessKey" : "/7PC5om....",
13//    }
14//
15// Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
16// value in the response. The format of the refreshable credentials response:
17//    {
18//        "AccessKeyId" : "MUA...",
19//        "SecretAccessKey" : "/7PC5om....",
20//        "Token" : "AQoDY....=",
21//        "Expiration" : "2016-02-25T06:03:31Z"
22//    }
23//
24// Errors should be returned in the following format and only returned with 400
25// or 500 HTTP status codes.
26//    {
27//        "code": "ErrorCode",
28//        "message": "Helpful error message."
29//    }
30package endpointcreds
31
32import (
33	"encoding/json"
34	"time"
35
36	"github.com/aws/aws-sdk-go/aws"
37	"github.com/aws/aws-sdk-go/aws/awserr"
38	"github.com/aws/aws-sdk-go/aws/client"
39	"github.com/aws/aws-sdk-go/aws/client/metadata"
40	"github.com/aws/aws-sdk-go/aws/credentials"
41	"github.com/aws/aws-sdk-go/aws/request"
42)
43
44// ProviderName is the name of the credentials provider.
45const ProviderName = `CredentialsEndpointProvider`
46
47// Provider satisfies the credentials.Provider interface, and is a client to
48// retrieve credentials from an arbitrary endpoint.
49type Provider struct {
50	staticCreds bool
51	credentials.Expiry
52
53	// Requires a AWS Client to make HTTP requests to the endpoint with.
54	// the Endpoint the request will be made to is provided by the aws.Config's
55	// Endpoint value.
56	Client *client.Client
57
58	// ExpiryWindow will allow the credentials to trigger refreshing prior to
59	// the credentials actually expiring. This is beneficial so race conditions
60	// with expiring credentials do not cause request to fail unexpectedly
61	// due to ExpiredTokenException exceptions.
62	//
63	// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
64	// 10 seconds before the credentials are actually expired.
65	//
66	// If ExpiryWindow is 0 or less it will be ignored.
67	ExpiryWindow time.Duration
68}
69
70// NewProviderClient returns a credentials Provider for retrieving AWS credentials
71// from arbitrary endpoint.
72func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider {
73	p := &Provider{
74		Client: client.New(
75			cfg,
76			metadata.ClientInfo{
77				ServiceName: "CredentialsEndpoint",
78				Endpoint:    endpoint,
79			},
80			handlers,
81		),
82	}
83
84	p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
85	p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
86	p.Client.Handlers.Validate.Clear()
87	p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
88
89	for _, option := range options {
90		option(p)
91	}
92
93	return p
94}
95
96// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
97// from an arbitrary endpoint concurrently. The client will request the
98func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
99	return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
100}
101
102// IsExpired returns true if the credentials retrieved are expired, or not yet
103// retrieved.
104func (p *Provider) IsExpired() bool {
105	if p.staticCreds {
106		return false
107	}
108	return p.Expiry.IsExpired()
109}
110
111// Retrieve will attempt to request the credentials from the endpoint the Provider
112// was configured for. And error will be returned if the retrieval fails.
113func (p *Provider) Retrieve() (credentials.Value, error) {
114	resp, err := p.getCredentials()
115	if err != nil {
116		return credentials.Value{ProviderName: ProviderName},
117			awserr.New("CredentialsEndpointError", "failed to load credentials", err)
118	}
119
120	if resp.Expiration != nil {
121		p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
122	} else {
123		p.staticCreds = true
124	}
125
126	return credentials.Value{
127		AccessKeyID:     resp.AccessKeyID,
128		SecretAccessKey: resp.SecretAccessKey,
129		SessionToken:    resp.Token,
130		ProviderName:    ProviderName,
131	}, nil
132}
133
134type getCredentialsOutput struct {
135	Expiration      *time.Time
136	AccessKeyID     string
137	SecretAccessKey string
138	Token           string
139}
140
141type errorOutput struct {
142	Code    string `json:"code"`
143	Message string `json:"message"`
144}
145
146func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
147	op := &request.Operation{
148		Name:       "GetCredentials",
149		HTTPMethod: "GET",
150	}
151
152	out := &getCredentialsOutput{}
153	req := p.Client.NewRequest(op, nil, out)
154	req.HTTPRequest.Header.Set("Accept", "application/json")
155
156	return out, req.Send()
157}
158
159func validateEndpointHandler(r *request.Request) {
160	if len(r.ClientInfo.Endpoint) == 0 {
161		r.Error = aws.ErrMissingEndpoint
162	}
163}
164
165func unmarshalHandler(r *request.Request) {
166	defer r.HTTPResponse.Body.Close()
167
168	out := r.Data.(*getCredentialsOutput)
169	if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
170		r.Error = awserr.New("SerializationError",
171			"failed to decode endpoint credentials",
172			err,
173		)
174	}
175}
176
177func unmarshalError(r *request.Request) {
178	defer r.HTTPResponse.Body.Close()
179
180	var errOut errorOutput
181	if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
182		r.Error = awserr.New("SerializationError",
183			"failed to decode endpoint credentials",
184			err,
185		)
186	}
187
188	// Response body format is not consistent between metadata endpoints.
189	// Grab the error message as a string and include that as the source error
190	r.Error = awserr.New(errOut.Code, errOut.Message, nil)
191}
192