1import os
2from os.path import join as pjoin, isdir
3import getpass
4import time
5import struct
6import hashlib
7import warnings
9from ...tmpdirs import InTemporaryDirectory
11import unittest
12import pytest
13import numpy as np
14from numpy.testing import assert_allclose, assert_array_equal
16from .. import (read_geometry, read_morph_data, read_annot, read_label,
17                write_geometry, write_morph_data, write_annot)
18from ..io import _pack_rgb
20from ...tests.nibabel_data import get_nibabel_data, needs_nibabel_data
21from ...fileslice import strided_scalar
22from ...testing import clear_and_catch_warnings
24DATA_SDIR = 'fsaverage'
26have_freesurfer = False
27if 'SUBJECTS_DIR' in os.environ:
28    # May have Freesurfer installed with data
29    data_path = pjoin(os.environ["SUBJECTS_DIR"], DATA_SDIR)
30    have_freesurfer = isdir(data_path)
32    # May have nibabel test data submodule checked out
33    nib_data = get_nibabel_data()
34    if nib_data != '':
35        data_path = pjoin(nib_data, 'nitest-freesurfer', DATA_SDIR)
36        have_freesurfer = isdir(data_path)
38freesurfer_test = unittest.skipUnless(have_freesurfer,
39                                      f'cannot find freesurfer {DATA_SDIR} directory')
41def _hash_file_content(fname):
42    hasher = hashlib.md5()
43    with open(fname, 'rb') as afile:
44        buf = afile.read()
45        hasher.update(buf)
46    return hasher.hexdigest()
50def test_geometry():
51    """Test IO of .surf"""
52    surf_path = pjoin(data_path, "surf", "lh.inflated")
53    coords, faces = read_geometry(surf_path)
54    assert 0 == faces.min()
55    assert coords.shape[0] == faces.max() + 1
57    surf_path = pjoin(data_path, "surf", "lh.sphere")
58    coords, faces, volume_info, create_stamp = read_geometry(
59        surf_path, read_metadata=True, read_stamp=True)
61    assert 0 == faces.min()
62    assert coords.shape[0] == faces.max() + 1
63    assert 9 == len(volume_info)
64    assert np.array_equal([2, 0, 20], volume_info['head'])
65    assert create_stamp == 'created by greve on Thu Jun  8 19:17:51 2006'
67    # Test equivalence of freesurfer- and nibabel-generated triangular files
68    # with respect to read_geometry()
69    with InTemporaryDirectory():
70        surf_path = 'test'
71        create_stamp = f"created by {getpass.getuser()} on {time.ctime()}"
72        volume_info['cras'] = [1., 2., 3.]
73        write_geometry(surf_path, coords, faces, create_stamp, volume_info)
75        coords2, faces2, volume_info2 = \
76            read_geometry(surf_path, read_metadata=True)
78        for key in ('xras', 'yras', 'zras', 'cras'):
79            assert_allclose(volume_info2[key], volume_info[key],
80                            rtol=1e-7, atol=1e-30)
82        assert np.array_equal(volume_info2['cras'], volume_info['cras'])
83        with open(surf_path, 'rb') as fobj:
84            np.fromfile(fobj, ">u1", 3)
85            read_create_stamp = fobj.readline().decode().rstrip('\n')
87        # now write an incomplete file
88        write_geometry(surf_path, coords, faces)
89        with pytest.warns(UserWarning) as w:
90            read_geometry(surf_path, read_metadata=True)
91        assert any('volume information contained' in str(ww.message) for ww in w)
92        assert any('extension code' in str(ww.message) for ww in w)
94        volume_info['head'] = [1, 2]
95        with pytest.warns(UserWarning, match="Unknown extension"):
96            write_geometry(surf_path, coords, faces, create_stamp, volume_info)
98        volume_info['a'] = 0
99        with pytest.raises(ValueError):
100            write_geometry(surf_path, coords, faces, create_stamp, volume_info)
102    assert create_stamp == read_create_stamp
104    assert np.array_equal(coords, coords2)
105    assert np.array_equal(faces, faces2)
107    # Validate byte ordering
108    coords_swapped = coords.byteswap().newbyteorder()
109    faces_swapped = faces.byteswap().newbyteorder()
110    assert np.array_equal(coords_swapped, coords)
111    assert np.array_equal(faces_swapped, faces)
116def test_quad_geometry():
117    """Test IO of freesurfer quad files."""
118    new_quad = pjoin(get_nibabel_data(), 'nitest-freesurfer', 'subjects',
119                     'bert', 'surf', 'lh.inflated.nofix')
120    coords, faces = read_geometry(new_quad)
121    assert 0 == faces.min()
122    assert coords.shape[0] == (faces.max() + 1)
123    with InTemporaryDirectory():
124        new_path = 'test'
125        write_geometry(new_path, coords, faces)
126        coords2, faces2 = read_geometry(new_path)
127        assert np.array_equal(coords,coords2)
128        assert np.array_equal(faces, faces2)
132def test_morph_data():
133    """Test IO of morphometry data file (eg. curvature)."""
134    curv_path = pjoin(data_path, "surf", "lh.curv")
135    curv = read_morph_data(curv_path)
136    assert -1.0 < curv.min() < 0
137    assert 0 < curv.max() < 1.0
138    with InTemporaryDirectory():
139        new_path = 'test'
140        write_morph_data(new_path, curv)
141        curv2 = read_morph_data(new_path)
142        assert np.array_equal(curv2, curv)
145def test_write_morph_data():
146    """Test write_morph_data edge cases"""
147    values = np.arange(20, dtype='>f4')
148    okay_shapes = [(20,), (20, 1), (20, 1, 1), (1, 20)]
149    bad_shapes = [(10, 2), (1, 1, 20, 1, 1)]
150    big_num = np.iinfo('i4').max + 1
151    with InTemporaryDirectory():
152        for shape in okay_shapes:
153            write_morph_data('test.curv', values.reshape(shape))
154            # Check ordering is preserved, regardless of shape
155            assert np.array_equal(read_morph_data('test.curv'), values)
157        with pytest.raises(ValueError):
158            write_morph_data('test.curv', np.zeros(shape), big_num)
159        # Windows 32-bit overflows Python int
160        if np.dtype(int) != np.dtype(np.int32):
161            with pytest.raises(ValueError):
162                write_morph_data('test.curv',  strided_scalar((big_num,)))
163        for shape in bad_shapes:
164            with pytest.raises(ValueError):
165                write_morph_data('test.curv', values.reshape(shape))
168def test_annot():
169    """Test IO of .annot against freesurfer example data."""
170    annots = ['aparc', 'aparc.a2005s']
171    for a in annots:
172        annot_path = pjoin(data_path, "label", f"lh.{a}.annot")
173        hash_ = _hash_file_content(annot_path)
175        labels, ctab, names = read_annot(annot_path)
176        assert labels.shape == (163842, )
177        assert ctab.shape == (len(names), 5)
179        labels_orig = None
180        if a == 'aparc':
181            labels_orig, _, _ = read_annot(annot_path, orig_ids=True)
182            np.testing.assert_array_equal(labels == -1, labels_orig == 0)
183            # Handle different version of fsaverage
184            if hash_ == 'bf0b488994657435cdddac5f107d21e8':
185                assert np.sum(labels_orig == 0) == 13887
186            elif hash_ == 'd4f5b7cbc2ed363ac6fcf89e19353504':
187                assert np.sum(labels_orig == 1639705) == 13327
188            else:
189                raise RuntimeError("Unknown freesurfer file. Please report "
190                                   "the problem to the maintainer of nibabel.")
192        # Test equivalence of freesurfer- and nibabel-generated annot files
193        # with respect to read_annot()
194        with InTemporaryDirectory():
195            annot_path = 'test'
196            write_annot(annot_path, labels, ctab, names)
198            labels2, ctab2, names2 = read_annot(annot_path)
199            if labels_orig is not None:
200                labels_orig_2, _, _ = read_annot(annot_path, orig_ids=True)
202        assert np.array_equal(labels, labels2)
203        if labels_orig is not None:
204            assert np.array_equal(labels_orig, labels_orig_2)
205        assert np.array_equal(ctab, ctab2)
206        assert names == names2
209def test_read_write_annot():
210    """Test generating .annot file and reading it back."""
211    # This annot file will store a LUT for a mesh made of 10 vertices, with
212    # 3 colours in the LUT.
213    nvertices = 10
214    nlabels = 3
215    names = [f'label {l}' for l in range(1, nlabels + 1)]
216    # randomly generate a label for each vertex, making sure
217    # that at least one of each label value is present. Label
218    # values are in the range (0, nlabels-1) - they are used
219    # as indices into the lookup table (generated below).
220    labels = list(range(nlabels)) + \
221             list(np.random.randint(0, nlabels, nvertices - nlabels))
222    labels = np.array(labels, dtype=np.int32)
223    np.random.shuffle(labels)
224    # Generate some random colours for the LUT
225    rgbal = np.zeros((nlabels, 5), dtype=np.int32)
226    rgbal[:, :4] = np.random.randint(0, 255, (nlabels, 4))
227    # But make sure we have at least one large alpha, to make sure that when
228    # it is packed into a signed 32 bit int, it results in a negative value
229    # for the annotation value.
230    rgbal[0, 3] = 255
231    # Generate the annotation values for each LUT entry
232    rgbal[:, 4] = (rgbal[:, 0] +
233                   rgbal[:, 1] * (2 ** 8) +
234                   rgbal[:, 2] * (2 ** 16))
235    annot_path = 'c.annot'
236    with InTemporaryDirectory():
237        write_annot(annot_path, labels, rgbal, names, fill_ctab=False)
238        labels2, rgbal2, names2 = read_annot(annot_path)
239        names2 = [n.decode('ascii') for n in names2]
240        assert np.all(np.isclose(rgbal2, rgbal))
241        assert np.all(np.isclose(labels2, labels))
242        assert names2 == names
245def test_write_annot_fill_ctab():
246    """Test the `fill_ctab` parameter to :func:`.write_annot`. """
247    nvertices = 10
248    nlabels = 3
249    names = [f'label {l}' for l in range(1, nlabels + 1)]
250    labels = list(range(nlabels)) + \
251             list(np.random.randint(0, nlabels, nvertices - nlabels))
252    labels = np.array(labels, dtype=np.int32)
253    np.random.shuffle(labels)
254    rgba = np.array(np.random.randint(0, 255, (nlabels, 4)), dtype=np.int32)
255    annot_path = 'c.annot'
256    with InTemporaryDirectory():
257        write_annot(annot_path, labels, rgba, names, fill_ctab=True)
258        labels2, rgbal2, names2 = read_annot(annot_path)
259        names2 = [n.decode('ascii') for n in names2]
260        assert np.all(np.isclose(rgbal2[:, :4], rgba))
261        assert np.all(np.isclose(labels2, labels))
262        assert names2 == names
263        # make sure a warning is emitted if fill_ctab is False, and the
264        # annotation values are wrong. Use orig_ids=True so we get those bad
265        # values back.
266        badannot = (10 * np.arange(nlabels, dtype=np.int32)).reshape(-1, 1)
267        rgbal = np.hstack((rgba, badannot))
268        with pytest.warns(UserWarning,
269                          match=f'Annotation values in {annot_path} will be incorrect'):
270            write_annot(annot_path, labels, rgbal, names, fill_ctab=False)
271        labels2, rgbal2, names2 = read_annot(annot_path, orig_ids=True)
272        names2 = [n.decode('ascii') for n in names2]
273        assert np.all(np.isclose(rgbal2[:, :4], rgba))
274        assert np.all(np.isclose(labels2, badannot[labels].squeeze()))
275        assert names2 == names
276        # make sure a warning is *not* emitted if fill_ctab is False, but the
277        # annotation values are correct.
278        rgbal = np.hstack((rgba, np.zeros((nlabels, 1), dtype=np.int32)))
279        rgbal[:, 4] = (rgbal[:, 0] +
280                       rgbal[:, 1] * (2 ** 8) +
281                       rgbal[:, 2] * (2 ** 16))
282        with clear_and_catch_warnings() as w:
283            write_annot(annot_path, labels, rgbal, names, fill_ctab=False)
284        assert all(f'Annotation values in {annot_path} will be incorrect' != str(ww.message)
285                   for ww in w)
286        labels2, rgbal2, names2 = read_annot(annot_path)
287        names2 = [n.decode('ascii') for n in names2]
288        assert np.all(np.isclose(rgbal2[:, :4], rgba))
289        assert np.all(np.isclose(labels2, labels))
290        assert names2 == names
293def test_read_annot_old_format():
294    """Test reading an old-style .annot file."""
295    def gen_old_annot_file(fpath, nverts, labels, rgba, names):
296        dt = '>i'
297        vdata = np.zeros((nverts, 2), dtype=dt)
298        vdata[:, 0] = np.arange(nverts)
299        vdata[:, [1]] = _pack_rgb(rgba[labels, :3])
300        fbytes = b''
301        # number of vertices
302        fbytes += struct.pack(dt, nverts)
303        # vertices + annotation values
304        fbytes += vdata.astype(dt).tobytes()
305        # is there a colour table?
306        fbytes += struct.pack(dt, 1)
307        # number of entries in colour table
308        fbytes += struct.pack(dt, rgba.shape[0])
309        # length of orig_tab string
310        fbytes += struct.pack(dt, 5)
311        fbytes += b'abcd\x00'
312        for i in range(rgba.shape[0]):
313            # length of entry name (+1 for terminating byte)
314            fbytes += struct.pack(dt, len(names[i]) + 1)
315            fbytes += names[i].encode('ascii') + b'\x00'
316            fbytes += rgba[i, :].astype(dt).tobytes()
317        with open(fpath, 'wb') as f:
318            f.write(fbytes)
319    with InTemporaryDirectory():
320        nverts = 10
321        nlabels = 3
322        names = [f'Label {l}' for l in range(nlabels)]
323        labels = np.concatenate((
324            np.arange(nlabels), np.random.randint(0, nlabels, nverts - nlabels)))
325        np.random.shuffle(labels)
326        rgba = np.random.randint(0, 255, (nlabels, 4))
327        # write an old .annot file
328        gen_old_annot_file('blah.annot', nverts, labels, rgba, names)
329        # read it back
330        rlabels, rrgba, rnames = read_annot('blah.annot')
331        rnames = [n.decode('ascii') for n in rnames]
332        assert np.all(np.isclose(labels, rlabels))
333        assert np.all(np.isclose(rgba, rrgba[:, :4]))
334        assert names == rnames
338def test_label():
339    """Test IO of .label"""
340    label_path = pjoin(data_path, "label", "lh.cortex.label")
341    label = read_label(label_path)
342    # XXX : test more
343    assert label.min() >= 0
344    assert label.max() <= 163841
345    assert label.shape[0] <= 163842
347    labels, scalars = read_label(label_path, True)
348    assert np.all(labels == label)
349    assert len(labels) == len(scalars)
352def test_write_annot_maxstruct():
353    """Test writing ANNOT files with repeated labels"""
354    with InTemporaryDirectory():
355        nlabels = 3
356        names = [f'label {l}' for l in range(1, nlabels + 1)]
357        # max label < n_labels
358        labels = np.array([1, 1, 1], dtype=np.int32)
359        rgba = np.array(np.random.randint(0, 255, (nlabels, 4)), dtype=np.int32)
360        annot_path = 'c.annot'
362        write_annot(annot_path, labels, rgba, names)
363        # Validate the file can be read
364        rt_labels, rt_ctab, rt_names = read_annot(annot_path)
365        # Check round-trip
366        assert np.array_equal(labels, rt_labels)
367        assert np.array_equal(rgba, rt_ctab[:, :4])
368        assert names == [n.decode('ascii') for n in rt_names]