1"""
2fs.sftpfs
3=========
4
5**Currently only avaiable on Python2 due to paramiko not being available for Python3**
6
7Filesystem accessing an SFTP server (via paramiko)
8
9"""
10
11import datetime
12import stat as statinfo
13import threading
14import os
15import paramiko
16from getpass import getuser
17import errno
18
19from fs.base import *
20from fs.path import *
21from fs.errors import *
22from fs.utils import isdir, isfile
23from fs import iotools
24
25
26ENOENT = errno.ENOENT
27
28
29class WrongHostKeyError(RemoteConnectionError):
30    pass
31
32
33# SFTPClient appears to not be thread-safe, so we use an instance per thread
34if hasattr(threading, "local"):
35    thread_local = threading.local
36    #class TL(object):
37    #    pass
38    #thread_local = TL
39else:
40    class thread_local(object):
41        def __init__(self):
42            self._map = {}
43
44        def __getattr__(self, attr):
45            try:
46                return self._map[(threading.currentThread().ident, attr)]
47            except KeyError:
48                raise AttributeError(attr)
49
50        def __setattr__(self, attr, value):
51            self._map[(threading.currentThread().ident, attr)] = value
52
53
54if not hasattr(paramiko.SFTPFile, "__enter__"):
55    paramiko.SFTPFile.__enter__ = lambda self: self
56    paramiko.SFTPFile.__exit__ = lambda self,et,ev,tb: self.close() and False
57
58
59class SFTPFS(FS):
60    """A filesystem stored on a remote SFTP server.
61
62    This is basically a compatibility wrapper for the excellent SFTPClient
63    class in the paramiko module.
64
65    """
66
67    _meta = { 'thread_safe' : True,
68              'virtual': False,
69              'read_only' : False,
70              'unicode_paths' : True,
71              'case_insensitive_paths' : False,
72              'network' : True,
73              'atomic.move' : True,
74              'atomic.copy' : True,
75              'atomic.makedir' : True,
76              'atomic.rename' : True,
77              'atomic.setcontents' : False
78              }
79
80    def __init__(self,
81                 connection,
82                 root_path="/",
83                 encoding=None,
84                 hostkey=None,
85                 username='',
86                 password=None,
87                 pkey=None,
88                 agent_auth=True,
89                 no_auth=False,
90                 look_for_keys=True):
91        """SFTPFS constructor.
92
93        The only required argument is 'connection', which must be something
94        from which we can construct a paramiko.SFTPClient object.  Possible
95        values include:
96
97            * a hostname string
98            * a (hostname,port) tuple
99            * a paramiko.Transport instance
100            * a paramiko.Channel instance in "sftp" mode
101
102        The keyword argument 'root_path' specifies the root directory on the
103        remote machine - access to files outside this root will be prevented.
104
105        :param connection: a connection string
106        :param root_path: The root path to open
107        :param encoding: String encoding of paths (defaults to UTF-8)
108        :param hostkey: the host key expected from the server or None if you don't require server validation
109        :param username: Name of SFTP user
110        :param password: Password for SFTP user
111        :param pkey: Public key
112        :param agent_auth: attempt to authorize with the user's public keys
113        :param no_auth: attempt to log in without any kind of authorization
114        :param look_for_keys: Look for keys in the same locations as ssh,
115            if other authentication is not succesful
116
117        """
118        credentials = dict(username=username,
119                           password=password,
120                           pkey=pkey)
121        self.credentials = credentials
122
123        if encoding is None:
124            encoding = "utf8"
125        self.encoding = encoding
126        self.closed = False
127        self._owns_transport = False
128        self._credentials = credentials
129        self._tlocal = thread_local()
130        self._transport = None
131        self._client = None
132
133        self.hostname = None
134        if isinstance(connection, basestring):
135            self.hostname = connection
136        elif isinstance(connection, tuple):
137            self.hostname = '%s:%s' % connection
138
139        super(SFTPFS, self).__init__()
140        self.root_path = abspath(normpath(root_path))
141
142        if isinstance(connection,paramiko.Channel):
143            self._transport = None
144            self._client = paramiko.SFTPClient(connection)
145        else:
146            if not isinstance(connection,paramiko.Transport):
147                connection = paramiko.Transport(connection)
148                connection.daemon = True
149                self._owns_transport = True
150
151        if hostkey is not None:
152            key = self.get_remote_server_key()
153            if hostkey != key:
154                raise WrongHostKeyError('Host keys do not match')
155
156        connection.start_client()
157
158        if not connection.is_active():
159            raise RemoteConnectionError(msg='Unable to connect')
160
161        if no_auth:
162            try:
163                connection.auth_none('')
164            except paramiko.SSHException:
165                pass
166
167        elif not connection.is_authenticated():
168            if not username:
169                username = getuser()
170            try:
171                if pkey:
172                    connection.auth_publickey(username, pkey)
173
174                if not connection.is_authenticated() and password:
175                    connection.auth_password(username, password)
176
177                if agent_auth and not connection.is_authenticated():
178                    self._agent_auth(connection, username)
179
180                if look_for_keys and not connection.is_authenticated():
181                    self._userkeys_auth(connection, username, password)
182
183                if not connection.is_authenticated():
184                    try:
185                        connection.auth_none(username)
186                    except paramiko.BadAuthenticationType, e:
187                        self.close()
188                        allowed = ', '.join(e.allowed_types)
189                        raise RemoteConnectionError(msg='no auth - server requires one of the following: %s' % allowed, details=e)
190
191                if not connection.is_authenticated():
192                    self.close()
193                    raise RemoteConnectionError(msg='no auth')
194
195            except paramiko.SSHException, e:
196                self.close()
197                raise RemoteConnectionError(msg='SSH exception (%s)' % str(e), details=e)
198
199        self._transport = connection
200
201    def __unicode__(self):
202        return u'<SFTPFS: %s>' % self.desc('/')
203
204    @classmethod
205    def _agent_auth(cls, transport, username):
206        """
207        Attempt to authenticate to the given transport using any of the private
208        keys available from an SSH agent.
209        """
210
211        agent = paramiko.Agent()
212        agent_keys = agent.get_keys()
213        if not agent_keys:
214            return None
215        for key in agent_keys:
216            try:
217                transport.auth_publickey(username, key)
218                return key
219            except paramiko.SSHException:
220                pass
221        return None
222
223    @classmethod
224    def _userkeys_auth(cls, transport, username, password):
225        """
226        Attempt to authenticate to the given transport using any of the private
227        keys in the users ~/.ssh and ~/ssh dirs
228
229        Derived from http://www.lag.net/paramiko/docs/paramiko.client-pysrc.html
230        """
231
232        keyfiles = []
233        rsa_key = os.path.expanduser('~/.ssh/id_rsa')
234        dsa_key = os.path.expanduser('~/.ssh/id_dsa')
235        if os.path.isfile(rsa_key):
236            keyfiles.append((paramiko.rsakey.RSAKey, rsa_key))
237        if os.path.isfile(dsa_key):
238            keyfiles.append((paramiko.dsskey.DSSKey, dsa_key))
239        # look in ~/ssh/ for windows users:
240        rsa_key = os.path.expanduser('~/ssh/id_rsa')
241        dsa_key = os.path.expanduser('~/ssh/id_dsa')
242        if os.path.isfile(rsa_key):
243            keyfiles.append((paramiko.rsakey.RSAKey, rsa_key))
244        if os.path.isfile(dsa_key):
245            keyfiles.append((paramiko.dsskey.DSSKey, dsa_key))
246
247        for pkey_class, filename in keyfiles:
248            key = pkey_class.from_private_key_file(filename, password)
249            try:
250                transport.auth_publickey(username, key)
251                return key
252            except paramiko.SSHException:
253                pass
254        return None
255
256    def __del__(self):
257        self.close()
258
259    @synchronize
260    def __getstate__(self):
261        state = super(SFTPFS,self).__getstate__()
262        del state["_tlocal"]
263        if self._owns_transport:
264            state['_transport'] = self._transport.getpeername()
265        return state
266
267    def __setstate__(self,state):
268        super(SFTPFS, self).__setstate__(state)
269        #for (k,v) in state.iteritems():
270        #    self.__dict__[k] = v
271        #self._lock = threading.RLock()
272        self._tlocal = thread_local()
273        if self._owns_transport:
274            self._transport = paramiko.Transport(self._transport)
275            self._transport.connect(**self._credentials)
276
277    @property
278    @synchronize
279    def client(self):
280        if self.closed:
281            return None
282        client = getattr(self._tlocal, 'client', None)
283        if client is None:
284            if self._transport is None:
285                return self._client
286            client = paramiko.SFTPClient.from_transport(self._transport)
287            self._tlocal.client = client
288        return client
289#        try:
290#            return self._tlocal.client
291#        except AttributeError:
292#            #if self._transport is None:
293#            #    return self._client
294#            client = paramiko.SFTPClient.from_transport(self._transport)
295#            self._tlocal.client = client
296#            return client
297
298    @synchronize
299    def close(self):
300        """Close the connection to the remote server."""
301        if not self.closed:
302            self._tlocal = None
303            #if self.client:
304            #    self.client.close()
305            if self._owns_transport and self._transport and self._transport.is_active:
306                self._transport.close()
307            self.closed = True
308
309    def _normpath(self, path):
310        if not isinstance(path, unicode):
311            path = path.decode(self.encoding)
312        npath = pathjoin(self.root_path, relpath(normpath(path)))
313        if not isprefix(self.root_path, npath):
314            raise PathError(path, msg="Path is outside root: %(path)s")
315        return npath
316
317    def getpathurl(self, path, allow_none=False):
318        path = self._normpath(path)
319        if self.hostname is None:
320            if allow_none:
321                return None
322            raise NoPathURLError(path=path)
323        username = self.credentials.get('username', '') or ''
324        password = self.credentials.get('password', '') or ''
325        credentials = ('%s:%s' % (username, password)).rstrip(':')
326        if credentials:
327            url = 'sftp://%s@%s%s' % (credentials, self.hostname.rstrip('/'), abspath(path))
328        else:
329            url = 'sftp://%s%s' % (self.hostname.rstrip('/'), abspath(path))
330        return url
331
332    @synchronize
333    @convert_os_errors
334    @iotools.filelike_to_stream
335    def open(self, path, mode='r', buffering=-1, encoding=None, errors=None, newline=None, line_buffering=False, bufsize=-1, **kwargs):
336        npath = self._normpath(path)
337        if self.isdir(path):
338            msg = "that's a directory: %(path)s"
339            raise ResourceInvalidError(path, msg=msg)
340        #  paramiko implements its own buffering and write-back logic,
341        #  so we don't need to use a RemoteFileBuffer here.
342        f = self.client.open(npath, mode, bufsize)
343        #  Unfortunately it has a broken truncate() method.
344        #  TODO: implement this as a wrapper
345        old_truncate = f.truncate
346
347        def new_truncate(size=None):
348            if size is None:
349                size = f.tell()
350            return old_truncate(size)
351        f.truncate = new_truncate
352        return f
353
354    @synchronize
355    def desc(self, path):
356        npath = self._normpath(path)
357        if self.hostname:
358            return u'sftp://%s%s' % (self.hostname, path)
359        else:
360            addr, port = self._transport.getpeername()
361            return u'sftp://%s:%i%s' % (addr, port, self.client.normalize(npath))
362
363    @synchronize
364    @convert_os_errors
365    def exists(self, path):
366        if path in ('', '/'):
367            return True
368        npath = self._normpath(path)
369        try:
370            self.client.stat(npath)
371        except IOError, e:
372            if getattr(e,"errno",None) == ENOENT:
373                return False
374            raise
375        return True
376
377    @synchronize
378    @convert_os_errors
379    def isdir(self,path):
380        if normpath(path) in ('', '/'):
381            return True
382        npath = self._normpath(path)
383        try:
384            stat = self.client.stat(npath)
385        except IOError, e:
386            if getattr(e,"errno",None) == ENOENT:
387                return False
388            raise
389        return statinfo.S_ISDIR(stat.st_mode) != 0
390
391    @synchronize
392    @convert_os_errors
393    def isfile(self,path):
394        npath = self._normpath(path)
395        try:
396            stat = self.client.stat(npath)
397        except IOError, e:
398            if getattr(e,"errno",None) == ENOENT:
399                return False
400            raise
401        return statinfo.S_ISREG(stat.st_mode) != 0
402
403    @synchronize
404    @convert_os_errors
405    def listdir(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
406        npath = self._normpath(path)
407        try:
408            attrs_map = None
409            if dirs_only or files_only:
410                attrs = self.client.listdir_attr(npath)
411                attrs_map = dict((a.filename, a) for a in attrs)
412                paths = list(attrs_map.iterkeys())
413            else:
414                paths = self.client.listdir(npath)
415        except IOError, e:
416            if getattr(e,"errno",None) == ENOENT:
417                if self.isfile(path):
418                    raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
419                raise ResourceNotFoundError(path)
420            elif self.isfile(path):
421                raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
422            raise
423
424        if attrs_map:
425            if dirs_only:
426                filter_paths = []
427                for apath, attr in attrs_map.iteritems():
428                    if isdir(self, path, attr.__dict__):
429                        filter_paths.append(apath)
430                paths = filter_paths
431            elif files_only:
432                filter_paths = []
433                for apath, attr in attrs_map.iteritems():
434                    if isfile(self, apath, attr.__dict__):
435                        filter_paths.append(apath)
436                paths = filter_paths
437
438        for (i,p) in enumerate(paths):
439            if not isinstance(p,unicode):
440                paths[i] = p.decode(self.encoding)
441
442        return self._listdir_helper(path, paths, wildcard, full, absolute, False, False)
443
444    @synchronize
445    @convert_os_errors
446    def listdirinfo(self,path="./",wildcard=None,full=False,absolute=False,dirs_only=False,files_only=False):
447        npath = self._normpath(path)
448        try:
449            attrs = self.client.listdir_attr(npath)
450            attrs_map = dict((a.filename, a) for a in attrs)
451            paths = attrs_map.keys()
452        except IOError, e:
453            if getattr(e,"errno",None) == ENOENT:
454                if self.isfile(path):
455                    raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
456                raise ResourceNotFoundError(path)
457            elif self.isfile(path):
458                raise ResourceInvalidError(path,msg="Can't list directory contents of a file: %(path)s")
459            raise
460
461        if dirs_only:
462            filter_paths = []
463            for path, attr in attrs_map.iteritems():
464                if isdir(self, path, attr.__dict__):
465                    filter_paths.append(path)
466            paths = filter_paths
467        elif files_only:
468            filter_paths = []
469            for path, attr in attrs_map.iteritems():
470                if isfile(self, path, attr.__dict__):
471                    filter_paths.append(path)
472            paths = filter_paths
473
474        for (i, p) in enumerate(paths):
475            if not isinstance(p, unicode):
476                paths[i] = p.decode(self.encoding)
477
478        def getinfo(p):
479            resourcename = basename(p)
480            info = attrs_map.get(resourcename)
481            if info is None:
482                return self.getinfo(pathjoin(path, p))
483            return self._extract_info(info.__dict__)
484
485        return [(p, getinfo(p)) for p in
486                    self._listdir_helper(path, paths, wildcard, full, absolute, False, False)]
487
488    @synchronize
489    @convert_os_errors
490    def makedir(self,path,recursive=False,allow_recreate=False):
491        npath = self._normpath(path)
492        try:
493            self.client.mkdir(npath)
494        except IOError, _e:
495            # Error code is unreliable, try to figure out what went wrong
496            try:
497                stat = self.client.stat(npath)
498            except IOError:
499                if not self.isdir(dirname(path)):
500                    # Parent dir is missing
501                    if not recursive:
502                        raise ParentDirectoryMissingError(path)
503                    self.makedir(dirname(path),recursive=True)
504                    self.makedir(path,allow_recreate=allow_recreate)
505                else:
506                    # Undetermined error, let the decorator handle it
507                    raise
508            else:
509                # Destination exists
510                if statinfo.S_ISDIR(stat.st_mode):
511                    if not allow_recreate:
512                        raise DestinationExistsError(path,msg="Can't create a directory that already exists (try allow_recreate=True): %(path)s")
513                else:
514                    raise ResourceInvalidError(path,msg="Can't create directory, there's already a file of that name: %(path)s")
515
516    @synchronize
517    @convert_os_errors
518    def remove(self,path):
519        npath = self._normpath(path)
520        try:
521            self.client.remove(npath)
522        except IOError, e:
523            if getattr(e,"errno",None) == ENOENT:
524                raise ResourceNotFoundError(path)
525            elif self.isdir(path):
526                raise ResourceInvalidError(path,msg="Cannot use remove() on a directory: %(path)s")
527            raise
528
529    @synchronize
530    @convert_os_errors
531    def removedir(self,path,recursive=False,force=False):
532        npath = self._normpath(path)
533        if normpath(path) in ('', '/'):
534            raise RemoveRootError(path)
535        if force:
536            for path2 in self.listdir(path,absolute=True):
537                try:
538                    self.remove(path2)
539                except ResourceInvalidError:
540                    self.removedir(path2,force=True)
541        if not self.exists(path):
542            raise ResourceNotFoundError(path)
543        try:
544            self.client.rmdir(npath)
545        except IOError, e:
546            if getattr(e,"errno",None) == ENOENT:
547                if self.isfile(path):
548                    raise ResourceInvalidError(path,msg="Can't use removedir() on a file: %(path)s")
549                raise ResourceNotFoundError(path)
550
551            elif self.listdir(path):
552                raise DirectoryNotEmptyError(path)
553            raise
554        if recursive:
555            try:
556                if dirname(path) not in ('', '/'):
557                    self.removedir(dirname(path),recursive=True)
558            except DirectoryNotEmptyError:
559                pass
560
561    @synchronize
562    @convert_os_errors
563    def rename(self,src,dst):
564        nsrc = self._normpath(src)
565        ndst = self._normpath(dst)
566        try:
567            self.client.rename(nsrc,ndst)
568        except IOError, e:
569            if getattr(e,"errno",None) == ENOENT:
570                raise ResourceNotFoundError(src)
571            if not self.isdir(dirname(dst)):
572                raise ParentDirectoryMissingError(dst)
573            raise
574
575    @synchronize
576    @convert_os_errors
577    def move(self,src,dst,overwrite=False,chunk_size=16384):
578        nsrc = self._normpath(src)
579        ndst = self._normpath(dst)
580        if overwrite and self.isfile(dst):
581            self.remove(dst)
582        try:
583            self.client.rename(nsrc,ndst)
584        except IOError, e:
585            if getattr(e,"errno",None) == ENOENT:
586                raise ResourceNotFoundError(src)
587            if self.exists(dst):
588                raise DestinationExistsError(dst)
589            if not self.isdir(dirname(dst)):
590                raise ParentDirectoryMissingError(dst,msg="Destination directory does not exist: %(path)s")
591            raise
592
593    @synchronize
594    @convert_os_errors
595    def movedir(self,src,dst,overwrite=False,ignore_errors=False,chunk_size=16384):
596        nsrc = self._normpath(src)
597        ndst = self._normpath(dst)
598        if overwrite and self.isdir(dst):
599            self.removedir(dst)
600        try:
601            self.client.rename(nsrc,ndst)
602        except IOError, e:
603            if getattr(e,"errno",None) == ENOENT:
604                raise ResourceNotFoundError(src)
605            if self.exists(dst):
606                raise DestinationExistsError(dst)
607            if not self.isdir(dirname(dst)):
608                raise ParentDirectoryMissingError(dst,msg="Destination directory does not exist: %(path)s")
609            raise
610
611    _info_vars = frozenset('st_size st_uid st_gid st_mode st_atime st_mtime'.split())
612    @classmethod
613    def _extract_info(cls, stats):
614        fromtimestamp = datetime.datetime.fromtimestamp
615        info = dict((k, v) for k, v in stats.iteritems() if k in cls._info_vars and not k.startswith('_'))
616        info['size'] = info['st_size']
617        ct = info.get('st_ctime')
618        if ct is not None:
619            info['created_time'] = fromtimestamp(ct)
620        at = info.get('st_atime')
621        if at is not None:
622            info['accessed_time'] = fromtimestamp(at)
623        mt = info.get('st_mtime')
624        if mt is not None:
625            info['modified_time'] = fromtimestamp(mt)
626        return info
627
628    @synchronize
629    @convert_os_errors
630    def getinfo(self, path):
631        npath = self._normpath(path)
632        stats = self.client.stat(npath)
633        info = dict((k, getattr(stats, k)) for k in dir(stats) if not k.startswith('_'))
634        info['size'] = info['st_size']
635        ct = info.get('st_ctime', None)
636        if ct is not None:
637            info['created_time'] = datetime.datetime.fromtimestamp(ct)
638        at = info.get('st_atime', None)
639        if at is not None:
640            info['accessed_time'] = datetime.datetime.fromtimestamp(at)
641        mt = info.get('st_mtime', None)
642        if mt is not None:
643            info['modified_time'] = datetime.datetime.fromtimestamp(mt)
644        return info
645
646    @synchronize
647    @convert_os_errors
648    def getsize(self, path):
649        npath = self._normpath(path)
650        stats = self.client.stat(npath)
651        return stats.st_size
652