1# -*- coding: utf-8 -*-
2#
3# (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
4#
5# Ansible is free software: you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation, either version 3 of the License, or
8# (at your option) any later version.
9#
10# Ansible is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
17
18from __future__ import absolute_import, division, print_function
19__metaclass__ = type
20
21
22import abc
23import datetime
24import errno
25import hashlib
26import os
27import re
28
29from ansible.module_utils import six
30from ansible.module_utils.common.text.converters import to_native, to_bytes
31
32try:
33    from OpenSSL import crypto
34    HAS_PYOPENSSL = True
35except ImportError:
36    # Error handled in the calling module.
37    HAS_PYOPENSSL = False
38
39try:
40    from cryptography import x509
41    from cryptography.hazmat.backends import default_backend as cryptography_backend
42    from cryptography.hazmat.primitives.serialization import load_pem_private_key
43    from cryptography.hazmat.primitives import hashes
44    from cryptography.hazmat.primitives import serialization
45except ImportError:
46    # Error handled in the calling module.
47    pass
48
49from .basic import (
50    OpenSSLObjectError,
51    OpenSSLBadPassphraseError,
52)
53
54
55# This list of preferred fingerprints is used when prefer_one=True is supplied to the
56# fingerprinting methods.
57PREFERRED_FINGERPRINTS = (
58    'sha256', 'sha3_256', 'sha512', 'sha3_512', 'sha384', 'sha3_384', 'sha1', 'md5'
59)
60
61
62def get_fingerprint_of_bytes(source, prefer_one=False):
63    """Generate the fingerprint of the given bytes."""
64
65    fingerprint = {}
66
67    try:
68        algorithms = hashlib.algorithms
69    except AttributeError:
70        try:
71            algorithms = hashlib.algorithms_guaranteed
72        except AttributeError:
73            return None
74
75    if prefer_one:
76        # Sort algorithms to have the ones in PREFERRED_FINGERPRINTS at the beginning
77        prefered_algorithms = [algorithm for algorithm in PREFERRED_FINGERPRINTS if algorithm in algorithms]
78        prefered_algorithms += sorted([algorithm for algorithm in algorithms if algorithm not in PREFERRED_FINGERPRINTS])
79        algorithms = prefered_algorithms
80
81    for algo in algorithms:
82        f = getattr(hashlib, algo)
83        try:
84            h = f(source)
85        except ValueError:
86            # This can happen for hash algorithms not supported in FIPS mode
87            # (https://github.com/ansible/ansible/issues/67213)
88            continue
89        try:
90            # Certain hash functions have a hexdigest() which expects a length parameter
91            pubkey_digest = h.hexdigest()
92        except TypeError:
93            pubkey_digest = h.hexdigest(32)
94        fingerprint[algo] = ':'.join(pubkey_digest[i:i + 2] for i in range(0, len(pubkey_digest), 2))
95        if prefer_one:
96            break
97
98    return fingerprint
99
100
101def get_fingerprint_of_privatekey(privatekey, backend='pyopenssl', prefer_one=False):
102    """Generate the fingerprint of the public key. """
103
104    if backend == 'pyopenssl':
105        try:
106            publickey = crypto.dump_publickey(crypto.FILETYPE_ASN1, privatekey)
107        except AttributeError:
108            # If PyOpenSSL < 16.0 crypto.dump_publickey() will fail.
109            try:
110                bio = crypto._new_mem_buf()
111                rc = crypto._lib.i2d_PUBKEY_bio(bio, privatekey._pkey)
112                if rc != 1:
113                    crypto._raise_current_error()
114                publickey = crypto._bio_to_string(bio)
115            except AttributeError:
116                # By doing this we prevent the code from raising an error
117                # yet we return no value in the fingerprint hash.
118                return None
119    elif backend == 'cryptography':
120        publickey = privatekey.public_key().public_bytes(
121            serialization.Encoding.DER,
122            serialization.PublicFormat.SubjectPublicKeyInfo
123        )
124
125    return get_fingerprint_of_bytes(publickey, prefer_one=prefer_one)
126
127
128def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl', prefer_one=False):
129    """Generate the fingerprint of the public key. """
130
131    privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend)
132
133    return get_fingerprint_of_privatekey(privatekey, backend=backend, prefer_one=prefer_one)
134
135
136def load_privatekey(path, passphrase=None, check_passphrase=True, content=None, backend='pyopenssl'):
137    """Load the specified OpenSSL private key.
138
139    The content can also be specified via content; in that case,
140    this function will not load the key from disk.
141    """
142
143    try:
144        if content is None:
145            with open(path, 'rb') as b_priv_key_fh:
146                priv_key_detail = b_priv_key_fh.read()
147        else:
148            priv_key_detail = content
149    except (IOError, OSError) as exc:
150        raise OpenSSLObjectError(exc)
151
152    if backend == 'pyopenssl':
153
154        # First try: try to load with real passphrase (resp. empty string)
155        # Will work if this is the correct passphrase, or the key is not
156        # password-protected.
157        try:
158            result = crypto.load_privatekey(crypto.FILETYPE_PEM,
159                                            priv_key_detail,
160                                            to_bytes(passphrase or ''))
161        except crypto.Error as e:
162            if len(e.args) > 0 and len(e.args[0]) > 0:
163                if e.args[0][0][2] in ('bad decrypt', 'bad password read'):
164                    # This happens in case we have the wrong passphrase.
165                    if passphrase is not None:
166                        raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key!')
167                    else:
168                        raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!')
169            raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e))
170        if check_passphrase:
171            # Next we want to make sure that the key is actually protected by
172            # a passphrase (in case we did try the empty string before, make
173            # sure that the key is not protected by the empty string)
174            try:
175                crypto.load_privatekey(crypto.FILETYPE_PEM,
176                                       priv_key_detail,
177                                       to_bytes('y' if passphrase == 'x' else 'x'))
178                if passphrase is not None:
179                    # Since we can load the key without an exception, the
180                    # key isn't password-protected
181                    raise OpenSSLBadPassphraseError('Passphrase provided, but private key is not password-protected!')
182            except crypto.Error as e:
183                if passphrase is None and len(e.args) > 0 and len(e.args[0]) > 0:
184                    if e.args[0][0][2] in ('bad decrypt', 'bad password read'):
185                        # The key is obviously protected by the empty string.
186                        # Don't do this at home (if it's possible at all)...
187                        raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!')
188    elif backend == 'cryptography':
189        try:
190            result = load_pem_private_key(priv_key_detail,
191                                          None if passphrase is None else to_bytes(passphrase),
192                                          cryptography_backend())
193        except TypeError:
194            raise OpenSSLBadPassphraseError('Wrong or empty passphrase provided for private key')
195        except ValueError:
196            raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key')
197
198    return result
199
200
201def load_publickey(path=None, content=None, backend=None):
202    if content is None:
203        if path is None:
204            raise OpenSSLObjectError('Must provide either path or content')
205        try:
206            with open(path, 'rb') as b_priv_key_fh:
207                content = b_priv_key_fh.read()
208        except (IOError, OSError) as exc:
209            raise OpenSSLObjectError(exc)
210
211    if backend == 'cryptography':
212        try:
213            return serialization.load_pem_public_key(content, backend=cryptography_backend())
214        except Exception as e:
215            raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e))
216    else:
217        try:
218            return crypto.load_publickey(crypto.FILETYPE_PEM, content)
219        except crypto.Error as e:
220            raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e))
221
222
223def load_certificate(path, content=None, backend='pyopenssl'):
224    """Load the specified certificate."""
225
226    try:
227        if content is None:
228            with open(path, 'rb') as cert_fh:
229                cert_content = cert_fh.read()
230        else:
231            cert_content = content
232    except (IOError, OSError) as exc:
233        raise OpenSSLObjectError(exc)
234    if backend == 'pyopenssl':
235        return crypto.load_certificate(crypto.FILETYPE_PEM, cert_content)
236    elif backend == 'cryptography':
237        try:
238            return x509.load_pem_x509_certificate(cert_content, cryptography_backend())
239        except ValueError as exc:
240            raise OpenSSLObjectError(exc)
241
242
243def load_certificate_request(path, content=None, backend='pyopenssl'):
244    """Load the specified certificate signing request."""
245    try:
246        if content is None:
247            with open(path, 'rb') as csr_fh:
248                csr_content = csr_fh.read()
249        else:
250            csr_content = content
251    except (IOError, OSError) as exc:
252        raise OpenSSLObjectError(exc)
253    if backend == 'pyopenssl':
254        return crypto.load_certificate_request(crypto.FILETYPE_PEM, csr_content)
255    elif backend == 'cryptography':
256        try:
257            return x509.load_pem_x509_csr(csr_content, cryptography_backend())
258        except ValueError as exc:
259            raise OpenSSLObjectError(exc)
260
261
262def parse_name_field(input_dict):
263    """Take a dict with key: value or key: list_of_values mappings and return a list of tuples"""
264
265    result = []
266    for key in input_dict:
267        if isinstance(input_dict[key], list):
268            for entry in input_dict[key]:
269                result.append((key, entry))
270        else:
271            result.append((key, input_dict[key]))
272    return result
273
274
275def convert_relative_to_datetime(relative_time_string):
276    """Get a datetime.datetime or None from a string in the time format described in sshd_config(5)"""
277
278    parsed_result = re.match(
279        r"^(?P<prefix>[+-])((?P<weeks>\d+)[wW])?((?P<days>\d+)[dD])?((?P<hours>\d+)[hH])?((?P<minutes>\d+)[mM])?((?P<seconds>\d+)[sS]?)?$",
280        relative_time_string)
281
282    if parsed_result is None or len(relative_time_string) == 1:
283        # not matched or only a single "+" or "-"
284        return None
285
286    offset = datetime.timedelta(0)
287    if parsed_result.group("weeks") is not None:
288        offset += datetime.timedelta(weeks=int(parsed_result.group("weeks")))
289    if parsed_result.group("days") is not None:
290        offset += datetime.timedelta(days=int(parsed_result.group("days")))
291    if parsed_result.group("hours") is not None:
292        offset += datetime.timedelta(hours=int(parsed_result.group("hours")))
293    if parsed_result.group("minutes") is not None:
294        offset += datetime.timedelta(
295            minutes=int(parsed_result.group("minutes")))
296    if parsed_result.group("seconds") is not None:
297        offset += datetime.timedelta(
298            seconds=int(parsed_result.group("seconds")))
299
300    if parsed_result.group("prefix") == "+":
301        return datetime.datetime.utcnow() + offset
302    else:
303        return datetime.datetime.utcnow() - offset
304
305
306def get_relative_time_option(input_string, input_name, backend='cryptography'):
307    """Return an absolute timespec if a relative timespec or an ASN1 formatted
308       string is provided.
309
310       The return value will be a datetime object for the cryptography backend,
311       and a ASN1 formatted string for the pyopenssl backend."""
312    result = to_native(input_string)
313    if result is None:
314        raise OpenSSLObjectError(
315            'The timespec "%s" for %s is not valid' %
316            input_string, input_name)
317    # Relative time
318    if result.startswith("+") or result.startswith("-"):
319        result_datetime = convert_relative_to_datetime(result)
320        if backend == 'pyopenssl':
321            return result_datetime.strftime("%Y%m%d%H%M%SZ")
322        elif backend == 'cryptography':
323            return result_datetime
324    # Absolute time
325    if backend == 'pyopenssl':
326        return input_string
327    elif backend == 'cryptography':
328        for date_fmt in ['%Y%m%d%H%M%SZ', '%Y%m%d%H%MZ', '%Y%m%d%H%M%S%z', '%Y%m%d%H%M%z']:
329            try:
330                return datetime.datetime.strptime(result, date_fmt)
331            except ValueError:
332                pass
333
334        raise OpenSSLObjectError(
335            'The time spec "%s" for %s is invalid' %
336            (input_string, input_name)
337        )
338
339
340def select_message_digest(digest_string):
341    digest = None
342    if digest_string == 'sha256':
343        digest = hashes.SHA256()
344    elif digest_string == 'sha384':
345        digest = hashes.SHA384()
346    elif digest_string == 'sha512':
347        digest = hashes.SHA512()
348    elif digest_string == 'sha1':
349        digest = hashes.SHA1()
350    elif digest_string == 'md5':
351        digest = hashes.MD5()
352    return digest
353
354
355@six.add_metaclass(abc.ABCMeta)
356class OpenSSLObject(object):
357
358    def __init__(self, path, state, force, check_mode):
359        self.path = path
360        self.state = state
361        self.force = force
362        self.name = os.path.basename(path)
363        self.changed = False
364        self.check_mode = check_mode
365
366    def check(self, module, perms_required=True):
367        """Ensure the resource is in its desired state."""
368
369        def _check_state():
370            return os.path.exists(self.path)
371
372        def _check_perms(module):
373            file_args = module.load_file_common_arguments(module.params)
374            if module.check_file_absent_if_check_mode(file_args['path']):
375                return False
376            return not module.set_fs_attributes_if_different(file_args, False)
377
378        if not perms_required:
379            return _check_state()
380
381        return _check_state() and _check_perms(module)
382
383    @abc.abstractmethod
384    def dump(self):
385        """Serialize the object into a dictionary."""
386
387        pass
388
389    @abc.abstractmethod
390    def generate(self):
391        """Generate the resource."""
392
393        pass
394
395    def remove(self, module):
396        """Remove the resource from the filesystem."""
397        if self.check_mode:
398            if os.path.exists(self.path):
399                self.changed = True
400            return
401
402        try:
403            os.remove(self.path)
404            self.changed = True
405        except OSError as exc:
406            if exc.errno != errno.ENOENT:
407                raise OpenSSLObjectError(exc)
408            else:
409                pass
410