1import base64
2import binascii
3import datetime
4import hashlib
5import hmac
6import os
7import re
8
9import boto3
10import six
11
12from .exceptions import ForceChangePasswordException
13
14# https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L22
15N_HEX = (
16    "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1"
17    "29024E088A67CC74020BBEA63B139B22514A08798E3404DD"
18    "EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245"
19    "E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED"
20    "EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D"
21    "C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F"
22    "83655D23DCA3AD961C62F356208552BB9ED529077096966D"
23    "670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B"
24    "E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9"
25    "DE2BCBF6955817183995497CEA956AE515D2261898FA0510"
26    "15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64"
27    "ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7"
28    "ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B"
29    "F12FFA06D98A0864D87602733EC86A64521F2B18177B200C"
30    "BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31"
31    "43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF"
32)
33# https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L49
34G_HEX = "2"
35INFO_BITS = bytearray("Caldera Derived Key", "utf-8")
36
37
38def hash_sha256(buf):
39    """AuthenticationHelper.hash"""
40    value = hashlib.sha256(buf).hexdigest()
41    return (64 - len(value)) * "0" + value
42
43
44def hex_hash(hex_string):
45    return hash_sha256(bytearray.fromhex(hex_string))
46
47
48def hex_to_long(hex_string):
49    return int(hex_string, 16)
50
51
52def long_to_hex(long_num):
53    return "%x" % long_num
54
55
56def get_random(nbytes):
57    random_hex = binascii.hexlify(os.urandom(nbytes))
58    return hex_to_long(random_hex)
59
60
61def pad_hex(long_int):
62    """
63    Converts a Long integer (or hex string) to hex format padded with zeroes for hashing
64    :param {Long integer|String} long_int Number or string to pad.
65    :return {String} Padded hex string.
66    """
67    if not isinstance(long_int, six.string_types):
68        hash_str = long_to_hex(long_int)
69    else:
70        hash_str = long_int
71    if len(hash_str) % 2 == 1:
72        hash_str = "0%s" % hash_str
73    elif hash_str[0] in "89ABCDEFabcdef":
74        hash_str = "00%s" % hash_str
75    return hash_str
76
77
78def compute_hkdf(ikm, salt):
79    """
80    Standard hkdf algorithm
81    :param {Buffer} ikm Input key material.
82    :param {Buffer} salt Salt value.
83    :return {Buffer} Strong key material.
84    @private
85    """
86    prk = hmac.new(salt, ikm, hashlib.sha256).digest()
87    info_bits_update = INFO_BITS + bytearray(chr(1), "utf-8")
88    hmac_hash = hmac.new(prk, info_bits_update, hashlib.sha256).digest()
89    return hmac_hash[:16]
90
91
92def calculate_u(big_a, big_b):
93    """
94    Calculate the client's value U which is the hash of A and B
95    :param {Long integer} big_a Large A value.
96    :param {Long integer} big_b Server B value.
97    :return {Long integer} Computed U value.
98    """
99    u_hex_hash = hex_hash(pad_hex(big_a) + pad_hex(big_b))
100    return hex_to_long(u_hex_hash)
101
102
103class AWSSRP:
104
105    NEW_PASSWORD_REQUIRED_CHALLENGE = "NEW_PASSWORD_REQUIRED"
106    PASSWORD_VERIFIER_CHALLENGE = "PASSWORD_VERIFIER"
107
108    def __init__(
109        self,
110        username,
111        password,
112        pool_id,
113        client_id,
114        pool_region=None,
115        client=None,
116        client_secret=None,
117    ):
118        if pool_region is not None and client is not None:
119            raise ValueError(
120                "pool_region and client should not both be specified "
121                "(region should be passed to the boto3 client instead)"
122            )
123
124        self.username = username
125        self.password = password
126        self.pool_id = pool_id
127        self.client_id = client_id
128        self.client_secret = client_secret
129        self.client = (
130            client if client else boto3.client("cognito-idp", region_name=pool_region)
131        )
132        self.big_n = hex_to_long(N_HEX)
133        self.val_g = hex_to_long(G_HEX)
134        self.val_k = hex_to_long(hex_hash("00" + N_HEX + "0" + G_HEX))
135        self.small_a_value = self.generate_random_small_a()
136        self.large_a_value = self.calculate_a()
137
138    def generate_random_small_a(self):
139        """
140        helper function to generate a random big integer
141        :return {Long integer} a random value.
142        """
143        random_long_int = get_random(128)
144        return random_long_int % self.big_n
145
146    def calculate_a(self):
147        """
148        Calculate the client's public value A = g^a%N
149        with the generated random number a
150        :param {Long integer} a Randomly generated small A.
151        :return {Long integer} Computed large A.
152        """
153        big_a = pow(self.val_g, self.small_a_value, self.big_n)
154        # safety check
155        if (big_a % self.big_n) == 0:
156            raise ValueError("Safety check for A failed")
157        return big_a
158
159    def get_password_authentication_key(self, username, password, server_b_value, salt):
160        """
161        Calculates the final hkdf based on computed S value, and computed U value and the key
162        :param {String} username Username.
163        :param {String} password Password.
164        :param {Long integer} server_b_value Server B value.
165        :param {Long integer} salt Generated salt.
166        :return {Buffer} Computed HKDF value.
167        """
168        u_value = calculate_u(self.large_a_value, server_b_value)
169        if u_value == 0:
170            raise ValueError("U cannot be zero.")
171        username_password = "%s%s:%s" % (self.pool_id.split("_")[1], username, password)
172        username_password_hash = hash_sha256(username_password.encode("utf-8"))
173
174        x_value = hex_to_long(hex_hash(pad_hex(salt) + username_password_hash))
175        g_mod_pow_xn = pow(self.val_g, x_value, self.big_n)
176        int_value2 = server_b_value - self.val_k * g_mod_pow_xn
177        s_value = pow(int_value2, self.small_a_value + u_value * x_value, self.big_n)
178        hkdf = compute_hkdf(
179            bytearray.fromhex(pad_hex(s_value)),
180            bytearray.fromhex(pad_hex(long_to_hex(u_value))),
181        )
182        return hkdf
183
184    def get_auth_params(self):
185        auth_params = {
186            "USERNAME": self.username,
187            "SRP_A": long_to_hex(self.large_a_value),
188        }
189        if self.client_secret is not None:
190            auth_params.update(
191                {
192                    "SECRET_HASH": self.get_secret_hash(
193                        self.username, self.client_id, self.client_secret
194                    )
195                }
196            )
197        return auth_params
198
199    @staticmethod
200    def get_secret_hash(username, client_id, client_secret):
201        message = bytearray(username + client_id, "utf-8")
202        hmac_obj = hmac.new(bytearray(client_secret, "utf-8"), message, hashlib.sha256)
203        return base64.standard_b64encode(hmac_obj.digest()).decode("utf-8")
204
205    def process_challenge(self, challenge_parameters):
206        internal_username = challenge_parameters["USERNAME"]
207        user_id_for_srp = challenge_parameters["USER_ID_FOR_SRP"]
208        salt_hex = challenge_parameters["SALT"]
209        srp_b_hex = challenge_parameters["SRP_B"]
210        secret_block_b64 = challenge_parameters["SECRET_BLOCK"]
211        # re strips leading zero from a day number (required by AWS Cognito)
212        timestamp = re.sub(
213            r" 0(\d) ",
214            r" \1 ",
215            datetime.datetime.utcnow().strftime("%a %b %d %H:%M:%S UTC %Y"),
216        )
217        hkdf = self.get_password_authentication_key(
218            user_id_for_srp, self.password, hex_to_long(srp_b_hex), salt_hex
219        )
220        secret_block_bytes = base64.standard_b64decode(secret_block_b64)
221        msg = (
222            bytearray(self.pool_id.split("_")[1], "utf-8")
223            + bytearray(user_id_for_srp, "utf-8")
224            + bytearray(secret_block_bytes)
225            + bytearray(timestamp, "utf-8")
226        )
227
228        hmac_obj = hmac.new(hkdf, msg, digestmod=hashlib.sha256)
229        signature_string = base64.standard_b64encode(hmac_obj.digest())
230        response = {
231            "TIMESTAMP": timestamp,
232            "USERNAME": internal_username,
233            "PASSWORD_CLAIM_SECRET_BLOCK": secret_block_b64,
234            "PASSWORD_CLAIM_SIGNATURE": signature_string.decode("utf-8"),
235        }
236        if self.client_secret is not None:
237            response.update(
238                {
239                    "SECRET_HASH": self.get_secret_hash(
240                        internal_username, self.client_id, self.client_secret
241                    )
242                }
243            )
244        return response
245
246    def authenticate_user(self, client=None):
247        boto_client = self.client or client
248        auth_params = self.get_auth_params()
249        response = boto_client.initiate_auth(
250            AuthFlow="USER_SRP_AUTH",
251            AuthParameters=auth_params,
252            ClientId=self.client_id,
253        )
254        if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE:
255            challenge_response = self.process_challenge(response["ChallengeParameters"])
256            tokens = boto_client.respond_to_auth_challenge(
257                ClientId=self.client_id,
258                ChallengeName=self.PASSWORD_VERIFIER_CHALLENGE,
259                ChallengeResponses=challenge_response,
260            )
261
262            if tokens.get("ChallengeName") == self.NEW_PASSWORD_REQUIRED_CHALLENGE:
263                raise ForceChangePasswordException(
264                    "Change password before authenticating"
265                )
266
267            return tokens
268
269        raise NotImplementedError(
270            "The %s challenge is not supported" % response["ChallengeName"]
271        )
272
273    def set_new_password_challenge(self, new_password, client=None):
274        boto_client = self.client or client
275        auth_params = self.get_auth_params()
276        response = boto_client.initiate_auth(
277            AuthFlow="USER_SRP_AUTH",
278            AuthParameters=auth_params,
279            ClientId=self.client_id,
280        )
281        if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE:
282            challenge_response = self.process_challenge(response["ChallengeParameters"])
283            tokens = boto_client.respond_to_auth_challenge(
284                ClientId=self.client_id,
285                ChallengeName=self.PASSWORD_VERIFIER_CHALLENGE,
286                ChallengeResponses=challenge_response,
287            )
288
289            if tokens["ChallengeName"] == self.NEW_PASSWORD_REQUIRED_CHALLENGE:
290                challenge_parameters = response["ChallengeParameters"]
291                challenge_response.update(
292                    {
293                        "USERNAME": challenge_parameters["USERNAME"],
294                        "NEW_PASSWORD": new_password,
295                    }
296                )
297                new_password_response = boto_client.respond_to_auth_challenge(
298                    ClientId=self.client_id,
299                    ChallengeName=self.NEW_PASSWORD_REQUIRED_CHALLENGE,
300                    Session=tokens["Session"],
301                    ChallengeResponses=challenge_response,
302                )
303                return new_password_response
304            return tokens
305
306        raise NotImplementedError(
307            "The %s challenge is not supported" % response["ChallengeName"]
308        )
309