1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Pure-Python RSA cryptography implementation.
16
17Uses the ``rsa``, ``pyasn1`` and ``pyasn1_modules`` packages
18to parse PEM files storing PKCS#1 or PKCS#8 keys as well as
19certificates. There is no support for p12 files.
20"""
21
22from __future__ import absolute_import
23
24from pyasn1.codec.der import decoder
25from pyasn1_modules import pem
26from pyasn1_modules.rfc2459 import Certificate
27from pyasn1_modules.rfc5208 import PrivateKeyInfo
28import rsa
29import six
30
31from google.auth import _helpers
32from google.auth.crypt import base
33
34_POW2 = (128, 64, 32, 16, 8, 4, 2, 1)
35_CERTIFICATE_MARKER = b"-----BEGIN CERTIFICATE-----"
36_PKCS1_MARKER = ("-----BEGIN RSA PRIVATE KEY-----", "-----END RSA PRIVATE KEY-----")
37_PKCS8_MARKER = ("-----BEGIN PRIVATE KEY-----", "-----END PRIVATE KEY-----")
38_PKCS8_SPEC = PrivateKeyInfo()
39
40
41def _bit_list_to_bytes(bit_list):
42    """Converts an iterable of 1s and 0s to bytes.
43
44    Combines the list 8 at a time, treating each group of 8 bits
45    as a single byte.
46
47    Args:
48        bit_list (Sequence): Sequence of 1s and 0s.
49
50    Returns:
51        bytes: The decoded bytes.
52    """
53    num_bits = len(bit_list)
54    byte_vals = bytearray()
55    for start in six.moves.xrange(0, num_bits, 8):
56        curr_bits = bit_list[start : start + 8]
57        char_val = sum(val * digit for val, digit in six.moves.zip(_POW2, curr_bits))
58        byte_vals.append(char_val)
59    return bytes(byte_vals)
60
61
62class RSAVerifier(base.Verifier):
63    """Verifies RSA cryptographic signatures using public keys.
64
65    Args:
66        public_key (rsa.key.PublicKey): The public key used to verify
67            signatures.
68    """
69
70    def __init__(self, public_key):
71        self._pubkey = public_key
72
73    @_helpers.copy_docstring(base.Verifier)
74    def verify(self, message, signature):
75        message = _helpers.to_bytes(message)
76        try:
77            return rsa.pkcs1.verify(message, signature, self._pubkey)
78        except (ValueError, rsa.pkcs1.VerificationError):
79            return False
80
81    @classmethod
82    def from_string(cls, public_key):
83        """Construct an Verifier instance from a public key or public
84        certificate string.
85
86        Args:
87            public_key (Union[str, bytes]): The public key in PEM format or the
88                x509 public key certificate.
89
90        Returns:
91            Verifier: The constructed verifier.
92
93        Raises:
94            ValueError: If the public_key can't be parsed.
95        """
96        public_key = _helpers.to_bytes(public_key)
97        is_x509_cert = _CERTIFICATE_MARKER in public_key
98
99        # If this is a certificate, extract the public key info.
100        if is_x509_cert:
101            der = rsa.pem.load_pem(public_key, "CERTIFICATE")
102            asn1_cert, remaining = decoder.decode(der, asn1Spec=Certificate())
103            if remaining != b"":
104                raise ValueError("Unused bytes", remaining)
105
106            cert_info = asn1_cert["tbsCertificate"]["subjectPublicKeyInfo"]
107            key_bytes = _bit_list_to_bytes(cert_info["subjectPublicKey"])
108            pubkey = rsa.PublicKey.load_pkcs1(key_bytes, "DER")
109        else:
110            pubkey = rsa.PublicKey.load_pkcs1(public_key, "PEM")
111        return cls(pubkey)
112
113
114class RSASigner(base.Signer, base.FromServiceAccountMixin):
115    """Signs messages with an RSA private key.
116
117    Args:
118        private_key (rsa.key.PrivateKey): The private key to sign with.
119        key_id (str): Optional key ID used to identify this private key. This
120            can be useful to associate the private key with its associated
121            public key or certificate.
122    """
123
124    def __init__(self, private_key, key_id=None):
125        self._key = private_key
126        self._key_id = key_id
127
128    @property
129    @_helpers.copy_docstring(base.Signer)
130    def key_id(self):
131        return self._key_id
132
133    @_helpers.copy_docstring(base.Signer)
134    def sign(self, message):
135        message = _helpers.to_bytes(message)
136        return rsa.pkcs1.sign(message, self._key, "SHA-256")
137
138    @classmethod
139    def from_string(cls, key, key_id=None):
140        """Construct an Signer instance from a private key in PEM format.
141
142        Args:
143            key (str): Private key in PEM format.
144            key_id (str): An optional key id used to identify the private key.
145
146        Returns:
147            google.auth.crypt.Signer: The constructed signer.
148
149        Raises:
150            ValueError: If the key cannot be parsed as PKCS#1 or PKCS#8 in
151                PEM format.
152        """
153        key = _helpers.from_bytes(key)  # PEM expects str in Python 3
154        marker_id, key_bytes = pem.readPemBlocksFromFile(
155            six.StringIO(key), _PKCS1_MARKER, _PKCS8_MARKER
156        )
157
158        # Key is in pkcs1 format.
159        if marker_id == 0:
160            private_key = rsa.key.PrivateKey.load_pkcs1(key_bytes, format="DER")
161        # Key is in pkcs8.
162        elif marker_id == 1:
163            key_info, remaining = decoder.decode(key_bytes, asn1Spec=_PKCS8_SPEC)
164            if remaining != b"":
165                raise ValueError("Unused bytes", remaining)
166            private_key_info = key_info.getComponentByName("privateKey")
167            private_key = rsa.key.PrivateKey.load_pkcs1(
168                private_key_info.asOctets(), format="DER"
169            )
170        else:
171            raise ValueError("No key could be detected.")
172
173        return cls(private_key, key_id=key_id)
174