1# Copyright (C) 2011  Jeff Forcier <jeff@bitprophet.org>
2#
3# This file is part of ssh.
4#
5# 'ssh' is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.
18
19import select
20import socket
21import struct
22
23from ssh.common import *
24from ssh import util
25from ssh.channel import Channel
26from ssh.message import Message
27
28
29CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, CMD_FSTAT, \
30           CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, CMD_REMOVE, CMD_MKDIR, \
31           CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, CMD_READLINK, CMD_SYMLINK \
32           = range(1, 21)
33CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106)
34CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202)
35
36SFTP_OK = 0
37SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, SFTP_BAD_MESSAGE, \
38         SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED = range(1, 9)
39
40SFTP_DESC = [ 'Success',
41              'End of file',
42              'No such file',
43              'Permission denied',
44              'Failure',
45              'Bad message',
46              'No connection',
47              'Connection lost',
48              'Operation unsupported' ]
49
50SFTP_FLAG_READ = 0x1
51SFTP_FLAG_WRITE = 0x2
52SFTP_FLAG_APPEND = 0x4
53SFTP_FLAG_CREATE = 0x8
54SFTP_FLAG_TRUNC = 0x10
55SFTP_FLAG_EXCL = 0x20
56
57_VERSION = 3
58
59
60# for debugging
61CMD_NAMES = {
62    CMD_INIT: 'init',
63    CMD_VERSION: 'version',
64    CMD_OPEN: 'open',
65    CMD_CLOSE: 'close',
66    CMD_READ: 'read',
67    CMD_WRITE: 'write',
68    CMD_LSTAT: 'lstat',
69    CMD_FSTAT: 'fstat',
70    CMD_SETSTAT: 'setstat',
71    CMD_FSETSTAT: 'fsetstat',
72    CMD_OPENDIR: 'opendir',
73    CMD_READDIR: 'readdir',
74    CMD_REMOVE: 'remove',
75    CMD_MKDIR: 'mkdir',
76    CMD_RMDIR: 'rmdir',
77    CMD_REALPATH: 'realpath',
78    CMD_STAT: 'stat',
79    CMD_RENAME: 'rename',
80    CMD_READLINK: 'readlink',
81    CMD_SYMLINK: 'symlink',
82    CMD_STATUS: 'status',
83    CMD_HANDLE: 'handle',
84    CMD_DATA: 'data',
85    CMD_NAME: 'name',
86    CMD_ATTRS: 'attrs',
87    CMD_EXTENDED: 'extended',
88    CMD_EXTENDED_REPLY: 'extended_reply'
89    }
90
91
92class SFTPError (Exception):
93    pass
94
95
96class BaseSFTP (object):
97    def __init__(self):
98        self.logger = util.get_logger('ssh.sftp')
99        self.sock = None
100        self.ultra_debug = False
101
102
103    ###  internals...
104
105
106    def _send_version(self):
107        self._send_packet(CMD_INIT, struct.pack('>I', _VERSION))
108        t, data = self._read_packet()
109        if t != CMD_VERSION:
110            raise SFTPError('Incompatible sftp protocol')
111        version = struct.unpack('>I', data[:4])[0]
112        #        if version != _VERSION:
113        #            raise SFTPError('Incompatible sftp protocol')
114        return version
115
116    def _send_server_version(self):
117        # winscp will freak out if the server sends version info before the
118        # client finishes sending INIT.
119        t, data = self._read_packet()
120        if t != CMD_INIT:
121            raise SFTPError('Incompatible sftp protocol')
122        version = struct.unpack('>I', data[:4])[0]
123        # advertise that we support "check-file"
124        extension_pairs = [ 'check-file', 'md5,sha1' ]
125        msg = Message()
126        msg.add_int(_VERSION)
127        msg.add(*extension_pairs)
128        self._send_packet(CMD_VERSION, str(msg))
129        return version
130
131    def _log(self, level, msg, *args):
132        self.logger.log(level, msg, *args)
133
134    def _write_all(self, out):
135        while len(out) > 0:
136            n = self.sock.send(out)
137            if n <= 0:
138                raise EOFError()
139            if n == len(out):
140                return
141            out = out[n:]
142        return
143
144    def _read_all(self, n):
145        out = ''
146        while n > 0:
147            if isinstance(self.sock, socket.socket):
148                # sometimes sftp is used directly over a socket instead of
149                # through a ssh channel.  in this case, check periodically
150                # if the socket is closed.  (for some reason, recv() won't ever
151                # return or raise an exception, but calling select on a closed
152                # socket will.)
153                while True:
154                    read, write, err = select.select([ self.sock ], [], [], 0.1)
155                    if len(read) > 0:
156                        x = self.sock.recv(n)
157                        break
158            else:
159                x = self.sock.recv(n)
160
161            if len(x) == 0:
162                raise EOFError()
163            out += x
164            n -= len(x)
165        return out
166
167    def _send_packet(self, t, packet):
168        #self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet)))
169        out = struct.pack('>I', len(packet) + 1) + chr(t) + packet
170        if self.ultra_debug:
171            self._log(DEBUG, util.format_binary(out, 'OUT: '))
172        self._write_all(out)
173
174    def _read_packet(self):
175        x = self._read_all(4)
176        # most sftp servers won't accept packets larger than about 32k, so
177        # anything with the high byte set (> 16MB) is just garbage.
178        if x[0] != '\x00':
179            raise SFTPError('Garbage packet received')
180        size = struct.unpack('>I', x)[0]
181        data = self._read_all(size)
182        if self.ultra_debug:
183            self._log(DEBUG, util.format_binary(data, 'IN: '));
184        if size > 0:
185            t = ord(data[0])
186            #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1))
187            return t, data[1:]
188        return 0, ''
189