1import json
2import base64
3import time
4import random
5import string
6import warnings
7import hashlib
8
9from . import oauth2
10
11def decode_part(raw, encoding="utf-8"):
12    """Decode a part of the JWT.
13
14    JWT is encoded by padding-less base64url,
15    based on `JWS specs <https://tools.ietf.org/html/rfc7515#appendix-C>`_.
16
17    :param encoding:
18        If you are going to decode the first 2 parts of a JWT, i.e. the header
19        or the payload, the default value "utf-8" would work fine.
20        If you are going to decode the last part i.e. the signature part,
21        it is a binary string so you should use `None` as encoding here.
22    """
23    raw += '=' * (-len(raw) % 4)  # https://stackoverflow.com/a/32517907/728675
24    raw = str(
25        # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode.
26        # This is not required on Python 3.
27        raw)
28    output = base64.urlsafe_b64decode(raw)
29    if encoding:
30        output = output.decode(encoding)
31    return output
32
33base64decode = decode_part  # Obsolete. For backward compatibility only.
34
35def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None):
36    """Decodes and validates an id_token and returns its claims as a dictionary.
37
38    ID token claims would at least contain: "iss", "sub", "aud", "exp", "iat",
39    per `specs <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>`_
40    and it may contain other optional content such as "preferred_username",
41    `maybe more <https://openid.net/specs/openid-connect-core-1_0.html#Claims>`_
42    """
43    decoded = json.loads(decode_part(id_token.split('.')[1]))
44    err = None  # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
45    _now = now or time.time()
46    skew = 120  # 2 minutes
47    if _now + skew < decoded.get("nbf", _now - 1):  # nbf is optional per JWT specs
48        # This is not an ID token validation, but a JWT validation
49        # https://tools.ietf.org/html/rfc7519#section-4.1.5
50        err = "0. The ID token is not yet valid."
51    if issuer and issuer != decoded["iss"]:
52        # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
53        err = ('2. The Issuer Identifier for the OpenID Provider, "%s", '
54            "(which is typically obtained during Discovery), "
55            "MUST exactly match the value of the iss (issuer) Claim.") % issuer
56    if client_id:
57        valid_aud = client_id in decoded["aud"] if isinstance(
58            decoded["aud"], list) else client_id == decoded["aud"]
59        if not valid_aud:
60            err = (
61                "3. The aud (audience) claim must contain this client's client_id "
62                '"%s", case-sensitively. Was your client_id in wrong casing?'
63                # Some IdP accepts wrong casing request but issues right casing IDT
64                ) % client_id
65    # Per specs:
66    # 6. If the ID Token is received via direct communication between
67    # the Client and the Token Endpoint (which it is during _obtain_token()),
68    # the TLS server validation MAY be used to validate the issuer
69    # in place of checking the token signature.
70    if _now > decoded["exp"]:
71        err = "9. The current time MUST be before the time represented by the exp Claim."
72    if nonce and nonce != decoded.get("nonce"):
73        err = ("11. Nonce must be the same value "
74            "as the one that was sent in the Authentication Request.")
75    if err:
76        raise RuntimeError("%s The id_token was: %s" % (
77            err, json.dumps(decoded, indent=2)))
78    return decoded
79
80
81def _nonce_hash(nonce):
82    # https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes
83    return hashlib.sha256(nonce.encode("ascii")).hexdigest()
84
85
86class Client(oauth2.Client):
87    """OpenID Connect is a layer on top of the OAuth2.
88
89    See its specs at https://openid.net/connect/
90    """
91
92    def decode_id_token(self, id_token, nonce=None):
93        """See :func:`~decode_id_token`."""
94        return decode_id_token(
95            id_token, nonce=nonce,
96            client_id=self.client_id, issuer=self.configuration.get("issuer"))
97
98    def _obtain_token(self, grant_type, *args, **kwargs):
99        """The result will also contain one more key "id_token_claims",
100        whose value will be a dictionary returned by :func:`~decode_id_token`.
101        """
102        ret = super(Client, self)._obtain_token(grant_type, *args, **kwargs)
103        if "id_token" in ret:
104            ret["id_token_claims"] = self.decode_id_token(ret["id_token"])
105        return ret
106
107    def build_auth_request_uri(self, response_type, nonce=None, **kwargs):
108        """Generate an authorization uri to be visited by resource owner.
109
110        Return value and all other parameters are the same as
111        :func:`oauth2.Client.build_auth_request_uri`, plus new parameter(s):
112
113        :param nonce:
114            A hard-to-guess string used to mitigate replay attacks. See also
115            `OIDC specs <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
116        """
117        warnings.warn("Use initiate_auth_code_flow() instead", DeprecationWarning)
118        return super(Client, self).build_auth_request_uri(
119            response_type, nonce=nonce, **kwargs)
120
121    def obtain_token_by_authorization_code(self, code, nonce=None, **kwargs):
122        """Get a token via authorization code. a.k.a. Authorization Code Grant.
123
124        Return value and all other parameters are the same as
125        :func:`oauth2.Client.obtain_token_by_authorization_code`,
126        plus new parameter(s):
127
128        :param nonce:
129            If you provided a nonce when calling :func:`build_auth_request_uri`,
130            same nonce should also be provided here, so that we'll validate it.
131            An exception will be raised if the nonce in id token mismatches.
132        """
133        warnings.warn(
134            "Use obtain_token_by_auth_code_flow() instead", DeprecationWarning)
135        result = super(Client, self).obtain_token_by_authorization_code(
136            code, **kwargs)
137        nonce_in_id_token = result.get("id_token_claims", {}).get("nonce")
138        if "id_token_claims" in result and nonce and nonce != nonce_in_id_token:
139            raise ValueError(
140                'The nonce in id token ("%s") should match your nonce ("%s")' %
141                (nonce_in_id_token, nonce))
142        return result
143
144    def initiate_auth_code_flow(
145            self,
146            scope=None,
147            **kwargs):
148        """Initiate an auth code flow.
149
150        It provides nonce protection automatically.
151
152        :param list scope:
153            A list of strings, e.g. ["profile", "email", ...].
154            This method will automatically send ["openid"] to the wire,
155            although it won't modify your input list.
156
157        See :func:`oauth2.Client.initiate_auth_code_flow` in parent class
158        for descriptions on other parameters and return value.
159        """
160        if "id_token" in kwargs.get("response_type", ""):
161            # Implicit grant would cause auth response coming back in #fragment,
162            # but fragment won't reach a web service.
163            raise ValueError('response_type="id_token ..." is not allowed')
164        _scope = list(scope) if scope else []  # We won't modify input parameter
165        if "openid" not in _scope:
166            # "If no openid scope value is present,
167            # the request may still be a valid OAuth 2.0 request,
168            # but is not an OpenID Connect request." -- OIDC Core Specs, 3.1.2.2
169            # https://openid.net/specs/openid-connect-core-1_0.html#AuthRequestValidation
170            # Here we just automatically add it. If the caller do not want id_token,
171            # they should simply go with oauth2.Client.
172            _scope.append("openid")
173        nonce = "".join(random.sample(string.ascii_letters, 16))
174        flow = super(Client, self).initiate_auth_code_flow(
175            scope=_scope, nonce=_nonce_hash(nonce), **kwargs)
176        flow["nonce"] = nonce
177        return flow
178
179    def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs):
180        """Validate the auth_response being redirected back, and then obtain tokens,
181        including ID token which can be used for user sign in.
182
183        Internally, it implements nonce to mitigate replay attack.
184        It also implements PKCE to mitigate the auth code interception attack.
185
186        See :func:`oauth2.Client.obtain_token_by_auth_code_flow` in parent class
187        for descriptions on other parameters and return value.
188        """
189        result = super(Client, self).obtain_token_by_auth_code_flow(
190            auth_code_flow, auth_response, **kwargs)
191        if "id_token_claims" in result:
192            nonce_in_id_token = result.get("id_token_claims", {}).get("nonce")
193            expected_hash = _nonce_hash(auth_code_flow["nonce"])
194            if nonce_in_id_token != expected_hash:
195                raise RuntimeError(
196                    'The nonce in id token ("%s") should match our nonce ("%s")' %
197                    (nonce_in_id_token, expected_hash))
198        return result
199
200    def obtain_token_by_browser(
201            self,
202            display=None,
203            prompt=None,
204            max_age=None,
205            ui_locales=None,
206            id_token_hint=None,  # It is relevant,
207                # because this library exposes raw ID token
208            login_hint=None,
209            acr_values=None,
210            **kwargs):
211        """A native app can use this method to obtain token via a local browser.
212
213        Internally, it implements nonce to mitigate replay attack.
214        It also implements PKCE to mitigate the auth code interception attack.
215
216        :param string display: Defined in
217            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
218        :param string prompt: Defined in
219            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
220        :param int max_age: Defined in
221            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
222        :param string ui_locales: Defined in
223            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
224        :param string id_token_hint: Defined in
225            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
226        :param string login_hint: Defined in
227            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
228        :param string acr_values: Defined in
229            `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_.
230
231        See :func:`oauth2.Client.obtain_token_by_browser` in parent class
232        for descriptions on other parameters and return value.
233        """
234        filtered_params = {k:v for k, v in dict(
235            prompt=prompt,
236            display=display,
237            max_age=max_age,
238            ui_locales=ui_locales,
239            id_token_hint=id_token_hint,
240            login_hint=login_hint,
241            acr_values=acr_values,
242            ).items() if v is not None}  # Filter out None values
243        return super(Client, self).obtain_token_by_browser(
244            auth_params=dict(kwargs.pop("auth_params", {}), **filtered_params),
245            **kwargs)
246
247