9""" Validate image proxy API
11Minimum array proxy API is:
13* read only ``shape`` property
14* read only ``is_proxy`` property set to True
15* returns array from ``np.asarray(prox)``
16* returns array slice from ``prox[<slice_spec>]`` where ``<slice_spec>`` is any
17  non-fancy slice specification.
21* that modifying no object outside ``prox`` will affect the result of
22  ``np.asarray(obj)``.  Specifically:
23  * Changes in position (``obj.tell()``) of any passed file-like objects
24    will not affect the output of from ``np.asarray(proxy)``.
25  * if you pass a header into the __init__, then modifying the original
26    header will not affect the result of the array return.
28These last are to allow the proxy to be re-used with different images.
31from os.path import join as pjoin
32import warnings
33from itertools import product
34from io import BytesIO
36import numpy as np
38from ..volumeutils import apply_read_scaling
39from ..analyze import AnalyzeHeader
40from ..spm99analyze import Spm99AnalyzeHeader
41from ..spm2analyze import Spm2AnalyzeHeader
42from ..nifti1 import Nifti1Header
43from ..freesurfer.mghformat import MGHHeader
44from .. import minc1
45from ..externals.netcdf import netcdf_file
46from .. import minc2
47from .._h5py_compat import h5py, have_h5py
48from .. import ecat
49from .. import parrec
50from ..casting import have_binary128
52from ..arrayproxy import ArrayProxy, is_proxy
54import unittest
55import pytest
56from numpy.testing import assert_almost_equal, assert_array_equal, assert_allclose
58from ..testing import data_path as DATA_PATH, assert_dt_equal, clear_and_catch_warnings
59from ..deprecator import ExpiredDeprecationError
61from ..tmpdirs import InTemporaryDirectory
63from .test_api_validators import ValidateAPI
64from .test_parrec import EG_REC, VARY_REC
67def _some_slicers(shape):
68    ndim = len(shape)
69    slicers = np.eye(ndim).astype(int).astype(object)
70    slicers[slicers == 0] = slice(None)
71    for i in range(ndim):
72        if i % 2:
73            slicers[i, i] = -1
74        elif shape[i] < 2:  # some proxy examples have length 1 axes
75            slicers[i, i] = 0
76    # Add a newaxis to keep us on our toes
77    no_pos = ndim // 2
78    slicers = np.hstack((slicers[:, :no_pos],
79                         np.empty((ndim, 1)),
80                         slicers[:, no_pos:]))
81    slicers[:, no_pos] = None
82    return [tuple(s) for s in slicers]
85class _TestProxyAPI(ValidateAPI):
86    """ Base class for testing proxy APIs
88    Assumes that real classes will provide an `obj_params` method which is a
89    generator returning 2 tuples of (<proxy_maker>, <param_dict>).
90    <proxy_maker> is a function returning a 3 tuple of (<proxy>, <fileobj>,
91    <header>).  <param_dict> is a dictionary containing at least keys
92    ``arr_out`` (expected output array from proxy), ``dtype_out`` (expected
93    output dtype for array) and ``shape`` (shape of array).
95    The <header> above should support at least "get_data_dtype",
96    "set_data_dtype", "get_data_shape", "set_data_shape"
97    """
98    # Flag True if offset can be set into header of image
99    settable_offset = False
101    def validate_shape(self, pmaker, params):
102        # Check shape
103        prox, fio, hdr = pmaker()
104        assert_array_equal(prox.shape, params['shape'])
105        # Read only
106        with pytest.raises(AttributeError):
107            prox.shape = params['shape']
109    def validate_ndim(self, pmaker, params):
110        # Check shape
111        prox, fio, hdr = pmaker()
112        assert prox.ndim == len(params['shape'])
113        # Read only
114        with pytest.raises(AttributeError):
115            prox.ndim = len(params['shape'])
117    def validate_is_proxy(self, pmaker, params):
118        # Check shape
119        prox, fio, hdr = pmaker()
120        assert prox.is_proxy
121        assert is_proxy(prox)
122        assert not is_proxy(np.arange(10))
123        # Read only
124        with pytest.raises(AttributeError):
125            prox.is_proxy = False
127    def validate_asarray(self, pmaker, params):
128        # Check proxy returns expected array from asarray
129        prox, fio, hdr = pmaker()
130        out = np.asarray(prox)
131        assert_array_equal(out, params['arr_out'])
132        assert_dt_equal(out.dtype, params['dtype_out'])
133        # Shape matches expected shape
134        assert out.shape == params['shape']
136    def validate_array_interface_with_dtype(self, pmaker, params):
137        # Check proxy returns expected array from asarray
138        prox, fio, hdr = pmaker()
139        orig = np.array(prox, dtype=None)
140        assert_array_equal(orig, params['arr_out'])
141        assert_dt_equal(orig.dtype, params['dtype_out'])
143        context = None
144        if np.issubdtype(orig.dtype, np.complexfloating):
145            context = clear_and_catch_warnings()
146            context.__enter__()
147            warnings.simplefilter('ignore', np.ComplexWarning)
149        for dtype in np.sctypes['float'] + np.sctypes['int'] + np.sctypes['uint']:
150            # Directly coerce with a dtype
151            direct = dtype(prox)
152            # Half-precision is imprecise. Obviously. It's a bad idea, but don't break
153            # the test over it.
154            rtol = 1e-03 if dtype == np.float16 else 1e-05
155            assert_allclose(direct, orig.astype(dtype), rtol=rtol, atol=1e-08)
156            assert_dt_equal(direct.dtype, np.dtype(dtype))
157            assert direct.shape == params['shape']
158            # All three methods should produce equivalent results
159            for arrmethod in (np.array, np.asarray, np.asanyarray):
160                out = arrmethod(prox, dtype=dtype)
161                assert_array_equal(out, direct)
162                assert_dt_equal(out.dtype, np.dtype(dtype))
163                # Shape matches expected shape
164                assert out.shape == params['shape']
166        if context is not None:
167            context.__exit__()
169    def validate_header_isolated(self, pmaker, params):
170        # Confirm altering input header has no effect
171        # Depends on header providing 'get_data_dtype', 'set_data_dtype',
172        # 'get_data_shape', 'set_data_shape', 'set_data_offset'
173        prox, fio, hdr = pmaker()
174        assert_array_equal(prox, params['arr_out'])
175        # Mess up header badly and hope for same correct result
176        if hdr.get_data_dtype() == np.uint8:
177            hdr.set_data_dtype(np.int16)
178        else:
179            hdr.set_data_dtype(np.uint8)
180        hdr.set_data_shape(np.array(hdr.get_data_shape()) + 1)
181        if self.settable_offset:
182            hdr.set_data_offset(32)
183        assert_array_equal(prox, params['arr_out'])
185    def validate_fileobj_isolated(self, pmaker, params):
186        # Check file position of read independent of file-like object
187        prox, fio, hdr = pmaker()
188        if isinstance(fio, str):
189            return
190        assert_array_equal(prox, params['arr_out'])
191        fio.read()  # move to end of file
192        assert_array_equal(prox, params['arr_out'])
194    def validate_proxy_slicing(self, pmaker, params):
195        # Confirm that proxy object can be sliced correctly
196        arr = params['arr_out']
197        shape = arr.shape
198        prox, fio, hdr = pmaker()
199        for sliceobj in _some_slicers(shape):
200            assert_array_equal(arr[sliceobj], prox[sliceobj])
203class TestAnalyzeProxyAPI(_TestProxyAPI):
204    """ Specific Analyze-type array proxy API test
206    The analyze proxy extends the general API by adding read-only attributes
207    ``slope, inter, offset``
208    """
209    proxy_class = ArrayProxy
210    header_class = AnalyzeHeader
211    shapes = ((2,), (2, 3), (2, 3, 4), (2, 3, 4, 5))
212    has_slope = False
213    has_inter = False
214    data_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.complex64, np.float64)
215    array_order = 'F'
216    # Cannot set offset for Freesurfer
217    settable_offset = True
218    # Freesurfer enforces big-endian. '=' means use native
219    data_endian = '='
221    def obj_params(self):
222        """ Iterator returning (``proxy_creator``, ``proxy_params``) pairs
224        Each pair will be tested separately.
226        ``proxy_creator`` is a function taking no arguments and returning (fresh
227        proxy object, fileobj, header).  We need to pass this function rather
228        than a proxy instance so we can recreate the proxy objects fresh for
229        each of multiple tests run from the ``validate_xxx`` autogenerated test
230        methods.  This allows the tests to modify the proxy instance without
231        having an effect on the later tests in the same function.
232        """
233        # Analyze and up wrap binary arrays, Fortran ordered, with given offset
234        # and dtype and shape.
235        if not self.settable_offset:
236            offsets = (self.header_class().get_data_offset(),)
237        else:
238            offsets = (0, 16)
239        # For non-integral parameters, cast to float32 value can be losslessly cast
240        # later, enabling exact checks, then back to float for consistency
241        slopes = (1., 2., float(np.float32(3.1416))) if self.has_slope else (1.,)
242        inters = (0., 10., float(np.float32(2.7183))) if self.has_inter else (0.,)
243        for shape, dtype, offset, slope, inter in product(self.shapes,
244                                                          self.data_dtypes,
245                                                          offsets,
246                                                          slopes,
247                                                          inters):
248            n_els = np.prod(shape)
249            dtype = np.dtype(dtype).newbyteorder(self.data_endian)
250            arr = np.arange(n_els, dtype=dtype).reshape(shape)
251            data = arr.tobytes(order=self.array_order)
252            hdr = self.header_class()
253            hdr.set_data_dtype(dtype)
254            hdr.set_data_shape(shape)
255            if self.settable_offset:
256                hdr.set_data_offset(offset)
257            if (slope, inter) == (1, 0):  # No scaling applied
258                # dtype from array
259                dtype_out = dtype
260            else:  # scaling or offset applied
261                # out dtype predictable from apply_read_scaling
262                # and datatypes of slope, inter
263                hdr.set_slope_inter(slope, inter)
264                s, i = hdr.get_slope_inter()
265                tmp = apply_read_scaling(arr,
266                                         1. if s is None else s,
267                                         0. if i is None else i)
268                dtype_out = tmp.dtype.type
270            def sio_func():
271                fio = BytesIO()
272                fio.truncate(0)
273                fio.seek(offset)
274                fio.write(data)
275                # Use a copy of the header to avoid changing
276                # global header in test functions.
277                new_hdr = hdr.copy()
278                return (self.proxy_class(fio, new_hdr),
279                        fio,
280                        new_hdr)
282            params = dict(
283                dtype=dtype,
284                dtype_out=dtype_out,
285                arr=arr.copy(),
286                arr_out=arr.astype(dtype_out) * slope + inter,
287                shape=shape,
288                offset=offset,
289                slope=slope,
290                inter=inter)
291            yield sio_func, params
292            # Same with filenames
293            with InTemporaryDirectory():
294                fname = 'data.bin'
296                def fname_func():
297                    with open(fname, 'wb') as fio:
298                        fio.seek(offset)
299                        fio.write(data)
300                    # Use a copy of the header to avoid changing
301                    # global header in test functions.
302                    new_hdr = hdr.copy()
303                    return (self.proxy_class(fname, new_hdr),
304                            fname,
305                            new_hdr)
306                params = params.copy()
307                yield fname_func, params
309    def validate_dtype(self, pmaker, params):
310        # Read-only dtype attribute
311        prox, fio, hdr = pmaker()
312        assert_dt_equal(prox.dtype, params['dtype'])
313        with pytest.raises(AttributeError):
314            prox.dtype = np.dtype(prox.dtype)
316    def validate_slope_inter_offset(self, pmaker, params):
317        # Check slope, inter, offset
318        prox, fio, hdr = pmaker()
319        for attr_name in ('slope', 'inter', 'offset'):
320            expected = params[attr_name]
321            assert_array_equal(getattr(prox, attr_name), expected)
322            # Read only
323            with pytest.raises(AttributeError):
324                setattr(prox, attr_name, expected)
326    def validate_deprecated_header(self, pmaker, params):
327        prox, fio, hdr = pmaker()
328        with pytest.raises(ExpiredDeprecationError):
329            prox.header
332class TestSpm99AnalyzeProxyAPI(TestAnalyzeProxyAPI):
333    # SPM-type analyze has slope scaling but not intercept
334    header_class = Spm99AnalyzeHeader
335    has_slope = True
338class TestSpm2AnalyzeProxyAPI(TestSpm99AnalyzeProxyAPI):
339    header_class = Spm2AnalyzeHeader
342class TestNifti1ProxyAPI(TestSpm99AnalyzeProxyAPI):
343    header_class = Nifti1Header
344    has_inter = True
345    data_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.complex64, np.float64,
346                   np.int8, np.uint16, np.uint32, np.int64, np.uint64, np.complex128)
347    if have_binary128():
348        data_dtypes += (np.float128, np.complex256)
351class TestMGHAPI(TestAnalyzeProxyAPI):
352    header_class = MGHHeader
353    shapes = ((2, 3, 4), (2, 3, 4, 5))  # MGH can only do >= 3D
354    has_slope = False
355    has_inter = False
356    settable_offset = False
357    data_endian = '>'
358    data_dtypes = (np.uint8, np.int16, np.int32, np.float32)
361class TestMinc1API(_TestProxyAPI):
362    module = minc1
363    file_class = minc1.Minc1File
364    eg_fname = 'tiny.mnc'
365    eg_shape = (10, 20, 20)
367    @staticmethod
368    def opener(f):
369        return netcdf_file(f, mode='r')
371    def obj_params(self):
372        """ Iterator returning (``proxy_creator``, ``proxy_params``) pairs
374        Each pair will be tested separately.
376        ``proxy_creator`` is a function taking no arguments and returning (fresh
377        proxy object, fileobj, header).  We need to pass this function rather
378        than a proxy instance so we can recreate the proxy objects fresh for
379        each of multiple tests run from the ``validate_xxx`` autogenerated test
380        methods.  This allows the tests to modify the proxy instance without
381        having an effect on the later tests in the same function.
382        """
383        eg_path = pjoin(DATA_PATH, self.eg_fname)
384        arr_out = self.file_class(
385            self.opener(eg_path)).get_scaled_data()
387        def eg_func():
388            mf = self.file_class(self.opener(eg_path))
389            prox = minc1.MincImageArrayProxy(mf)
390            img = self.module.load(eg_path)
391            fobj = open(eg_path, 'rb')
392            return prox, fobj, img.header
393        yield (eg_func,
394               dict(shape=self.eg_shape,
395                    dtype_out=np.float64,
396                    arr_out=arr_out))
399if have_h5py:
400    class TestMinc2API(TestMinc1API):
401        module = minc2
402        file_class = minc2.Minc2File
403        eg_fname = 'small.mnc'
404        eg_shape = (18, 28, 29)
406        @staticmethod
407        def opener(f):
408            return h5py.File(f, mode='r')
411class TestEcatAPI(_TestProxyAPI):
412    eg_fname = 'tinypet.v'
413    eg_shape = (10, 10, 3, 1)
415    def obj_params(self):
416        eg_path = pjoin(DATA_PATH, self.eg_fname)
417        img = ecat.load(eg_path)
418        arr_out = img.get_fdata()
420        def eg_func():
421            img = ecat.load(eg_path)
422            sh = img.get_subheaders()
423            prox = ecat.EcatImageArrayProxy(sh)
424            fobj = open(eg_path, 'rb')
425            return prox, fobj, sh
426        yield (eg_func,
427               dict(shape=self.eg_shape,
428                    dtype_out=np.float64,
429                    arr_out=arr_out))
431    def validate_header_isolated(self, pmaker, params):
432        raise unittest.SkipTest('ECAT header does not support dtype get')
435class TestPARRECAPI(_TestProxyAPI):
437    def _func_dict(self, rec_name):
438        img = parrec.load(rec_name)
439        arr_out = img.get_fdata()
441        def eg_func():
442            img = parrec.load(rec_name)
443            prox = parrec.PARRECArrayProxy(rec_name,
444                                           img.header,
445                                           scaling='dv')
446            fobj = open(rec_name, 'rb')
447            return prox, fobj, img.header
448        return (eg_func,
449                dict(shape=img.shape,
450                     dtype_out=np.float64,
451                     arr_out=arr_out))
453    def obj_params(self):
454        yield self._func_dict(EG_REC)
455        yield self._func_dict(VARY_REC)