1# coding: utf-8 2# Licensed under the Apache License, Version 2.0 (the "License"); 3# you may not use this file except in compliance with the License. 4# You may obtain a copy of the License at 5# 6# http://www.apache.org/licenses/LICENSE-2.0 7# 8# Unless required by applicable law or agreed to in writing, software 9# distributed under the License is distributed on an "AS IS" BASIS, 10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11# See the License for the specific language governing permissions and 12# limitations under the License. 13 14from __future__ import unicode_literals 15 16import base64 17 18from testinfra.backend import base 19 20 21class SshBackend(base.BaseBackend): 22 """Run command through ssh command""" 23 NAME = "ssh" 24 25 def __init__(self, hostspec, ssh_config=None, ssh_identity_file=None, 26 timeout=10, controlpersist=60, ssh_extra_args=None, 27 *args, **kwargs): 28 self.host = self.parse_hostspec(hostspec) 29 self.ssh_config = ssh_config 30 self.ssh_identity_file = ssh_identity_file 31 self.timeout = int(timeout) 32 self.controlpersist = int(controlpersist) 33 self.ssh_extra_args = ssh_extra_args 34 super(SshBackend, self).__init__(self.host.name, *args, **kwargs) 35 36 def run(self, command, *args, **kwargs): 37 return self.run_ssh(self.get_command(command, *args)) 38 39 def _build_ssh_command(self, command): 40 cmd = ["ssh"] 41 cmd_args = [] 42 if self.ssh_extra_args: 43 cmd.append(self.ssh_extra_args.replace('%', '%%')) 44 if self.ssh_config: 45 cmd.append("-F %s") 46 cmd_args.append(self.ssh_config) 47 if self.host.user: 48 cmd.append("-o User=%s") 49 cmd_args.append(self.host.user) 50 if self.host.port: 51 cmd.append("-o Port=%s") 52 cmd_args.append(self.host.port) 53 if self.ssh_identity_file: 54 cmd.append("-i %s") 55 cmd_args.append(self.ssh_identity_file) 56 if 'connecttimeout' not in (self.ssh_extra_args or '').lower(): 57 cmd.append("-o ConnectTimeout={}".format(self.timeout)) 58 if self.controlpersist and ( 59 'controlmaster' not in (self.ssh_extra_args or '').lower() 60 ): 61 cmd.append("-o ControlMaster=auto -o ControlPersist=%ds" % ( 62 self.controlpersist)) 63 cmd.append("%s %s") 64 cmd_args.extend([self.host.name, command]) 65 return cmd, cmd_args 66 67 def run_ssh(self, command): 68 cmd, cmd_args = self._build_ssh_command(command) 69 out = self.run_local( 70 " ".join(cmd), *cmd_args) 71 out.command = self.encode(command) 72 if out.rc == 255: 73 # ssh exits with the exit status of the remote command or with 255 74 # if an error occurred. 75 raise RuntimeError(out) 76 return out 77 78 79class SafeSshBackend(SshBackend): 80 """Run command using ssh command but try to get a more sane output 81 82 When using ssh (or a potentially bugged wrapper) additional output can be 83 added in stdout/stderr and exit status may not be propagate correctly 84 85 To avoid that kind of bugs, we wrap the command to have an output like 86 this: 87 88 TESTINFRA_START;EXIT_STATUS;STDOUT;STDERR;TESTINFRA_END 89 90 where STDOUT/STDERR are base64 encoded, then we parse that magic string to 91 get sanes variables 92 """ 93 NAME = "safe-ssh" 94 95 def run(self, command, *args, **kwargs): 96 orig_command = self.get_command(command, *args) 97 orig_command = self.get_command('sh -c %s', orig_command) 98 99 out = self.run_ssh(( 100 '''of=$(mktemp)&&ef=$(mktemp)&&%s >$of 2>$ef; r=$?;''' 101 '''echo "TESTINFRA_START;$r;$(base64 < $of);$(base64 < $ef);''' 102 '''TESTINFRA_END";rm -f $of $ef''') % (orig_command,)) 103 104 start = out.stdout.find("TESTINFRA_START;") + len("TESTINFRA_START;") 105 end = out.stdout.find("TESTINFRA_END") - 1 106 rc, stdout, stderr = out.stdout[start:end].split(";") 107 rc = int(rc) 108 stdout = base64.b64decode(stdout) 109 stderr = base64.b64decode(stderr) 110 return self.result(rc, self.encode(orig_command), stdout, stderr) 111