1import numpy.testing as npt 2from os.path import join 3import nibabel as nib 4import numpy as np 5from nibabel.tmpdirs import TemporaryDirectory 6from dipy.data import get_fnames 7from dipy.segment.mask import median_otsu 8from dipy.tracking.streamline import Streamlines 9from dipy.workflows.segment import MedianOtsuFlow 10from dipy.workflows.segment import RecoBundlesFlow, LabelsBundlesFlow 11from dipy.io.stateful_tractogram import Space, StatefulTractogram 12from dipy.io.streamline import load_tractogram, save_tractogram 13from dipy.io.image import load_nifti_data 14from os.path import join as pjoin 15from dipy.tracking.streamline import set_number_of_points 16from dipy.align.streamlinear import BundleMinDistanceMetric 17 18 19def test_median_otsu_flow(): 20 with TemporaryDirectory() as out_dir: 21 data_path, _, _ = get_fnames('small_25') 22 volume = load_nifti_data(data_path) 23 save_masked = True 24 median_radius = 3 25 numpass = 3 26 autocrop = False 27 vol_idx = [0] 28 dilate = 0 29 30 mo_flow = MedianOtsuFlow() 31 mo_flow.run(data_path, out_dir=out_dir, save_masked=save_masked, 32 median_radius=median_radius, numpass=numpass, 33 autocrop=autocrop, vol_idx=vol_idx, dilate=dilate) 34 35 mask_name = mo_flow.last_generated_outputs['out_mask'] 36 masked_name = mo_flow.last_generated_outputs['out_masked'] 37 38 masked, mask = median_otsu(volume, 39 vol_idx=vol_idx, 40 median_radius=median_radius, 41 numpass=numpass, 42 autocrop=autocrop, dilate=dilate) 43 44 result_mask_data = load_nifti_data(join(out_dir, mask_name)) 45 npt.assert_array_equal(result_mask_data.astype(np.uint8), mask) 46 47 result_masked = nib.load(join(out_dir, masked_name)) 48 result_masked_data = np.asanyarray(result_masked.dataobj) 49 50 npt.assert_array_equal(np.round(result_masked_data), masked) 51 52 53def test_recobundles_flow(): 54 with TemporaryDirectory() as out_dir: 55 data_path = get_fnames('fornix') 56 57 fornix = load_tractogram(data_path, 'same', 58 bbox_valid_check=False).streamlines 59 60 f = Streamlines(fornix) 61 f1 = f.copy() 62 63 f2 = f1[:15].copy() 64 f2._data += np.array([40, 0, 0]) 65 66 f.extend(f2) 67 68 f2_path = pjoin(out_dir, "f2.trk") 69 sft = StatefulTractogram(f2, data_path, Space.RASMM) 70 save_tractogram(sft, f2_path, bbox_valid_check=False) 71 72 f1_path = pjoin(out_dir, "f1.trk") 73 sft = StatefulTractogram(f, data_path, Space.RASMM) 74 save_tractogram(sft, f1_path, bbox_valid_check=False) 75 76 rb_flow = RecoBundlesFlow(force=True) 77 rb_flow.run(f1_path, f2_path, greater_than=0, clust_thr=10, 78 model_clust_thr=5., reduction_thr=10, out_dir=out_dir) 79 80 labels = rb_flow.last_generated_outputs['out_recognized_labels'] 81 recog_trk = rb_flow.last_generated_outputs['out_recognized_transf'] 82 83 rec_bundle = load_tractogram(recog_trk, 'same', 84 bbox_valid_check=False).streamlines 85 npt.assert_equal(len(rec_bundle) == len(f2), True) 86 87 label_flow = LabelsBundlesFlow(force=True) 88 label_flow.run(f1_path, labels) 89 90 recog_bundle = label_flow.last_generated_outputs['out_bundle'] 91 rec_bundle_org = load_tractogram(recog_bundle, 'same', 92 bbox_valid_check=False).streamlines 93 94 BMD = BundleMinDistanceMetric() 95 nb_pts = 20 96 static = set_number_of_points(f2, nb_pts) 97 moving = set_number_of_points(rec_bundle_org, nb_pts) 98 99 BMD.setup(static, moving) 100 x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0]) # affine 101 bmd_value = BMD.distance(x0.tolist()) 102 103 npt.assert_equal(bmd_value < 1, True) 104 105 106if __name__ == '__main__': 107 npt.run_module_suite() 108