1# scp.py
2# Copyright (C) 2008 James Bardin <j.bardin@gmail.com>
3
4"""
5Utilities for sending files over ssh using the scp1 protocol.
6"""
7
8__version__ = '0.13.3'
9
10import locale
11import os
12import re
13from socket import timeout as SocketTimeout
14import types
15
16
17# this is quote from the shlex module, added in py3.3
18_find_unsafe = re.compile(br'[^\w@%+=:,./~-]').search
19
20
21def _sh_quote(s):
22    """Return a shell-escaped version of the string `s`."""
23    if not s:
24        return b""
25    if _find_unsafe(s) is None:
26        return s
27
28    # use single quotes, and put single quotes into double quotes
29    # the string $'b is then quoted as '$'"'"'b'
30    return b"'" + s.replace(b"'", b"'\"'\"'") + b"'"
31
32
33# Unicode conversion functions; assume UTF-8
34
35def asbytes(s):
36    """Turns unicode into bytes, if needed.
37
38    Assumes UTF-8.
39    """
40    if isinstance(s, bytes):
41        return s
42    else:
43        return s.encode('utf-8')
44
45
46def asunicode(s):
47    """Turns bytes into unicode, if needed.
48
49    Uses UTF-8.
50    """
51    if isinstance(s, bytes):
52        return s.decode('utf-8', 'replace')
53    else:
54        return s
55
56
57# os.path.sep is unicode on Python 3, no matter the platform
58bytes_sep = asbytes(os.path.sep)
59
60
61# Unicode conversion function for Windows
62# Used to convert local paths if the local machine is Windows
63
64def asunicode_win(s):
65    """Turns bytes into unicode, if needed.
66    """
67    if isinstance(s, bytes):
68        return s.decode(locale.getpreferredencoding())
69    else:
70        return s
71
72
73class SCPClient(object):
74    """
75    An scp1 implementation, compatible with openssh scp.
76    Raises SCPException for all transport related errors. Local filesystem
77    and OS errors pass through.
78
79    Main public methods are .put and .get
80    The get method is controlled by the remote scp instance, and behaves
81    accordingly. This means that symlinks are resolved, and the transfer is
82    halted after too many levels of symlinks are detected.
83    The put method uses os.walk for recursion, and sends files accordingly.
84    Since scp doesn't support symlinks, we send file symlinks as the file
85    (matching scp behaviour), but we make no attempt at symlinked directories.
86    """
87    def __init__(self, transport, buff_size=16384, socket_timeout=10.0,
88                 progress=None, progress4=None, sanitize=_sh_quote):
89        """
90        Create an scp1 client.
91
92        @param transport: an existing paramiko L{Transport}
93        @type transport: L{Transport}
94        @param buff_size: size of the scp send buffer.
95        @type buff_size: int
96        @param socket_timeout: channel socket timeout in seconds
97        @type socket_timeout: float
98        @param progress: callback - called with (filename, size, sent) during
99            transfers
100        @param progress4: callback - called with (filename, size, sent, peername)
101            during transfers. peername is a tuple contains (IP, PORT)
102        @param sanitize: function - called with filename, should return
103            safe or escaped string.  Uses _sh_quote by default.
104        @type progress: function(string, int, int, tuple)
105        """
106        self.transport = transport
107        self.buff_size = buff_size
108        self.socket_timeout = socket_timeout
109        self.channel = None
110        self.preserve_times = False
111        if progress is not None and progress4 is not None:
112            raise TypeError("You may only set one of progress, progress4")
113        elif progress4 is not None:
114            self._progress = progress4
115        elif progress is not None:
116            self._progress = lambda *a: progress(*a[:3])
117        else:
118            self._progress = None
119        self._recv_dir = b''
120        self._depth = 0
121        self._rename = False
122        self._utime = None
123        self.sanitize = sanitize
124        self._dirtimes = {}
125        self.peername = self.transport.getpeername()
126
127    def __enter__(self):
128        self.channel = self._open()
129        return self
130
131    def __exit__(self, type, value, traceback):
132        self.close()
133
134    def put(self, files, remote_path=b'.',
135            recursive=False, preserve_times=False):
136        """
137        Transfer files and directories to remote host.
138
139        @param files: A single path, or a list of paths to be transferred.
140            recursive must be True to transfer directories.
141        @type files: string OR list of strings
142        @param remote_path: path in which to receive the files on the remote
143            host. defaults to '.'
144        @type remote_path: str
145        @param recursive: transfer files and directories recursively
146        @type recursive: bool
147        @param preserve_times: preserve mtime and atime of transferred files
148            and directories.
149        @type preserve_times: bool
150        """
151        self.preserve_times = preserve_times
152        self.channel = self._open()
153        self._pushed = 0
154        self.channel.settimeout(self.socket_timeout)
155        scp_command = (b'scp -t ', b'scp -r -t ')[recursive]
156        self.channel.exec_command(scp_command +
157                                  self.sanitize(asbytes(remote_path)))
158        self._recv_confirm()
159
160        if not isinstance(files, (list, tuple)):
161            files = [files]
162
163        if recursive:
164            self._send_recursive(files)
165        else:
166            self._send_files(files)
167
168        self.close()
169
170    def putfo(self, fl, remote_path, mode='0644', size=None):
171        """
172        Transfer file-like object to remote host.
173
174        @param fl: opened file or file-like object to copy
175        @type fl: file-like object
176        @param remote_path: full destination path
177        @type remote_path: str
178        @param mode: permissions (posix-style) for the uploaded file
179        @type mode: str
180        @param size: size of the file in bytes. If ``None``, the size will be
181            computed using `seek()` and `tell()`.
182        """
183        if size is None:
184            pos = fl.tell()
185            fl.seek(0, os.SEEK_END)  # Seek to end
186            size = fl.tell() - pos
187            fl.seek(pos, os.SEEK_SET)  # Seek back
188
189        self.channel = self._open()
190        self.channel.settimeout(self.socket_timeout)
191        self.channel.exec_command(b'scp -t ' +
192                                  self.sanitize(asbytes(remote_path)))
193        self._recv_confirm()
194        self._send_file(fl, remote_path, mode, size=size)
195        self.close()
196
197    def get(self, remote_path, local_path='',
198            recursive=False, preserve_times=False):
199        """
200        Transfer files and directories from remote host to localhost.
201
202        @param remote_path: path to retrieve from remote host. since this is
203            evaluated by scp on the remote host, shell wildcards and
204            environment variables may be used.
205        @type remote_path: str
206        @param local_path: path in which to receive files locally
207        @type local_path: str
208        @param recursive: transfer files and directories recursively
209        @type recursive: bool
210        @param preserve_times: preserve mtime and atime of transferred files
211            and directories.
212        @type preserve_times: bool
213        """
214        if not isinstance(remote_path, (list, tuple)):
215            remote_path = [remote_path]
216        remote_path = [self.sanitize(asbytes(r)) for r in remote_path]
217        self._recv_dir = local_path or os.getcwd()
218        self._depth = 0
219        self._rename = (len(remote_path) == 1 and
220                        not os.path.isdir(os.path.abspath(local_path)))
221        if len(remote_path) > 1:
222            if not os.path.exists(self._recv_dir):
223                raise SCPException("Local path '%s' does not exist" %
224                                   asunicode(self._recv_dir))
225            elif not os.path.isdir(self._recv_dir):
226                raise SCPException("Local path '%s' is not a directory" %
227                                   asunicode(self._recv_dir))
228        rcsv = (b'', b' -r')[recursive]
229        prsv = (b'', b' -p')[preserve_times]
230        self.channel = self._open()
231        self._pushed = 0
232        self.channel.settimeout(self.socket_timeout)
233        self.channel.exec_command(b"scp" +
234                                  rcsv +
235                                  prsv +
236                                  b" -f " +
237                                  b' '.join(remote_path))
238        self._recv_all()
239        self.close()
240
241    def _open(self):
242        """open a scp channel"""
243        if self.channel is None or self.channel.closed:
244            self.channel = self.transport.open_session()
245
246        return self.channel
247
248    def close(self):
249        """close scp channel"""
250        if self.channel is not None:
251            self.channel.close()
252            self.channel = None
253
254    def _read_stats(self, name):
255        """return just the file stats needed for scp"""
256        if os.name == 'nt':
257            name = asunicode(name)
258        stats = os.stat(name)
259        mode = oct(stats.st_mode)[-4:]
260        size = stats.st_size
261        atime = int(stats.st_atime)
262        mtime = int(stats.st_mtime)
263        return (mode, size, mtime, atime)
264
265    def _send_files(self, files):
266        for name in files:
267            (mode, size, mtime, atime) = self._read_stats(name)
268            if self.preserve_times:
269                self._send_time(mtime, atime)
270            fl = open(name, 'rb')
271            self._send_file(fl, name, mode, size)
272            fl.close()
273
274    def _send_file(self, fl, name, mode, size):
275        basename = asbytes(os.path.basename(name))
276        # The protocol can't handle \n in the filename.
277        # Quote them as the control sequence \^J for now,
278        # which is how openssh handles it.
279        self.channel.sendall(("C%s %d " % (mode, size)).encode('ascii') +
280                             basename.replace(b'\n', b'\\^J') + b"\n")
281        self._recv_confirm()
282        file_pos = 0
283        if self._progress:
284            if size == 0:
285                # avoid divide-by-zero
286                self._progress(basename, 1, 1, self.peername)
287            else:
288                self._progress(basename, size, 0, self.peername)
289        buff_size = self.buff_size
290        chan = self.channel
291        while file_pos < size:
292            chan.sendall(fl.read(buff_size))
293            file_pos = fl.tell()
294            if self._progress:
295                self._progress(basename, size, file_pos, self.peername)
296        chan.sendall('\x00')
297        self._recv_confirm()
298
299    def _chdir(self, from_dir, to_dir):
300        # Pop until we're one level up from our next push.
301        # Push *once* into to_dir.
302        # This is dependent on the depth-first traversal from os.walk
303
304        # add path.sep to each when checking the prefix, so we can use
305        # path.dirname after
306        common = os.path.commonprefix([from_dir + bytes_sep,
307                                       to_dir + bytes_sep])
308        # now take the dirname, since commonprefix is character based,
309        # and we either have a separator, or a partial name
310        common = os.path.dirname(common)
311        cur_dir = from_dir.rstrip(bytes_sep)
312        while cur_dir != common:
313            cur_dir = os.path.split(cur_dir)[0]
314            self._send_popd()
315        # now we're in our common base directory, so on
316        self._send_pushd(to_dir)
317
318    def _send_recursive(self, files):
319        for base in files:
320            if not os.path.isdir(base):
321                # filename mixed into the bunch
322                self._send_files([base])
323                continue
324            last_dir = asbytes(base)
325            for root, dirs, fls in os.walk(base):
326                self._chdir(last_dir, asbytes(root))
327                self._send_files([os.path.join(root, f) for f in fls])
328                last_dir = asbytes(root)
329            # back out of the directory
330            while self._pushed > 0:
331                self._send_popd()
332
333    def _send_pushd(self, directory):
334        (mode, size, mtime, atime) = self._read_stats(directory)
335        basename = asbytes(os.path.basename(directory))
336        if self.preserve_times:
337            self._send_time(mtime, atime)
338        self.channel.sendall(('D%s 0 ' % mode).encode('ascii') +
339                             basename.replace(b'\n', b'\\^J') + b'\n')
340        self._recv_confirm()
341        self._pushed += 1
342
343    def _send_popd(self):
344        self.channel.sendall('E\n')
345        self._recv_confirm()
346        self._pushed -= 1
347
348    def _send_time(self, mtime, atime):
349        self.channel.sendall(('T%d 0 %d 0\n' % (mtime, atime)).encode('ascii'))
350        self._recv_confirm()
351
352    def _recv_confirm(self):
353        # read scp response
354        msg = b''
355        try:
356            msg = self.channel.recv(512)
357        except SocketTimeout:
358            raise SCPException('Timeout waiting for scp response')
359        # slice off the first byte, so this compare will work in py2 and py3
360        if msg and msg[0:1] == b'\x00':
361            return
362        elif msg and msg[0:1] == b'\x01':
363            raise SCPException(asunicode(msg[1:]))
364        elif self.channel.recv_stderr_ready():
365            msg = self.channel.recv_stderr(512)
366            raise SCPException(asunicode(msg))
367        elif not msg:
368            raise SCPException('No response from server')
369        else:
370            raise SCPException('Invalid response from server', msg)
371
372    def _recv_all(self):
373        # loop over scp commands, and receive as necessary
374        command = {b'C': self._recv_file,
375                   b'T': self._set_time,
376                   b'D': self._recv_pushd,
377                   b'E': self._recv_popd}
378        while not self.channel.closed:
379            # wait for command as long as we're open
380            self.channel.sendall('\x00')
381            msg = self.channel.recv(1024)
382            if not msg:  # chan closed while recving
383                break
384            assert msg[-1:] == b'\n'
385            msg = msg[:-1]
386            code = msg[0:1]
387            if code not in command:
388                raise SCPException(asunicode(msg[1:]))
389            command[code](msg[1:])
390        # directory times can't be set until we're done writing files
391        self._set_dirtimes()
392
393    def _set_time(self, cmd):
394        try:
395            times = cmd.split(b' ')
396            mtime = int(times[0])
397            atime = int(times[2]) or mtime
398        except:
399            self.channel.send(b'\x01')
400            raise SCPException('Bad time format')
401        # save for later
402        self._utime = (atime, mtime)
403
404    def _recv_file(self, cmd):
405        chan = self.channel
406        parts = cmd.strip().split(b' ', 2)
407
408        try:
409            mode = int(parts[0], 8)
410            size = int(parts[1])
411            if self._rename:
412                path = self._recv_dir
413                self._rename = False
414            elif os.name == 'nt':
415                name = parts[2].decode('utf-8')
416                assert not os.path.isabs(name)
417                path = os.path.join(asunicode_win(self._recv_dir), name)
418            else:
419                name = parts[2]
420                assert not os.path.isabs(name)
421                path = os.path.join(asbytes(self._recv_dir), name)
422        except:
423            chan.send('\x01')
424            chan.close()
425            raise SCPException('Bad file format')
426
427        try:
428            file_hdl = open(path, 'wb')
429        except IOError as e:
430            chan.send(b'\x01' + str(e).encode('utf-8'))
431            chan.close()
432            raise
433
434        if self._progress:
435            if size == 0:
436                # avoid divide-by-zero
437                self._progress(path, 1, 1, self.peername)
438            else:
439                self._progress(path, size, 0, self.peername)
440        buff_size = self.buff_size
441        pos = 0
442        chan.send(b'\x00')
443        try:
444            while pos < size:
445                # we have to make sure we don't read the final byte
446                if size - pos <= buff_size:
447                    buff_size = size - pos
448                data = chan.recv(buff_size)
449                if not data:
450                    raise SCPException("Underlying channel was closed")
451                file_hdl.write(data)
452                pos = file_hdl.tell()
453                if self._progress:
454                    self._progress(path, size, pos, self.peername)
455            msg = chan.recv(512)
456            if msg and msg[0:1] != b'\x00':
457                raise SCPException(asunicode(msg[1:]))
458        except SocketTimeout:
459            chan.close()
460            raise SCPException('Error receiving, socket.timeout')
461
462        file_hdl.truncate()
463        try:
464            os.utime(path, self._utime)
465            self._utime = None
466            os.chmod(path, mode)
467            # should we notify the other end?
468        finally:
469            file_hdl.close()
470        # '\x00' confirmation sent in _recv_all
471
472    def _recv_pushd(self, cmd):
473        parts = cmd.split(b' ', 2)
474        try:
475            mode = int(parts[0], 8)
476            if self._rename:
477                path = self._recv_dir
478                self._rename = False
479            elif os.name == 'nt':
480                name = parts[2].decode('utf-8')
481                assert not os.path.isabs(name)
482                path = os.path.join(asunicode_win(self._recv_dir), name)
483                self._depth += 1
484            else:
485                name = parts[2]
486                assert not os.path.isabs(name)
487                path = os.path.join(asbytes(self._recv_dir), name)
488                self._depth += 1
489        except:
490            self.channel.send(b'\x01')
491            raise SCPException('Bad directory format')
492        try:
493            if not os.path.exists(path):
494                os.mkdir(path, mode)
495            elif os.path.isdir(path):
496                os.chmod(path, mode)
497            else:
498                raise SCPException('%s: Not a directory' % path)
499            self._dirtimes[path] = (self._utime)
500            self._utime = None
501            self._recv_dir = path
502        except (OSError, SCPException) as e:
503            self.channel.send(b'\x01' + asbytes(str(e)))
504            raise
505
506    def _recv_popd(self, *cmd):
507        if self._depth > 0:
508            self._depth -= 1
509            self._recv_dir = os.path.split(self._recv_dir)[0]
510
511    def _set_dirtimes(self):
512        try:
513            for d in self._dirtimes:
514                os.utime(d, self._dirtimes[d])
515        finally:
516            self._dirtimes = {}
517
518
519class SCPException(Exception):
520    """SCP exception class"""
521    pass
522
523
524def put(transport, files, remote_path=b'.',
525        recursive=False, preserve_times=False):
526    """
527    Transfer files and directories to remote host.
528
529    This is a convenience function that creates a SCPClient from the given
530    transport and closes it at the end, useful for one-off transfers.
531
532    @param files: A single path, or a list of paths to be transferred.
533        recursive must be True to transfer directories.
534    @type files: string OR list of strings
535    @param remote_path: path in which to receive the files on the remote host.
536        defaults to '.'
537    @type remote_path: str
538    @param recursive: transfer files and directories recursively
539    @type recursive: bool
540    @param preserve_times: preserve mtime and atime of transferred files and
541        directories.
542    @type preserve_times: bool
543    """
544    with SCPClient(transport) as client:
545        client.put(files, remote_path, recursive, preserve_times)
546
547
548def get(transport, remote_path, local_path='',
549        recursive=False, preserve_times=False):
550    """
551    Transfer files and directories from remote host to localhost.
552
553    This is a convenience function that creates a SCPClient from the given
554    transport and closes it at the end, useful for one-off transfers.
555
556    @param transport: an paramiko L{Transport}
557    @type transport: L{Transport}
558    @param remote_path: path to retrieve from remote host. since this is
559        evaluated by scp on the remote host, shell wildcards and environment
560        variables may be used.
561    @type remote_path: str
562    @param local_path: path in which to receive files locally
563    @type local_path: str
564    @param recursive: transfer files and directories recursively
565    @type recursive: bool
566    @param preserve_times: preserve mtime and atime of transferred files
567        and directories.
568    @type preserve_times: bool
569    """
570    with SCPClient(transport) as client:
571        client.get(remote_path, local_path, recursive, preserve_times)
572