1import sys
2import numpy as np
3import pytest
4
5from numpy.testing import (assert_equal,
6                           assert_almost_equal,
7                           run_module_suite)
8from dipy.data import get_fnames
9from dipy.io.streamline import load_tractogram
10from dipy.segment.bundles import RecoBundles
11from dipy.tracking.distances import bundles_distances_mam
12from dipy.tracking.streamline import Streamlines
13from dipy.segment.clustering import qbx_and_merge
14
15is_big_endian = 'big' in sys.byteorder.lower()
16
17
18def setup_module():
19    global f, f1, f2, f3, fornix
20
21    fname = get_fnames('fornix')
22    fornix = load_tractogram(fname, 'same',
23                             bbox_valid_check=False).streamlines
24
25    f = Streamlines(fornix)
26    f1 = f.copy()
27
28    f2 = f1[:20].copy()
29    f2._data += np.array([50, 0, 0])
30
31    f3 = f1[200:].copy()
32    f3._data += np.array([100, 0, 0])
33
34    f.extend(f2)
35    f.extend(f3)
36
37
38@pytest.mark.skipif(is_big_endian,
39                    reason="Little Endian architecture required")
40def test_rb_check_defaults():
41
42    rb = RecoBundles(f, greater_than=0, clust_thr=10)
43
44    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
45                                         model_clust_thr=5.,
46                                         reduction_thr=10)
47
48    D = bundles_distances_mam(f2, f[rec_labels])
49
50    # check if the bundle is recognized correctly
51    if len(f2) == len(rec_labels):
52        for row in D:
53            assert_equal(row.min(), 0)
54
55    refine_trans, refine_labels = rb.refine(model_bundle=f2,
56                                            pruned_streamlines=rec_trans,
57                                            model_clust_thr=5.,
58                                            reduction_thr=10)
59
60    D = bundles_distances_mam(f2, f[refine_labels])
61
62    # check if the bundle is recognized correctly
63    for row in D:
64        assert_equal(row.min(), 0)
65
66
67@pytest.mark.skipif(is_big_endian,
68                    reason="Little Endian architecture required")
69def test_rb_disable_slr():
70
71    rb = RecoBundles(f, greater_than=0, clust_thr=10)
72
73    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
74                                         model_clust_thr=5.,
75                                         reduction_thr=10,
76                                         slr=False)
77
78    D = bundles_distances_mam(f2, f[rec_labels])
79
80    # check if the bundle is recognized correctly
81    if len(f2) == len(rec_labels):
82        for row in D:
83            assert_equal(row.min(), 0)
84
85    refine_trans, refine_labels = rb.refine(model_bundle=f2,
86                                            pruned_streamlines=rec_trans,
87                                            model_clust_thr=5.,
88                                            reduction_thr=10)
89
90    D = bundles_distances_mam(f2, f[refine_labels])
91
92    # check if the bundle is recognized correctly
93    for row in D:
94        assert_equal(row.min(), 0)
95
96
97@pytest.mark.skipif(is_big_endian,
98                    reason="Little Endian architecture required")
99def test_rb_slr_threads():
100
101    rng_multi = np.random.RandomState(42)
102    rb_multi = RecoBundles(f, greater_than=0, clust_thr=10,
103                           rng=np.random.RandomState(42))
104    rec_trans_multi_threads, _ = rb_multi.recognize(model_bundle=f2,
105                                                    model_clust_thr=5.,
106                                                    reduction_thr=10,
107                                                    slr=True,
108                                                    num_threads=None)
109
110    rb_single = RecoBundles(f, greater_than=0, clust_thr=10,
111                            rng=np.random.RandomState(42))
112    rec_trans_single_thread, _ = rb_single.recognize(model_bundle=f2,
113                                                     model_clust_thr=5.,
114                                                     reduction_thr=10,
115                                                     slr=True,
116                                                     num_threads=1)
117
118    D = bundles_distances_mam(rec_trans_multi_threads, rec_trans_single_thread)
119
120    # check if the bundle is recognized correctly
121    # multi-threading prevent an exact match
122    for row in D:
123        assert_almost_equal(row.min(), 0, decimal=4)
124
125
126@pytest.mark.skipif(is_big_endian,
127                    reason="Little Endian architecture required")
128def test_rb_no_verbose_and_mam():
129
130    rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=False)
131
132    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
133                                         model_clust_thr=5.,
134                                         reduction_thr=10,
135                                         slr=True,
136                                         pruning_distance='mam')
137
138    D = bundles_distances_mam(f2, f[rec_labels])
139
140    # check if the bundle is recognized correctly
141    if len(f2) == len(rec_labels):
142        for row in D:
143            assert_equal(row.min(), 0)
144
145    refine_trans, refine_labels = rb.refine(model_bundle=f2,
146                                            pruned_streamlines=rec_trans,
147                                            model_clust_thr=5.,
148                                            reduction_thr=10)
149
150    D = bundles_distances_mam(f2, f[refine_labels])
151
152    # check if the bundle is recognized correctly
153    for row in D:
154        assert_equal(row.min(), 0)
155
156
157@pytest.mark.skipif(is_big_endian,
158                    reason="Little Endian architecture required")
159def test_rb_clustermap():
160
161    cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10])
162
163    rb = RecoBundles(f, greater_than=0, less_than=1000000,
164                     cluster_map=cluster_map, clust_thr=10)
165    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
166                                         model_clust_thr=5.,
167                                         reduction_thr=10)
168
169    D = bundles_distances_mam(f2, f[rec_labels])
170
171    # check if the bundle is recognized correctly
172    if len(f2) == len(rec_labels):
173        for row in D:
174            assert_equal(row.min(), 0)
175
176    refine_trans, refine_labels = rb.refine(model_bundle=f2,
177                                            pruned_streamlines=rec_trans,
178                                            model_clust_thr=5.,
179                                            reduction_thr=10)
180
181    D = bundles_distances_mam(f2, f[refine_labels])
182
183    # check if the bundle is recognized correctly
184    for row in D:
185        assert_equal(row.min(), 0)
186
187
188@pytest.mark.skipif(is_big_endian,
189                    reason="Little Endian architecture required")
190def test_rb_no_neighb():
191    # what if no neighbors are found? No recognition
192
193    b = Streamlines(fornix)
194    b1 = b.copy()
195
196    b2 = b1[:20].copy()
197    b2._data += np.array([100, 0, 0])
198
199    b3 = b1[:20].copy()
200    b3._data += np.array([300, 0, 0])
201
202    b.extend(b3)
203
204    rb = RecoBundles(b, greater_than=0, clust_thr=10)
205
206    rec_trans, rec_labels = rb.recognize(model_bundle=b2,
207                                         model_clust_thr=5.,
208                                         reduction_thr=10)
209
210    if len(rec_trans) > 0:
211        refine_trans, refine_labels = rb.refine(model_bundle=b2,
212                                                pruned_streamlines=rec_trans,
213                                                model_clust_thr=5.,
214                                                reduction_thr=10)
215
216        assert_equal(len(refine_labels), 0)
217        assert_equal(len(refine_trans), 0)
218
219    else:
220        assert_equal(len(rec_labels), 0)
221        assert_equal(len(rec_trans), 0)
222
223
224@pytest.mark.skipif(is_big_endian,
225                    reason="Little Endian architecture required")
226def test_rb_reduction_mam():
227
228    rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=True)
229
230    rec_trans, rec_labels = rb.recognize(model_bundle=f2,
231                                         model_clust_thr=5.,
232                                         reduction_thr=10,
233                                         reduction_distance='mam',
234                                         slr=True,
235                                         slr_metric='asymmetric',
236                                         pruning_distance='mam')
237
238    D = bundles_distances_mam(f2, f[rec_labels])
239
240    # check if the bundle is recognized correctly
241    if len(f2) == len(rec_labels):
242        for row in D:
243            assert_equal(row.min(), 0)
244
245    refine_trans, refine_labels = rb.refine(model_bundle=f2,
246                                            pruned_streamlines=rec_trans,
247                                            model_clust_thr=5.,
248                                            reduction_thr=10)
249
250    D = bundles_distances_mam(f2, f[refine_labels])
251
252    # check if the bundle is recognized correctly
253    for row in D:
254        assert_equal(row.min(), 0)
255
256
257if __name__ == '__main__':
258
259    run_module_suite()
260