1"""
2Test the datasets module
3"""
4# Author: Alexandre Abraham
5# License: simplified BSD
6
7import os
8import shutil
9import itertools
10from pathlib import Path
11import re
12
13import numpy as np
14import pandas as pd
15
16import nibabel
17import pytest
18
19from numpy.testing import assert_array_equal
20
21
22from nilearn.datasets import utils, atlas
23from nilearn.image import get_data
24from nilearn._utils.testing import serialize_niimg
25from nilearn.datasets._testing import dict_to_archive
26from nilearn._utils import data_gen
27
28
29def test_get_dataset_dir(tmp_path):
30    # testing folder creation under different environments, enforcing
31    # a custom clean install
32    os.environ.pop('NILEARN_DATA', None)
33    os.environ.pop('NILEARN_SHARED_DATA', None)
34
35    expected_base_dir = os.path.expanduser('~/nilearn_data')
36    data_dir = utils._get_dataset_dir('test', verbose=0)
37    assert data_dir == os.path.join(expected_base_dir, 'test')
38    assert os.path.exists(data_dir)
39    shutil.rmtree(data_dir)
40
41    expected_base_dir = str(tmp_path / 'test_nilearn_data')
42    os.environ['NILEARN_DATA'] = expected_base_dir
43    data_dir = utils._get_dataset_dir('test', verbose=0)
44    assert data_dir == os.path.join(expected_base_dir, 'test')
45    assert os.path.exists(data_dir)
46    shutil.rmtree(data_dir)
47
48    expected_base_dir = str(tmp_path / 'nilearn_shared_data')
49    os.environ['NILEARN_SHARED_DATA'] = expected_base_dir
50    data_dir = utils._get_dataset_dir('test', verbose=0)
51    assert data_dir == os.path.join(expected_base_dir, 'test')
52    assert os.path.exists(data_dir)
53    shutil.rmtree(data_dir)
54
55    expected_base_dir = str(tmp_path / 'env_data')
56    expected_dataset_dir = os.path.join(expected_base_dir, 'test')
57    data_dir = utils._get_dataset_dir(
58        'test', default_paths=[expected_dataset_dir], verbose=0)
59    assert data_dir == os.path.join(expected_base_dir, 'test')
60    assert os.path.exists(data_dir)
61    shutil.rmtree(data_dir)
62
63    no_write = str(tmp_path / 'no_write')
64    os.makedirs(no_write)
65    os.chmod(no_write, 0o400)
66
67    expected_base_dir = str(tmp_path / 'nilearn_shared_data')
68    os.environ['NILEARN_SHARED_DATA'] = expected_base_dir
69    data_dir = utils._get_dataset_dir('test',
70                                      default_paths=[no_write],
71                                      verbose=0)
72    # Non writeable dir is returned because dataset may be in there.
73    assert data_dir == no_write
74    assert os.path.exists(data_dir)
75    # Set back write permissions in order to be able to remove the file
76    os.chmod(no_write, 0o600)
77    shutil.rmtree(data_dir)
78
79    # Verify exception for a path which exists and is a file
80    test_file = str(tmp_path / 'some_file')
81    with open(test_file, 'w') as out:
82        out.write('abcfeg')
83    with pytest.raises(OSError, match=('Nilearn tried to store the dataset '
84                                       'in the following directories, but')
85                       ):
86        utils._get_dataset_dir('test', test_file, verbose=0)
87
88
89def test_downloader(tmp_path, request_mocker):
90
91    # Sandboxing test
92    # ===============
93
94    # When nilearn downloads a file, everything is first downloaded in a
95    # temporary directory (sandbox) and moved to the "real" data directory if
96    # all files are present. In case of error, the sandbox is deleted.
97
98    # To test this feature, we do as follow:
99    # - create the data dir with a file that has a specific content
100    # - try to download the dataset but make it fail
101    #   on purpose (by requesting a file that is not in the archive)
102    # - check that the previously created file is untouched :
103    #   - if sandboxing is faulty, the file would be replaced
104    #     by the file of the archive
105    #   - if sandboxing works, the file must be untouched.
106
107    local_archive = Path(
108        __file__).parent / "data" / "craddock_2011_parcellations.tar.gz"
109    url = "http://example.com/craddock_atlas"
110    request_mocker.url_mapping["*craddock*"] = local_archive
111    datasetdir = tmp_path / 'craddock_2012'
112    datasetdir.mkdir()
113
114    # Create a dummy file. If sandboxing is successful, it won't be overwritten
115    dummy_file = datasetdir / "random_all.nii.gz"
116    with dummy_file.open("w") as f:
117        f.write('stuff')
118
119    opts = {'uncompress': True}
120    files = [
121        ('random_all.nii.gz', url, opts),
122        # The following file does not exists. It will cause an abortion of
123        # the fetching procedure
124        ('bald.nii.gz', url, opts)
125    ]
126
127    pytest.raises(IOError, utils._fetch_files,
128                  str(tmp_path / 'craddock_2012'), files,
129                  verbose=0)
130    with dummy_file.open("r") as f:
131        stuff = f.read(5)
132    assert stuff == 'stuff'
133
134    # Downloading test
135    # ================
136
137    # Now, we use the regular downloading feature. This will override the dummy
138    # file created before.
139
140    atlas.fetch_atlas_craddock_2012(data_dir=tmp_path)
141    with dummy_file.open() as f:
142        stuff = f.read()
143    assert stuff == ''
144
145
146def test_fetch_atlas_source(tmp_path, request_mocker):
147
148    # specify non-existing atlas source
149    with pytest.raises(ValueError, match='Atlas source'):
150        atlas._get_atlas_data_and_labels('new_source', 'not_inside')
151
152
153def _write_to_xml(ho_dir, filename, is_symm):
154    with open(os.path.join(ho_dir, filename + '.xml'), 'w') as dm:
155        if(not is_symm):
156            dm.write("<?xml version='1.0' encoding='us-ascii'?>\n"
157                     "<data>\n"
158                     '<label index="0" x="48" y="94" z="35">R1</label>\n'
159                     '<label index="1" x="25" y="70" z="32">R2</label>\n'
160                     '<label index="2" x="33" y="73" z="63">R3</label>\n'
161                     "</data>")
162        else:
163            dm.write("<?xml version='1.0' encoding='us-ascii'?>\n"
164                     "<data>\n"
165                     '<label index="0" x="63" y="86" z="49">Left R1</label>\n'
166                     '<label index="1" x="21" y="86" z="33">Right R1</label>\n'
167                     '<label index="2" x="64" y="69" z="32">Left R2</label>\n'
168                     '<label index="3" x="26" y="70" z="32">Right R2</label>\n'
169                     '<label index="4" x="47" y="75" z="66">Left R3</label>\n'
170                     '<label index="5" x="43" y="80" z="61">Right R3</label>\n'
171                     "</data>")
172        dm.close()
173
174
175def _test_result_xml(res, is_symm):
176    if not is_symm:
177        assert isinstance(res.maps, nibabel.Nifti1Image)
178        assert isinstance(res.labels, list)
179        assert len(res.labels) == 4
180        assert res.labels[0] == "Background"
181        assert res.labels[1] == "R1"
182        assert res.labels[2] == "R2"
183        assert res.labels[3] == "R3"
184    else:
185        assert isinstance(res.maps, nibabel.Nifti1Image)
186        assert isinstance(res.labels, list)
187        assert len(res.labels) == 7
188        assert res.labels[0] == "Background"
189        assert res.labels[1] == "Left R1"
190        assert res.labels[2] == "Right R1"
191        assert res.labels[3] == "Left R2"
192        assert res.labels[4] == "Right R2"
193        assert res.labels[5] == "Left R3"
194        assert res.labels[6] == "Right R3"
195
196
197@pytest.fixture
198def fsl_fetcher(name):
199    if name == "Juelich":
200        return atlas.fetch_atlas_juelich
201    return atlas.fetch_atlas_harvard_oxford
202
203
204@pytest.mark.parametrize('name,prob',
205                         [("HarvardOxford", "cortl-prob-1mm"),
206                          ("Juelich", "prob-1mm")])
207def test_fetch_atlas_fsl_errors(name, prob, fsl_fetcher,
208                                tmp_path, request_mocker):
209    # specify non-existing atlas item
210    with pytest.raises(ValueError, match='Invalid atlas name'):
211        fsl_fetcher('not_inside')
212    # Choose a probabilistic atlas with symmetric split
213    with pytest.raises(ValueError, match='Region splitting'):
214        fsl_fetcher(prob, data_dir=str(tmp_path), symmetric_split=True)
215
216
217@pytest.fixture
218def atlas_data():
219    # Create false atlas
220    atlas_data = np.zeros((10, 10, 10), dtype=int)
221    # Create an interhemispheric map
222    atlas_data[:, :2, :] = 1
223    # Create a left map
224    atlas_data[5:, 7:9, :] = 3
225    atlas_data[5:, 3, :] = 2
226    # Create a right map, with one voxel on the left side
227    atlas_data[:5:, 3:5, :] = 2
228    atlas_data[:5, 8, :] = 3
229    atlas_data[4, 7, 0] = 3
230    return atlas_data
231
232
233@pytest.mark.parametrize('name,label_fname,fname,is_symm,split',
234                         [("HarvardOxford", "-Cortical",
235                           "cort-prob-1mm", False, False),
236                          ("HarvardOxford", "-Subcortical",
237                           "sub-maxprob-thr0-1mm", False, True),
238                          ("HarvardOxford", "-Cortical-Lateralized",
239                           "cortl-maxprob-thr0-1mm", True, True),
240                          ("Juelich", "", "prob-1mm", False, False),
241                          ("Juelich", "", "maxprob-thr0-1mm", False, False),
242                          ("Juelich", "", "maxprob-thr0-1mm", False, True)])
243def test_fetch_atlas_fsl(name, label_fname, fname, is_symm, split,
244                         atlas_data, fsl_fetcher, tmp_path, request_mocker):
245    ho_dir = str(tmp_path / 'fsl' / 'data' / 'atlases')
246    os.makedirs(ho_dir)
247    nifti_dir = os.path.join(ho_dir, name)
248    os.makedirs(nifti_dir)
249    _write_to_xml(ho_dir, f"{name}{label_fname}", is_symm=is_symm)
250    target_atlas_fname = f'{name}-{fname}.nii.gz'
251    target_atlas_nii = os.path.join(nifti_dir, target_atlas_fname)
252    nibabel.Nifti1Image(atlas_data, np.eye(4) * 3).to_filename(
253        target_atlas_nii)
254    ho_wo = fsl_fetcher(fname, data_dir=str(tmp_path), symmetric_split=split)
255    _test_result_xml(ho_wo, is_symm=is_symm or split)
256
257
258def test_fetch_atlas_craddock_2012(tmp_path, request_mocker):
259    local_archive = Path(
260        __file__).parent / "data" / "craddock_2011_parcellations.tar.gz"
261    request_mocker.url_mapping["*craddock*"] = local_archive
262    bunch = atlas.fetch_atlas_craddock_2012(data_dir=tmp_path,
263                                            verbose=0)
264
265    keys = ("scorr_mean", "tcorr_mean",
266            "scorr_2level", "tcorr_2level",
267            "random")
268    filenames = [
269        "scorr05_mean_all.nii.gz",
270        "tcorr05_mean_all.nii.gz",
271        "scorr05_2level_all.nii.gz",
272        "tcorr05_2level_all.nii.gz",
273        "random_all.nii.gz",
274    ]
275    assert request_mocker.url_count == 1
276    for key, fn in zip(keys, filenames):
277        assert bunch[key] == str(tmp_path / 'craddock_2012' / fn)
278    assert bunch.description != ''
279
280
281def test_fetch_atlas_smith_2009(tmp_path, request_mocker):
282    bunch = atlas.fetch_atlas_smith_2009(data_dir=tmp_path, verbose=0)
283
284    keys = ("rsn20", "rsn10", "rsn70",
285            "bm20", "bm10", "bm70")
286    filenames = [
287        "rsn20.nii.gz",
288        "PNAS_Smith09_rsn10.nii.gz",
289        "rsn70.nii.gz",
290        "bm20.nii.gz",
291        "PNAS_Smith09_bm10.nii.gz",
292        "bm70.nii.gz",
293    ]
294
295    assert request_mocker.url_count == 6
296    for key, fn in zip(keys, filenames):
297        assert bunch[key] == str(tmp_path / 'smith_2009' / fn)
298    assert bunch.description != ''
299
300
301def test_fetch_coords_power_2011(request_mocker):
302    bunch = atlas.fetch_coords_power_2011()
303    assert len(bunch.rois) == 264
304    assert bunch.description != ''
305
306
307def test_fetch_coords_seitzman_2018(request_mocker):
308    bunch = atlas.fetch_coords_seitzman_2018()
309    assert len(bunch.rois) == 300
310    assert len(bunch.radius) == 300
311    assert len(bunch.networks) == 300
312    assert len(bunch.regions) == 300
313    assert len(np.unique(bunch.networks)) == 14
314    assert len(np.unique(bunch.regions)) == 8
315    np.testing.assert_array_equal(bunch.networks, np.sort(bunch.networks))
316    assert bunch.description != ''
317
318    assert bunch.regions[0] == "cortexL"
319
320    bunch = atlas.fetch_coords_seitzman_2018(ordered_regions=False)
321    assert np.any(bunch.networks != np.sort(bunch.networks))
322
323
324def _destrieux_data():
325    data = {"destrieux2009.rst": "readme"}
326    for lat in ["_lateralized", ""]:
327        lat_data = {
328            "destrieux2009_rois_labels{}.csv".format(lat): "name,index",
329            "destrieux2009_rois{}.nii.gz".format(lat): "",
330        }
331        data.update(lat_data)
332    return dict_to_archive(data)
333
334
335def test_fetch_atlas_destrieux_2009(tmp_path, request_mocker):
336    request_mocker.url_mapping["*destrieux2009.tgz"] = _destrieux_data()
337    bunch = atlas.fetch_atlas_destrieux_2009(data_dir=tmp_path,
338                                             verbose=0)
339
340    assert request_mocker.url_count == 1
341    assert bunch['maps'] == str(tmp_path / 'destrieux_2009'
342                                / 'destrieux2009_rois_lateralized.nii.gz')
343
344    bunch = atlas.fetch_atlas_destrieux_2009(
345        lateralized=False, data_dir=tmp_path, verbose=0)
346
347    assert request_mocker.url_count == 1
348    assert bunch['maps'] == str(tmp_path / 'destrieux_2009'
349                                / 'destrieux2009_rois.nii.gz')
350
351
352def test_fetch_atlas_msdl(tmp_path, request_mocker):
353    labels = pd.DataFrame(
354        {"x": [1.5, 1.2], "y": [1.5, 1.3],
355         "z": [1.5, 1.4], "name": ["Aud", "DMN"], "net_name": ["Aud", "DMN"]})
356    root = Path("MSDL_rois")
357    archive = {root / "msdl_rois_labels.csv": labels.to_csv(index=False),
358               root / "msdl_rois.nii": "",
359               root / "README.txt": ""}
360    request_mocker.url_mapping["*MSDL_rois.zip"] = dict_to_archive(
361        archive, "zip")
362    dataset = atlas.fetch_atlas_msdl(data_dir=tmp_path, verbose=0)
363    assert isinstance(dataset.labels, list)
364    assert isinstance(dataset.region_coords, list)
365    assert isinstance(dataset.networks, list)
366    assert isinstance(dataset.maps, str)
367    assert request_mocker.url_count == 1
368    assert dataset.description != ''
369
370
371def test_fetch_atlas_yeo_2011(tmp_path, request_mocker):
372    dataset = atlas.fetch_atlas_yeo_2011(data_dir=tmp_path, verbose=0)
373    assert isinstance(dataset.anat, str)
374    assert isinstance(dataset.colors_17, str)
375    assert isinstance(dataset.colors_7, str)
376    assert isinstance(dataset.thick_17, str)
377    assert isinstance(dataset.thick_7, str)
378    assert isinstance(dataset.thin_17, str)
379    assert isinstance(dataset.thin_7, str)
380    assert request_mocker.url_count == 1
381    assert dataset.description != ''
382
383
384def test_fetch_atlas_difumo(tmp_path, request_mocker):
385    resolutions = [2, 3]  # Valid resolution values
386    dimensions = [64, 128, 256, 512, 1024]  # Valid dimension values
387    dimension_urls = ['pqu9r', 'wjvd5', '3vrct', '9b76y', '34792']
388    url_mapping = {k: v for k, v in zip(dimensions, dimension_urls)}
389    url_count = 1
390
391    for dim in dimensions:
392        url_count += 1
393        url = "*osf.io/{0}/*".format(url_mapping[dim])
394        labels = pd.DataFrame(
395            {"Component": [_ for _ in range(1, dim + 1)],
396             "Difumo_names": ["" for _ in range(dim)],
397             "Yeo_networks7": ["" for _ in range(dim)],
398             "Yeo_networks17": ["" for _ in range(dim)],
399             "GM": ["" for _ in range(dim)],
400             "WM": ["" for _ in range(dim)],
401             "CSF": ["" for _ in range(dim)]}
402        )
403        root = Path("{0}".format(dim))
404        archive = {root / "labels_{0}_dictionary.csv".format(dim): labels.to_csv(index=False),
405                   root / "2mm" / "maps.nii.gz": "",
406                   root / "3mm" / "maps.nii.gz": ""}
407        request_mocker.url_mapping[url] = dict_to_archive(archive, "zip")
408
409        for res in resolutions:
410            dataset = atlas.fetch_atlas_difumo(data_dir=tmp_path,
411                                               dimension=dim,
412                                               resolution_mm=res,
413                                               verbose=0)
414            assert len(dataset.keys()) == 3
415            assert len(dataset.labels) == dim
416            assert isinstance(dataset.maps, str)
417            assert request_mocker.url_count == url_count
418            assert dataset.description != ''
419
420    with pytest.raises(ValueError):
421        atlas.fetch_atlas_difumo(data_dir=tmp_path,
422                                 dimension=42, resolution_mm=3)
423        atlas.fetch_atlas_difumo(data_dir=tmp_path,
424                                 dimension=128, resolution_mm=3.14)
425
426
427def test_fetch_atlas_aal(tmp_path, request_mocker):
428    metadata = (b"<?xml version='1.0' encoding='us-ascii'?>"
429                b"<metadata></metadata>")
430    archive_root = Path("aal", "atlas")
431    aal_data = dict_to_archive(
432        {archive_root / "AAL.xml": metadata, archive_root / "AAL.nii": ""})
433
434    request_mocker.url_mapping["*AAL_files*"] = aal_data
435    dataset = atlas.fetch_atlas_aal(data_dir=tmp_path, verbose=0)
436    assert isinstance(dataset.maps, str)
437    assert isinstance(dataset.labels, list)
438    assert isinstance(dataset.indices, list)
439    assert request_mocker.url_count == 1
440
441    with pytest.raises(ValueError,
442                       match='The version of AAL requested "FLS33"'
443                       ):
444        atlas.fetch_atlas_aal(version="FLS33",
445                              data_dir=tmp_path,
446                              verbose=0)
447
448    assert dataset.description != ''
449
450
451def test_fetch_atlas_basc_multiscale_2015(tmp_path, request_mocker):
452    # default version='sym',
453    data_sym = atlas.fetch_atlas_basc_multiscale_2015(data_dir=tmp_path,
454                                                      verbose=0)
455    # version='asym'
456    data_asym = atlas.fetch_atlas_basc_multiscale_2015(version='asym',
457                                                       verbose=0,
458                                                       data_dir=tmp_path)
459
460    keys = ['scale007', 'scale012', 'scale020', 'scale036', 'scale064',
461            'scale122', 'scale197', 'scale325', 'scale444']
462
463    dataset_name = 'basc_multiscale_2015'
464    name_sym = 'template_cambridge_basc_multiscale_nii_sym'
465    basenames_sym = ['template_cambridge_basc_multiscale_sym_' +
466                     key + '.nii.gz' for key in keys]
467    for key, basename_sym in zip(keys, basenames_sym):
468        assert data_sym[key] == str(tmp_path / dataset_name / name_sym
469                                    / basename_sym)
470
471    name_asym = 'template_cambridge_basc_multiscale_nii_asym'
472    basenames_asym = ['template_cambridge_basc_multiscale_asym_' +
473                      key + '.nii.gz' for key in keys]
474    for key, basename_asym in zip(keys, basenames_asym):
475        assert data_asym[key] == str(tmp_path / dataset_name / name_asym
476                                     / basename_asym)
477
478    assert len(data_sym) == 10
479    with pytest.raises(
480            ValueError,
481            match='The version of Brain parcellations requested "aym"'):
482        atlas.fetch_atlas_basc_multiscale_2015(version="aym",
483                                               data_dir=tmp_path,
484                                               verbose=0)
485
486    assert request_mocker.url_count == 2
487    assert data_sym.description != ''
488    assert data_asym.description != ''
489
490
491def test_fetch_coords_dosenbach_2010(request_mocker):
492    bunch = atlas.fetch_coords_dosenbach_2010()
493    assert len(bunch.rois) == 160
494    assert len(bunch.labels) == 160
495    assert len(np.unique(bunch.networks)) == 6
496    assert bunch.description != ''
497    np.testing.assert_array_equal(bunch.networks, np.sort(bunch.networks))
498
499    bunch = atlas.fetch_coords_dosenbach_2010(ordered_regions=False)
500    assert np.any(bunch.networks != np.sort(bunch.networks))
501
502
503def test_fetch_atlas_allen_2011(tmp_path, request_mocker):
504    bunch = atlas.fetch_atlas_allen_2011(data_dir=tmp_path, verbose=0)
505    keys = ("maps",
506            "rsn28",
507            "comps")
508
509    filenames = ["ALL_HC_unthresholded_tmaps.nii.gz",
510                 "RSN_HC_unthresholded_tmaps.nii.gz",
511                 "rest_hcp_agg__component_ica_.nii.gz"]
512
513    assert request_mocker.url_count == 1
514    for key, fn in zip(keys, filenames):
515        assert bunch[key] == str(tmp_path / 'allen_rsn_2011'
516                                 / 'allen_rsn_2011' / fn)
517
518    assert bunch.description != ''
519
520
521def test_fetch_atlas_surf_destrieux(tmp_path, request_mocker, verbose=0):
522    data_dir = str(tmp_path / 'destrieux_surface')
523    os.mkdir(data_dir)
524    # Create mock annots
525    for hemi in ('left', 'right'):
526        nibabel.freesurfer.write_annot(
527                os.path.join(data_dir,
528                             '%s.aparc.a2009s.annot' % hemi),
529                np.arange(4), np.zeros((4, 5)), 5 * ['a'],)
530
531    bunch = atlas.fetch_atlas_surf_destrieux(data_dir=tmp_path, verbose=0)
532    # Our mock annots have 4 labels
533    assert len(bunch.labels) == 4
534    assert bunch.map_left.shape == (4, )
535    assert bunch.map_right.shape == (4, )
536    assert bunch.description != ''
537
538
539def _get_small_fake_talairach():
540    labels = ['*', 'b', 'a']
541    all_labels = itertools.product(*(labels,) * 5)
542    labels_txt = '\n'.join(map('.'.join, all_labels))
543    extensions = nibabel.nifti1.Nifti1Extensions([
544        nibabel.nifti1.Nifti1Extension(
545            'afni', labels_txt.encode('utf-8'))
546    ])
547    img = nibabel.Nifti1Image(
548        np.arange(243).reshape((3, 9, 9)),
549        np.eye(4), nibabel.Nifti1Header(extensions=extensions))
550    return serialize_niimg(img, gzipped=False)
551
552
553def test_fetch_atlas_talairach(tmp_path, request_mocker):
554    request_mocker.url_mapping["*talairach.nii"] = _get_small_fake_talairach()
555    level_values = np.ones((81, 3)) * [0, 1, 2]
556    talairach = atlas.fetch_atlas_talairach('hemisphere',
557                                            data_dir=tmp_path)
558    assert_array_equal(get_data(talairach.maps).ravel(),
559                       level_values.T.ravel())
560    assert_array_equal(talairach.labels, ['Background', 'b', 'a'])
561    talairach = atlas.fetch_atlas_talairach('ba', data_dir=tmp_path)
562    assert_array_equal(get_data(talairach.maps).ravel(),
563                       level_values.ravel())
564    pytest.raises(ValueError, atlas.fetch_atlas_talairach, 'bad_level')
565
566
567def test_fetch_atlas_pauli_2017(tmp_path, request_mocker):
568    labels = pd.DataFrame(
569        {"label": list(map("label_{}".format, range(16)))}).to_csv(
570            sep="\t", header=False)
571    det_atlas = data_gen.generate_labeled_regions((7, 6, 5), 16)
572    prob_atlas, _ = data_gen.generate_maps((7, 6, 5), 16)
573    request_mocker.url_mapping["*osf.io/6qrcb/*"] = labels
574    request_mocker.url_mapping["*osf.io/5mqfx/*"] = det_atlas
575    request_mocker.url_mapping["*osf.io/w8zq2/*"] = prob_atlas
576    data_dir = str(tmp_path / 'pauli_2017')
577
578    data = atlas.fetch_atlas_pauli_2017('det', data_dir)
579    assert len(data.labels) == 16
580
581    values = get_data(nibabel.load(data.maps))
582    assert len(np.unique(values)) == 17
583
584    data = atlas.fetch_atlas_pauli_2017('prob', data_dir)
585    assert nibabel.load(data.maps).shape[-1] == 16
586
587    with pytest.raises(NotImplementedError):
588        atlas.fetch_atlas_pauli_2017('junk for testing', data_dir)
589
590
591def _schaefer_labels(match, request):
592    info = match.groupdict()
593    label_names = ["{}Networks".format(info["network"])] * int(info["n_rois"])
594    labels = pd.DataFrame({"label": label_names})
595    return labels.to_csv(sep="\t", header=False).encode("utf-8")
596
597
598def _schaefer_img(match, request):
599    info = match.groupdict()
600    shape = (15, 14, 13)
601    affine = np.eye(4) * float(info["res"])
602    affine[3, 3] = 1.
603    img = data_gen.generate_labeled_regions(
604        shape, int(info["n_rois"]), affine=affine)
605    return serialize_niimg(img)
606
607
608def test_fetch_atlas_schaefer_2018(tmp_path, request_mocker):
609    labels_pattern = re.compile(
610        r".*2018_(?P<n_rois>\d+)Parcels_(?P<network>\d+)Networks_order.txt")
611    img_pattern = re.compile(
612        r".*_(?P<n_rois>\d+)Parcels_(?P<network>\d+)"
613        r"Networks_order_FSLMNI152_(?P<res>\d)mm.nii.gz")
614    request_mocker.url_mapping[labels_pattern] = _schaefer_labels
615    request_mocker.url_mapping[img_pattern] = _schaefer_img
616    valid_n_rois = list(range(100, 1100, 100))
617    valid_yeo_networks = [7, 17]
618    valid_resolution_mm = [1, 2]
619
620    pytest.raises(ValueError, atlas.fetch_atlas_schaefer_2018, n_rois=44)
621    pytest.raises(ValueError, atlas.fetch_atlas_schaefer_2018, yeo_networks=10)
622    pytest.raises(ValueError, atlas.fetch_atlas_schaefer_2018, resolution_mm=3)
623
624    for n_rois, yeo_networks, resolution_mm in \
625            itertools.product(valid_n_rois, valid_yeo_networks,
626                              valid_resolution_mm):
627        data = atlas.fetch_atlas_schaefer_2018(n_rois=n_rois,
628                                               yeo_networks=yeo_networks,
629                                               resolution_mm=resolution_mm,
630                                               data_dir=tmp_path,
631                                               verbose=0)
632        assert data.description != ''
633        assert isinstance(data.maps, str)
634        assert isinstance(data.labels, np.ndarray)
635        assert len(data.labels) == n_rois
636        assert data.labels[0].astype(str).startswith("{}Networks".
637                                                     format(yeo_networks))
638        img = nibabel.load(data.maps)
639        assert img.header.get_zooms()[0] == resolution_mm
640        assert np.array_equal(np.unique(img.dataobj),
641                              np.arange(n_rois + 1))
642