1import logging
2import os
3import numpy.testing as npt
4from dipy.data import get_fnames
5from dipy.io.image import load_nifti
6from dipy.testing import assert_true
7from dipy.data.fetcher import dipy_home
8from dipy.workflows.io import IoInfoFlow, FetchFlow, SplitFlow
9from nibabel.tmpdirs import TemporaryDirectory
10from os.path import join as pjoin
11from tempfile import mkstemp
12fname_log = mkstemp()[1]
13
14logging.basicConfig(level=logging.INFO,
15                    format='%(levelname)s %(message)s',
16                    filename=fname_log,
17                    filemode='w')
18
19
20def test_io_info():
21    fimg, fbvals, fbvecs = get_fnames('small_101D')
22    io_info_flow = IoInfoFlow()
23    io_info_flow.run([fimg, fbvals, fbvecs])
24
25    fimg, fbvals, fvecs = get_fnames('small_25')
26    io_info_flow = IoInfoFlow()
27    io_info_flow.run([fimg, fbvals, fvecs])
28
29    io_info_flow = IoInfoFlow()
30    io_info_flow.run([fimg, fbvals, fvecs], b0_threshold=20, bvecs_tol=0.001)
31
32    file = open(fname_log, 'r')
33    lines = file.readlines()
34    try:
35        npt.assert_equal(lines[-3], 'INFO Total number of unit bvectors 25\n')
36    except IndexError:  # logging maybe disabled in IDE setting
37        pass
38    file.close()
39
40
41def test_io_fetch():
42    fetch_flow = FetchFlow()
43    with TemporaryDirectory() as out_dir:
44
45        fetch_flow.run(['bundle_fa_hcp'])
46        npt.assert_equal(os.path.isdir(os.path.join(dipy_home,
47                                                    'bundle_fa_hcp')),
48                         True)
49
50        fetch_flow.run(['bundle_fa_hcp'], out_dir=out_dir)
51        npt.assert_equal(os.path.isdir(os.path.join(out_dir,
52                                                    'bundle_fa_hcp')),
53                         True)
54
55
56def test_io_fetch_fetcher_datanames():
57    available_data = FetchFlow.get_fetcher_datanames()
58
59    dataset_names = ['bundle_atlas_hcp842', 'bundle_fa_hcp',
60                     'bundles_2_subjects', 'cenir_multib', 'cfin_multib',
61                     'file_formats', 'fury_surface',
62                     'gold_standard_io', 'isbi2013_2shell',
63                     'ivim', 'mni_template', 'qtdMRI_test_retest_2subjects',
64                     'scil_b0', 'sherbrooke_3shell', 'stanford_hardi',
65                     'stanford_labels', 'stanford_pve_maps', 'stanford_t1',
66                     'syn_data', 'taiwan_ntu_dsi', 'target_tractogram_hcp',
67                     'tissue_data']
68
69    num_expected_fetch_methods = len(dataset_names)
70    npt.assert_equal(len(available_data), num_expected_fetch_methods)
71    npt.assert_equal(all(dataset_name in available_data.keys()
72                         for dataset_name in dataset_names), True)
73
74
75def test_split_flow():
76    with TemporaryDirectory() as out_dir:
77        split_flow = SplitFlow()
78        data_path, _, _ = get_fnames()
79        volume, affine = load_nifti(data_path)
80        split_flow.run(data_path, out_dir=out_dir)
81        assert_true(os.path.isfile(
82         split_flow.last_generated_outputs['out_split']))
83        split_flow._force_overwrite = True
84        split_flow.run(data_path, vol_idx=0, out_dir=out_dir)
85        split_path = split_flow.last_generated_outputs['out_split']
86        assert_true(os.path.isfile(split_path))
87        split_data, split_affine = load_nifti(split_path)
88        npt.assert_equal(split_data.shape, volume[..., 0].shape)
89        npt.assert_array_almost_equal(split_affine, affine)
90
91
92if __name__ == '__main__':
93    test_io_fetch()
94    test_io_fetch_fetcher_datanames()
95    test_io_info()
96    test_split_flow()
97