1package oidc
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io/ioutil"
8	"net/http"
9	"sync"
10	"time"
11
12	"github.com/pquerna/cachecontrol"
13	jose "gopkg.in/square/go-jose.v2"
14)
15
16// keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
17// server.
18//
19// When keys expire, they are valid for this amount of time after.
20//
21// If the keys have not expired, and an ID Token claims it was signed by a key not in
22// the cache, if and only if the keys expire in this amount of time, the keys will be
23// updated.
24const keysExpiryDelta = 30 * time.Second
25
26// NewRemoteKeySet returns a KeySet that can validate JSON web tokens by using HTTP
27// GETs to fetch JSON web token sets hosted at a remote URL. This is automatically
28// used by NewProvider using the URLs returned by OpenID Connect discovery, but is
29// exposed for providers that don't support discovery or to prevent round trips to the
30// discovery URL.
31//
32// The returned KeySet is a long lived verifier that caches keys based on cache-control
33// headers. Reuse a common remote key set instead of creating new ones as needed.
34//
35// The behavior of the returned KeySet is undefined once the context is canceled.
36func NewRemoteKeySet(ctx context.Context, jwksURL string) KeySet {
37	return newRemoteKeySet(ctx, jwksURL, time.Now)
38}
39
40func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
41	if now == nil {
42		now = time.Now
43	}
44	return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
45}
46
47type remoteKeySet struct {
48	jwksURL string
49	ctx     context.Context
50	now     func() time.Time
51
52	// guard all other fields
53	mu sync.Mutex
54
55	// inflight suppresses parallel execution of updateKeys and allows
56	// multiple goroutines to wait for its result.
57	inflight *inflight
58
59	// A set of cached keys and their expiry.
60	cachedKeys []jose.JSONWebKey
61	expiry     time.Time
62}
63
64// inflight is used to wait on some in-flight request from multiple goroutines.
65type inflight struct {
66	doneCh chan struct{}
67
68	keys []jose.JSONWebKey
69	err  error
70}
71
72func newInflight() *inflight {
73	return &inflight{doneCh: make(chan struct{})}
74}
75
76// wait returns a channel that multiple goroutines can receive on. Once it returns
77// a value, the inflight request is done and result() can be inspected.
78func (i *inflight) wait() <-chan struct{} {
79	return i.doneCh
80}
81
82// done can only be called by a single goroutine. It records the result of the
83// inflight request and signals other goroutines that the result is safe to
84// inspect.
85func (i *inflight) done(keys []jose.JSONWebKey, err error) {
86	i.keys = keys
87	i.err = err
88	close(i.doneCh)
89}
90
91// result cannot be called until the wait() channel has returned a value.
92func (i *inflight) result() ([]jose.JSONWebKey, error) {
93	return i.keys, i.err
94}
95
96func (r *remoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
97	jws, err := jose.ParseSigned(jwt)
98	if err != nil {
99		return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
100	}
101	return r.verify(ctx, jws)
102}
103
104func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
105	// We don't support JWTs signed with multiple signatures.
106	keyID := ""
107	for _, sig := range jws.Signatures {
108		keyID = sig.Header.KeyID
109		break
110	}
111
112	keys, expiry := r.keysFromCache()
113
114	// Don't check expiry yet. This optimizes for when the provider is unavailable.
115	for _, key := range keys {
116		if keyID == "" || key.KeyID == keyID {
117			if payload, err := jws.Verify(&key); err == nil {
118				return payload, nil
119			}
120		}
121	}
122
123	if !r.now().Add(keysExpiryDelta).After(expiry) {
124		// Keys haven't expired, don't refresh.
125		return nil, errors.New("failed to verify id token signature")
126	}
127
128	keys, err := r.keysFromRemote(ctx)
129	if err != nil {
130		return nil, fmt.Errorf("fetching keys %v", err)
131	}
132
133	for _, key := range keys {
134		if keyID == "" || key.KeyID == keyID {
135			if payload, err := jws.Verify(&key); err == nil {
136				return payload, nil
137			}
138		}
139	}
140	return nil, errors.New("failed to verify id token signature")
141}
142
143func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) {
144	r.mu.Lock()
145	defer r.mu.Unlock()
146	return r.cachedKeys, r.expiry
147}
148
149// keysFromRemote syncs the key set from the remote set, records the values in the
150// cache, and returns the key set.
151func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
152	// Need to lock to inspect the inflight request field.
153	r.mu.Lock()
154	// If there's not a current inflight request, create one.
155	if r.inflight == nil {
156		r.inflight = newInflight()
157
158		// This goroutine has exclusive ownership over the current inflight
159		// request. It releases the resource by nil'ing the inflight field
160		// once the goroutine is done.
161		go func() {
162			// Sync keys and finish inflight when that's done.
163			keys, expiry, err := r.updateKeys()
164
165			r.inflight.done(keys, err)
166
167			// Lock to update the keys and indicate that there is no longer an
168			// inflight request.
169			r.mu.Lock()
170			defer r.mu.Unlock()
171
172			if err == nil {
173				r.cachedKeys = keys
174				r.expiry = expiry
175			}
176
177			// Free inflight so a different request can run.
178			r.inflight = nil
179		}()
180	}
181	inflight := r.inflight
182	r.mu.Unlock()
183
184	select {
185	case <-ctx.Done():
186		return nil, ctx.Err()
187	case <-inflight.wait():
188		return inflight.result()
189	}
190}
191
192func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {
193	req, err := http.NewRequest("GET", r.jwksURL, nil)
194	if err != nil {
195		return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err)
196	}
197
198	resp, err := doRequest(r.ctx, req)
199	if err != nil {
200		return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err)
201	}
202	defer resp.Body.Close()
203
204	body, err := ioutil.ReadAll(resp.Body)
205	if err != nil {
206		return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
207	}
208
209	if resp.StatusCode != http.StatusOK {
210		return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
211	}
212
213	var keySet jose.JSONWebKeySet
214	err = unmarshalResp(resp, body, &keySet)
215	if err != nil {
216		return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
217	}
218
219	// If the server doesn't provide cache control headers, assume the
220	// keys expire immediately.
221	expiry := r.now()
222
223	_, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
224	if err == nil && e.After(expiry) {
225		expiry = e
226	}
227	return keySet.Keys, expiry, nil
228}
229