1#------------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation.
4# All rights reserved.
5#
6# This code is licensed under the MIT License.
7#
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files(the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions :
14#
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24# THE SOFTWARE.
25#
26#------------------------------------------------------------------------------
27
28import time
29import datetime
30import uuid
31import base64
32import binascii
33import re
34
35import jwt
36
37from .constants import Jwt
38from .log import Logger
39from .adal_error import AdalError
40
41def _get_date_now():
42    return datetime.datetime.now()
43
44def _get_new_jwt_id():
45    return str(uuid.uuid4())
46
47def _create_x5t_value(thumbprint):
48    hex_val = binascii.a2b_hex(thumbprint)
49    return base64.urlsafe_b64encode(hex_val).decode()
50
51def _sign_jwt(header, payload, certificate):
52    try:
53        encoded_jwt = _encode_jwt(payload, certificate, header)
54    except Exception as exp:
55        raise AdalError("Error:Invalid Certificate: Expected Start of Certificate to be '-----BEGIN RSA PRIVATE KEY-----'", exp)
56    _raise_on_invalid_jwt_signature(encoded_jwt)
57    return encoded_jwt
58
59def _encode_jwt(payload, certificate, header):
60    encoded = jwt.encode(payload, certificate, algorithm='RS256', headers=header)
61    try:
62        return encoded.decode()  # PyJWT 1.x returns bytes; historically we convert it to string
63    except AttributeError:
64        return encoded  # PyJWT 2 will return string
65
66def _raise_on_invalid_jwt_signature(encoded_jwt):
67    segments = encoded_jwt.split('.')
68    if len(segments) < 3 or not segments[2]:
69        raise AdalError('Failed to sign JWT. This is most likely due to an invalid certificate.')
70
71def _extract_certs(public_cert_content):
72    # Parses raw public certificate file contents and returns a list of strings
73    # Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())}
74    public_certificates = re.findall(
75        r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
76        public_cert_content, re.I)
77    if public_certificates:
78        return [cert.strip() for cert in public_certificates]
79    # The public cert tags are not found in the input,
80    # let's make best effort to exclude a private key pem file.
81    if "PRIVATE KEY" in public_cert_content:
82        raise ValueError(
83            "We expect your public key but detect a private key instead")
84    return [public_cert_content.strip()]
85
86class SelfSignedJwt(object):
87
88    NumCharIn128BitHexString = 128/8*2
89    numCharIn160BitHexString = 160/8*2
90    ThumbprintRegEx = r"^[a-f\d]*$"
91
92    def __init__(self, call_context, authority, client_id):
93        self._log = Logger('SelfSignedJwt', call_context['log_context'])
94        self._call_context = call_context
95
96        self._authortiy = authority
97        self._token_endpoint = authority.token_endpoint
98        self._client_id = client_id
99
100    def _create_header(self, thumbprint, public_certificate):
101        x5t = _create_x5t_value(thumbprint)
102        header = {'typ':'JWT', 'alg':'RS256', 'x5t':x5t}
103        if public_certificate:
104            header['x5c'] = _extract_certs(public_certificate)
105        self._log.debug("Creating self signed JWT header. x5t: %(x5t)s, x5c: %(x5c)s",
106                        {"x5t": x5t, "x5c": public_certificate})
107
108        return header
109
110    def _create_payload(self):
111        now = _get_date_now()
112        minutes = datetime.timedelta(0, 0, 0, 0, Jwt.SELF_SIGNED_JWT_LIFETIME)
113        expires = now + minutes
114
115        self._log.debug(
116            'Creating self signed JWT payload. Expires: %(expires)s NotBefore: %(nbf)s',
117            {"expires": expires, "nbf": now})
118
119        jwt_payload = {}
120        jwt_payload[Jwt.AUDIENCE] = self._token_endpoint
121        jwt_payload[Jwt.ISSUER] = self._client_id
122        jwt_payload[Jwt.SUBJECT] = self._client_id
123        jwt_payload[Jwt.NOT_BEFORE] = int(time.mktime(now.timetuple()))
124        jwt_payload[Jwt.EXPIRES_ON] = int(time.mktime(expires.timetuple()))
125        jwt_payload[Jwt.JWT_ID] = _get_new_jwt_id()
126
127        return jwt_payload
128
129    def _raise_on_invalid_thumbprint(self, thumbprint):
130        thumbprint_sizes = [self.NumCharIn128BitHexString, self.numCharIn160BitHexString]
131        size_ok = len(thumbprint) in thumbprint_sizes
132        if not size_ok or not re.search(self.ThumbprintRegEx, thumbprint):
133            raise AdalError("The thumbprint does not match a known format")
134
135    def _reduce_thumbprint(self, thumbprint):
136        canonical = thumbprint.lower().replace(' ', '').replace(':', '')
137        self._raise_on_invalid_thumbprint(canonical)
138        return canonical
139
140    def create(self, certificate, thumbprint, public_certificate):
141        thumbprint = self._reduce_thumbprint(thumbprint)
142
143        header = self._create_header(thumbprint, public_certificate)
144        payload = self._create_payload()
145        return _sign_jwt(header, payload, certificate)
146