1# coding: utf-8
2# Licensed under the Apache License, Version 2.0 (the "License");
3# you may not use this file except in compliance with the License.
4# You may obtain a copy of the License at
5#
6#     http://www.apache.org/licenses/LICENSE-2.0
7#
8# Unless required by applicable law or agreed to in writing, software
9# distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and
12# limitations under the License.
13
14from __future__ import unicode_literals
15
16import base64
17
18from testinfra.backend import base
19
20
21class SshBackend(base.BaseBackend):
22    """Run command through ssh command"""
23    NAME = "ssh"
24
25    def __init__(self, hostspec, ssh_config=None, ssh_identity_file=None,
26                 timeout=10, controlpersist=60, ssh_extra_args=None,
27                 *args, **kwargs):
28        self.host = self.parse_hostspec(hostspec)
29        self.ssh_config = ssh_config
30        self.ssh_identity_file = ssh_identity_file
31        self.timeout = int(timeout)
32        self.controlpersist = int(controlpersist)
33        self.ssh_extra_args = ssh_extra_args
34        super(SshBackend, self).__init__(self.host.name, *args, **kwargs)
35
36    def run(self, command, *args, **kwargs):
37        return self.run_ssh(self.get_command(command, *args))
38
39    def _build_ssh_command(self, command):
40        cmd = ["ssh"]
41        cmd_args = []
42        if self.ssh_extra_args:
43            cmd.append(self.ssh_extra_args.replace('%', '%%'))
44        if self.ssh_config:
45            cmd.append("-F %s")
46            cmd_args.append(self.ssh_config)
47        if self.host.user:
48            cmd.append("-o User=%s")
49            cmd_args.append(self.host.user)
50        if self.host.port:
51            cmd.append("-o Port=%s")
52            cmd_args.append(self.host.port)
53        if self.ssh_identity_file:
54            cmd.append("-i %s")
55            cmd_args.append(self.ssh_identity_file)
56        if 'connecttimeout' not in (self.ssh_extra_args or '').lower():
57            cmd.append("-o ConnectTimeout={}".format(self.timeout))
58        if self.controlpersist and (
59            'controlmaster' not in (self.ssh_extra_args or '').lower()
60        ):
61            cmd.append("-o ControlMaster=auto -o ControlPersist=%ds" % (
62                self.controlpersist))
63        cmd.append("%s %s")
64        cmd_args.extend([self.host.name, command])
65        return cmd, cmd_args
66
67    def run_ssh(self, command):
68        cmd, cmd_args = self._build_ssh_command(command)
69        out = self.run_local(
70            " ".join(cmd), *cmd_args)
71        out.command = self.encode(command)
72        if out.rc == 255:
73            # ssh exits with the exit status of the remote command or with 255
74            # if an error occurred.
75            raise RuntimeError(out)
76        return out
77
78
79class SafeSshBackend(SshBackend):
80    """Run command using ssh command but try to get a more sane output
81
82    When using ssh (or a potentially bugged wrapper) additional output can be
83    added in stdout/stderr and exit status may not be propagate correctly
84
85    To avoid that kind of bugs, we wrap the command to have an output like
86    this:
87
88    TESTINFRA_START;EXIT_STATUS;STDOUT;STDERR;TESTINFRA_END
89
90    where STDOUT/STDERR are base64 encoded, then we parse that magic string to
91    get sanes variables
92    """
93    NAME = "safe-ssh"
94
95    def run(self, command, *args, **kwargs):
96        orig_command = self.get_command(command, *args)
97        orig_command = self.get_command('sh -c %s', orig_command)
98
99        out = self.run_ssh((
100            '''of=$(mktemp)&&ef=$(mktemp)&&%s >$of 2>$ef; r=$?;'''
101            '''echo "TESTINFRA_START;$r;$(base64 < $of);$(base64 < $ef);'''
102            '''TESTINFRA_END";rm -f $of $ef''') % (orig_command,))
103
104        start = out.stdout.find("TESTINFRA_START;") + len("TESTINFRA_START;")
105        end = out.stdout.find("TESTINFRA_END") - 1
106        rc, stdout, stderr = out.stdout[start:end].split(";")
107        rc = int(rc)
108        stdout = base64.b64decode(stdout)
109        stderr = base64.b64decode(stderr)
110        return self.result(rc, self.encode(orig_command), stdout, stderr)
111