1import sys 2import numpy as np 3import pytest 4 5from numpy.testing import (assert_equal, 6 assert_almost_equal, 7 run_module_suite) 8from dipy.data import get_fnames 9from dipy.io.streamline import load_tractogram 10from dipy.segment.bundles import RecoBundles 11from dipy.tracking.distances import bundles_distances_mam 12from dipy.tracking.streamline import Streamlines 13from dipy.segment.clustering import qbx_and_merge 14 15is_big_endian = 'big' in sys.byteorder.lower() 16 17 18def setup_module(): 19 global f, f1, f2, f3, fornix 20 21 fname = get_fnames('fornix') 22 fornix = load_tractogram(fname, 'same', 23 bbox_valid_check=False).streamlines 24 25 f = Streamlines(fornix) 26 f1 = f.copy() 27 28 f2 = f1[:20].copy() 29 f2._data += np.array([50, 0, 0]) 30 31 f3 = f1[200:].copy() 32 f3._data += np.array([100, 0, 0]) 33 34 f.extend(f2) 35 f.extend(f3) 36 37 38@pytest.mark.skipif(is_big_endian, 39 reason="Little Endian architecture required") 40def test_rb_check_defaults(): 41 42 rb = RecoBundles(f, greater_than=0, clust_thr=10) 43 44 rec_trans, rec_labels = rb.recognize(model_bundle=f2, 45 model_clust_thr=5., 46 reduction_thr=10) 47 48 D = bundles_distances_mam(f2, f[rec_labels]) 49 50 # check if the bundle is recognized correctly 51 if len(f2) == len(rec_labels): 52 for row in D: 53 assert_equal(row.min(), 0) 54 55 refine_trans, refine_labels = rb.refine(model_bundle=f2, 56 pruned_streamlines=rec_trans, 57 model_clust_thr=5., 58 reduction_thr=10) 59 60 D = bundles_distances_mam(f2, f[refine_labels]) 61 62 # check if the bundle is recognized correctly 63 for row in D: 64 assert_equal(row.min(), 0) 65 66 67@pytest.mark.skipif(is_big_endian, 68 reason="Little Endian architecture required") 69def test_rb_disable_slr(): 70 71 rb = RecoBundles(f, greater_than=0, clust_thr=10) 72 73 rec_trans, rec_labels = rb.recognize(model_bundle=f2, 74 model_clust_thr=5., 75 reduction_thr=10, 76 slr=False) 77 78 D = bundles_distances_mam(f2, f[rec_labels]) 79 80 # check if the bundle is recognized correctly 81 if len(f2) == len(rec_labels): 82 for row in D: 83 assert_equal(row.min(), 0) 84 85 refine_trans, refine_labels = rb.refine(model_bundle=f2, 86 pruned_streamlines=rec_trans, 87 model_clust_thr=5., 88 reduction_thr=10) 89 90 D = bundles_distances_mam(f2, f[refine_labels]) 91 92 # check if the bundle is recognized correctly 93 for row in D: 94 assert_equal(row.min(), 0) 95 96 97@pytest.mark.skipif(is_big_endian, 98 reason="Little Endian architecture required") 99def test_rb_slr_threads(): 100 101 rng_multi = np.random.RandomState(42) 102 rb_multi = RecoBundles(f, greater_than=0, clust_thr=10, 103 rng=np.random.RandomState(42)) 104 rec_trans_multi_threads, _ = rb_multi.recognize(model_bundle=f2, 105 model_clust_thr=5., 106 reduction_thr=10, 107 slr=True, 108 num_threads=None) 109 110 rb_single = RecoBundles(f, greater_than=0, clust_thr=10, 111 rng=np.random.RandomState(42)) 112 rec_trans_single_thread, _ = rb_single.recognize(model_bundle=f2, 113 model_clust_thr=5., 114 reduction_thr=10, 115 slr=True, 116 num_threads=1) 117 118 D = bundles_distances_mam(rec_trans_multi_threads, rec_trans_single_thread) 119 120 # check if the bundle is recognized correctly 121 # multi-threading prevent an exact match 122 for row in D: 123 assert_almost_equal(row.min(), 0, decimal=4) 124 125 126@pytest.mark.skipif(is_big_endian, 127 reason="Little Endian architecture required") 128def test_rb_no_verbose_and_mam(): 129 130 rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=False) 131 132 rec_trans, rec_labels = rb.recognize(model_bundle=f2, 133 model_clust_thr=5., 134 reduction_thr=10, 135 slr=True, 136 pruning_distance='mam') 137 138 D = bundles_distances_mam(f2, f[rec_labels]) 139 140 # check if the bundle is recognized correctly 141 if len(f2) == len(rec_labels): 142 for row in D: 143 assert_equal(row.min(), 0) 144 145 refine_trans, refine_labels = rb.refine(model_bundle=f2, 146 pruned_streamlines=rec_trans, 147 model_clust_thr=5., 148 reduction_thr=10) 149 150 D = bundles_distances_mam(f2, f[refine_labels]) 151 152 # check if the bundle is recognized correctly 153 for row in D: 154 assert_equal(row.min(), 0) 155 156 157@pytest.mark.skipif(is_big_endian, 158 reason="Little Endian architecture required") 159def test_rb_clustermap(): 160 161 cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10]) 162 163 rb = RecoBundles(f, greater_than=0, less_than=1000000, 164 cluster_map=cluster_map, clust_thr=10) 165 rec_trans, rec_labels = rb.recognize(model_bundle=f2, 166 model_clust_thr=5., 167 reduction_thr=10) 168 169 D = bundles_distances_mam(f2, f[rec_labels]) 170 171 # check if the bundle is recognized correctly 172 if len(f2) == len(rec_labels): 173 for row in D: 174 assert_equal(row.min(), 0) 175 176 refine_trans, refine_labels = rb.refine(model_bundle=f2, 177 pruned_streamlines=rec_trans, 178 model_clust_thr=5., 179 reduction_thr=10) 180 181 D = bundles_distances_mam(f2, f[refine_labels]) 182 183 # check if the bundle is recognized correctly 184 for row in D: 185 assert_equal(row.min(), 0) 186 187 188@pytest.mark.skipif(is_big_endian, 189 reason="Little Endian architecture required") 190def test_rb_no_neighb(): 191 # what if no neighbors are found? No recognition 192 193 b = Streamlines(fornix) 194 b1 = b.copy() 195 196 b2 = b1[:20].copy() 197 b2._data += np.array([100, 0, 0]) 198 199 b3 = b1[:20].copy() 200 b3._data += np.array([300, 0, 0]) 201 202 b.extend(b3) 203 204 rb = RecoBundles(b, greater_than=0, clust_thr=10) 205 206 rec_trans, rec_labels = rb.recognize(model_bundle=b2, 207 model_clust_thr=5., 208 reduction_thr=10) 209 210 if len(rec_trans) > 0: 211 refine_trans, refine_labels = rb.refine(model_bundle=b2, 212 pruned_streamlines=rec_trans, 213 model_clust_thr=5., 214 reduction_thr=10) 215 216 assert_equal(len(refine_labels), 0) 217 assert_equal(len(refine_trans), 0) 218 219 else: 220 assert_equal(len(rec_labels), 0) 221 assert_equal(len(rec_trans), 0) 222 223 224@pytest.mark.skipif(is_big_endian, 225 reason="Little Endian architecture required") 226def test_rb_reduction_mam(): 227 228 rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=True) 229 230 rec_trans, rec_labels = rb.recognize(model_bundle=f2, 231 model_clust_thr=5., 232 reduction_thr=10, 233 reduction_distance='mam', 234 slr=True, 235 slr_metric='asymmetric', 236 pruning_distance='mam') 237 238 D = bundles_distances_mam(f2, f[rec_labels]) 239 240 # check if the bundle is recognized correctly 241 if len(f2) == len(rec_labels): 242 for row in D: 243 assert_equal(row.min(), 0) 244 245 refine_trans, refine_labels = rb.refine(model_bundle=f2, 246 pruned_streamlines=rec_trans, 247 model_clust_thr=5., 248 reduction_thr=10) 249 250 D = bundles_distances_mam(f2, f[refine_labels]) 251 252 # check if the bundle is recognized correctly 253 for row in D: 254 assert_equal(row.min(), 0) 255 256 257if __name__ == '__main__': 258 259 run_module_suite() 260