1import numpy.testing as npt
2from os.path import join
3import nibabel as nib
4import numpy as np
5from nibabel.tmpdirs import TemporaryDirectory
6from dipy.data import get_fnames
7from dipy.segment.mask import median_otsu
8from dipy.tracking.streamline import Streamlines
9from dipy.workflows.segment import MedianOtsuFlow
10from dipy.workflows.segment import RecoBundlesFlow, LabelsBundlesFlow
11from dipy.io.stateful_tractogram import Space, StatefulTractogram
12from dipy.io.streamline import load_tractogram, save_tractogram
13from dipy.io.image import load_nifti_data
14from os.path import join as pjoin
15from dipy.tracking.streamline import set_number_of_points
16from dipy.align.streamlinear import BundleMinDistanceMetric
17
18
19def test_median_otsu_flow():
20    with TemporaryDirectory() as out_dir:
21        data_path, _, _ = get_fnames('small_25')
22        volume = load_nifti_data(data_path)
23        save_masked = True
24        median_radius = 3
25        numpass = 3
26        autocrop = False
27        vol_idx = [0]
28        dilate = 0
29
30        mo_flow = MedianOtsuFlow()
31        mo_flow.run(data_path, out_dir=out_dir, save_masked=save_masked,
32                    median_radius=median_radius, numpass=numpass,
33                    autocrop=autocrop, vol_idx=vol_idx, dilate=dilate)
34
35        mask_name = mo_flow.last_generated_outputs['out_mask']
36        masked_name = mo_flow.last_generated_outputs['out_masked']
37
38        masked, mask = median_otsu(volume,
39                                   vol_idx=vol_idx,
40                                   median_radius=median_radius,
41                                   numpass=numpass,
42                                   autocrop=autocrop, dilate=dilate)
43
44        result_mask_data = load_nifti_data(join(out_dir, mask_name))
45        npt.assert_array_equal(result_mask_data.astype(np.uint8), mask)
46
47        result_masked = nib.load(join(out_dir, masked_name))
48        result_masked_data = np.asanyarray(result_masked.dataobj)
49
50        npt.assert_array_equal(np.round(result_masked_data), masked)
51
52
53def test_recobundles_flow():
54    with TemporaryDirectory() as out_dir:
55        data_path = get_fnames('fornix')
56
57        fornix = load_tractogram(data_path, 'same',
58                                 bbox_valid_check=False).streamlines
59
60        f = Streamlines(fornix)
61        f1 = f.copy()
62
63        f2 = f1[:15].copy()
64        f2._data += np.array([40, 0, 0])
65
66        f.extend(f2)
67
68        f2_path = pjoin(out_dir, "f2.trk")
69        sft = StatefulTractogram(f2, data_path, Space.RASMM)
70        save_tractogram(sft, f2_path, bbox_valid_check=False)
71
72        f1_path = pjoin(out_dir, "f1.trk")
73        sft = StatefulTractogram(f, data_path, Space.RASMM)
74        save_tractogram(sft, f1_path, bbox_valid_check=False)
75
76        rb_flow = RecoBundlesFlow(force=True)
77        rb_flow.run(f1_path, f2_path, greater_than=0, clust_thr=10,
78                    model_clust_thr=5., reduction_thr=10, out_dir=out_dir)
79
80        labels = rb_flow.last_generated_outputs['out_recognized_labels']
81        recog_trk = rb_flow.last_generated_outputs['out_recognized_transf']
82
83        rec_bundle = load_tractogram(recog_trk, 'same',
84                                     bbox_valid_check=False).streamlines
85        npt.assert_equal(len(rec_bundle) == len(f2), True)
86
87        label_flow = LabelsBundlesFlow(force=True)
88        label_flow.run(f1_path, labels)
89
90        recog_bundle = label_flow.last_generated_outputs['out_bundle']
91        rec_bundle_org = load_tractogram(recog_bundle, 'same',
92                                         bbox_valid_check=False).streamlines
93
94        BMD = BundleMinDistanceMetric()
95        nb_pts = 20
96        static = set_number_of_points(f2, nb_pts)
97        moving = set_number_of_points(rec_bundle_org, nb_pts)
98
99        BMD.setup(static, moving)
100        x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])  # affine
101        bmd_value = BMD.distance(x0.tolist())
102
103        npt.assert_equal(bmd_value < 1, True)
104
105
106if __name__ == '__main__':
107    npt.run_module_suite()
108