1// Package endpointcreds provides support for retrieving credentials from an
2// arbitrary HTTP endpoint.
3//
4// The credentials endpoint Provider can receive both static and refreshable
5// credentials that will expire. Credentials are static when an "Expiration"
6// value is not provided in the endpoint's response.
7//
8// Static credentials will never expire once they have been retrieved. The format
9// of the static credentials response:
10//    {
11//        "AccessKeyId" : "MUA...",
12//        "SecretAccessKey" : "/7PC5om....",
13//    }
14//
15// Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
16// value in the response. The format of the refreshable credentials response:
17//    {
18//        "AccessKeyId" : "MUA...",
19//        "SecretAccessKey" : "/7PC5om....",
20//        "Token" : "AQoDY....=",
21//        "Expiration" : "2016-02-25T06:03:31Z"
22//    }
23//
24// Errors should be returned in the following format and only returned with 400
25// or 500 HTTP status codes.
26//    {
27//        "code": "ErrorCode",
28//        "message": "Helpful error message."
29//    }
30package endpointcreds
31
32import (
33	"encoding/json"
34	"time"
35
36	"github.com/aws/aws-sdk-go/aws"
37	"github.com/aws/aws-sdk-go/aws/awserr"
38	"github.com/aws/aws-sdk-go/aws/client"
39	"github.com/aws/aws-sdk-go/aws/client/metadata"
40	"github.com/aws/aws-sdk-go/aws/credentials"
41	"github.com/aws/aws-sdk-go/aws/request"
42	"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
43)
44
45// ProviderName is the name of the credentials provider.
46const ProviderName = `CredentialsEndpointProvider`
47
48// Provider satisfies the credentials.Provider interface, and is a client to
49// retrieve credentials from an arbitrary endpoint.
50type Provider struct {
51	staticCreds bool
52	credentials.Expiry
53
54	// Requires a AWS Client to make HTTP requests to the endpoint with.
55	// the Endpoint the request will be made to is provided by the aws.Config's
56	// Endpoint value.
57	Client *client.Client
58
59	// ExpiryWindow will allow the credentials to trigger refreshing prior to
60	// the credentials actually expiring. This is beneficial so race conditions
61	// with expiring credentials do not cause request to fail unexpectedly
62	// due to ExpiredTokenException exceptions.
63	//
64	// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
65	// 10 seconds before the credentials are actually expired.
66	//
67	// If ExpiryWindow is 0 or less it will be ignored.
68	ExpiryWindow time.Duration
69
70	// Optional authorization token value if set will be used as the value of
71	// the Authorization header of the endpoint credential request.
72	AuthorizationToken string
73}
74
75// NewProviderClient returns a credentials Provider for retrieving AWS credentials
76// from arbitrary endpoint.
77func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider {
78	p := &Provider{
79		Client: client.New(
80			cfg,
81			metadata.ClientInfo{
82				ServiceName: "CredentialsEndpoint",
83				Endpoint:    endpoint,
84			},
85			handlers,
86		),
87	}
88
89	p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
90	p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
91	p.Client.Handlers.Validate.Clear()
92	p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
93
94	for _, option := range options {
95		option(p)
96	}
97
98	return p
99}
100
101// NewCredentialsClient returns a pointer to a new Credentials object
102// wrapping the endpoint credentials Provider.
103func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
104	return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
105}
106
107// IsExpired returns true if the credentials retrieved are expired, or not yet
108// retrieved.
109func (p *Provider) IsExpired() bool {
110	if p.staticCreds {
111		return false
112	}
113	return p.Expiry.IsExpired()
114}
115
116// Retrieve will attempt to request the credentials from the endpoint the Provider
117// was configured for. And error will be returned if the retrieval fails.
118func (p *Provider) Retrieve() (credentials.Value, error) {
119	return p.RetrieveWithContext(aws.BackgroundContext())
120}
121
122// RetrieveWithContext will attempt to request the credentials from the endpoint the Provider
123// was configured for. And error will be returned if the retrieval fails.
124func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
125	resp, err := p.getCredentials(ctx)
126	if err != nil {
127		return credentials.Value{ProviderName: ProviderName},
128			awserr.New("CredentialsEndpointError", "failed to load credentials", err)
129	}
130
131	if resp.Expiration != nil {
132		p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
133	} else {
134		p.staticCreds = true
135	}
136
137	return credentials.Value{
138		AccessKeyID:     resp.AccessKeyID,
139		SecretAccessKey: resp.SecretAccessKey,
140		SessionToken:    resp.Token,
141		ProviderName:    ProviderName,
142	}, nil
143}
144
145type getCredentialsOutput struct {
146	Expiration      *time.Time
147	AccessKeyID     string
148	SecretAccessKey string
149	Token           string
150}
151
152type errorOutput struct {
153	Code    string `json:"code"`
154	Message string `json:"message"`
155}
156
157func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) {
158	op := &request.Operation{
159		Name:       "GetCredentials",
160		HTTPMethod: "GET",
161	}
162
163	out := &getCredentialsOutput{}
164	req := p.Client.NewRequest(op, nil, out)
165	req.SetContext(ctx)
166	req.HTTPRequest.Header.Set("Accept", "application/json")
167	if authToken := p.AuthorizationToken; len(authToken) != 0 {
168		req.HTTPRequest.Header.Set("Authorization", authToken)
169	}
170
171	return out, req.Send()
172}
173
174func validateEndpointHandler(r *request.Request) {
175	if len(r.ClientInfo.Endpoint) == 0 {
176		r.Error = aws.ErrMissingEndpoint
177	}
178}
179
180func unmarshalHandler(r *request.Request) {
181	defer r.HTTPResponse.Body.Close()
182
183	out := r.Data.(*getCredentialsOutput)
184	if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
185		r.Error = awserr.New(request.ErrCodeSerialization,
186			"failed to decode endpoint credentials",
187			err,
188		)
189	}
190}
191
192func unmarshalError(r *request.Request) {
193	defer r.HTTPResponse.Body.Close()
194
195	var errOut errorOutput
196	err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
197	if err != nil {
198		r.Error = awserr.NewRequestFailure(
199			awserr.New(request.ErrCodeSerialization,
200				"failed to decode error message", err),
201			r.HTTPResponse.StatusCode,
202			r.RequestID,
203		)
204		return
205	}
206
207	// Response body format is not consistent between metadata endpoints.
208	// Grab the error message as a string and include that as the source error
209	r.Error = awserr.New(errOut.Code, errOut.Message, nil)
210}
211