1# -*- coding: utf-8 -*-
2#
3# Copyright: (c) 2021, Andrew Pantuso (@ajpantuso) <ajpantuso@gmail.com>
4#
5# Ansible is free software: you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation, either version 3 of the License, or
8# (at your option) any later version.
9#
10# Ansible is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
17
18from __future__ import absolute_import, division, print_function
19__metaclass__ = type
20
21# Protocol References
22# -------------------
23# https://datatracker.ietf.org/doc/html/rfc4251
24# https://datatracker.ietf.org/doc/html/rfc4253
25# https://datatracker.ietf.org/doc/html/rfc5656
26# https://datatracker.ietf.org/doc/html/rfc8032
27# https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
28#
29# Inspired by:
30# ------------
31# https://github.com/pyca/cryptography/blob/main/src/cryptography/hazmat/primitives/serialization/ssh.py
32# https://github.com/paramiko/paramiko/blob/master/paramiko/message.py
33
34import abc
35import binascii
36import os
37from base64 import b64encode
38from datetime import datetime
39from hashlib import sha256
40
41from ansible.module_utils import six
42from ansible.module_utils.common.text.converters import to_text
43from ansible_collections.community.crypto.plugins.module_utils.crypto.support import convert_relative_to_datetime
44from ansible_collections.community.crypto.plugins.module_utils.openssh.utils import (
45    OpensshParser,
46    _OpensshWriter,
47)
48
49# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
50_USER_TYPE = 1
51_HOST_TYPE = 2
52
53_SSH_TYPE_STRINGS = {
54    'rsa': b"ssh-rsa",
55    'dsa': b"ssh-dss",
56    'ecdsa-nistp256': b"ecdsa-sha2-nistp256",
57    'ecdsa-nistp384': b"ecdsa-sha2-nistp384",
58    'ecdsa-nistp521': b"ecdsa-sha2-nistp521",
59    'ed25519': b"ssh-ed25519",
60}
61_CERT_SUFFIX_V01 = b"-cert-v01@openssh.com"
62
63# See https://datatracker.ietf.org/doc/html/rfc5656#section-6.1
64_ECDSA_CURVE_IDENTIFIERS = {
65    'ecdsa-nistp256': b'nistp256',
66    'ecdsa-nistp384': b'nistp384',
67    'ecdsa-nistp521': b'nistp521',
68}
69_ECDSA_CURVE_IDENTIFIERS_LOOKUP = {
70    b'nistp256': 'ecdsa-nistp256',
71    b'nistp384': 'ecdsa-nistp384',
72    b'nistp521': 'ecdsa-nistp521',
73}
74
75_ALWAYS = datetime(1970, 1, 1)
76_FOREVER = datetime.max
77
78_CRITICAL_OPTIONS = (
79    'force-command',
80    'source-address',
81    'verify-required',
82)
83
84_DIRECTIVES = (
85    'clear',
86    'no-x11-forwarding',
87    'no-agent-forwarding',
88    'no-port-forwarding',
89    'no-pty',
90    'no-user-rc',
91)
92
93_EXTENSIONS = (
94    'permit-x11-forwarding',
95    'permit-agent-forwarding',
96    'permit-port-forwarding',
97    'permit-pty',
98    'permit-user-rc'
99)
100
101if six.PY3:
102    long = int
103
104
105class OpensshCertificateTimeParameters(object):
106    def __init__(self, valid_from, valid_to):
107        self._valid_from = self.to_datetime(valid_from)
108        self._valid_to = self.to_datetime(valid_to)
109
110        if self._valid_from > self._valid_to:
111            raise ValueError("Valid from: %s must not be greater than Valid to: %s" % (valid_from, valid_to))
112
113    def __eq__(self, other):
114        if not isinstance(other, type(self)):
115            return NotImplemented
116        else:
117            return self._valid_from == other._valid_from and self._valid_to == other._valid_to
118
119    def __ne__(self, other):
120        return not self == other
121
122    @property
123    def validity_string(self):
124        if not (self._valid_from == _ALWAYS and self._valid_to == _FOREVER):
125            return "%s:%s" % (
126                self.valid_from(date_format='openssh'), self.valid_to(date_format='openssh')
127            )
128        return ""
129
130    def valid_from(self, date_format):
131        return self.format_datetime(self._valid_from, date_format)
132
133    def valid_to(self, date_format):
134        return self.format_datetime(self._valid_to, date_format)
135
136    def within_range(self, valid_at):
137        if valid_at is not None:
138            valid_at_datetime = self.to_datetime(valid_at)
139            return self._valid_from <= valid_at_datetime <= self._valid_to
140        return True
141
142    @staticmethod
143    def format_datetime(dt, date_format):
144        if date_format in ('human_readable', 'openssh'):
145            if dt == _ALWAYS:
146                result = 'always'
147            elif dt == _FOREVER:
148                result = 'forever'
149            else:
150                result = dt.isoformat() if date_format == 'human_readable' else dt.strftime("%Y%m%d%H%M%S")
151        elif date_format == 'timestamp':
152            td = dt - _ALWAYS
153            result = int((td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6)
154        else:
155            raise ValueError("%s is not a valid format" % date_format)
156        return result
157
158    @staticmethod
159    def to_datetime(time_string_or_timestamp):
160        try:
161            if isinstance(time_string_or_timestamp, six.string_types):
162                result = OpensshCertificateTimeParameters._time_string_to_datetime(time_string_or_timestamp.strip())
163            elif isinstance(time_string_or_timestamp, (long, int)):
164                result = OpensshCertificateTimeParameters._timestamp_to_datetime(time_string_or_timestamp)
165            else:
166                raise ValueError(
167                    "Value must be of type (str, unicode, int, long) not %s" % type(time_string_or_timestamp)
168                )
169        except ValueError:
170            raise
171        return result
172
173    @staticmethod
174    def _timestamp_to_datetime(timestamp):
175        if timestamp == 0x0:
176            result = _ALWAYS
177        elif timestamp == 0xFFFFFFFFFFFFFFFF:
178            result = _FOREVER
179        else:
180            try:
181                result = datetime.utcfromtimestamp(timestamp)
182            except OverflowError as e:
183                raise ValueError
184        return result
185
186    @staticmethod
187    def _time_string_to_datetime(time_string):
188        result = None
189        if time_string == 'always':
190            result = _ALWAYS
191        elif time_string == 'forever':
192            result = _FOREVER
193        elif is_relative_time_string(time_string):
194            result = convert_relative_to_datetime(time_string)
195        else:
196            for time_format in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
197                try:
198                    result = datetime.strptime(time_string, time_format)
199                except ValueError:
200                    pass
201            if result is None:
202                raise ValueError
203        return result
204
205
206class OpensshCertificateOption(object):
207    def __init__(self, option_type, name, data):
208        if option_type not in ('critical', 'extension'):
209            raise ValueError("type must be either 'critical' or 'extension'")
210
211        if not isinstance(name, six.string_types):
212            raise TypeError("name must be a string not %s" % type(name))
213
214        if not isinstance(data, six.string_types):
215            raise TypeError("data must be a string not %s" % type(data))
216
217        self._option_type = option_type
218        self._name = name.lower()
219        self._data = data
220
221    def __eq__(self, other):
222        if not isinstance(other, type(self)):
223            return NotImplemented
224
225        return all([
226            self._option_type == other._option_type,
227            self._name == other._name,
228            self._data == other._data,
229        ])
230
231    def __hash__(self):
232        return hash((self._option_type, self._name, self._data))
233
234    def __ne__(self, other):
235        return not self == other
236
237    def __str__(self):
238        if self._data:
239            return "%s=%s" % (self._name, self._data)
240        return self._name
241
242    @property
243    def data(self):
244        return self._data
245
246    @property
247    def name(self):
248        return self._name
249
250    @property
251    def type(self):
252        return self._option_type
253
254    @classmethod
255    def from_string(cls, option_string):
256        if not isinstance(option_string, six.string_types):
257            raise ValueError("option_string must be a string not %s" % type(option_string))
258        option_type = None
259
260        if ':' in option_string:
261            option_type, value = option_string.strip().split(':', 1)
262            if '=' in value:
263                name, data = value.split('=', 1)
264            else:
265                name, data = value, ''
266        elif '=' in option_string:
267            name, data = option_string.strip().split('=', 1)
268        else:
269            name, data = option_string.strip(), ''
270
271        return cls(
272            option_type=option_type or get_option_type(name.lower()),
273            name=name,
274            data=data
275        )
276
277
278@six.add_metaclass(abc.ABCMeta)
279class OpensshCertificateInfo:
280    """Encapsulates all certificate information which is signed by a CA key"""
281    def __init__(self,
282                 nonce=None,
283                 serial=None,
284                 cert_type=None,
285                 key_id=None,
286                 principals=None,
287                 valid_after=None,
288                 valid_before=None,
289                 critical_options=None,
290                 extensions=None,
291                 reserved=None,
292                 signing_key=None):
293        self.nonce = nonce
294        self.serial = serial
295        self._cert_type = cert_type
296        self.key_id = key_id
297        self.principals = principals
298        self.valid_after = valid_after
299        self.valid_before = valid_before
300        self.critical_options = critical_options
301        self.extensions = extensions
302        self.reserved = reserved
303        self.signing_key = signing_key
304
305        self.type_string = None
306
307    @property
308    def cert_type(self):
309        if self._cert_type == _USER_TYPE:
310            return 'user'
311        elif self._cert_type == _HOST_TYPE:
312            return 'host'
313        else:
314            return ''
315
316    @cert_type.setter
317    def cert_type(self, cert_type):
318        if cert_type == 'user' or cert_type == _USER_TYPE:
319            self._cert_type = _USER_TYPE
320        elif cert_type == 'host' or cert_type == _HOST_TYPE:
321            self._cert_type = _HOST_TYPE
322        else:
323            raise ValueError("%s is not a valid certificate type" % cert_type)
324
325    def signing_key_fingerprint(self):
326        return fingerprint(self.signing_key)
327
328    @abc.abstractmethod
329    def public_key_fingerprint(self):
330        pass
331
332    @abc.abstractmethod
333    def parse_public_numbers(self, parser):
334        pass
335
336
337class OpensshRSACertificateInfo(OpensshCertificateInfo):
338    def __init__(self, e=None, n=None, **kwargs):
339        super(OpensshRSACertificateInfo, self).__init__(**kwargs)
340        self.type_string = _SSH_TYPE_STRINGS['rsa'] + _CERT_SUFFIX_V01
341        self.e = e
342        self.n = n
343
344    # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
345    def public_key_fingerprint(self):
346        if any([self.e is None, self.n is None]):
347            return b''
348
349        writer = _OpensshWriter()
350        writer.string(_SSH_TYPE_STRINGS['rsa'])
351        writer.mpint(self.e)
352        writer.mpint(self.n)
353
354        return fingerprint(writer.bytes())
355
356    def parse_public_numbers(self, parser):
357        self.e = parser.mpint()
358        self.n = parser.mpint()
359
360
361class OpensshDSACertificateInfo(OpensshCertificateInfo):
362    def __init__(self, p=None, q=None, g=None, y=None, **kwargs):
363        super(OpensshDSACertificateInfo, self).__init__(**kwargs)
364        self.type_string = _SSH_TYPE_STRINGS['dsa'] + _CERT_SUFFIX_V01
365        self.p = p
366        self.q = q
367        self.g = g
368        self.y = y
369
370    # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
371    def public_key_fingerprint(self):
372        if any([self.p is None, self.q is None, self.g is None, self.y is None]):
373            return b''
374
375        writer = _OpensshWriter()
376        writer.string(_SSH_TYPE_STRINGS['dsa'])
377        writer.mpint(self.p)
378        writer.mpint(self.q)
379        writer.mpint(self.g)
380        writer.mpint(self.y)
381
382        return fingerprint(writer.bytes())
383
384    def parse_public_numbers(self, parser):
385        self.p = parser.mpint()
386        self.q = parser.mpint()
387        self.g = parser.mpint()
388        self.y = parser.mpint()
389
390
391class OpensshECDSACertificateInfo(OpensshCertificateInfo):
392    def __init__(self, curve=None, public_key=None, **kwargs):
393        super(OpensshECDSACertificateInfo, self).__init__(**kwargs)
394        self._curve = None
395        if curve is not None:
396            self.curve = curve
397
398        self.public_key = public_key
399
400    @property
401    def curve(self):
402        return self._curve
403
404    @curve.setter
405    def curve(self, curve):
406        if curve in _ECDSA_CURVE_IDENTIFIERS.values():
407            self._curve = curve
408            self.type_string = _SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[curve]] + _CERT_SUFFIX_V01
409        else:
410            raise ValueError(
411                "Curve must be one of %s" % (b','.join(list(_ECDSA_CURVE_IDENTIFIERS.values()))).decode('UTF-8')
412            )
413
414    # See https://datatracker.ietf.org/doc/html/rfc4253#section-6.6
415    def public_key_fingerprint(self):
416        if any([self.curve is None, self.public_key is None]):
417            return b''
418
419        writer = _OpensshWriter()
420        writer.string(_SSH_TYPE_STRINGS[_ECDSA_CURVE_IDENTIFIERS_LOOKUP[self.curve]])
421        writer.string(self.curve)
422        writer.string(self.public_key)
423
424        return fingerprint(writer.bytes())
425
426    def parse_public_numbers(self, parser):
427        self.curve = parser.string()
428        self.public_key = parser.string()
429
430
431class OpensshED25519CertificateInfo(OpensshCertificateInfo):
432    def __init__(self, pk=None, **kwargs):
433        super(OpensshED25519CertificateInfo, self).__init__(**kwargs)
434        self.type_string = _SSH_TYPE_STRINGS['ed25519'] + _CERT_SUFFIX_V01
435        self.pk = pk
436
437    def public_key_fingerprint(self):
438        if self.pk is None:
439            return b''
440
441        writer = _OpensshWriter()
442        writer.string(_SSH_TYPE_STRINGS['ed25519'])
443        writer.string(self.pk)
444
445        return fingerprint(writer.bytes())
446
447    def parse_public_numbers(self, parser):
448        self.pk = parser.string()
449
450
451# See https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD
452class OpensshCertificate(object):
453    """Encapsulates a formatted OpenSSH certificate including signature and signing key"""
454    def __init__(self, cert_info, signature):
455
456        self._cert_info = cert_info
457        self.signature = signature
458
459    @classmethod
460    def load(cls, path):
461        if not os.path.exists(path):
462            raise ValueError("%s is not a valid path." % path)
463
464        try:
465            with open(path, 'rb') as cert_file:
466                data = cert_file.read()
467        except (IOError, OSError) as e:
468            raise ValueError("%s cannot be opened for reading: %s" % (path, e))
469
470        try:
471            format_identifier, b64_cert = data.split(b' ')[:2]
472            cert = binascii.a2b_base64(b64_cert)
473        except (binascii.Error, ValueError):
474            raise ValueError("Certificate not in OpenSSH format")
475
476        for key_type, string in _SSH_TYPE_STRINGS.items():
477            if format_identifier == string + _CERT_SUFFIX_V01:
478                pub_key_type = key_type
479                break
480        else:
481            raise ValueError("Invalid certificate format identifier: %s" % format_identifier)
482
483        parser = OpensshParser(cert)
484
485        if format_identifier != parser.string():
486            raise ValueError("Certificate formats do not match")
487
488        try:
489            cert_info = cls._parse_cert_info(pub_key_type, parser)
490            signature = parser.string()
491        except (TypeError, ValueError) as e:
492            raise ValueError("Invalid certificate data: %s" % e)
493
494        if parser.remaining_bytes():
495            raise ValueError(
496                "%s bytes of additional data was not parsed while loading %s" % (parser.remaining_bytes(), path)
497            )
498
499        return cls(
500            cert_info=cert_info,
501            signature=signature,
502        )
503
504    @property
505    def type_string(self):
506        return to_text(self._cert_info.type_string)
507
508    @property
509    def nonce(self):
510        return self._cert_info.nonce
511
512    @property
513    def public_key(self):
514        return to_text(self._cert_info.public_key_fingerprint())
515
516    @property
517    def serial(self):
518        return self._cert_info.serial
519
520    @property
521    def type(self):
522        return self._cert_info.cert_type
523
524    @property
525    def key_id(self):
526        return to_text(self._cert_info.key_id)
527
528    @property
529    def principals(self):
530        return [to_text(p) for p in self._cert_info.principals]
531
532    @property
533    def valid_after(self):
534        return self._cert_info.valid_after
535
536    @property
537    def valid_before(self):
538        return self._cert_info.valid_before
539
540    @property
541    def critical_options(self):
542        return [
543            OpensshCertificateOption('critical', to_text(n), to_text(d)) for n, d in self._cert_info.critical_options
544        ]
545
546    @property
547    def extensions(self):
548        return [OpensshCertificateOption('extension', to_text(n), to_text(d)) for n, d in self._cert_info.extensions]
549
550    @property
551    def reserved(self):
552        return self._cert_info.reserved
553
554    @property
555    def signing_key(self):
556        return to_text(self._cert_info.signing_key_fingerprint())
557
558    @property
559    def signature_type(self):
560        signature_data = OpensshParser.signature_data(self.signature)
561        return to_text(signature_data['signature_type'])
562
563    @staticmethod
564    def _parse_cert_info(pub_key_type, parser):
565        cert_info = get_cert_info_object(pub_key_type)
566        cert_info.nonce = parser.string()
567        cert_info.parse_public_numbers(parser)
568        cert_info.serial = parser.uint64()
569        cert_info.cert_type = parser.uint32()
570        cert_info.key_id = parser.string()
571        cert_info.principals = parser.string_list()
572        cert_info.valid_after = parser.uint64()
573        cert_info.valid_before = parser.uint64()
574        cert_info.critical_options = parser.option_list()
575        cert_info.extensions = parser.option_list()
576        cert_info.reserved = parser.string()
577        cert_info.signing_key = parser.string()
578
579        return cert_info
580
581    def to_dict(self):
582        time_parameters = OpensshCertificateTimeParameters(
583            valid_from=self.valid_after,
584            valid_to=self.valid_before
585        )
586        return {
587            'type_string': self.type_string,
588            'nonce': self.nonce,
589            'serial': self.serial,
590            'cert_type': self.type,
591            'identifier': self.key_id,
592            'principals': self.principals,
593            'valid_after': time_parameters.valid_from(date_format='human_readable'),
594            'valid_before': time_parameters.valid_to(date_format='human_readable'),
595            'critical_options': [str(critical_option) for critical_option in self.critical_options],
596            'extensions': [str(extension) for extension in self.extensions],
597            'reserved': self.reserved,
598            'public_key': self.public_key,
599            'signing_key': self.signing_key,
600        }
601
602
603def apply_directives(directives):
604    if any(d not in _DIRECTIVES for d in directives):
605        raise ValueError("directives must be one of %s" % ", ".join(_DIRECTIVES))
606
607    directive_to_option = {
608        'no-x11-forwarding': OpensshCertificateOption('extension', 'permit-x11-forwarding', ''),
609        'no-agent-forwarding': OpensshCertificateOption('extension', 'permit-agent-forwarding', ''),
610        'no-port-forwarding': OpensshCertificateOption('extension', 'permit-port-forwarding', ''),
611        'no-pty': OpensshCertificateOption('extension', 'permit-pty', ''),
612        'no-user-rc': OpensshCertificateOption('extension', 'permit-user-rc', ''),
613    }
614
615    if 'clear' in directives:
616        return []
617    else:
618        return list(set(default_options()) - set(directive_to_option[d] for d in directives))
619
620
621def default_options():
622    return [OpensshCertificateOption('extension', name, '') for name in _EXTENSIONS]
623
624
625def fingerprint(public_key):
626    """Generates a SHA256 hash and formats output to resemble ``ssh-keygen``"""
627    h = sha256()
628    h.update(public_key)
629    return b'SHA256:' + b64encode(h.digest()).rstrip(b'=')
630
631
632def get_cert_info_object(key_type):
633    if key_type == 'rsa':
634        cert_info = OpensshRSACertificateInfo()
635    elif key_type == 'dsa':
636        cert_info = OpensshDSACertificateInfo()
637    elif key_type in ('ecdsa-nistp256', 'ecdsa-nistp384', 'ecdsa-nistp521'):
638        cert_info = OpensshECDSACertificateInfo()
639    elif key_type == 'ed25519':
640        cert_info = OpensshED25519CertificateInfo()
641    else:
642        raise ValueError("%s is not a valid key type" % key_type)
643
644    return cert_info
645
646
647def get_option_type(name):
648    if name in _CRITICAL_OPTIONS:
649        result = 'critical'
650    elif name in _EXTENSIONS:
651        result = 'extension'
652    else:
653        raise ValueError("%s is not a valid option. " % name +
654                         "Custom options must start with 'critical:' or 'extension:' to indicate type")
655    return result
656
657
658def is_relative_time_string(time_string):
659    return time_string.startswith("+") or time_string.startswith("-")
660
661
662def parse_option_list(option_list):
663    critical_options = []
664    directives = []
665    extensions = []
666
667    for option in option_list:
668        if option.lower() in _DIRECTIVES:
669            directives.append(option.lower())
670        else:
671            option_object = OpensshCertificateOption.from_string(option)
672            if option_object.type == 'critical':
673                critical_options.append(option_object)
674            else:
675                extensions.append(option_object)
676
677    return critical_options, list(set(extensions + apply_directives(directives)))
678