1import os
2from os.path import join as pjoin, isdir
3import getpass
4import time
5import struct
6import hashlib
7import warnings
8
9from ...tmpdirs import InTemporaryDirectory
10
11import unittest
12import pytest
13import numpy as np
14from numpy.testing import assert_allclose, assert_array_equal
15
16from .. import (read_geometry, read_morph_data, read_annot, read_label,
17                write_geometry, write_morph_data, write_annot)
18from ..io import _pack_rgb
19
20from ...tests.nibabel_data import get_nibabel_data, needs_nibabel_data
21from ...fileslice import strided_scalar
22from ...testing import clear_and_catch_warnings
23
24DATA_SDIR = 'fsaverage'
25
26have_freesurfer = False
27if 'SUBJECTS_DIR' in os.environ:
28    # May have Freesurfer installed with data
29    data_path = pjoin(os.environ["SUBJECTS_DIR"], DATA_SDIR)
30    have_freesurfer = isdir(data_path)
31else:
32    # May have nibabel test data submodule checked out
33    nib_data = get_nibabel_data()
34    if nib_data != '':
35        data_path = pjoin(nib_data, 'nitest-freesurfer', DATA_SDIR)
36        have_freesurfer = isdir(data_path)
37
38freesurfer_test = unittest.skipUnless(have_freesurfer,
39                                      f'cannot find freesurfer {DATA_SDIR} directory')
40
41def _hash_file_content(fname):
42    hasher = hashlib.md5()
43    with open(fname, 'rb') as afile:
44        buf = afile.read()
45        hasher.update(buf)
46    return hasher.hexdigest()
47
48
49@freesurfer_test
50def test_geometry():
51    """Test IO of .surf"""
52    surf_path = pjoin(data_path, "surf", "lh.inflated")
53    coords, faces = read_geometry(surf_path)
54    assert 0 == faces.min()
55    assert coords.shape[0] == faces.max() + 1
56
57    surf_path = pjoin(data_path, "surf", "lh.sphere")
58    coords, faces, volume_info, create_stamp = read_geometry(
59        surf_path, read_metadata=True, read_stamp=True)
60
61    assert 0 == faces.min()
62    assert coords.shape[0] == faces.max() + 1
63    assert 9 == len(volume_info)
64    assert np.array_equal([2, 0, 20], volume_info['head'])
65    assert create_stamp == 'created by greve on Thu Jun  8 19:17:51 2006'
66
67    # Test equivalence of freesurfer- and nibabel-generated triangular files
68    # with respect to read_geometry()
69    with InTemporaryDirectory():
70        surf_path = 'test'
71        create_stamp = f"created by {getpass.getuser()} on {time.ctime()}"
72        volume_info['cras'] = [1., 2., 3.]
73        write_geometry(surf_path, coords, faces, create_stamp, volume_info)
74
75        coords2, faces2, volume_info2 = \
76            read_geometry(surf_path, read_metadata=True)
77
78        for key in ('xras', 'yras', 'zras', 'cras'):
79            assert_allclose(volume_info2[key], volume_info[key],
80                            rtol=1e-7, atol=1e-30)
81
82        assert np.array_equal(volume_info2['cras'], volume_info['cras'])
83        with open(surf_path, 'rb') as fobj:
84            np.fromfile(fobj, ">u1", 3)
85            read_create_stamp = fobj.readline().decode().rstrip('\n')
86
87        # now write an incomplete file
88        write_geometry(surf_path, coords, faces)
89        with pytest.warns(UserWarning) as w:
90            read_geometry(surf_path, read_metadata=True)
91        assert any('volume information contained' in str(ww.message) for ww in w)
92        assert any('extension code' in str(ww.message) for ww in w)
93
94        volume_info['head'] = [1, 2]
95        with pytest.warns(UserWarning, match="Unknown extension"):
96            write_geometry(surf_path, coords, faces, create_stamp, volume_info)
97
98        volume_info['a'] = 0
99        with pytest.raises(ValueError):
100            write_geometry(surf_path, coords, faces, create_stamp, volume_info)
101
102    assert create_stamp == read_create_stamp
103
104    assert np.array_equal(coords, coords2)
105    assert np.array_equal(faces, faces2)
106
107    # Validate byte ordering
108    coords_swapped = coords.byteswap().newbyteorder()
109    faces_swapped = faces.byteswap().newbyteorder()
110    assert np.array_equal(coords_swapped, coords)
111    assert np.array_equal(faces_swapped, faces)
112
113
114@freesurfer_test
115@needs_nibabel_data('nitest-freesurfer')
116def test_quad_geometry():
117    """Test IO of freesurfer quad files."""
118    new_quad = pjoin(get_nibabel_data(), 'nitest-freesurfer', 'subjects',
119                     'bert', 'surf', 'lh.inflated.nofix')
120    coords, faces = read_geometry(new_quad)
121    assert 0 == faces.min()
122    assert coords.shape[0] == (faces.max() + 1)
123    with InTemporaryDirectory():
124        new_path = 'test'
125        write_geometry(new_path, coords, faces)
126        coords2, faces2 = read_geometry(new_path)
127        assert np.array_equal(coords,coords2)
128        assert np.array_equal(faces, faces2)
129
130
131@freesurfer_test
132def test_morph_data():
133    """Test IO of morphometry data file (eg. curvature)."""
134    curv_path = pjoin(data_path, "surf", "lh.curv")
135    curv = read_morph_data(curv_path)
136    assert -1.0 < curv.min() < 0
137    assert 0 < curv.max() < 1.0
138    with InTemporaryDirectory():
139        new_path = 'test'
140        write_morph_data(new_path, curv)
141        curv2 = read_morph_data(new_path)
142        assert np.array_equal(curv2, curv)
143
144
145def test_write_morph_data():
146    """Test write_morph_data edge cases"""
147    values = np.arange(20, dtype='>f4')
148    okay_shapes = [(20,), (20, 1), (20, 1, 1), (1, 20)]
149    bad_shapes = [(10, 2), (1, 1, 20, 1, 1)]
150    big_num = np.iinfo('i4').max + 1
151    with InTemporaryDirectory():
152        for shape in okay_shapes:
153            write_morph_data('test.curv', values.reshape(shape))
154            # Check ordering is preserved, regardless of shape
155            assert np.array_equal(read_morph_data('test.curv'), values)
156
157        with pytest.raises(ValueError):
158            write_morph_data('test.curv', np.zeros(shape), big_num)
159        # Windows 32-bit overflows Python int
160        if np.dtype(int) != np.dtype(np.int32):
161            with pytest.raises(ValueError):
162                write_morph_data('test.curv',  strided_scalar((big_num,)))
163        for shape in bad_shapes:
164            with pytest.raises(ValueError):
165                write_morph_data('test.curv', values.reshape(shape))
166
167@freesurfer_test
168def test_annot():
169    """Test IO of .annot against freesurfer example data."""
170    annots = ['aparc', 'aparc.a2005s']
171    for a in annots:
172        annot_path = pjoin(data_path, "label", f"lh.{a}.annot")
173        hash_ = _hash_file_content(annot_path)
174
175        labels, ctab, names = read_annot(annot_path)
176        assert labels.shape == (163842, )
177        assert ctab.shape == (len(names), 5)
178
179        labels_orig = None
180        if a == 'aparc':
181            labels_orig, _, _ = read_annot(annot_path, orig_ids=True)
182            np.testing.assert_array_equal(labels == -1, labels_orig == 0)
183            # Handle different version of fsaverage
184            if hash_ == 'bf0b488994657435cdddac5f107d21e8':
185                assert np.sum(labels_orig == 0) == 13887
186            elif hash_ == 'd4f5b7cbc2ed363ac6fcf89e19353504':
187                assert np.sum(labels_orig == 1639705) == 13327
188            else:
189                raise RuntimeError("Unknown freesurfer file. Please report "
190                                   "the problem to the maintainer of nibabel.")
191
192        # Test equivalence of freesurfer- and nibabel-generated annot files
193        # with respect to read_annot()
194        with InTemporaryDirectory():
195            annot_path = 'test'
196            write_annot(annot_path, labels, ctab, names)
197
198            labels2, ctab2, names2 = read_annot(annot_path)
199            if labels_orig is not None:
200                labels_orig_2, _, _ = read_annot(annot_path, orig_ids=True)
201
202        assert np.array_equal(labels, labels2)
203        if labels_orig is not None:
204            assert np.array_equal(labels_orig, labels_orig_2)
205        assert np.array_equal(ctab, ctab2)
206        assert names == names2
207
208
209def test_read_write_annot():
210    """Test generating .annot file and reading it back."""
211    # This annot file will store a LUT for a mesh made of 10 vertices, with
212    # 3 colours in the LUT.
213    nvertices = 10
214    nlabels = 3
215    names = [f'label {l}' for l in range(1, nlabels + 1)]
216    # randomly generate a label for each vertex, making sure
217    # that at least one of each label value is present. Label
218    # values are in the range (0, nlabels-1) - they are used
219    # as indices into the lookup table (generated below).
220    labels = list(range(nlabels)) + \
221             list(np.random.randint(0, nlabels, nvertices - nlabels))
222    labels = np.array(labels, dtype=np.int32)
223    np.random.shuffle(labels)
224    # Generate some random colours for the LUT
225    rgbal = np.zeros((nlabels, 5), dtype=np.int32)
226    rgbal[:, :4] = np.random.randint(0, 255, (nlabels, 4))
227    # But make sure we have at least one large alpha, to make sure that when
228    # it is packed into a signed 32 bit int, it results in a negative value
229    # for the annotation value.
230    rgbal[0, 3] = 255
231    # Generate the annotation values for each LUT entry
232    rgbal[:, 4] = (rgbal[:, 0] +
233                   rgbal[:, 1] * (2 ** 8) +
234                   rgbal[:, 2] * (2 ** 16))
235    annot_path = 'c.annot'
236    with InTemporaryDirectory():
237        write_annot(annot_path, labels, rgbal, names, fill_ctab=False)
238        labels2, rgbal2, names2 = read_annot(annot_path)
239        names2 = [n.decode('ascii') for n in names2]
240        assert np.all(np.isclose(rgbal2, rgbal))
241        assert np.all(np.isclose(labels2, labels))
242        assert names2 == names
243
244
245def test_write_annot_fill_ctab():
246    """Test the `fill_ctab` parameter to :func:`.write_annot`. """
247    nvertices = 10
248    nlabels = 3
249    names = [f'label {l}' for l in range(1, nlabels + 1)]
250    labels = list(range(nlabels)) + \
251             list(np.random.randint(0, nlabels, nvertices - nlabels))
252    labels = np.array(labels, dtype=np.int32)
253    np.random.shuffle(labels)
254    rgba = np.array(np.random.randint(0, 255, (nlabels, 4)), dtype=np.int32)
255    annot_path = 'c.annot'
256    with InTemporaryDirectory():
257        write_annot(annot_path, labels, rgba, names, fill_ctab=True)
258        labels2, rgbal2, names2 = read_annot(annot_path)
259        names2 = [n.decode('ascii') for n in names2]
260        assert np.all(np.isclose(rgbal2[:, :4], rgba))
261        assert np.all(np.isclose(labels2, labels))
262        assert names2 == names
263        # make sure a warning is emitted if fill_ctab is False, and the
264        # annotation values are wrong. Use orig_ids=True so we get those bad
265        # values back.
266        badannot = (10 * np.arange(nlabels, dtype=np.int32)).reshape(-1, 1)
267        rgbal = np.hstack((rgba, badannot))
268        with pytest.warns(UserWarning,
269                          match=f'Annotation values in {annot_path} will be incorrect'):
270            write_annot(annot_path, labels, rgbal, names, fill_ctab=False)
271        labels2, rgbal2, names2 = read_annot(annot_path, orig_ids=True)
272        names2 = [n.decode('ascii') for n in names2]
273        assert np.all(np.isclose(rgbal2[:, :4], rgba))
274        assert np.all(np.isclose(labels2, badannot[labels].squeeze()))
275        assert names2 == names
276        # make sure a warning is *not* emitted if fill_ctab is False, but the
277        # annotation values are correct.
278        rgbal = np.hstack((rgba, np.zeros((nlabels, 1), dtype=np.int32)))
279        rgbal[:, 4] = (rgbal[:, 0] +
280                       rgbal[:, 1] * (2 ** 8) +
281                       rgbal[:, 2] * (2 ** 16))
282        with clear_and_catch_warnings() as w:
283            write_annot(annot_path, labels, rgbal, names, fill_ctab=False)
284        assert all(f'Annotation values in {annot_path} will be incorrect' != str(ww.message)
285                   for ww in w)
286        labels2, rgbal2, names2 = read_annot(annot_path)
287        names2 = [n.decode('ascii') for n in names2]
288        assert np.all(np.isclose(rgbal2[:, :4], rgba))
289        assert np.all(np.isclose(labels2, labels))
290        assert names2 == names
291
292
293def test_read_annot_old_format():
294    """Test reading an old-style .annot file."""
295    def gen_old_annot_file(fpath, nverts, labels, rgba, names):
296        dt = '>i'
297        vdata = np.zeros((nverts, 2), dtype=dt)
298        vdata[:, 0] = np.arange(nverts)
299        vdata[:, [1]] = _pack_rgb(rgba[labels, :3])
300        fbytes = b''
301        # number of vertices
302        fbytes += struct.pack(dt, nverts)
303        # vertices + annotation values
304        fbytes += vdata.astype(dt).tobytes()
305        # is there a colour table?
306        fbytes += struct.pack(dt, 1)
307        # number of entries in colour table
308        fbytes += struct.pack(dt, rgba.shape[0])
309        # length of orig_tab string
310        fbytes += struct.pack(dt, 5)
311        fbytes += b'abcd\x00'
312        for i in range(rgba.shape[0]):
313            # length of entry name (+1 for terminating byte)
314            fbytes += struct.pack(dt, len(names[i]) + 1)
315            fbytes += names[i].encode('ascii') + b'\x00'
316            fbytes += rgba[i, :].astype(dt).tobytes()
317        with open(fpath, 'wb') as f:
318            f.write(fbytes)
319    with InTemporaryDirectory():
320        nverts = 10
321        nlabels = 3
322        names = [f'Label {l}' for l in range(nlabels)]
323        labels = np.concatenate((
324            np.arange(nlabels), np.random.randint(0, nlabels, nverts - nlabels)))
325        np.random.shuffle(labels)
326        rgba = np.random.randint(0, 255, (nlabels, 4))
327        # write an old .annot file
328        gen_old_annot_file('blah.annot', nverts, labels, rgba, names)
329        # read it back
330        rlabels, rrgba, rnames = read_annot('blah.annot')
331        rnames = [n.decode('ascii') for n in rnames]
332        assert np.all(np.isclose(labels, rlabels))
333        assert np.all(np.isclose(rgba, rrgba[:, :4]))
334        assert names == rnames
335
336
337@freesurfer_test
338def test_label():
339    """Test IO of .label"""
340    label_path = pjoin(data_path, "label", "lh.cortex.label")
341    label = read_label(label_path)
342    # XXX : test more
343    assert label.min() >= 0
344    assert label.max() <= 163841
345    assert label.shape[0] <= 163842
346
347    labels, scalars = read_label(label_path, True)
348    assert np.all(labels == label)
349    assert len(labels) == len(scalars)
350
351
352def test_write_annot_maxstruct():
353    """Test writing ANNOT files with repeated labels"""
354    with InTemporaryDirectory():
355        nlabels = 3
356        names = [f'label {l}' for l in range(1, nlabels + 1)]
357        # max label < n_labels
358        labels = np.array([1, 1, 1], dtype=np.int32)
359        rgba = np.array(np.random.randint(0, 255, (nlabels, 4)), dtype=np.int32)
360        annot_path = 'c.annot'
361
362        write_annot(annot_path, labels, rgba, names)
363        # Validate the file can be read
364        rt_labels, rt_ctab, rt_names = read_annot(annot_path)
365        # Check round-trip
366        assert np.array_equal(labels, rt_labels)
367        assert np.array_equal(rgba, rt_ctab[:, :4])
368        assert names == [n.decode('ascii') for n in rt_names]
369