1"""
2    salt.utils.vt_helper
3    ~~~~~~~~~~~~~~~~~~~~
4
5    VT Helper
6
7    This module provides the SSHConnection to expose an SSH connection object
8    allowing users to programmatically execute commands on a remote server using
9    Salt VT.
10"""
11
12import logging
13import os
14import re
15
16from salt.utils.vt import Terminal, TerminalException
17
18SSH_PASSWORD_PROMPT_RE = re.compile(r"(?:.*)[Pp]assword(?: for .*)?:", re.M)
19KEY_VALID_RE = re.compile(r".*\(yes\/no\).*")
20
21log = logging.getLogger(__name__)
22
23
24class SSHConnection:
25    """
26    SSH Connection to a remote server.
27    """
28
29    def __init__(
30        self,
31        username="salt",
32        password="password",
33        host="localhost",
34        key_accept=False,
35        prompt=r"(Cmd)",
36        passwd_retries=3,
37        linesep=os.linesep,
38        ssh_args="",
39    ):
40        """
41        Establishes a connection to the remote server.
42
43        The format for parameters is:
44
45        username (string): The username to use for this
46            ssh connection. Defaults to root.
47        password (string): The password to use for this
48            ssh connection. Defaults to password.
49        host (string): The host to connect to.
50            Defaults to localhost.
51        key_accept (boolean): Should we accept this host's key
52            and add it to the known_hosts file? Defaults to False.
53        prompt (string): The shell prompt (regex) on the server.
54            Prompt is compiled into a regular expression.
55            Defaults to (Cmd)
56        passwd_retries (int): How many times should I try to send the password?
57            Defaults to 3.
58        linesep (string): The line separator to use when sending
59            commands to the server. Defaults to os.linesep.
60        ssh_args (string): Extra ssh args to use with ssh.
61             Example: '-o PubkeyAuthentication=no'
62        """
63        self.conn = Terminal(
64            "ssh {} -l {} {}".format(ssh_args, username, host),
65            shell=True,
66            log_stdout=True,
67            log_stdout_level="trace",
68            log_stderr=True,
69            log_stderr_level="trace",
70            stream_stdout=False,
71            stream_stderr=False,
72        )
73        sent_passwd = 0
74
75        self.prompt_re = re.compile(prompt)
76        self.linesep = linesep
77
78        while self.conn.has_unread_data:
79            stdout, stderr = self.conn.recv()
80
81            if stdout and SSH_PASSWORD_PROMPT_RE.search(stdout):
82                if not password:
83                    log.error("Failure while authentication.")
84                    raise TerminalException(
85                        "Permission denied, no authentication information"
86                    )
87                if sent_passwd < passwd_retries:
88                    self.conn.sendline(password, self.linesep)
89                    sent_passwd += 1
90                    continue
91                else:
92                    # asking for a password, and we can't seem to send it
93                    raise TerminalException("Password authentication failed")
94            elif stdout and KEY_VALID_RE.search(stdout):
95                # Connecting to this server for the first time
96                # and need to accept key
97                if key_accept:
98                    log.info("Adding %s to known_hosts", host)
99                    self.conn.sendline("yes")
100                    continue
101                else:
102                    self.conn.sendline("no")
103            elif stdout and self.prompt_re.search(stdout):
104                # Auth success!
105                # We now have a prompt
106                break
107
108    def sendline(self, cmd):
109        """
110        Send this command to the server and
111        return a tuple of the output and the stderr.
112
113        The format for parameters is:
114
115        cmd (string): The command to send to the sever.
116        """
117        self.conn.sendline(cmd, self.linesep)
118
119        # saw_prompt = False
120        ret_stdout = []
121        ret_stderr = []
122        while self.conn.has_unread_data:
123            stdout, stderr = self.conn.recv()
124
125            if stdout:
126                ret_stdout.append(stdout)
127            if stderr:
128                log.debug("Error while executing command.")
129                ret_stderr.append(stderr)
130
131            if stdout and self.prompt_re.search(stdout):
132                break
133
134        return "".join(ret_stdout), "".join(ret_stderr)
135
136    def close_connection(self):
137        """
138        Close the server connection
139        """
140        self.conn.close(terminate=True, kill=True)
141