1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import os
7import os.path
8
9from knack.util import CLIError
10from knack.log import get_logger
11
12logger = get_logger(__name__)
13
14
15def is_valid_ssh_rsa_public_key(openssh_pubkey):
16    # http://stackoverflow.com/questions/2494450/ssh-rsa-public-key-validation-using-a-regular-expression # pylint: disable=line-too-long
17    # A "good enough" check is to see if the key starts with the correct header.
18    import struct
19    try:
20        from base64 import decodebytes as base64_decode
21    except ImportError:
22        # deprecated and redirected to decodebytes in Python 3
23        from base64 import decodestring as base64_decode
24
25    parts = openssh_pubkey.split()
26    if len(parts) < 2:
27        return False
28    key_type = parts[0]
29    key_string = parts[1]
30
31    data = base64_decode(key_string.encode())  # pylint:disable=deprecated-method
32    int_len = 4
33    str_len = struct.unpack('>I', data[:int_len])[0]  # this should return 7
34    return data[int_len:int_len + str_len] == key_type.encode()
35
36
37def generate_ssh_keys(private_key_filepath, public_key_filepath):
38    import paramiko
39    from paramiko.ssh_exception import PasswordRequiredException, SSHException
40
41    if os.path.isfile(public_key_filepath):
42        try:
43            with open(public_key_filepath, 'r') as public_key_file:
44                public_key = public_key_file.read()
45                pub_ssh_dir = os.path.dirname(public_key_filepath)
46                logger.warning("Public SSH key file '%s' already exists in the directory: '%s'. "
47                               "New SSH key files will not be generated.",
48                               public_key_filepath, pub_ssh_dir)
49
50                return public_key
51        except IOError as e:
52            raise CLIError(e)
53
54    ssh_dir = os.path.dirname(private_key_filepath)
55    if not os.path.exists(ssh_dir):
56        os.makedirs(ssh_dir)
57        os.chmod(ssh_dir, 0o700)
58
59    if os.path.isfile(private_key_filepath):
60        # try to use existing private key if it exists.
61        try:
62            key = paramiko.RSAKey(filename=private_key_filepath)
63            logger.warning("Private SSH key file '%s' was found in the directory: '%s'. "
64                           "A paired public key file '%s' will be generated.",
65                           private_key_filepath, ssh_dir, public_key_filepath)
66        except (PasswordRequiredException, SSHException, IOError) as e:
67            raise CLIError(e)
68
69    else:
70        # otherwise generate new private key.
71        key = paramiko.RSAKey.generate(2048)
72        key.write_private_key_file(private_key_filepath)
73        os.chmod(private_key_filepath, 0o600)
74
75    with open(public_key_filepath, 'w') as public_key_file:
76        public_key = '{} {}'.format(key.get_name(), key.get_base64())
77        public_key_file.write(public_key)
78    os.chmod(public_key_filepath, 0o644)
79
80    return public_key
81