1# This file is part of h5py, a Python interface to the HDF5 library.
2#
3# http://www.h5py.org
4#
5# Copyright 2008-2013 Andrew Collette and contributors
6#
7# License:  Standard 3-clause BSD; see "license.txt" for full license terms
8#           and contributor agreement.
9
10import sys
11import os
12import shutil
13import inspect
14import tempfile
15import subprocess
16from contextlib import contextmanager
17from functools import wraps
18
19import numpy as np
20import h5py
21
22import unittest as ut
23
24
25# Check if non-ascii filenames are supported
26# Evidently this is the most reliable way to check
27# See also h5py issue #263 and ipython #466
28# To test for this, run the testsuite with LC_ALL=C
29try:
30    testfile, fname = tempfile.mkstemp(chr(0x03b7))
31except UnicodeError:
32    UNICODE_FILENAMES = False
33else:
34    UNICODE_FILENAMES = True
35    os.close(testfile)
36    os.unlink(fname)
37    del fname
38    del testfile
39
40
41class TestCase(ut.TestCase):
42
43    """
44        Base class for unit tests.
45    """
46
47    @classmethod
48    def setUpClass(cls):
49        cls.tempdir = tempfile.mkdtemp(prefix='h5py-test_')
50
51    @classmethod
52    def tearDownClass(cls):
53        shutil.rmtree(cls.tempdir)
54
55    def mktemp(self, suffix='.hdf5', prefix='', dir=None):
56        if dir is None:
57            dir = self.tempdir
58        return tempfile.mktemp(suffix, prefix, dir=dir)
59
60    def mktemp_mpi(self, comm=None, suffix='.hdf5', prefix='', dir=None):
61        if comm is None:
62            from mpi4py import MPI
63            comm = MPI.COMM_WORLD
64        fname = None
65        if comm.Get_rank() == 0:
66            fname = self.mktemp(suffix, prefix, dir)
67        fname = comm.bcast(fname, 0)
68        return fname
69
70    def setUp(self):
71        self.f = h5py.File(self.mktemp(), 'w')
72
73    def tearDown(self):
74        try:
75            if self.f:
76                self.f.close()
77        except:
78            pass
79
80    def assertSameElements(self, a, b):
81        for x in a:
82            match = False
83            for y in b:
84                if x == y:
85                    match = True
86            if not match:
87                raise AssertionError("Item '%s' appears in a but not b" % x)
88
89        for x in b:
90            match = False
91            for y in a:
92                if x == y:
93                    match = True
94            if not match:
95                raise AssertionError("Item '%s' appears in b but not a" % x)
96
97    def assertArrayEqual(self, dset, arr, message=None, precision=None):
98        """ Make sure dset and arr have the same shape, dtype and contents, to
99            within the given precision.
100
101            Note that dset may be a NumPy array or an HDF5 dataset.
102        """
103        if precision is None:
104            precision = 1e-5
105        if message is None:
106            message = ''
107        else:
108            message = ' (%s)' % message
109
110        if np.isscalar(dset) or np.isscalar(arr):
111            assert np.isscalar(dset) and np.isscalar(arr), \
112                'Scalar/array mismatch ("%r" vs "%r")%s' % (dset, arr, message)
113            assert dset - arr < precision, \
114                "Scalars differ by more than %.3f%s" % (precision, message)
115            return
116
117        assert dset.shape == arr.shape, \
118            "Shape mismatch (%s vs %s)%s" % (dset.shape, arr.shape, message)
119        assert dset.dtype == arr.dtype, \
120            "Dtype mismatch (%s vs %s)%s" % (dset.dtype, arr.dtype, message)
121
122        if arr.dtype.names is not None:
123            for n in arr.dtype.names:
124                message = '[FIELD %s] %s' % (n, message)
125                self.assertArrayEqual(dset[n], arr[n], message=message, precision=precision)
126        elif arr.dtype.kind in ('i', 'f'):
127            assert np.all(np.abs(dset[...] - arr[...]) < precision), \
128                "Arrays differ by more than %.3f%s" % (precision, message)
129        else:
130            assert np.all(dset[...] == arr[...]), \
131                "Arrays are not equal (dtype %s) %s" % (arr.dtype.str, message)
132
133    def assertNumpyBehavior(self, dset, arr, s, skip_fast_reader=False):
134        """ Apply slicing arguments "s" to both dset and arr.
135
136        Succeeds if the results of the slicing are identical, or the
137        exception raised is of the same type for both.
138
139        "arr" must be a Numpy array; "dset" may be a NumPy array or dataset.
140        """
141        exc = None
142        try:
143            arr_result = arr[s]
144        except Exception as e:
145            exc = type(e)
146
147        s_fast = s if isinstance(s, tuple) else (s,)
148
149        if exc is None:
150            self.assertArrayEqual(dset[s], arr_result)
151
152            if not skip_fast_reader:
153                self.assertArrayEqual(
154                    dset._fast_reader.read(s_fast),
155                    arr_result,
156                )
157        else:
158            with self.assertRaises(exc):
159                dset[s]
160
161            if not skip_fast_reader:
162                with self.assertRaises(exc):
163                    dset._fast_reader.read(s_fast)
164
165NUMPY_RELEASE_VERSION = tuple([int(i) for i in np.__version__.split(".")[0:2]])
166
167@contextmanager
168def closed_tempfile(suffix='', text=None):
169    """
170    Context manager which yields the path to a closed temporary file with the
171    suffix `suffix`. The file will be deleted on exiting the context. An
172    additional argument `text` can be provided to have the file contain `text`.
173    """
174    with tempfile.NamedTemporaryFile(
175        'w+t', suffix=suffix, delete=False
176    ) as test_file:
177        file_name = test_file.name
178        if text is not None:
179            test_file.write(text)
180            test_file.flush()
181    yield file_name
182    shutil.rmtree(file_name, ignore_errors=True)
183
184
185def insubprocess(f):
186    """Runs a test in its own subprocess"""
187    @wraps(f)
188    def wrapper(request, *args, **kwargs):
189        curr_test = inspect.getsourcefile(f) + "::" + request.node.name
190        # get block around test name
191        insub = "IN_SUBPROCESS_" + curr_test
192        for c in "/\\,:.":
193            insub = insub.replace(c, "_")
194        defined = os.environ.get(insub, None)
195        # fork process
196        if defined:
197            return f(request, *args, **kwargs)
198        else:
199            os.environ[insub] = '1'
200            env = os.environ.copy()
201            env[insub] = '1'
202            env.update(getattr(f, 'subproc_env', {}))
203
204            with closed_tempfile() as stdout:
205                with open(stdout, 'w+t') as fh:
206                    rtn = subprocess.call([sys.executable, '-m', 'pytest', curr_test],
207                                          stdout=fh, stderr=fh, env=env)
208                with open(stdout, 'rt') as fh:
209                    out = fh.read()
210
211            assert rtn == 0, "\n" + out
212    return wrapper
213
214
215def subproc_env(d):
216    """Set environment variables for the @insubprocess decorator"""
217    def decorator(f):
218        f.subproc_env = d
219        return f
220
221    return decorator
222