1// Package oidc implements OpenID Connect client logic for the golang.org/x/oauth2 package. 2package oidc 3 4import ( 5 "context" 6 "crypto/sha256" 7 "crypto/sha512" 8 "encoding/base64" 9 "encoding/json" 10 "errors" 11 "fmt" 12 "hash" 13 "io/ioutil" 14 "mime" 15 "net/http" 16 "strings" 17 "time" 18 19 "golang.org/x/oauth2" 20 jose "gopkg.in/square/go-jose.v2" 21) 22 23const ( 24 // ScopeOpenID is the mandatory scope for all OpenID Connect OAuth2 requests. 25 ScopeOpenID = "openid" 26 27 // ScopeOfflineAccess is an optional scope defined by OpenID Connect for requesting 28 // OAuth2 refresh tokens. 29 // 30 // Support for this scope differs between OpenID Connect providers. For instance 31 // Google rejects it, favoring appending "access_type=offline" as part of the 32 // authorization request instead. 33 // 34 // See: https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess 35 ScopeOfflineAccess = "offline_access" 36) 37 38var ( 39 errNoAtHash = errors.New("id token did not have an access token hash") 40 errInvalidAtHash = errors.New("access token hash does not match value in ID token") 41) 42 43// ClientContext returns a new Context that carries the provided HTTP client. 44// 45// This method sets the same context key used by the golang.org/x/oauth2 package, 46// so the returned context works for that package too. 47// 48// myClient := &http.Client{} 49// ctx := oidc.ClientContext(parentContext, myClient) 50// 51// // This will use the custom client 52// provider, err := oidc.NewProvider(ctx, "https://accounts.example.com") 53// 54func ClientContext(ctx context.Context, client *http.Client) context.Context { 55 return context.WithValue(ctx, oauth2.HTTPClient, client) 56} 57 58func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) { 59 client := http.DefaultClient 60 if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { 61 client = c 62 } 63 return client.Do(req.WithContext(ctx)) 64} 65 66// Provider represents an OpenID Connect server's configuration. 67type Provider struct { 68 issuer string 69 authURL string 70 tokenURL string 71 userInfoURL string 72 73 // Raw claims returned by the server. 74 rawClaims []byte 75 76 remoteKeySet KeySet 77} 78 79type cachedKeys struct { 80 keys []jose.JSONWebKey 81 expiry time.Time 82} 83 84type providerJSON struct { 85 Issuer string `json:"issuer"` 86 AuthURL string `json:"authorization_endpoint"` 87 TokenURL string `json:"token_endpoint"` 88 JWKSURL string `json:"jwks_uri"` 89 UserInfoURL string `json:"userinfo_endpoint"` 90} 91 92// NewProvider uses the OpenID Connect discovery mechanism to construct a Provider. 93// 94// The issuer is the URL identifier for the service. For example: "https://accounts.google.com" 95// or "https://login.salesforce.com". 96func NewProvider(ctx context.Context, issuer string) (*Provider, error) { 97 wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" 98 req, err := http.NewRequest("GET", wellKnown, nil) 99 if err != nil { 100 return nil, err 101 } 102 resp, err := doRequest(ctx, req) 103 if err != nil { 104 return nil, err 105 } 106 defer resp.Body.Close() 107 108 body, err := ioutil.ReadAll(resp.Body) 109 if err != nil { 110 return nil, fmt.Errorf("unable to read response body: %v", err) 111 } 112 113 if resp.StatusCode != http.StatusOK { 114 return nil, fmt.Errorf("%s: %s", resp.Status, body) 115 } 116 117 var p providerJSON 118 err = unmarshalResp(resp, body, &p) 119 if err != nil { 120 return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err) 121 } 122 123 if p.Issuer != issuer { 124 return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer) 125 } 126 return &Provider{ 127 issuer: p.Issuer, 128 authURL: p.AuthURL, 129 tokenURL: p.TokenURL, 130 userInfoURL: p.UserInfoURL, 131 rawClaims: body, 132 remoteKeySet: NewRemoteKeySet(ctx, p.JWKSURL), 133 }, nil 134} 135 136// Claims unmarshals raw fields returned by the server during discovery. 137// 138// var claims struct { 139// ScopesSupported []string `json:"scopes_supported"` 140// ClaimsSupported []string `json:"claims_supported"` 141// } 142// 143// if err := provider.Claims(&claims); err != nil { 144// // handle unmarshaling error 145// } 146// 147// For a list of fields defined by the OpenID Connect spec see: 148// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata 149func (p *Provider) Claims(v interface{}) error { 150 if p.rawClaims == nil { 151 return errors.New("oidc: claims not set") 152 } 153 return json.Unmarshal(p.rawClaims, v) 154} 155 156// Endpoint returns the OAuth2 auth and token endpoints for the given provider. 157func (p *Provider) Endpoint() oauth2.Endpoint { 158 return oauth2.Endpoint{AuthURL: p.authURL, TokenURL: p.tokenURL} 159} 160 161// UserInfo represents the OpenID Connect userinfo claims. 162type UserInfo struct { 163 Subject string `json:"sub"` 164 Profile string `json:"profile"` 165 Email string `json:"email"` 166 EmailVerified bool `json:"email_verified"` 167 168 claims []byte 169} 170 171// Claims unmarshals the raw JSON object claims into the provided object. 172func (u *UserInfo) Claims(v interface{}) error { 173 if u.claims == nil { 174 return errors.New("oidc: claims not set") 175 } 176 return json.Unmarshal(u.claims, v) 177} 178 179// UserInfo uses the token source to query the provider's user info endpoint. 180func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) { 181 if p.userInfoURL == "" { 182 return nil, errors.New("oidc: user info endpoint is not supported by this provider") 183 } 184 185 req, err := http.NewRequest("GET", p.userInfoURL, nil) 186 if err != nil { 187 return nil, fmt.Errorf("oidc: create GET request: %v", err) 188 } 189 190 token, err := tokenSource.Token() 191 if err != nil { 192 return nil, fmt.Errorf("oidc: get access token: %v", err) 193 } 194 token.SetAuthHeader(req) 195 196 resp, err := doRequest(ctx, req) 197 if err != nil { 198 return nil, err 199 } 200 defer resp.Body.Close() 201 body, err := ioutil.ReadAll(resp.Body) 202 if err != nil { 203 return nil, err 204 } 205 if resp.StatusCode != http.StatusOK { 206 return nil, fmt.Errorf("%s: %s", resp.Status, body) 207 } 208 209 var userInfo UserInfo 210 if err := json.Unmarshal(body, &userInfo); err != nil { 211 return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err) 212 } 213 userInfo.claims = body 214 return &userInfo, nil 215} 216 217// IDToken is an OpenID Connect extension that provides a predictable representation 218// of an authorization event. 219// 220// The ID Token only holds fields OpenID Connect requires. To access additional 221// claims returned by the server, use the Claims method. 222type IDToken struct { 223 // The URL of the server which issued this token. OpenID Connect 224 // requires this value always be identical to the URL used for 225 // initial discovery. 226 // 227 // Note: Because of a known issue with Google Accounts' implementation 228 // this value may differ when using Google. 229 // 230 // See: https://developers.google.com/identity/protocols/OpenIDConnect#obtainuserinfo 231 Issuer string 232 233 // The client ID, or set of client IDs, that this token is issued for. For 234 // common uses, this is the client that initialized the auth flow. 235 // 236 // This package ensures the audience contains an expected value. 237 Audience []string 238 239 // A unique string which identifies the end user. 240 Subject string 241 242 // Expiry of the token. Ths package will not process tokens that have 243 // expired unless that validation is explicitly turned off. 244 Expiry time.Time 245 // When the token was issued by the provider. 246 IssuedAt time.Time 247 248 // Initial nonce provided during the authentication redirect. 249 // 250 // This package does NOT provided verification on the value of this field 251 // and it's the user's responsibility to ensure it contains a valid value. 252 Nonce string 253 254 // at_hash claim, if set in the ID token. Callers can verify an access token 255 // that corresponds to the ID token using the VerifyAccessToken method. 256 AccessTokenHash string 257 258 // signature algorithm used for ID token, needed to compute a verification hash of an 259 // access token 260 sigAlgorithm string 261 262 // Raw payload of the id_token. 263 claims []byte 264 265 // Map of distributed claim names to claim sources 266 distributedClaims map[string]claimSource 267} 268 269// Claims unmarshals the raw JSON payload of the ID Token into a provided struct. 270// 271// idToken, err := idTokenVerifier.Verify(rawIDToken) 272// if err != nil { 273// // handle error 274// } 275// var claims struct { 276// Email string `json:"email"` 277// EmailVerified bool `json:"email_verified"` 278// } 279// if err := idToken.Claims(&claims); err != nil { 280// // handle error 281// } 282// 283func (i *IDToken) Claims(v interface{}) error { 284 if i.claims == nil { 285 return errors.New("oidc: claims not set") 286 } 287 return json.Unmarshal(i.claims, v) 288} 289 290// VerifyAccessToken verifies that the hash of the access token that corresponds to the iD token 291// matches the hash in the id token. It returns an error if the hashes don't match. 292// It is the caller's responsibility to ensure that the optional access token hash is present for the ID token 293// before calling this method. See https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken 294func (i *IDToken) VerifyAccessToken(accessToken string) error { 295 if i.AccessTokenHash == "" { 296 return errNoAtHash 297 } 298 var h hash.Hash 299 switch i.sigAlgorithm { 300 case RS256, ES256, PS256: 301 h = sha256.New() 302 case RS384, ES384, PS384: 303 h = sha512.New384() 304 case RS512, ES512, PS512: 305 h = sha512.New() 306 default: 307 return fmt.Errorf("oidc: unsupported signing algorithm %q", i.sigAlgorithm) 308 } 309 h.Write([]byte(accessToken)) // hash documents that Write will never return an error 310 sum := h.Sum(nil)[:h.Size()/2] 311 actual := base64.RawURLEncoding.EncodeToString(sum) 312 if actual != i.AccessTokenHash { 313 return errInvalidAtHash 314 } 315 return nil 316} 317 318type idToken struct { 319 Issuer string `json:"iss"` 320 Subject string `json:"sub"` 321 Audience audience `json:"aud"` 322 Expiry jsonTime `json:"exp"` 323 IssuedAt jsonTime `json:"iat"` 324 NotBefore *jsonTime `json:"nbf"` 325 Nonce string `json:"nonce"` 326 AtHash string `json:"at_hash"` 327 ClaimNames map[string]string `json:"_claim_names"` 328 ClaimSources map[string]claimSource `json:"_claim_sources"` 329} 330 331type claimSource struct { 332 Endpoint string `json:"endpoint"` 333 AccessToken string `json:"access_token"` 334} 335 336type audience []string 337 338func (a *audience) UnmarshalJSON(b []byte) error { 339 var s string 340 if json.Unmarshal(b, &s) == nil { 341 *a = audience{s} 342 return nil 343 } 344 var auds []string 345 if err := json.Unmarshal(b, &auds); err != nil { 346 return err 347 } 348 *a = audience(auds) 349 return nil 350} 351 352type jsonTime time.Time 353 354func (j *jsonTime) UnmarshalJSON(b []byte) error { 355 var n json.Number 356 if err := json.Unmarshal(b, &n); err != nil { 357 return err 358 } 359 var unix int64 360 361 if t, err := n.Int64(); err == nil { 362 unix = t 363 } else { 364 f, err := n.Float64() 365 if err != nil { 366 return err 367 } 368 unix = int64(f) 369 } 370 *j = jsonTime(time.Unix(unix, 0)) 371 return nil 372} 373 374func unmarshalResp(r *http.Response, body []byte, v interface{}) error { 375 err := json.Unmarshal(body, &v) 376 if err == nil { 377 return nil 378 } 379 ct := r.Header.Get("Content-Type") 380 mediaType, _, parseErr := mime.ParseMediaType(ct) 381 if parseErr == nil && mediaType == "application/json" { 382 return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err) 383 } 384 return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err) 385} 386