1""" Testing dataobj_images module
2"""
3
4import numpy as np
5
6from nibabel.filebasedimages import FileBasedHeader
7from nibabel.dataobj_images import DataobjImage
8
9from nibabel.tests.test_image_api import DataInterfaceMixin
10from nibabel.tests.test_filebasedimages import TestFBImageAPI as _TFI
11
12
13class DoNumpyImage(DataobjImage):
14    header_class = FileBasedHeader
15    valid_exts = ('.npy',)
16    files_types = (('image', '.npy'),)
17
18    @classmethod
19    def from_file_map(klass, file_map, mmap=True, keep_file_open=None):
20        if mmap not in (True, False, 'c', 'r'):
21            raise ValueError("mmap should be one of {True, False, 'c', 'r'}")
22        if mmap is True:
23            mmap = 'c'
24        elif mmap is False:
25            mmap = None
26        with file_map['image'].get_prepare_fileobj('rb') as fobj:
27            try:
28                arr = np.load(fobj, mmap=mmap)
29            except:
30                arr = np.load(fobj)
31        return klass(arr)
32
33    def to_file_map(self, file_map=None):
34        file_map = self.file_map if file_map is None else file_map
35        with file_map['image'].get_prepare_fileobj('wb') as fobj:
36            np.save(fobj, self.dataobj)
37
38    def get_data_dtype(self):
39        return self.dataobj.dtype
40
41    def set_data_dtype(self, dtype):
42        self._dataobj = self._dataobj.astype(dtype)
43
44
45class TestDataobjAPI(_TFI, DataInterfaceMixin):
46    """ Validation for DataobjImage instances
47    """
48    # A callable returning an image from ``image_maker(data, header)``
49    image_maker = DoNumpyImage
50