1package oidc 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "net/http" 9 "sync" 10 "time" 11 12 "github.com/pquerna/cachecontrol" 13 jose "gopkg.in/square/go-jose.v2" 14) 15 16// keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect 17// server. 18// 19// When keys expire, they are valid for this amount of time after. 20// 21// If the keys have not expired, and an ID Token claims it was signed by a key not in 22// the cache, if and only if the keys expire in this amount of time, the keys will be 23// updated. 24const keysExpiryDelta = 30 * time.Second 25 26// NewRemoteKeySet returns a KeySet that can validate JSON web tokens by using HTTP 27// GETs to fetch JSON web token sets hosted at a remote URL. This is automatically 28// used by NewProvider using the URLs returned by OpenID Connect discovery, but is 29// exposed for providers that don't support discovery or to prevent round trips to the 30// discovery URL. 31// 32// The returned KeySet is a long lived verifier that caches keys based on cache-control 33// headers. Reuse a common remote key set instead of creating new ones as needed. 34// 35// The behavior of the returned KeySet is undefined once the context is canceled. 36func NewRemoteKeySet(ctx context.Context, jwksURL string) KeySet { 37 return newRemoteKeySet(ctx, jwksURL, time.Now) 38} 39 40func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet { 41 if now == nil { 42 now = time.Now 43 } 44 return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now} 45} 46 47type remoteKeySet struct { 48 jwksURL string 49 ctx context.Context 50 now func() time.Time 51 52 // guard all other fields 53 mu sync.Mutex 54 55 // inflight suppresses parallel execution of updateKeys and allows 56 // multiple goroutines to wait for its result. 57 inflight *inflight 58 59 // A set of cached keys and their expiry. 60 cachedKeys []jose.JSONWebKey 61 expiry time.Time 62} 63 64// inflight is used to wait on some in-flight request from multiple goroutines. 65type inflight struct { 66 doneCh chan struct{} 67 68 keys []jose.JSONWebKey 69 err error 70} 71 72func newInflight() *inflight { 73 return &inflight{doneCh: make(chan struct{})} 74} 75 76// wait returns a channel that multiple goroutines can receive on. Once it returns 77// a value, the inflight request is done and result() can be inspected. 78func (i *inflight) wait() <-chan struct{} { 79 return i.doneCh 80} 81 82// done can only be called by a single goroutine. It records the result of the 83// inflight request and signals other goroutines that the result is safe to 84// inspect. 85func (i *inflight) done(keys []jose.JSONWebKey, err error) { 86 i.keys = keys 87 i.err = err 88 close(i.doneCh) 89} 90 91// result cannot be called until the wait() channel has returned a value. 92func (i *inflight) result() ([]jose.JSONWebKey, error) { 93 return i.keys, i.err 94} 95 96func (r *remoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) { 97 jws, err := jose.ParseSigned(jwt) 98 if err != nil { 99 return nil, fmt.Errorf("oidc: malformed jwt: %v", err) 100 } 101 return r.verify(ctx, jws) 102} 103 104func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { 105 // We don't support JWTs signed with multiple signatures. 106 keyID := "" 107 for _, sig := range jws.Signatures { 108 keyID = sig.Header.KeyID 109 break 110 } 111 112 keys, expiry := r.keysFromCache() 113 114 // Don't check expiry yet. This optimizes for when the provider is unavailable. 115 for _, key := range keys { 116 if keyID == "" || key.KeyID == keyID { 117 if payload, err := jws.Verify(&key); err == nil { 118 return payload, nil 119 } 120 } 121 } 122 123 if !r.now().Add(keysExpiryDelta).After(expiry) { 124 // Keys haven't expired, don't refresh. 125 return nil, errors.New("failed to verify id token signature") 126 } 127 128 keys, err := r.keysFromRemote(ctx) 129 if err != nil { 130 return nil, fmt.Errorf("fetching keys %v", err) 131 } 132 133 for _, key := range keys { 134 if keyID == "" || key.KeyID == keyID { 135 if payload, err := jws.Verify(&key); err == nil { 136 return payload, nil 137 } 138 } 139 } 140 return nil, errors.New("failed to verify id token signature") 141} 142 143func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) { 144 r.mu.Lock() 145 defer r.mu.Unlock() 146 return r.cachedKeys, r.expiry 147} 148 149// keysFromRemote syncs the key set from the remote set, records the values in the 150// cache, and returns the key set. 151func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) { 152 // Need to lock to inspect the inflight request field. 153 r.mu.Lock() 154 // If there's not a current inflight request, create one. 155 if r.inflight == nil { 156 r.inflight = newInflight() 157 158 // This goroutine has exclusive ownership over the current inflight 159 // request. It releases the resource by nil'ing the inflight field 160 // once the goroutine is done. 161 go func() { 162 // Sync keys and finish inflight when that's done. 163 keys, expiry, err := r.updateKeys() 164 165 r.inflight.done(keys, err) 166 167 // Lock to update the keys and indicate that there is no longer an 168 // inflight request. 169 r.mu.Lock() 170 defer r.mu.Unlock() 171 172 if err == nil { 173 r.cachedKeys = keys 174 r.expiry = expiry 175 } 176 177 // Free inflight so a different request can run. 178 r.inflight = nil 179 }() 180 } 181 inflight := r.inflight 182 r.mu.Unlock() 183 184 select { 185 case <-ctx.Done(): 186 return nil, ctx.Err() 187 case <-inflight.wait(): 188 return inflight.result() 189 } 190} 191 192func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) { 193 req, err := http.NewRequest("GET", r.jwksURL, nil) 194 if err != nil { 195 return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err) 196 } 197 198 resp, err := doRequest(r.ctx, req) 199 if err != nil { 200 return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err) 201 } 202 defer resp.Body.Close() 203 204 body, err := ioutil.ReadAll(resp.Body) 205 if err != nil { 206 return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err) 207 } 208 209 if resp.StatusCode != http.StatusOK { 210 return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body) 211 } 212 213 var keySet jose.JSONWebKeySet 214 err = unmarshalResp(resp, body, &keySet) 215 if err != nil { 216 return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body) 217 } 218 219 // If the server doesn't provide cache control headers, assume the 220 // keys expire immediately. 221 expiry := r.now() 222 223 _, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{}) 224 if err == nil && e.After(expiry) { 225 expiry = e 226 } 227 return keySet.Keys, expiry, nil 228} 229