1# Class for actually running tests.
2#
3# Copyright (c) 2020-2021 Virtuozzo International GmbH
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import os
20from pathlib import Path
21import datetime
22import time
23import difflib
24import subprocess
25import contextlib
26import json
27import termios
28import sys
29from contextlib import contextmanager
30from typing import List, Optional, Iterator, Any, Sequence, Dict, \
31        ContextManager
32
33from testenv import TestEnv
34
35
36def silent_unlink(path: Path) -> None:
37    try:
38        path.unlink()
39    except OSError:
40        pass
41
42
43def file_diff(file1: str, file2: str) -> List[str]:
44    with open(file1, encoding="utf-8") as f1, \
45         open(file2, encoding="utf-8") as f2:
46        # We want to ignore spaces at line ends. There are a lot of mess about
47        # it in iotests.
48        # TODO: fix all tests to not produce extra spaces, fix all .out files
49        # and use strict diff here!
50        seq1 = [line.rstrip() for line in f1]
51        seq2 = [line.rstrip() for line in f2]
52        res = [line.rstrip()
53               for line in difflib.unified_diff(seq1, seq2, file1, file2)]
54        return res
55
56
57# We want to save current tty settings during test run,
58# since an aborting qemu call may leave things screwed up.
59@contextmanager
60def savetty() -> Iterator[None]:
61    isterm = sys.stdin.isatty()
62    if isterm:
63        fd = sys.stdin.fileno()
64        attr = termios.tcgetattr(fd)
65
66    try:
67        yield
68    finally:
69        if isterm:
70            termios.tcsetattr(fd, termios.TCSADRAIN, attr)
71
72
73class LastElapsedTime(ContextManager['LastElapsedTime']):
74    """ Cache for elapsed time for tests, to show it during new test run
75
76    It is safe to use get() at any time.  To use update(), you must either
77    use it inside with-block or use save() after update().
78    """
79    def __init__(self, cache_file: str, env: TestEnv) -> None:
80        self.env = env
81        self.cache_file = cache_file
82        self.cache: Dict[str, Dict[str, Dict[str, float]]]
83
84        try:
85            with open(cache_file, encoding="utf-8") as f:
86                self.cache = json.load(f)
87        except (OSError, ValueError):
88            self.cache = {}
89
90    def get(self, test: str,
91            default: Optional[float] = None) -> Optional[float]:
92        if test not in self.cache:
93            return default
94
95        if self.env.imgproto not in self.cache[test]:
96            return default
97
98        return self.cache[test][self.env.imgproto].get(self.env.imgfmt,
99                                                       default)
100
101    def update(self, test: str, elapsed: float) -> None:
102        d = self.cache.setdefault(test, {})
103        d.setdefault(self.env.imgproto, {})[self.env.imgfmt] = elapsed
104
105    def save(self) -> None:
106        with open(self.cache_file, 'w', encoding="utf-8") as f:
107            json.dump(self.cache, f)
108
109    def __enter__(self) -> 'LastElapsedTime':
110        return self
111
112    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
113        self.save()
114
115
116class TestResult:
117    def __init__(self, status: str, description: str = '',
118                 elapsed: Optional[float] = None, diff: Sequence[str] = (),
119                 casenotrun: str = '', interrupted: bool = False) -> None:
120        self.status = status
121        self.description = description
122        self.elapsed = elapsed
123        self.diff = diff
124        self.casenotrun = casenotrun
125        self.interrupted = interrupted
126
127
128class TestRunner(ContextManager['TestRunner']):
129    def __init__(self, env: TestEnv, makecheck: bool = False,
130                 color: str = 'auto') -> None:
131        self.env = env
132        self.test_run_env = self.env.get_env()
133        self.makecheck = makecheck
134        self.last_elapsed = LastElapsedTime('.last-elapsed-cache', env)
135
136        assert color in ('auto', 'on', 'off')
137        self.color = (color == 'on') or (color == 'auto' and
138                                         sys.stdout.isatty())
139
140        self._stack: contextlib.ExitStack
141
142    def __enter__(self) -> 'TestRunner':
143        self._stack = contextlib.ExitStack()
144        self._stack.enter_context(self.env)
145        self._stack.enter_context(self.last_elapsed)
146        self._stack.enter_context(savetty())
147        return self
148
149    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
150        self._stack.close()
151
152    def test_print_one_line(self, test: str, starttime: str,
153                            endtime: Optional[str] = None, status: str = '...',
154                            lasttime: Optional[float] = None,
155                            thistime: Optional[float] = None,
156                            description: str = '',
157                            test_field_width: Optional[int] = None,
158                            end: str = '\n') -> None:
159        """ Print short test info before/after test run """
160        test = os.path.basename(test)
161
162        if test_field_width is None:
163            test_field_width = 8
164
165        if self.makecheck and status != '...':
166            if status and status != 'pass':
167                status = f' [{status}]'
168            else:
169                status = ''
170
171            print(f'  TEST   iotest-{self.env.imgfmt}: {test}{status}')
172            return
173
174        if lasttime:
175            lasttime_s = f' (last: {lasttime:.1f}s)'
176        else:
177            lasttime_s = ''
178        if thistime:
179            thistime_s = f'{thistime:.1f}s'
180        else:
181            thistime_s = '...'
182
183        if endtime:
184            endtime = f'[{endtime}]'
185        else:
186            endtime = ''
187
188        if self.color:
189            if status == 'pass':
190                col = '\033[32m'
191            elif status == 'fail':
192                col = '\033[1m\033[31m'
193            elif status == 'not run':
194                col = '\033[33m'
195            else:
196                col = ''
197
198            col_end = '\033[0m'
199        else:
200            col = ''
201            col_end = ''
202
203        print(f'{test:{test_field_width}} {col}{status:10}{col_end} '
204              f'[{starttime}] {endtime:13}{thistime_s:5} {lasttime_s:14} '
205              f'{description}', end=end)
206
207    def find_reference(self, test: str) -> str:
208        if self.env.cachemode == 'none':
209            ref = f'{test}.out.nocache'
210            if os.path.isfile(ref):
211                return ref
212
213        ref = f'{test}.out.{self.env.imgfmt}'
214        if os.path.isfile(ref):
215            return ref
216
217        ref = f'{test}.{self.env.qemu_default_machine}.out'
218        if os.path.isfile(ref):
219            return ref
220
221        return f'{test}.out'
222
223    def do_run_test(self, test: str) -> TestResult:
224        f_test = Path(test)
225        f_bad = Path(f_test.name + '.out.bad')
226        f_notrun = Path(f_test.name + '.notrun')
227        f_casenotrun = Path(f_test.name + '.casenotrun')
228        f_reference = Path(self.find_reference(test))
229
230        if not f_test.exists():
231            return TestResult(status='fail',
232                              description=f'No such test file: {f_test}')
233
234        if not os.access(str(f_test), os.X_OK):
235            sys.exit(f'Not executable: {f_test}')
236
237        if not f_reference.exists():
238            return TestResult(status='not run',
239                              description='No qualified output '
240                                          f'(expected {f_reference})')
241
242        for p in (f_bad, f_notrun, f_casenotrun):
243            silent_unlink(p)
244
245        args = [str(f_test.resolve())]
246        if self.env.debug:
247            args.append('-d')
248
249        with f_test.open(encoding="utf-8") as f:
250            try:
251                if f.readline().rstrip() == '#!/usr/bin/env python3':
252                    args.insert(0, self.env.python)
253            except UnicodeDecodeError:  # binary test? for future.
254                pass
255
256        env = os.environ.copy()
257        env.update(self.test_run_env)
258
259        t0 = time.time()
260        with f_bad.open('w', encoding="utf-8") as f:
261            proc = subprocess.Popen(args, cwd=str(f_test.parent), env=env,
262                                    stdout=f, stderr=subprocess.STDOUT)
263            try:
264                proc.wait()
265            except KeyboardInterrupt:
266                proc.terminate()
267                proc.wait()
268                return TestResult(status='not run',
269                                  description='Interrupted by user',
270                                  interrupted=True)
271            ret = proc.returncode
272
273        elapsed = round(time.time() - t0, 1)
274
275        if ret != 0:
276            return TestResult(status='fail', elapsed=elapsed,
277                              description=f'failed, exit status {ret}',
278                              diff=file_diff(str(f_reference), str(f_bad)))
279
280        if f_notrun.exists():
281            return TestResult(status='not run',
282                              description=f_notrun.read_text().strip())
283
284        casenotrun = ''
285        if f_casenotrun.exists():
286            casenotrun = f_casenotrun.read_text()
287
288        diff = file_diff(str(f_reference), str(f_bad))
289        if diff:
290            return TestResult(status='fail', elapsed=elapsed,
291                              description=f'output mismatch (see {f_bad})',
292                              diff=diff, casenotrun=casenotrun)
293        else:
294            f_bad.unlink()
295            self.last_elapsed.update(test, elapsed)
296            return TestResult(status='pass', elapsed=elapsed,
297                              casenotrun=casenotrun)
298
299    def run_test(self, test: str,
300                 test_field_width: Optional[int] = None) -> TestResult:
301        last_el = self.last_elapsed.get(test)
302        start = datetime.datetime.now().strftime('%H:%M:%S')
303
304        if not self.makecheck:
305            self.test_print_one_line(test=test, starttime=start,
306                                     lasttime=last_el, end='\r',
307                                     test_field_width=test_field_width)
308
309        res = self.do_run_test(test)
310
311        end = datetime.datetime.now().strftime('%H:%M:%S')
312        self.test_print_one_line(test=test, status=res.status,
313                                 starttime=start, endtime=end,
314                                 lasttime=last_el, thistime=res.elapsed,
315                                 description=res.description,
316                                 test_field_width=test_field_width)
317
318        if res.casenotrun:
319            print(res.casenotrun)
320
321        return res
322
323    def run_tests(self, tests: List[str]) -> bool:
324        n_run = 0
325        failed = []
326        notrun = []
327        casenotrun = []
328
329        if not self.makecheck:
330            self.env.print_env()
331            print()
332
333        test_field_width = max(len(os.path.basename(t)) for t in tests) + 2
334
335        for t in tests:
336            name = os.path.basename(t)
337            res = self.run_test(t, test_field_width=test_field_width)
338
339            assert res.status in ('pass', 'fail', 'not run')
340
341            if res.casenotrun:
342                casenotrun.append(t)
343
344            if res.status != 'not run':
345                n_run += 1
346
347            if res.status == 'fail':
348                failed.append(name)
349                if self.makecheck:
350                    self.env.print_env()
351                if res.diff:
352                    print('\n'.join(res.diff))
353            elif res.status == 'not run':
354                notrun.append(name)
355
356            if res.interrupted:
357                break
358
359        if notrun:
360            print('Not run:', ' '.join(notrun))
361
362        if casenotrun:
363            print('Some cases not run in:', ' '.join(casenotrun))
364
365        if failed:
366            print('Failures:', ' '.join(failed))
367            print(f'Failed {len(failed)} of {n_run} iotests')
368            return False
369        else:
370            print(f'Passed all {n_run} iotests')
371            return True
372