1import warnings
2
3from dipy.reconst.mcsd import (mask_for_response_msmt,
4                               response_from_mask_msmt,
5                               auto_response_msmt)
6from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response
7from dipy.reconst import mcsd
8import numpy as np
9import numpy.testing as npt
10import pytest
11
12from dipy.sims.voxel import single_tensor, multi_tensor, add_noise
13from dipy.reconst import shm
14from dipy.reconst.dti import fractional_anisotropy, mean_diffusivity
15from dipy.data import default_sphere, get_3shell_gtab, get_fnames
16from dipy.core.gradients import GradientTable, gradient_table
17
18from dipy.io.gradients import read_bvals_bvecs
19
20from dipy.utils.optpkg import optional_package
21cvx, have_cvxpy, _ = optional_package("cvxpy")
22
23needs_cvxpy = pytest.mark.skipif(not have_cvxpy)
24
25
26wm_response = np.array([[1.7E-3, 0.4E-3, 0.4E-3, 25.],
27                        [1.7E-3, 0.4E-3, 0.4E-3, 25.],
28                        [1.7E-3, 0.4E-3, 0.4E-3, 25.]])
29csf_response = np.array([[3.0E-3, 3.0E-3, 3.0E-3, 100.],
30                         [3.0E-3, 3.0E-3, 3.0E-3, 100.],
31                         [3.0E-3, 3.0E-3, 3.0E-3, 100.]])
32gm_response = np.array([[4.0E-4, 4.0E-4, 4.0E-4, 40.],
33                        [4.0E-4, 4.0E-4, 4.0E-4, 40.],
34                        [4.0E-4, 4.0E-4, 4.0E-4, 40.]])
35
36
37def get_test_data():
38    gtab = get_3shell_gtab()
39    evals_list = [np.array([1.7E-3, 0.4E-3, 0.4E-3]),
40                  np.array([6.0E-4, 4.0E-4, 4.0E-4]),
41                  np.array([3.0E-3, 3.0E-3, 3.0E-3])]
42    s0 = [0.8, 1, 4]
43    signals = [single_tensor(gtab, x[0], x[1]) for x in zip(s0, evals_list)]
44    tissues = [0, 0, 2, 0, 1, 0, 0, 1, 2]  # wm=0, gm=1, csf=2
45    data = [add_noise(signals[tissue], 80, s0[0]) for tissue in tissues]
46    data = np.asarray(data).reshape((3, 3, 1, len(signals[0])))
47    evals = [evals_list[tissue] for tissue in tissues]
48    evals = np.asarray(evals).reshape((3, 3, 1, 3))
49    tissues = np.asarray(tissues).reshape((3, 3, 1))
50    masks = [np.where(tissues == x, 1, 0) for x in range(3)]
51    responses = [np.concatenate((x[0], [x[1]])) for x in zip(evals_list, s0)]
52    return (gtab, data, masks, responses)
53
54
55def _expand(m, iso, coeff):
56    params = np.zeros(len(m))
57    params[m == 0] = coeff[iso:]
58    params = np.concatenate([coeff[:iso], params])
59    return params
60
61
62@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY")
63def test_mcsd_model_delta():
64    sh_order = 8
65    gtab = get_3shell_gtab()
66    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
67                                          wm_response,
68                                          gm_response,
69                                          csf_response)
70    model = MultiShellDeconvModel(gtab, response)
71    iso = response.iso
72
73    theta, phi = default_sphere.theta, default_sphere.phi
74    B = shm.real_sh_descoteaux_from_index(
75        response.m, response.n, theta[:, None], phi[:, None])
76
77    wm_delta = model.delta.copy()
78    # set isotropic components to zero
79    wm_delta[:iso] = 0.
80    wm_delta = _expand(model.m, iso, wm_delta)
81
82    for i, s in enumerate([0, 1000, 2000, 3500]):
83        g = GradientTable(default_sphere.vertices * s)
84        signal = model.predict(wm_delta, g)
85        expected = np.dot(response.response[i, iso:], B.T)
86        npt.assert_array_almost_equal(signal, expected)
87
88    signal = model.predict(wm_delta, gtab)
89    fit = model.fit(signal)
90    m = model.m
91    npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2)
92
93
94@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY")
95def test_MultiShellDeconvModel_response():
96    gtab = get_3shell_gtab()
97
98    sh_order = 8
99    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
100                                          wm_response,
101                                          gm_response,
102                                          csf_response)
103    model_1 = MultiShellDeconvModel(gtab, response, sh_order=sh_order)
104    responses = np.array([wm_response, gm_response, csf_response])
105    model_2 = MultiShellDeconvModel(gtab, responses, sh_order=sh_order)
106    response_1 = model_1.response.response
107    response_2 = model_2.response.response
108    npt.assert_array_almost_equal(response_1, response_2, 0)
109
110    npt.assert_raises(ValueError, MultiShellDeconvModel,
111                      gtab, np.ones((4, 3, 4)))
112    npt.assert_raises(ValueError, MultiShellDeconvModel,
113                      gtab, np.ones((3, 3, 4)), iso=3)
114
115
116@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY")
117def test_MultiShellDeconvModel():
118    gtab = get_3shell_gtab()
119
120    mevals = np.array([wm_response[0, :3], wm_response[0, :3]])
121    angles = [(0, 0), (60, 0)]
122
123    S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles,
124                                fractions=[30., 70.], snr=None)
125    S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0])
126    S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0])
127
128    sh_order = 8
129    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
130                                          wm_response,
131                                          gm_response,
132                                          csf_response)
133    model = MultiShellDeconvModel(gtab, response)
134    vf = [0.325, 0.2, 0.475]
135    signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm]))
136    fit = model.fit(signal)
137
138    # Testing both ways to predict
139    S_pred_fit = fit.predict()
140    S_pred_model = model.predict(fit.all_shm_coeff)
141
142    npt.assert_array_almost_equal(S_pred_fit, S_pred_model, 0)
143    npt.assert_array_almost_equal(S_pred_fit, signal, 0)
144
145
146@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY")
147def test_MSDeconvFit():
148    gtab = get_3shell_gtab()
149
150    mevals = np.array([wm_response[0, :3], wm_response[0, :3]])
151    angles = [(0, 0), (60, 0)]
152
153    S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles,
154                                fractions=[30., 70.], snr=None)
155    S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0])
156    S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0])
157
158    sh_order = 8
159    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
160                                          wm_response,
161                                          gm_response,
162                                          csf_response)
163    model = MultiShellDeconvModel(gtab, response)
164    vf = [0.325, 0.2, 0.475]
165    signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm]))
166    fit = model.fit(signal)
167
168    # Testing volume fractions
169    npt.assert_array_almost_equal(fit.volume_fractions, vf, 1)
170
171
172def test_multi_shell_fiber_response():
173    gtab = get_3shell_gtab()
174    sh_order = 8
175    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
176                                          wm_response,
177                                          gm_response,
178                                          csf_response)
179
180    npt.assert_equal(response.response.shape, (4, 7))
181
182    with warnings.catch_warnings(record=True) as w:
183        response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500],
184                                              wm_response,
185                                              gm_response,
186                                              csf_response)
187        # Test that the number of warnings raised is greater than 1, with
188        # deprecation warnings being raised from using legacy SH bases as well
189        # as a warning from multi_shell_fiber_response
190        npt.assert_(len(w) > 1)
191        # The last warning in list is the one from multi_shell_fiber_response
192        npt.assert_(issubclass(w[-1].category, UserWarning))
193        npt.assert_("""No b0 given. Proceeding either way.""" in
194                    str(w[-1].message))
195        npt.assert_equal(response.response.shape, (3, 7))
196
197
198def test_mask_for_response_msmt():
199    gtab, data, masks_gt, _ = get_test_data()
200
201    wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
202                                                        roi_center=None,
203                                                        roi_radii=(1, 1, 0),
204                                                        wm_fa_thr=0.7,
205                                                        gm_fa_thr=0.3,
206                                                        csf_fa_thr=0.15,
207                                                        gm_md_thr=0.001,
208                                                        csf_md_thr=0.0032)
209
210    # Verifies that masks are not empty:
211    masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask))
212    npt.assert_equal(masks_sum != 0, True)
213
214    npt.assert_array_almost_equal(masks_gt[0], wm_mask)
215    npt.assert_array_almost_equal(masks_gt[1], gm_mask)
216    npt.assert_array_almost_equal(masks_gt[2], csf_mask)
217
218
219def test_mask_for_response_msmt_nvoxels():
220    gtab, data, _, _ = get_test_data()
221
222    wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
223                                                        roi_center=None,
224                                                        roi_radii=(1, 1, 0),
225                                                        wm_fa_thr=0.7,
226                                                        gm_fa_thr=0.3,
227                                                        csf_fa_thr=0.15,
228                                                        gm_md_thr=0.001,
229                                                        csf_md_thr=0.0032)
230
231    wm_nvoxels = np.sum(wm_mask)
232    gm_nvoxels = np.sum(gm_mask)
233    csf_nvoxels = np.sum(csf_mask)
234    npt.assert_equal(wm_nvoxels, 5)
235    npt.assert_equal(gm_nvoxels, 2)
236    npt.assert_equal(csf_nvoxels, 2)
237
238    with warnings.catch_warnings(record=True) as w:
239        wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
240                                                            roi_center=None,
241                                                            roi_radii=(1, 1, 0),
242                                                            wm_fa_thr=1,
243                                                            gm_fa_thr=0,
244                                                            csf_fa_thr=0,
245                                                            gm_md_thr=0,
246                                                            csf_md_thr=0)
247        npt.assert_equal(len(w), 6)
248        npt.assert_(issubclass(w[0].category, UserWarning))
249        npt.assert_("""Some b-values are higher than 1200.""" in
250                    str(w[0].message))
251        npt.assert_("No voxel with a FA higher than 1 were found" in
252                    str(w[1].message))
253        npt.assert_("No voxel with a FA lower than 0 were found" in
254                    str(w[2].message))
255        npt.assert_("No voxel with a MD lower than 0 were found" in
256                    str(w[3].message))
257        npt.assert_("No voxel with a FA lower than 0 were found" in
258                    str(w[4].message))
259        npt.assert_("No voxel with a MD lower than 0 were found" in
260                    str(w[5].message))
261
262    wm_nvoxels = np.sum(wm_mask)
263    gm_nvoxels = np.sum(gm_mask)
264    csf_nvoxels = np.sum(csf_mask)
265    npt.assert_equal(wm_nvoxels, 0)
266    npt.assert_equal(gm_nvoxels, 0)
267    npt.assert_equal(csf_nvoxels, 0)
268
269
270def test_response_from_mask_msmt():
271    gtab, data, masks_gt, responses_gt = get_test_data()
272
273    response_wm, response_gm, response_csf \
274        = response_from_mask_msmt(gtab, data, masks_gt[0],
275                                  masks_gt[1], masks_gt[2], tol=20)
276
277    # Verifying that csf's response is greater than gm's
278    npt.assert_equal(np.sum(response_csf[:, :3]) > np.sum(response_gm[:, :3]),
279                     True)
280    # Verifying that csf and gm are described by spheres
281    npt.assert_almost_equal(response_csf[:, 1], response_csf[:, 2])
282    npt.assert_allclose(response_csf[:, 0], response_csf[:, 1], rtol=1, atol=0)
283    npt.assert_almost_equal(response_gm[:, 1], response_gm[:, 2])
284    npt.assert_allclose(response_gm[:, 0], response_gm[:, 1], rtol=1, atol=0)
285    # Verifying that wm is anisotropic in one direction
286    npt.assert_almost_equal(response_wm[:, 1], response_wm[:, 2])
287    npt.assert_equal(response_wm[:, 0] > 2.5 * response_wm[:, 1], True)
288
289    # Verifying with ground truth for the first bvalue
290    npt.assert_array_almost_equal(response_wm[0], responses_gt[0], 1)
291    npt.assert_array_almost_equal(response_gm[0], responses_gt[1], 1)
292    npt.assert_array_almost_equal(response_csf[0], responses_gt[2], 1)
293
294
295def test_auto_response_msmt():
296    gtab, data, _, _ = get_test_data()
297
298    with warnings.catch_warnings(record=True) as w:
299        response_auto_wm, response_auto_gm, response_auto_csf = \
300            auto_response_msmt(gtab, data, tol=20,
301                               roi_center=None, roi_radii=(1, 1, 0),
302                               wm_fa_thr=0.7, gm_fa_thr=0.3, csf_fa_thr=0.15,
303                               gm_md_thr=0.001, csf_md_thr=0.0032)
304
305        npt.assert_(issubclass(w[0].category, UserWarning))
306        npt.assert_("""Some b-values are higher than 1200.
307        The DTI fit might be affected. It is advised to use
308        mask_for_response_msmt with bvalues lower than 1200, followed by
309        response_from_mask_msmt with all bvalues to overcome this."""
310                    in str(w[0].message))
311
312        mask_wm, mask_gm, mask_csf = mask_for_response_msmt(gtab, data,
313                                                            roi_center=None,
314                                                            roi_radii=(1, 1, 0),
315                                                            wm_fa_thr=0.7,
316                                                            gm_fa_thr=0.3,
317                                                            csf_fa_thr=0.15,
318                                                            gm_md_thr=0.001,
319                                                            csf_md_thr=0.0032)
320
321        response_from_mask_wm, response_from_mask_gm, response_from_mask_csf = \
322            response_from_mask_msmt(gtab, data,
323                                    mask_wm, mask_gm, mask_csf,
324                                    tol=20)
325
326        npt.assert_array_equal(response_auto_wm, response_from_mask_wm)
327        npt.assert_array_equal(response_auto_gm, response_from_mask_gm)
328        npt.assert_array_equal(response_auto_csf, response_from_mask_csf)
329