1# -*- coding: utf-8 -*-
2
3########################################################################
4#
5# License: BSD
6# Created: 2005-05-24
7# Author: Ivan Vilata i Balaguer - ivan@selidor.net
8#
9# $Id$
10#
11########################################################################
12
13"""Utilities for PyTables' test suites."""
14
15import os
16import re
17import sys
18import time
19import locale
20import platform
21import tempfile
22import warnings
23from distutils.version import LooseVersion
24
25from pkg_resources import resource_filename
26
27import unittest
28
29import numpy
30import numexpr
31
32import tables
33from tables.utils import detect_number_of_cores
34from tables.req_versions import min_blosc_bitshuffle_version
35
36hdf5_version = LooseVersion(tables.hdf5_version)
37blosc_version = LooseVersion(tables.which_lib_version("blosc")[1])
38
39
40verbose = False
41"""Show detailed output of the testing process."""
42
43heavy = False
44"""Run all tests even when they take long to complete."""
45
46show_memory = False
47"""Show the progress of memory consumption."""
48
49
50def parse_argv(argv):
51    global verbose, heavy
52
53    if 'verbose' in argv:
54        verbose = True
55        argv.remove('verbose')
56
57    if 'silent' in argv:  # take care of old flag, just in case
58        verbose = False
59        argv.remove('silent')
60
61    if '--heavy' in argv:
62        heavy = True
63        argv.remove('--heavy')
64
65    return argv
66
67
68zlib_avail = tables.which_lib_version("zlib") is not None
69lzo_avail = tables.which_lib_version("lzo") is not None
70bzip2_avail = tables.which_lib_version("bzip2") is not None
71blosc_avail = tables.which_lib_version("blosc") is not None
72
73
74def print_heavy(heavy):
75    if heavy:
76        print("""Performing the complete test suite!""")
77    else:
78        print("""\
79Performing only a light (yet comprehensive) subset of the test suite.
80If you want a more complete test, try passing the --heavy flag to this script
81(or set the 'heavy' parameter in case you are using tables.test() call).
82The whole suite will take more than 4 hours to complete on a relatively
83modern CPU and around 512 MB of main memory.""")
84    print('-=' * 38)
85
86
87def print_versions():
88    """Print all the versions of software that PyTables relies on."""
89
90    print('-=' * 38)
91    print("PyTables version:    %s" % tables.__version__)
92    print("HDF5 version:        %s" % tables.which_lib_version("hdf5")[1])
93    print("NumPy version:       %s" % numpy.__version__)
94    tinfo = tables.which_lib_version("zlib")
95    if numexpr.use_vml:
96        # Get only the main version number and strip out all the rest
97        vml_version = numexpr.get_vml_version()
98        vml_version = re.findall("[0-9.]+", vml_version)[0]
99        vml_avail = "using VML/MKL %s" % vml_version
100    else:
101        vml_avail = "not using Intel's VML/MKL"
102    print("Numexpr version:     %s (%s)" % (numexpr.__version__, vml_avail))
103    if tinfo is not None:
104        print("Zlib version:        %s (%s)" % (tinfo[1],
105                                                "in Python interpreter"))
106    tinfo = tables.which_lib_version("lzo")
107    if tinfo is not None:
108        print("LZO version:         %s (%s)" % (tinfo[1], tinfo[2]))
109    tinfo = tables.which_lib_version("bzip2")
110    if tinfo is not None:
111        print("BZIP2 version:       %s (%s)" % (tinfo[1], tinfo[2]))
112    tinfo = tables.which_lib_version("blosc")
113    if tinfo is not None:
114        blosc_date = tinfo[2].split()[1]
115        print("Blosc version:       %s (%s)" % (tinfo[1], blosc_date))
116        blosc_cinfo = tables.blosc_get_complib_info()
117        blosc_cinfo = [
118            "%s (%s)" % (k, v[1]) for k, v in sorted(blosc_cinfo.items())
119        ]
120        print("Blosc compressors:   %s" % ', '.join(blosc_cinfo))
121        blosc_finfo = ['shuffle']
122        if tinfo[1] >= min_blosc_bitshuffle_version:
123            blosc_finfo.append('bitshuffle')
124        print("Blosc filters:       %s" % ', '.join(blosc_finfo))
125    try:
126        from Cython import __version__ as cython_version
127        print('Cython version:      %s' % cython_version)
128    except:
129        pass
130    print('Python version:      %s' % sys.version)
131    print('Platform:            %s' % platform.platform())
132    #if os.name == 'posix':
133    #    (sysname, nodename, release, version, machine) = os.uname()
134    #    print('Platform:          %s-%s' % (sys.platform, machine))
135    print('Byte-ordering:       %s' % sys.byteorder)
136    print('Detected cores:      %s' % detect_number_of_cores())
137    print('Default encoding:    %s' % sys.getdefaultencoding())
138    print('Default FS encoding: %s' % sys.getfilesystemencoding())
139    print('Default locale:      (%s, %s)' % locale.getdefaultlocale())
140    print('-=' * 38)
141
142    # This should improve readability whan tests are run by CI tools
143    sys.stdout.flush()
144
145
146def test_filename(filename):
147    return resource_filename('tables.tests', filename)
148
149
150def verbosePrint(string, nonl=False):
151    """Print out the `string` if verbose output is enabled."""
152    if not verbose:
153        return
154    if nonl:
155        print(string, end=' ')
156    else:
157        print(string)
158
159
160def allequal(a, b, flavor="numpy"):
161    """Checks if two numerical objects are equal."""
162
163    # print("a-->", repr(a))
164    # print("b-->", repr(b))
165    if not hasattr(b, "shape"):
166        # Scalar case
167        return a == b
168
169    if ((not hasattr(a, "shape") or a.shape == ()) and
170            (not hasattr(b, "shape") or b.shape == ())):
171        return a == b
172
173    if a.shape != b.shape:
174        if verbose:
175            print("Shape is not equal:", a.shape, "!=", b.shape)
176        return 0
177
178    # Way to check the type equality without byteorder considerations
179    if hasattr(b, "dtype") and a.dtype.str[1:] != b.dtype.str[1:]:
180        if verbose:
181            print("dtype is not equal:", a.dtype, "!=", b.dtype)
182        return 0
183
184    # Rank-0 case
185    if len(a.shape) == 0:
186        if a[()] == b[()]:
187            return 1
188        else:
189            if verbose:
190                print("Shape is not equal:", a.shape, "!=", b.shape)
191            return 0
192
193    # null arrays
194    if a.size == 0:  # len(a) is not correct for generic shapes
195        if b.size == 0:
196            return 1
197        else:
198            if verbose:
199                print("length is not equal")
200                print("len(a.data) ==>", len(a.data))
201                print("len(b.data) ==>", len(b.data))
202            return 0
203
204    # Multidimensional case
205    result = (a == b)
206    result = numpy.all(result)
207    if not result and verbose:
208        print("Some of the elements in arrays are not equal")
209
210    return result
211
212
213def areArraysEqual(arr1, arr2):
214    """Are both `arr1` and `arr2` equal arrays?
215
216    Arguments can be regular NumPy arrays, chararray arrays or
217    structured arrays (including structured record arrays). They are
218    checked for type and value equality.
219
220    """
221
222    t1 = type(arr1)
223    t2 = type(arr2)
224
225    if not ((hasattr(arr1, 'dtype') and arr1.dtype == arr2.dtype) or
226            issubclass(t1, t2) or issubclass(t2, t1)):
227        return False
228
229    return numpy.all(arr1 == arr2)
230
231
232# COMPATIBILITY: assertWarns is new in Python 3.2
233# Code copied from the standard unittest.case module (Python 3.4)
234if not hasattr(unittest.TestCase, 'assertWarns'):
235    class _BaseTestCaseContext:
236        def __init__(self, test_case):
237            self.test_case = test_case
238
239        def _raiseFailure(self, standardMsg):
240            msg = self.test_case._formatMessage(self.msg, standardMsg)
241            raise self.test_case.failureException(msg)
242
243    class _AssertRaisesBaseContext(_BaseTestCaseContext):
244        def __init__(self, expected, test_case, callable_obj=None,
245                     expected_regex=None):
246            _BaseTestCaseContext.__init__(self, test_case)
247            self.expected = expected
248            self.test_case = test_case
249            if callable_obj is not None:
250                try:
251                    self.obj_name = callable_obj.__name__
252                except AttributeError:
253                    self.obj_name = str(callable_obj)
254            else:
255                self.obj_name = None
256            if expected_regex is not None:
257                expected_regex = re.compile(expected_regex)
258            self.expected_regex = expected_regex
259            self.msg = None
260
261        def handle(self, name, callable_obj, args, kwargs):
262            """
263            If callable_obj is None, assertRaises/Warns is being used as a
264            context manager, so check for a 'msg' kwarg and return self.
265            If callable_obj is not None, call it passing args and kwargs.
266            """
267            if callable_obj is None:
268                self.msg = kwargs.pop('msg', None)
269                return self
270            with self:
271                callable_obj(*args, **kwargs)
272
273    class _AssertWarnsContext(_AssertRaisesBaseContext):
274        def __enter__(self):
275            for v in list(sys.modules.values()):
276                if getattr(v, '__warningregistry__', None):
277                    v.__warningregistry__ = {}
278            self.warnings_manager = warnings.catch_warnings(record=True)
279            self.warnings = self.warnings_manager.__enter__()
280            warnings.simplefilter("always", self.expected)
281            return self
282
283        def __exit__(self, exc_type, exc_value, tb):
284            self.warnings_manager.__exit__(exc_type, exc_value, tb)
285            if exc_type is not None:
286                # let unexpected exceptions pass through
287                return
288            try:
289                exc_name = self.expected.__name__
290            except AttributeError:
291                exc_name = str(self.expected)
292            first_matching = None
293            for m in self.warnings:
294                w = m.message
295                if not isinstance(w, self.expected):
296                    continue
297                if first_matching is None:
298                    first_matching = w
299                if (self.expected_regex is not None and
300                        not self.expected_regex.search(str(w))):
301                    continue
302                # store warning for later retrieval
303                self.warning = w
304                self.filename = m.filename
305                self.lineno = m.lineno
306                return
307            # Now we simply try to choose a helpful failure message
308            if first_matching is not None:
309                self._raiseFailure(
310                    '"{0}" does not match "{1}"'.format(
311                        self.expected_regex.pattern, str(first_matching)))
312            if self.obj_name:
313                self._raiseFailure("{0} not triggered by {1}".format(
314                                   exc_name, self.obj_name))
315            else:
316                self._raiseFailure("{0} not triggered".format(exc_name))
317
318
319class PyTablesTestCase(unittest.TestCase):
320    def tearDown(self):
321        super(PyTablesTestCase, self).tearDown()
322        for key in self.__dict__:
323            if self.__dict__[key].__class__.__name__ not in ('instancemethod'):
324                self.__dict__[key] = None
325
326    def _getName(self):
327        """Get the name of this test case."""
328        return self.id().split('.')[-2]
329
330    def _getMethodName(self):
331        """Get the name of the method currently running in the test case."""
332        return self.id().split('.')[-1]
333
334    def _verboseHeader(self):
335        """Print a nice header for the current test method if verbose."""
336
337        if verbose:
338            name = self._getName()
339            methodName = self._getMethodName()
340
341            title = "Running %s.%s" % (name, methodName)
342            print('%s\n%s' % (title, '-' * len(title)))
343
344    # COMPATIBILITY: assertWarns is new in Python 3.2
345    if not hasattr(unittest.TestCase, 'assertWarns'):
346        def assertWarns(self, expected_warning, callable_obj=None,
347                        *args, **kwargs):
348            context = _AssertWarnsContext(expected_warning, self, callable_obj)
349            return context.handle('assertWarns', callable_obj, args, kwargs)
350
351    def _checkEqualityGroup(self, node1, node2, hardlink=False):
352        if verbose:
353            print("Group 1:", node1)
354            print("Group 2:", node2)
355        if hardlink:
356            self.assertTrue(
357                node1._v_pathname != node2._v_pathname,
358                "node1 and node2 have the same pathnames.")
359        else:
360            self.assertTrue(
361                node1._v_pathname == node2._v_pathname,
362                "node1 and node2 does not have the same pathnames.")
363        self.assertTrue(
364            node1._v_children == node2._v_children,
365            "node1 and node2 does not have the same children.")
366
367    def _checkEqualityLeaf(self, node1, node2, hardlink=False):
368        if verbose:
369            print("Leaf 1:", node1)
370            print("Leaf 2:", node2)
371        if hardlink:
372            self.assertTrue(
373                node1._v_pathname != node2._v_pathname,
374                "node1 and node2 have the same pathnames.")
375        else:
376            self.assertTrue(
377                node1._v_pathname == node2._v_pathname,
378                "node1 and node2 does not have the same pathnames.")
379        self.assertTrue(
380            areArraysEqual(node1[:], node2[:]),
381            "node1 and node2 does not have the same values.")
382
383
384class TestFileMixin(object):
385    h5fname = None
386    open_kwargs = {}
387
388    def setUp(self):
389        super(TestFileMixin, self).setUp()
390        self.h5file = tables.open_file(
391            self.h5fname, title=self._getName(), **self.open_kwargs)
392
393    def tearDown(self):
394        """Close ``h5file``."""
395
396        self.h5file.close()
397        super(TestFileMixin, self).tearDown()
398
399
400class TempFileMixin(object):
401    open_mode = 'w'
402    open_kwargs = {}
403
404    def _getTempFileName(self):
405        return tempfile.mktemp(prefix=self._getName(), suffix='.h5')
406
407    def setUp(self):
408        """Set ``h5file`` and ``h5fname`` instance attributes.
409
410        * ``h5fname``: the name of the temporary HDF5 file.
411        * ``h5file``: the writable, empty, temporary HDF5 file.
412
413        """
414
415        super(TempFileMixin, self).setUp()
416        self.h5fname = self._getTempFileName()
417        self.h5file = tables.open_file(
418            self.h5fname, self.open_mode, title=self._getName(),
419            **self.open_kwargs)
420
421    def tearDown(self):
422        """Close ``h5file`` and remove ``h5fname``."""
423
424        self.h5file.close()
425        self.h5file = None
426        os.remove(self.h5fname)   # comment this for debugging purposes only
427        super(TempFileMixin, self).tearDown()
428
429    def _reopen(self, mode='r', **kwargs):
430        """Reopen ``h5file`` in the specified ``mode``.
431
432        Returns a true or false value depending on whether the file was
433        reopenend or not.  If not, nothing is changed.
434
435        """
436
437        self.h5file.close()
438        self.h5file = tables.open_file(self.h5fname, mode, **kwargs)
439        return True
440
441
442class ShowMemTime(PyTablesTestCase):
443    tref = time.time()
444    """Test for showing memory and time consumption."""
445
446    def test00(self):
447        """Showing memory and time consumption."""
448
449        # Obtain memory info (only for Linux 2.6.x)
450        for line in open("/proc/self/status"):
451            if line.startswith("VmSize:"):
452                vmsize = int(line.split()[1])
453            elif line.startswith("VmRSS:"):
454                vmrss = int(line.split()[1])
455            elif line.startswith("VmData:"):
456                vmdata = int(line.split()[1])
457            elif line.startswith("VmStk:"):
458                vmstk = int(line.split()[1])
459            elif line.startswith("VmExe:"):
460                vmexe = int(line.split()[1])
461            elif line.startswith("VmLib:"):
462                vmlib = int(line.split()[1])
463        print("\nWallClock time:", time.time() - self.tref)
464        print("Memory usage: ******* %s *******" % self._getName())
465        print("VmSize: %7s kB\tVmRSS: %7s kB" % (vmsize, vmrss))
466        print("VmData: %7s kB\tVmStk: %7s kB" % (vmdata, vmstk))
467        print("VmExe:  %7s kB\tVmLib: %7s kB" % (vmexe, vmlib))
468
469
470## Local Variables:
471## mode: python
472## py-indent-offset: 4
473## tab-width: 4
474## fill-column: 72
475## End:
476