1# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
2# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
3#
4# This file is part of logilab-common.
5#
6# logilab-common is free software: you can redistribute it and/or modify it under
7# the terms of the GNU Lesser General Public License as published by the Free
8# Software Foundation, either version 2.1 of the License, or (at your option) any
9# later version.
10#
11# logilab-common is distributed in the hope that it will be useful, but WITHOUT
12# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
14# details.
15#
16# You should have received a copy of the GNU Lesser General Public License along
17# with logilab-common.  If not, see <http://www.gnu.org/licenses/>.
18"""unit tests for logilab.common.shellutils"""
19
20from os.path import join, dirname, abspath
21from unittest.mock import patch
22
23from logilab.common.testlib import TestCase, unittest_main
24
25from logilab.common.shellutils import globfind, find, ProgressBar, RawInput
26from logilab.common.compat import StringIO
27
28
29DATA_DIR = join(dirname(abspath(__file__)), "data", "find_test")
30
31
32class FindTC(TestCase):
33    def test_include(self):
34        files = set(find(DATA_DIR, ".py"))
35        self.assertSetEqual(
36            files,
37            {
38                join(DATA_DIR, f)
39                for f in [
40                    "__init__.py",
41                    "module.py",
42                    "module2.py",
43                    "noendingnewline.py",
44                    "nonregr.py",
45                    join("sub", "momo.py"),
46                ]
47            },
48        )
49        files = set(find(DATA_DIR, (".py",), blacklist=("sub",)))
50        self.assertSetEqual(
51            files,
52            {
53                join(DATA_DIR, f)
54                for f in [
55                    "__init__.py",
56                    "module.py",
57                    "module2.py",
58                    "noendingnewline.py",
59                    "nonregr.py",
60                ]
61            },
62        )
63
64    def test_exclude(self):
65        files = set(find(DATA_DIR, (".py", ".pyc"), exclude=True))
66        self.assertSetEqual(
67            files,
68            {
69                join(DATA_DIR, f)
70                for f in [
71                    "foo.txt",
72                    "newlines.txt",
73                    "normal_file.txt",
74                    "test.ini",
75                    "test1.msg",
76                    "test2.msg",
77                    "spam.txt",
78                    join("sub", "doc.txt"),
79                    "write_protected_file.txt",
80                ]
81            },
82        )
83
84    def test_globfind(self):
85        files = set(globfind(DATA_DIR, "*.py"))
86        self.assertSetEqual(
87            files,
88            {
89                join(DATA_DIR, f)
90                for f in [
91                    "__init__.py",
92                    "module.py",
93                    "module2.py",
94                    "noendingnewline.py",
95                    "nonregr.py",
96                    join("sub", "momo.py"),
97                ]
98            },
99        )
100        files = set(globfind(DATA_DIR, "mo*.py"))
101        self.assertSetEqual(
102            files,
103            {join(DATA_DIR, f) for f in ["module.py", "module2.py", join("sub", "momo.py")]},
104        )
105        files = set(globfind(DATA_DIR, "mo*.py", blacklist=("sub",)))
106        self.assertSetEqual(files, {join(DATA_DIR, f) for f in ["module.py", "module2.py"]})
107
108
109class ProgressBarTC(TestCase):
110    def test_refresh(self):
111        pgb_stream = StringIO()
112        expected_stream = StringIO()
113        pgb = ProgressBar(20, stream=pgb_stream)
114        self.assertEqual(
115            pgb_stream.getvalue(), expected_stream.getvalue()
116        )  # nothing print before refresh
117        pgb.refresh()
118        expected_stream.write("\r[" + " " * 20 + "]")
119        self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
120
121    def test_refresh_g_size(self):
122        pgb_stream = StringIO()
123        expected_stream = StringIO()
124        pgb = ProgressBar(20, 35, stream=pgb_stream)
125        pgb.refresh()
126        expected_stream.write("\r[" + " " * 35 + "]")
127        self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
128
129    def test_refresh_l_size(self):
130        pgb_stream = StringIO()
131        expected_stream = StringIO()
132        pgb = ProgressBar(20, 3, stream=pgb_stream)
133        pgb.refresh()
134        expected_stream.write("\r[" + " " * 3 + "]")
135        self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
136
137    def _update_test(self, nbops, expected, size=None):
138        pgb_stream = StringIO()
139        expected_stream = StringIO()
140        if size is None:
141            pgb = ProgressBar(nbops, stream=pgb_stream)
142            size = 20
143        else:
144            pgb = ProgressBar(nbops, size, stream=pgb_stream)
145        last = 0
146        for round in expected:
147            if not hasattr(round, "__int__"):
148                dots, update = round
149            else:
150                dots, update = round, None
151            pgb.update()
152            if update or (update is None and dots != last):
153                last = dots
154                expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
155            self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
156
157    def test_default(self):
158        self._update_test(20, range(1, 21))
159
160    def test_nbops_gt_size(self):
161        """Test the progress bar for nbops > size"""
162
163        def half(total):
164            for counter in range(1, total + 1):
165                yield counter // 2
166
167        self._update_test(40, half(40))
168
169    def test_nbops_lt_size(self):
170        """Test the progress bar for nbops < size"""
171
172        def double(total):
173            for counter in range(1, total + 1):
174                yield counter * 2
175
176        self._update_test(10, double(10))
177
178    def test_nbops_nomul_size(self):
179        """Test the progress bar for size % nbops !=0 (non int number of dots per update)"""
180        self._update_test(3, (6, 13, 20))
181
182    def test_overflow(self):
183        self._update_test(5, (8, 16, 25, 33, 42, (42, True)), size=42)
184
185    def test_update_exact(self):
186        pgb_stream = StringIO()
187        expected_stream = StringIO()
188        size = 20
189        pgb = ProgressBar(100, size, stream=pgb_stream)
190        for dots in range(10, 105, 15):
191            pgb.update(dots, exact=True)
192            dots //= 5
193            expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
194            self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
195
196    def test_update_relative(self):
197        pgb_stream = StringIO()
198        expected_stream = StringIO()
199        size = 20
200        pgb = ProgressBar(100, size, stream=pgb_stream)
201        for dots in range(5, 105, 5):
202            pgb.update(5, exact=False)
203            dots //= 5
204            expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
205            self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
206
207
208class RawInputTC(TestCase):
209    def auto_input(self, *args):
210        self.input_args = args
211        return self.input_answer
212
213    def setUp(self):
214        null_printer = lambda x: None
215        self.qa = RawInput(self.auto_input, null_printer)
216
217    def test_ask_using_builtin_input(self):
218        with patch("builtins.input", return_value="no"):
219            qa = RawInput()
220            answer = qa.ask("text", ("yes", "no"), "yes")
221        self.assertEqual(answer, "no")
222
223    def test_ask_default(self):
224        self.input_answer = ""
225        answer = self.qa.ask("text", ("yes", "no"), "yes")
226        self.assertEqual(answer, "yes")
227        self.input_answer = "  "
228        answer = self.qa.ask("text", ("yes", "no"), "yes")
229        self.assertEqual(answer, "yes")
230
231    def test_ask_case(self):
232        self.input_answer = "no"
233        answer = self.qa.ask("text", ("yes", "no"), "yes")
234        self.assertEqual(answer, "no")
235        self.input_answer = "No"
236        answer = self.qa.ask("text", ("yes", "no"), "yes")
237        self.assertEqual(answer, "no")
238        self.input_answer = "NO"
239        answer = self.qa.ask("text", ("yes", "no"), "yes")
240        self.assertEqual(answer, "no")
241        self.input_answer = "nO"
242        answer = self.qa.ask("text", ("yes", "no"), "yes")
243        self.assertEqual(answer, "no")
244        self.input_answer = "YES"
245        answer = self.qa.ask("text", ("yes", "no"), "yes")
246        self.assertEqual(answer, "yes")
247
248    def test_ask_prompt(self):
249        self.input_answer = ""
250        self.qa.ask("text", ("yes", "no"), "yes")
251        self.assertEqual(self.input_args[0], "text [Y(es)/n(o)]: ")
252        self.qa.ask("text", ("y", "n"), "y")
253        self.assertEqual(self.input_args[0], "text [Y/n]: ")
254        self.qa.ask("text", ("n", "y"), "y")
255        self.assertEqual(self.input_args[0], "text [n/Y]: ")
256        self.qa.ask("text", ("yes", "no", "maybe", "1"), "yes")
257        self.assertEqual(self.input_args[0], "text [Y(es)/n(o)/m(aybe)/1]: ")
258
259    def test_ask_ambiguous(self):
260        self.input_answer = "y"
261        self.assertRaises(Exception, self.qa.ask, "text", ("yes", "yep"), "yes")
262
263    def test_confirm(self):
264        self.input_answer = "y"
265        self.assertEqual(self.qa.confirm("Say yes"), True)
266        self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), True)
267        self.input_answer = "n"
268        self.assertEqual(self.qa.confirm("Say yes"), False)
269        self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), False)
270        self.input_answer = ""
271        self.assertEqual(self.qa.confirm("Say default"), True)
272        self.assertEqual(self.qa.confirm("Say default", default_is_yes=False), False)
273
274
275if __name__ == "__main__":
276    unittest_main()
277