1# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. 2# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr 3# 4# This file is part of logilab-common. 5# 6# logilab-common is free software: you can redistribute it and/or modify it under 7# the terms of the GNU Lesser General Public License as published by the Free 8# Software Foundation, either version 2.1 of the License, or (at your option) any 9# later version. 10# 11# logilab-common is distributed in the hope that it will be useful, but WITHOUT 12# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 13# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 14# details. 15# 16# You should have received a copy of the GNU Lesser General Public License along 17# with logilab-common. If not, see <http://www.gnu.org/licenses/>. 18"""unit tests for logilab.common.shellutils""" 19 20from os.path import join, dirname, abspath 21from unittest.mock import patch 22 23from logilab.common.testlib import TestCase, unittest_main 24 25from logilab.common.shellutils import globfind, find, ProgressBar, RawInput 26from logilab.common.compat import StringIO 27 28 29DATA_DIR = join(dirname(abspath(__file__)), "data", "find_test") 30 31 32class FindTC(TestCase): 33 def test_include(self): 34 files = set(find(DATA_DIR, ".py")) 35 self.assertSetEqual( 36 files, 37 { 38 join(DATA_DIR, f) 39 for f in [ 40 "__init__.py", 41 "module.py", 42 "module2.py", 43 "noendingnewline.py", 44 "nonregr.py", 45 join("sub", "momo.py"), 46 ] 47 }, 48 ) 49 files = set(find(DATA_DIR, (".py",), blacklist=("sub",))) 50 self.assertSetEqual( 51 files, 52 { 53 join(DATA_DIR, f) 54 for f in [ 55 "__init__.py", 56 "module.py", 57 "module2.py", 58 "noendingnewline.py", 59 "nonregr.py", 60 ] 61 }, 62 ) 63 64 def test_exclude(self): 65 files = set(find(DATA_DIR, (".py", ".pyc"), exclude=True)) 66 self.assertSetEqual( 67 files, 68 { 69 join(DATA_DIR, f) 70 for f in [ 71 "foo.txt", 72 "newlines.txt", 73 "normal_file.txt", 74 "test.ini", 75 "test1.msg", 76 "test2.msg", 77 "spam.txt", 78 join("sub", "doc.txt"), 79 "write_protected_file.txt", 80 ] 81 }, 82 ) 83 84 def test_globfind(self): 85 files = set(globfind(DATA_DIR, "*.py")) 86 self.assertSetEqual( 87 files, 88 { 89 join(DATA_DIR, f) 90 for f in [ 91 "__init__.py", 92 "module.py", 93 "module2.py", 94 "noendingnewline.py", 95 "nonregr.py", 96 join("sub", "momo.py"), 97 ] 98 }, 99 ) 100 files = set(globfind(DATA_DIR, "mo*.py")) 101 self.assertSetEqual( 102 files, 103 {join(DATA_DIR, f) for f in ["module.py", "module2.py", join("sub", "momo.py")]}, 104 ) 105 files = set(globfind(DATA_DIR, "mo*.py", blacklist=("sub",))) 106 self.assertSetEqual(files, {join(DATA_DIR, f) for f in ["module.py", "module2.py"]}) 107 108 109class ProgressBarTC(TestCase): 110 def test_refresh(self): 111 pgb_stream = StringIO() 112 expected_stream = StringIO() 113 pgb = ProgressBar(20, stream=pgb_stream) 114 self.assertEqual( 115 pgb_stream.getvalue(), expected_stream.getvalue() 116 ) # nothing print before refresh 117 pgb.refresh() 118 expected_stream.write("\r[" + " " * 20 + "]") 119 self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) 120 121 def test_refresh_g_size(self): 122 pgb_stream = StringIO() 123 expected_stream = StringIO() 124 pgb = ProgressBar(20, 35, stream=pgb_stream) 125 pgb.refresh() 126 expected_stream.write("\r[" + " " * 35 + "]") 127 self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) 128 129 def test_refresh_l_size(self): 130 pgb_stream = StringIO() 131 expected_stream = StringIO() 132 pgb = ProgressBar(20, 3, stream=pgb_stream) 133 pgb.refresh() 134 expected_stream.write("\r[" + " " * 3 + "]") 135 self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) 136 137 def _update_test(self, nbops, expected, size=None): 138 pgb_stream = StringIO() 139 expected_stream = StringIO() 140 if size is None: 141 pgb = ProgressBar(nbops, stream=pgb_stream) 142 size = 20 143 else: 144 pgb = ProgressBar(nbops, size, stream=pgb_stream) 145 last = 0 146 for round in expected: 147 if not hasattr(round, "__int__"): 148 dots, update = round 149 else: 150 dots, update = round, None 151 pgb.update() 152 if update or (update is None and dots != last): 153 last = dots 154 expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]") 155 self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) 156 157 def test_default(self): 158 self._update_test(20, range(1, 21)) 159 160 def test_nbops_gt_size(self): 161 """Test the progress bar for nbops > size""" 162 163 def half(total): 164 for counter in range(1, total + 1): 165 yield counter // 2 166 167 self._update_test(40, half(40)) 168 169 def test_nbops_lt_size(self): 170 """Test the progress bar for nbops < size""" 171 172 def double(total): 173 for counter in range(1, total + 1): 174 yield counter * 2 175 176 self._update_test(10, double(10)) 177 178 def test_nbops_nomul_size(self): 179 """Test the progress bar for size % nbops !=0 (non int number of dots per update)""" 180 self._update_test(3, (6, 13, 20)) 181 182 def test_overflow(self): 183 self._update_test(5, (8, 16, 25, 33, 42, (42, True)), size=42) 184 185 def test_update_exact(self): 186 pgb_stream = StringIO() 187 expected_stream = StringIO() 188 size = 20 189 pgb = ProgressBar(100, size, stream=pgb_stream) 190 for dots in range(10, 105, 15): 191 pgb.update(dots, exact=True) 192 dots //= 5 193 expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]") 194 self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) 195 196 def test_update_relative(self): 197 pgb_stream = StringIO() 198 expected_stream = StringIO() 199 size = 20 200 pgb = ProgressBar(100, size, stream=pgb_stream) 201 for dots in range(5, 105, 5): 202 pgb.update(5, exact=False) 203 dots //= 5 204 expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]") 205 self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) 206 207 208class RawInputTC(TestCase): 209 def auto_input(self, *args): 210 self.input_args = args 211 return self.input_answer 212 213 def setUp(self): 214 null_printer = lambda x: None 215 self.qa = RawInput(self.auto_input, null_printer) 216 217 def test_ask_using_builtin_input(self): 218 with patch("builtins.input", return_value="no"): 219 qa = RawInput() 220 answer = qa.ask("text", ("yes", "no"), "yes") 221 self.assertEqual(answer, "no") 222 223 def test_ask_default(self): 224 self.input_answer = "" 225 answer = self.qa.ask("text", ("yes", "no"), "yes") 226 self.assertEqual(answer, "yes") 227 self.input_answer = " " 228 answer = self.qa.ask("text", ("yes", "no"), "yes") 229 self.assertEqual(answer, "yes") 230 231 def test_ask_case(self): 232 self.input_answer = "no" 233 answer = self.qa.ask("text", ("yes", "no"), "yes") 234 self.assertEqual(answer, "no") 235 self.input_answer = "No" 236 answer = self.qa.ask("text", ("yes", "no"), "yes") 237 self.assertEqual(answer, "no") 238 self.input_answer = "NO" 239 answer = self.qa.ask("text", ("yes", "no"), "yes") 240 self.assertEqual(answer, "no") 241 self.input_answer = "nO" 242 answer = self.qa.ask("text", ("yes", "no"), "yes") 243 self.assertEqual(answer, "no") 244 self.input_answer = "YES" 245 answer = self.qa.ask("text", ("yes", "no"), "yes") 246 self.assertEqual(answer, "yes") 247 248 def test_ask_prompt(self): 249 self.input_answer = "" 250 self.qa.ask("text", ("yes", "no"), "yes") 251 self.assertEqual(self.input_args[0], "text [Y(es)/n(o)]: ") 252 self.qa.ask("text", ("y", "n"), "y") 253 self.assertEqual(self.input_args[0], "text [Y/n]: ") 254 self.qa.ask("text", ("n", "y"), "y") 255 self.assertEqual(self.input_args[0], "text [n/Y]: ") 256 self.qa.ask("text", ("yes", "no", "maybe", "1"), "yes") 257 self.assertEqual(self.input_args[0], "text [Y(es)/n(o)/m(aybe)/1]: ") 258 259 def test_ask_ambiguous(self): 260 self.input_answer = "y" 261 self.assertRaises(Exception, self.qa.ask, "text", ("yes", "yep"), "yes") 262 263 def test_confirm(self): 264 self.input_answer = "y" 265 self.assertEqual(self.qa.confirm("Say yes"), True) 266 self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), True) 267 self.input_answer = "n" 268 self.assertEqual(self.qa.confirm("Say yes"), False) 269 self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), False) 270 self.input_answer = "" 271 self.assertEqual(self.qa.confirm("Say default"), True) 272 self.assertEqual(self.qa.confirm("Say default", default_is_yes=False), False) 273 274 275if __name__ == "__main__": 276 unittest_main() 277