1from __future__ import print_function
2
3from io import BytesIO
4import os
5import paramiko
6import random
7import shutil
8import sys
9from scp import SCPClient, SCPException, put, get
10import tempfile
11try:
12    import unittest2 as unittest
13    sys.modules['unittest'] = unittest
14except ImportError:
15    import unittest
16
17
18ssh_info = {
19    'hostname': os.environ.get('SCPPY_HOSTNAME', '127.0.0.1'),
20    'port': int(os.environ.get('SCPPY_PORT', 22)),
21    'username': os.environ.get('SCPPY_USERNAME', None),
22}
23
24
25# Environment info
26PY3 = sys.version_info >= (3,)
27WINDOWS = os.name == 'nt'
28MACOS = sys.platform == 'darwin'
29
30
31if MACOS:
32    import unicodedata
33
34    def normalize_paths(names):
35        """Ensures the test names are normalized (NFC).
36
37        HFS (on Mac OS X) will normalize filenames if necessary.
38        """
39        normed = set()
40        for n in names:
41            if isinstance(n, bytes):
42                n = n.decode('utf-8')
43
44            normed.add(unicodedata.normalize('NFC', n).encode('utf-8'))
45        return normed
46else:
47    normalize_paths = set
48
49
50def unique_names():
51    """Generates unique sequences of bytes.
52    """
53    characters = (b"abcdefghijklmnopqrstuvwxyz"
54                  b"0123456789")
55    characters = [characters[i:i + 1] for i in range(len(characters))]
56    rng = random.Random()
57    while True:
58        letters = [rng.choice(characters) for i in range(10)]
59        yield b''.join(letters)
60unique_names = unique_names()
61
62
63class TestDownload(unittest.TestCase):
64    @classmethod
65    def setUpClass(cls):
66        # Server connection
67        cls.ssh = paramiko.SSHClient()
68        cls.ssh.load_system_host_keys()
69        cls.ssh.set_missing_host_key_policy(paramiko.WarningPolicy())
70        cls.ssh.connect(**ssh_info)
71
72        # Makes some files on the server
73        chan = cls.ssh.get_transport().open_session()
74        chan.exec_command(
75            b'if ! echo -ne "/tmp/r\\xC3\\xA9mi" | xargs test -d; then '
76            # Directory
77            b'echo -ne "/tmp/bien rang\\xC3\\xA9" | xargs -0 mkdir; '
78            # Files
79            b'echo -ne "'
80            b'/tmp/r\\xC3\\xA9mi\\x00'
81            b'/tmp/bien rang\\xC3\\xA9/file\\x00'
82            b'/tmp/bien rang\\xC3\\xA9/b\\xC3\\xA8te\\x00'
83            b'/tmp/p\\xE9t\\xE9'  # invalid UTF-8 here
84            b'" | xargs -0 touch; '
85            b'fi')
86        assert chan.recv_exit_status() == 0
87
88        print("Running tests on %s with %s" % (
89              "Windows" if WINDOWS else
90              "Mac OS X" if MACOS else
91              "POSIX",
92              "Python 3" if PY3 else "Python 2"))
93
94    def download_test(self, filename, recursive, destination=None,
95                      expected_win=[], expected_posix=[]):
96        # Make a temporary directory
97        temp = tempfile.mkdtemp(prefix='scp-py_test_')
98        # Add some unicode in the path
99        if WINDOWS:
100            if isinstance(temp, bytes):
101                temp = temp.decode(sys.getfilesystemencoding())
102            temp_in = os.path.join(temp, u'cl\xE9')
103        else:
104            if not isinstance(temp, bytes):
105                temp = temp.encode('utf-8')
106            temp_in = os.path.join(temp, b'cl\xC3\xA9')
107        previous = os.getcwd()
108        os.mkdir(temp_in)
109        os.chdir(temp_in)
110        cb3 = lambda filename, size, sent: None
111        try:
112            with SCPClient(self.ssh.get_transport(), progress=cb3) as scp:
113                scp.get(filename,
114                        destination if destination is not None else u'.',
115                        preserve_times=True, recursive=recursive)
116            actual = []
117
118            def listdir(path, fpath):
119                for name in os.listdir(fpath):
120                    fname = os.path.join(fpath, name)
121                    actual.append(os.path.join(path, name))
122                    if os.path.isdir(fname):
123                        listdir(name, fname)
124            listdir(u'' if WINDOWS else b'',
125                    u'.' if WINDOWS else b'.')
126            self.assertEqual(normalize_paths(actual),
127                             set(expected_win if WINDOWS else expected_posix))
128        finally:
129            os.chdir(previous)
130            shutil.rmtree(temp)
131
132    def test_get_bytes(self):
133        self.download_test(b'/tmp/r\xC3\xA9mi', False, b'target',
134                           [u'target'], [b'target'])
135        self.download_test(b'/tmp/r\xC3\xA9mi', False, u'target',
136                           [u'target'], [b'target'])
137        self.download_test(b'/tmp/r\xC3\xA9mi', False, None,
138                           [u'r\xE9mi'], [b'r\xC3\xA9mi'])
139        self.download_test([b'/tmp/bien rang\xC3\xA9/file',
140                            b'/tmp/bien rang\xC3\xA9/b\xC3\xA8te'],
141                           False, None,
142                           [u'file', u'b\xE8te'], [b'file', b'b\xC3\xA8te'])
143
144    def test_get_unicode(self):
145        self.download_test(u'/tmp/r\xE9mi', False, b'target',
146                           [u'target'], [b'target'])
147        self.download_test(u'/tmp/r\xE9mi', False, u'target',
148                           [u'target'], [b'target'])
149        self.download_test(u'/tmp/r\xE9mi', False, None,
150                           [u'r\xE9mi'], [b'r\xC3\xA9mi'])
151        self.download_test([u'/tmp/bien rang\xE9/file',
152                            u'/tmp/bien rang\xE9/b\xE8te'],
153                           False, None,
154                           [u'file', u'b\xE8te'], [b'file', b'b\xC3\xA8te'])
155
156    def test_get_folder(self):
157        self.download_test(b'/tmp/bien rang\xC3\xA9', True, None,
158                           [u'bien rang\xE9', u'bien rang\xE9\\file',
159                            u'bien rang\xE9\\b\xE8te'],
160                           [b'bien rang\xC3\xA9', b'bien rang\xC3\xA9/file',
161                            b'bien rang\xC3\xA9/b\xC3\xA8te'])
162        self.download_test(b'/tmp/bien rang\xC3\xA9', True, b'target',
163                           [u'target', u'target\\file',
164                            u'target\\b\xE8te'],
165                           [b'target', b'target/file',
166                            b'target/b\xC3\xA8te'])
167
168    def test_get_invalid_unicode(self):
169        self.download_test(b'/tmp/p\xE9t\xE9', False, u'target',
170                           [u'target'], [b'target'])
171        if WINDOWS:
172            with self.assertRaises(SCPException):
173                self.download_test(b'/tmp/p\xE9t\xE9', False, None,
174                                   [], [])
175        elif MACOS:
176            self.download_test(b'/tmp/p\xE9t\xE9', False, None,
177                               [u'not windows'], [b'p%E9t%E9'])
178        else:
179            self.download_test(b'/tmp/p\xE9t\xE9', False, None,
180                               [u'not windows'], [b'p\xE9t\xE9'])
181
182
183class TestUpload(unittest.TestCase):
184    @classmethod
185    def setUpClass(cls):
186        # Server connection
187        cls.ssh = paramiko.SSHClient()
188        cls.ssh.load_system_host_keys()
189        cls.ssh.set_missing_host_key_policy(paramiko.WarningPolicy())
190        cls.ssh.connect(**ssh_info)
191
192        # Makes some files locally
193        cls._temp = tempfile.mkdtemp(prefix='scp_py_test_')
194        if isinstance(cls._temp, bytes):
195            cls._temp = cls._temp.decode(sys.getfilesystemencoding())
196        inner = os.path.join(cls._temp, u'cl\xE9')
197        os.mkdir(inner)
198        os.mkdir(os.path.join(inner, u'dossi\xE9'))
199        os.mkdir(os.path.join(inner, u'dossi\xE9', u'bien rang\xE9'))
200        open(os.path.join(inner, u'dossi\xE9', u'bien rang\xE9', u'test'),
201             'w').close()
202        open(os.path.join(inner, u'r\xE9mi'), 'w').close()
203
204    @classmethod
205    def tearDownClass(cls):
206        shutil.rmtree(cls._temp)
207
208    def upload_test(self, filenames, recursive, expected=[], fl=None):
209        destination = b'/tmp/upp\xC3\xA9' + next(unique_names)
210        chan = self.ssh.get_transport().open_session()
211        chan.exec_command(b'mkdir ' + destination)
212        assert chan.recv_exit_status() == 0
213        previous = os.getcwd()
214        cb4 = lambda filename, size, sent, peername: None
215        try:
216            os.chdir(self._temp)
217            with SCPClient(self.ssh.get_transport(), progress4=cb4) as scp:
218                if not fl:
219                    scp.put(filenames, destination, recursive)
220                else:
221                    prefix = destination.decode(sys.getfilesystemencoding())
222                    remote_path = '%s/%s' % (prefix, filenames)
223                    scp.putfo(fl, remote_path)
224                    fl.close()
225
226            chan = self.ssh.get_transport().open_session()
227            chan.exec_command(
228                b'echo -ne "' +
229                destination.decode('iso-8859-1')
230                    .encode('ascii', 'backslashreplace') +
231                b'" | xargs find')
232            out_list = b''
233            while True:
234                data = chan.recv(1024)
235                if not data:
236                    break
237                out_list += data
238            prefix = len(destination) + 1
239            out_list = [l[prefix:] for l in out_list.splitlines()
240                        if len(l) > prefix]
241            self.assertEqual(normalize_paths(out_list), set(expected))
242        finally:
243            os.chdir(previous)
244            chan = self.ssh.get_transport().open_session()
245            chan.exec_command(b'rm -Rf ' + destination)
246            assert chan.recv_exit_status() == 0
247
248    @unittest.skipIf(WINDOWS, "Use unicode paths on Windows")
249    def test_put_bytes(self):
250        self.upload_test(b'cl\xC3\xA9/r\xC3\xA9mi', False, [b'r\xC3\xA9mi'])
251        self.upload_test(b'cl\xC3\xA9/dossi\xC3\xA9/bien rang\xC3\xA9/test',
252                         False,
253                         [b'test'])
254        self.upload_test(b'cl\xC3\xA9/dossi\xC3\xA9', True,
255                         [b'dossi\xC3\xA9',
256                          b'dossi\xC3\xA9/bien rang\xC3\xA9',
257                          b'dossi\xC3\xA9/bien rang\xC3\xA9/test'])
258
259    def test_put_unicode(self):
260        self.upload_test(u'cl\xE9/r\xE9mi', False, [b'r\xC3\xA9mi'])
261        self.upload_test(u'cl\xE9/dossi\xE9/bien rang\xE9/test', False,
262                         [b'test'])
263        self.upload_test(u'cl\xE9/dossi\xE9', True,
264                         [b'dossi\xC3\xA9',
265                          b'dossi\xC3\xA9/bien rang\xC3\xA9',
266                          b'dossi\xC3\xA9/bien rang\xC3\xA9/test'])
267        self.upload_test([u'cl\xE9/dossi\xE9/bien rang\xE9',
268                          u'cl\xE9/r\xE9mi'], True,
269                         [b'bien rang\xC3\xA9',
270                          b'bien rang\xC3\xA9/test',
271                          b'r\xC3\xA9mi'])
272        self.upload_test([u'cl\xE9/dossi\xE9',
273                          u'cl\xE9/r\xE9mi'], True,
274                         [b'dossi\xC3\xA9',
275                          b'dossi\xC3\xA9/bien rang\xC3\xA9',
276                          b'dossi\xC3\xA9/bien rang\xC3\xA9/test',
277                          b'r\xC3\xA9mi'])
278
279    def test_putfo(self):
280        fl = BytesIO()
281        fl.write(b'r\xC3\xA9mi')
282        fl.seek(0)
283        self.upload_test(u'putfo-test', False, [b'putfo-test'], fl)
284
285
286class TestUpAndDown(unittest.TestCase):
287    @classmethod
288    def setUpClass(cls):
289        # Server connection
290        cls.ssh = paramiko.SSHClient()
291        cls.ssh.load_system_host_keys()
292        cls.ssh.set_missing_host_key_policy(paramiko.WarningPolicy())
293        cls.ssh.connect(**ssh_info)
294
295        # Makes some files locally
296        cls._temp = tempfile.mkdtemp(prefix='scp_py_test_')
297        if isinstance(cls._temp, bytes):
298            cls._temp = cls._temp.decode(sys.getfilesystemencoding())
299
300    @classmethod
301    def tearDownClass(cls):
302        shutil.rmtree(cls._temp)
303
304    def test_up_and_down(self):
305        '''send and receive files with the same client'''
306        previous = os.getcwd()
307        testfile = os.path.join(self._temp, 'testfile')
308        testfile_sent = os.path.join(self._temp, 'testfile_sent')
309        testfile_rcvd = os.path.join(self._temp, 'testfile_rcvd')
310        try:
311            os.chdir(self._temp)
312            with open(testfile, 'w') as f:
313                f.write("TESTING\n")
314            put(self.ssh.get_transport(), testfile, testfile_sent)
315            get(self.ssh.get_transport(), testfile_sent, testfile_rcvd)
316
317            with open(testfile_rcvd) as f:
318                self.assertEqual(f.read(), 'TESTING\n')
319        finally:
320            os.chdir(previous)
321
322
323if __name__ == '__main__':
324    unittest.main()
325