1# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2# vi: set ft=python sts=4 ts=4 sw=4 et:
3### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4#
5#   See COPYING file distributed along with the NiBabel package for the
6#   copyright and license terms.
7#
8### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9""" Validate image proxy API
10
11Minimum array proxy API is:
12
13* read only ``shape`` property
14* read only ``is_proxy`` property set to True
15* returns array from ``np.asarray(prox)``
16* returns array slice from ``prox[<slice_spec>]`` where ``<slice_spec>`` is any
17  non-fancy slice specification.
18
19And:
20
21* that modifying no object outside ``prox`` will affect the result of
22  ``np.asarray(obj)``.  Specifically:
23  * Changes in position (``obj.tell()``) of any passed file-like objects
24    will not affect the output of from ``np.asarray(proxy)``.
25  * if you pass a header into the __init__, then modifying the original
26    header will not affect the result of the array return.
27
28These last are to allow the proxy to be re-used with different images.
29"""
30
31from os.path import join as pjoin
32import warnings
33from itertools import product
34from io import BytesIO
35
36import numpy as np
37
38from ..volumeutils import apply_read_scaling
39from ..analyze import AnalyzeHeader
40from ..spm99analyze import Spm99AnalyzeHeader
41from ..spm2analyze import Spm2AnalyzeHeader
42from ..nifti1 import Nifti1Header
43from ..freesurfer.mghformat import MGHHeader
44from .. import minc1
45from ..externals.netcdf import netcdf_file
46from .. import minc2
47from .._h5py_compat import h5py, have_h5py
48from .. import ecat
49from .. import parrec
50from ..casting import have_binary128
51
52from ..arrayproxy import ArrayProxy, is_proxy
53
54import unittest
55import pytest
56from numpy.testing import assert_almost_equal, assert_array_equal, assert_allclose
57
58from ..testing import data_path as DATA_PATH, assert_dt_equal, clear_and_catch_warnings
59from ..deprecator import ExpiredDeprecationError
60
61from ..tmpdirs import InTemporaryDirectory
62
63from .test_api_validators import ValidateAPI
64from .test_parrec import EG_REC, VARY_REC
65
66
67def _some_slicers(shape):
68    ndim = len(shape)
69    slicers = np.eye(ndim).astype(int).astype(object)
70    slicers[slicers == 0] = slice(None)
71    for i in range(ndim):
72        if i % 2:
73            slicers[i, i] = -1
74        elif shape[i] < 2:  # some proxy examples have length 1 axes
75            slicers[i, i] = 0
76    # Add a newaxis to keep us on our toes
77    no_pos = ndim // 2
78    slicers = np.hstack((slicers[:, :no_pos],
79                         np.empty((ndim, 1)),
80                         slicers[:, no_pos:]))
81    slicers[:, no_pos] = None
82    return [tuple(s) for s in slicers]
83
84
85class _TestProxyAPI(ValidateAPI):
86    """ Base class for testing proxy APIs
87
88    Assumes that real classes will provide an `obj_params` method which is a
89    generator returning 2 tuples of (<proxy_maker>, <param_dict>).
90    <proxy_maker> is a function returning a 3 tuple of (<proxy>, <fileobj>,
91    <header>).  <param_dict> is a dictionary containing at least keys
92    ``arr_out`` (expected output array from proxy), ``dtype_out`` (expected
93    output dtype for array) and ``shape`` (shape of array).
94
95    The <header> above should support at least "get_data_dtype",
96    "set_data_dtype", "get_data_shape", "set_data_shape"
97    """
98    # Flag True if offset can be set into header of image
99    settable_offset = False
100
101    def validate_shape(self, pmaker, params):
102        # Check shape
103        prox, fio, hdr = pmaker()
104        assert_array_equal(prox.shape, params['shape'])
105        # Read only
106        with pytest.raises(AttributeError):
107            prox.shape = params['shape']
108
109    def validate_ndim(self, pmaker, params):
110        # Check shape
111        prox, fio, hdr = pmaker()
112        assert prox.ndim == len(params['shape'])
113        # Read only
114        with pytest.raises(AttributeError):
115            prox.ndim = len(params['shape'])
116
117    def validate_is_proxy(self, pmaker, params):
118        # Check shape
119        prox, fio, hdr = pmaker()
120        assert prox.is_proxy
121        assert is_proxy(prox)
122        assert not is_proxy(np.arange(10))
123        # Read only
124        with pytest.raises(AttributeError):
125            prox.is_proxy = False
126
127    def validate_asarray(self, pmaker, params):
128        # Check proxy returns expected array from asarray
129        prox, fio, hdr = pmaker()
130        out = np.asarray(prox)
131        assert_array_equal(out, params['arr_out'])
132        assert_dt_equal(out.dtype, params['dtype_out'])
133        # Shape matches expected shape
134        assert out.shape == params['shape']
135
136    def validate_array_interface_with_dtype(self, pmaker, params):
137        # Check proxy returns expected array from asarray
138        prox, fio, hdr = pmaker()
139        orig = np.array(prox, dtype=None)
140        assert_array_equal(orig, params['arr_out'])
141        assert_dt_equal(orig.dtype, params['dtype_out'])
142
143        context = None
144        if np.issubdtype(orig.dtype, np.complexfloating):
145            context = clear_and_catch_warnings()
146            context.__enter__()
147            warnings.simplefilter('ignore', np.ComplexWarning)
148
149        for dtype in np.sctypes['float'] + np.sctypes['int'] + np.sctypes['uint']:
150            # Directly coerce with a dtype
151            direct = dtype(prox)
152            # Half-precision is imprecise. Obviously. It's a bad idea, but don't break
153            # the test over it.
154            rtol = 1e-03 if dtype == np.float16 else 1e-05
155            assert_allclose(direct, orig.astype(dtype), rtol=rtol, atol=1e-08)
156            assert_dt_equal(direct.dtype, np.dtype(dtype))
157            assert direct.shape == params['shape']
158            # All three methods should produce equivalent results
159            for arrmethod in (np.array, np.asarray, np.asanyarray):
160                out = arrmethod(prox, dtype=dtype)
161                assert_array_equal(out, direct)
162                assert_dt_equal(out.dtype, np.dtype(dtype))
163                # Shape matches expected shape
164                assert out.shape == params['shape']
165
166        if context is not None:
167            context.__exit__()
168
169    def validate_header_isolated(self, pmaker, params):
170        # Confirm altering input header has no effect
171        # Depends on header providing 'get_data_dtype', 'set_data_dtype',
172        # 'get_data_shape', 'set_data_shape', 'set_data_offset'
173        prox, fio, hdr = pmaker()
174        assert_array_equal(prox, params['arr_out'])
175        # Mess up header badly and hope for same correct result
176        if hdr.get_data_dtype() == np.uint8:
177            hdr.set_data_dtype(np.int16)
178        else:
179            hdr.set_data_dtype(np.uint8)
180        hdr.set_data_shape(np.array(hdr.get_data_shape()) + 1)
181        if self.settable_offset:
182            hdr.set_data_offset(32)
183        assert_array_equal(prox, params['arr_out'])
184
185    def validate_fileobj_isolated(self, pmaker, params):
186        # Check file position of read independent of file-like object
187        prox, fio, hdr = pmaker()
188        if isinstance(fio, str):
189            return
190        assert_array_equal(prox, params['arr_out'])
191        fio.read()  # move to end of file
192        assert_array_equal(prox, params['arr_out'])
193
194    def validate_proxy_slicing(self, pmaker, params):
195        # Confirm that proxy object can be sliced correctly
196        arr = params['arr_out']
197        shape = arr.shape
198        prox, fio, hdr = pmaker()
199        for sliceobj in _some_slicers(shape):
200            assert_array_equal(arr[sliceobj], prox[sliceobj])
201
202
203class TestAnalyzeProxyAPI(_TestProxyAPI):
204    """ Specific Analyze-type array proxy API test
205
206    The analyze proxy extends the general API by adding read-only attributes
207    ``slope, inter, offset``
208    """
209    proxy_class = ArrayProxy
210    header_class = AnalyzeHeader
211    shapes = ((2,), (2, 3), (2, 3, 4), (2, 3, 4, 5))
212    has_slope = False
213    has_inter = False
214    data_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.complex64, np.float64)
215    array_order = 'F'
216    # Cannot set offset for Freesurfer
217    settable_offset = True
218    # Freesurfer enforces big-endian. '=' means use native
219    data_endian = '='
220
221    def obj_params(self):
222        """ Iterator returning (``proxy_creator``, ``proxy_params``) pairs
223
224        Each pair will be tested separately.
225
226        ``proxy_creator`` is a function taking no arguments and returning (fresh
227        proxy object, fileobj, header).  We need to pass this function rather
228        than a proxy instance so we can recreate the proxy objects fresh for
229        each of multiple tests run from the ``validate_xxx`` autogenerated test
230        methods.  This allows the tests to modify the proxy instance without
231        having an effect on the later tests in the same function.
232        """
233        # Analyze and up wrap binary arrays, Fortran ordered, with given offset
234        # and dtype and shape.
235        if not self.settable_offset:
236            offsets = (self.header_class().get_data_offset(),)
237        else:
238            offsets = (0, 16)
239        # For non-integral parameters, cast to float32 value can be losslessly cast
240        # later, enabling exact checks, then back to float for consistency
241        slopes = (1., 2., float(np.float32(3.1416))) if self.has_slope else (1.,)
242        inters = (0., 10., float(np.float32(2.7183))) if self.has_inter else (0.,)
243        for shape, dtype, offset, slope, inter in product(self.shapes,
244                                                          self.data_dtypes,
245                                                          offsets,
246                                                          slopes,
247                                                          inters):
248            n_els = np.prod(shape)
249            dtype = np.dtype(dtype).newbyteorder(self.data_endian)
250            arr = np.arange(n_els, dtype=dtype).reshape(shape)
251            data = arr.tobytes(order=self.array_order)
252            hdr = self.header_class()
253            hdr.set_data_dtype(dtype)
254            hdr.set_data_shape(shape)
255            if self.settable_offset:
256                hdr.set_data_offset(offset)
257            if (slope, inter) == (1, 0):  # No scaling applied
258                # dtype from array
259                dtype_out = dtype
260            else:  # scaling or offset applied
261                # out dtype predictable from apply_read_scaling
262                # and datatypes of slope, inter
263                hdr.set_slope_inter(slope, inter)
264                s, i = hdr.get_slope_inter()
265                tmp = apply_read_scaling(arr,
266                                         1. if s is None else s,
267                                         0. if i is None else i)
268                dtype_out = tmp.dtype.type
269
270            def sio_func():
271                fio = BytesIO()
272                fio.truncate(0)
273                fio.seek(offset)
274                fio.write(data)
275                # Use a copy of the header to avoid changing
276                # global header in test functions.
277                new_hdr = hdr.copy()
278                return (self.proxy_class(fio, new_hdr),
279                        fio,
280                        new_hdr)
281
282            params = dict(
283                dtype=dtype,
284                dtype_out=dtype_out,
285                arr=arr.copy(),
286                arr_out=arr.astype(dtype_out) * slope + inter,
287                shape=shape,
288                offset=offset,
289                slope=slope,
290                inter=inter)
291            yield sio_func, params
292            # Same with filenames
293            with InTemporaryDirectory():
294                fname = 'data.bin'
295
296                def fname_func():
297                    with open(fname, 'wb') as fio:
298                        fio.seek(offset)
299                        fio.write(data)
300                    # Use a copy of the header to avoid changing
301                    # global header in test functions.
302                    new_hdr = hdr.copy()
303                    return (self.proxy_class(fname, new_hdr),
304                            fname,
305                            new_hdr)
306                params = params.copy()
307                yield fname_func, params
308
309    def validate_dtype(self, pmaker, params):
310        # Read-only dtype attribute
311        prox, fio, hdr = pmaker()
312        assert_dt_equal(prox.dtype, params['dtype'])
313        with pytest.raises(AttributeError):
314            prox.dtype = np.dtype(prox.dtype)
315
316    def validate_slope_inter_offset(self, pmaker, params):
317        # Check slope, inter, offset
318        prox, fio, hdr = pmaker()
319        for attr_name in ('slope', 'inter', 'offset'):
320            expected = params[attr_name]
321            assert_array_equal(getattr(prox, attr_name), expected)
322            # Read only
323            with pytest.raises(AttributeError):
324                setattr(prox, attr_name, expected)
325
326    def validate_deprecated_header(self, pmaker, params):
327        prox, fio, hdr = pmaker()
328        with pytest.raises(ExpiredDeprecationError):
329            prox.header
330
331
332class TestSpm99AnalyzeProxyAPI(TestAnalyzeProxyAPI):
333    # SPM-type analyze has slope scaling but not intercept
334    header_class = Spm99AnalyzeHeader
335    has_slope = True
336
337
338class TestSpm2AnalyzeProxyAPI(TestSpm99AnalyzeProxyAPI):
339    header_class = Spm2AnalyzeHeader
340
341
342class TestNifti1ProxyAPI(TestSpm99AnalyzeProxyAPI):
343    header_class = Nifti1Header
344    has_inter = True
345    data_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.complex64, np.float64,
346                   np.int8, np.uint16, np.uint32, np.int64, np.uint64, np.complex128)
347    if have_binary128():
348        data_dtypes += (np.float128, np.complex256)
349
350
351class TestMGHAPI(TestAnalyzeProxyAPI):
352    header_class = MGHHeader
353    shapes = ((2, 3, 4), (2, 3, 4, 5))  # MGH can only do >= 3D
354    has_slope = False
355    has_inter = False
356    settable_offset = False
357    data_endian = '>'
358    data_dtypes = (np.uint8, np.int16, np.int32, np.float32)
359
360
361class TestMinc1API(_TestProxyAPI):
362    module = minc1
363    file_class = minc1.Minc1File
364    eg_fname = 'tiny.mnc'
365    eg_shape = (10, 20, 20)
366
367    @staticmethod
368    def opener(f):
369        return netcdf_file(f, mode='r')
370
371    def obj_params(self):
372        """ Iterator returning (``proxy_creator``, ``proxy_params``) pairs
373
374        Each pair will be tested separately.
375
376        ``proxy_creator`` is a function taking no arguments and returning (fresh
377        proxy object, fileobj, header).  We need to pass this function rather
378        than a proxy instance so we can recreate the proxy objects fresh for
379        each of multiple tests run from the ``validate_xxx`` autogenerated test
380        methods.  This allows the tests to modify the proxy instance without
381        having an effect on the later tests in the same function.
382        """
383        eg_path = pjoin(DATA_PATH, self.eg_fname)
384        arr_out = self.file_class(
385            self.opener(eg_path)).get_scaled_data()
386
387        def eg_func():
388            mf = self.file_class(self.opener(eg_path))
389            prox = minc1.MincImageArrayProxy(mf)
390            img = self.module.load(eg_path)
391            fobj = open(eg_path, 'rb')
392            return prox, fobj, img.header
393        yield (eg_func,
394               dict(shape=self.eg_shape,
395                    dtype_out=np.float64,
396                    arr_out=arr_out))
397
398
399if have_h5py:
400    class TestMinc2API(TestMinc1API):
401        module = minc2
402        file_class = minc2.Minc2File
403        eg_fname = 'small.mnc'
404        eg_shape = (18, 28, 29)
405
406        @staticmethod
407        def opener(f):
408            return h5py.File(f, mode='r')
409
410
411class TestEcatAPI(_TestProxyAPI):
412    eg_fname = 'tinypet.v'
413    eg_shape = (10, 10, 3, 1)
414
415    def obj_params(self):
416        eg_path = pjoin(DATA_PATH, self.eg_fname)
417        img = ecat.load(eg_path)
418        arr_out = img.get_fdata()
419
420        def eg_func():
421            img = ecat.load(eg_path)
422            sh = img.get_subheaders()
423            prox = ecat.EcatImageArrayProxy(sh)
424            fobj = open(eg_path, 'rb')
425            return prox, fobj, sh
426        yield (eg_func,
427               dict(shape=self.eg_shape,
428                    dtype_out=np.float64,
429                    arr_out=arr_out))
430
431    def validate_header_isolated(self, pmaker, params):
432        raise unittest.SkipTest('ECAT header does not support dtype get')
433
434
435class TestPARRECAPI(_TestProxyAPI):
436
437    def _func_dict(self, rec_name):
438        img = parrec.load(rec_name)
439        arr_out = img.get_fdata()
440
441        def eg_func():
442            img = parrec.load(rec_name)
443            prox = parrec.PARRECArrayProxy(rec_name,
444                                           img.header,
445                                           scaling='dv')
446            fobj = open(rec_name, 'rb')
447            return prox, fobj, img.header
448        return (eg_func,
449                dict(shape=img.shape,
450                     dtype_out=np.float64,
451                     arr_out=arr_out))
452
453    def obj_params(self):
454        yield self._func_dict(EG_REC)
455        yield self._func_dict(VARY_REC)
456