1import warnings 2 3from dipy.reconst.mcsd import (mask_for_response_msmt, 4 response_from_mask_msmt, 5 auto_response_msmt) 6from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response 7from dipy.reconst import mcsd 8import numpy as np 9import numpy.testing as npt 10import pytest 11 12from dipy.sims.voxel import single_tensor, multi_tensor, add_noise 13from dipy.reconst import shm 14from dipy.reconst.dti import fractional_anisotropy, mean_diffusivity 15from dipy.data import default_sphere, get_3shell_gtab, get_fnames 16from dipy.core.gradients import GradientTable, gradient_table 17 18from dipy.io.gradients import read_bvals_bvecs 19 20from dipy.utils.optpkg import optional_package 21cvx, have_cvxpy, _ = optional_package("cvxpy") 22 23needs_cvxpy = pytest.mark.skipif(not have_cvxpy) 24 25 26wm_response = np.array([[1.7E-3, 0.4E-3, 0.4E-3, 25.], 27 [1.7E-3, 0.4E-3, 0.4E-3, 25.], 28 [1.7E-3, 0.4E-3, 0.4E-3, 25.]]) 29csf_response = np.array([[3.0E-3, 3.0E-3, 3.0E-3, 100.], 30 [3.0E-3, 3.0E-3, 3.0E-3, 100.], 31 [3.0E-3, 3.0E-3, 3.0E-3, 100.]]) 32gm_response = np.array([[4.0E-4, 4.0E-4, 4.0E-4, 40.], 33 [4.0E-4, 4.0E-4, 4.0E-4, 40.], 34 [4.0E-4, 4.0E-4, 4.0E-4, 40.]]) 35 36 37def get_test_data(): 38 gtab = get_3shell_gtab() 39 evals_list = [np.array([1.7E-3, 0.4E-3, 0.4E-3]), 40 np.array([6.0E-4, 4.0E-4, 4.0E-4]), 41 np.array([3.0E-3, 3.0E-3, 3.0E-3])] 42 s0 = [0.8, 1, 4] 43 signals = [single_tensor(gtab, x[0], x[1]) for x in zip(s0, evals_list)] 44 tissues = [0, 0, 2, 0, 1, 0, 0, 1, 2] # wm=0, gm=1, csf=2 45 data = [add_noise(signals[tissue], 80, s0[0]) for tissue in tissues] 46 data = np.asarray(data).reshape((3, 3, 1, len(signals[0]))) 47 evals = [evals_list[tissue] for tissue in tissues] 48 evals = np.asarray(evals).reshape((3, 3, 1, 3)) 49 tissues = np.asarray(tissues).reshape((3, 3, 1)) 50 masks = [np.where(tissues == x, 1, 0) for x in range(3)] 51 responses = [np.concatenate((x[0], [x[1]])) for x in zip(evals_list, s0)] 52 return (gtab, data, masks, responses) 53 54 55def _expand(m, iso, coeff): 56 params = np.zeros(len(m)) 57 params[m == 0] = coeff[iso:] 58 params = np.concatenate([coeff[:iso], params]) 59 return params 60 61 62@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY") 63def test_mcsd_model_delta(): 64 sh_order = 8 65 gtab = get_3shell_gtab() 66 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], 67 wm_response, 68 gm_response, 69 csf_response) 70 model = MultiShellDeconvModel(gtab, response) 71 iso = response.iso 72 73 theta, phi = default_sphere.theta, default_sphere.phi 74 B = shm.real_sh_descoteaux_from_index( 75 response.m, response.n, theta[:, None], phi[:, None]) 76 77 wm_delta = model.delta.copy() 78 # set isotropic components to zero 79 wm_delta[:iso] = 0. 80 wm_delta = _expand(model.m, iso, wm_delta) 81 82 for i, s in enumerate([0, 1000, 2000, 3500]): 83 g = GradientTable(default_sphere.vertices * s) 84 signal = model.predict(wm_delta, g) 85 expected = np.dot(response.response[i, iso:], B.T) 86 npt.assert_array_almost_equal(signal, expected) 87 88 signal = model.predict(wm_delta, gtab) 89 fit = model.fit(signal) 90 m = model.m 91 npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2) 92 93 94@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY") 95def test_MultiShellDeconvModel_response(): 96 gtab = get_3shell_gtab() 97 98 sh_order = 8 99 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], 100 wm_response, 101 gm_response, 102 csf_response) 103 model_1 = MultiShellDeconvModel(gtab, response, sh_order=sh_order) 104 responses = np.array([wm_response, gm_response, csf_response]) 105 model_2 = MultiShellDeconvModel(gtab, responses, sh_order=sh_order) 106 response_1 = model_1.response.response 107 response_2 = model_2.response.response 108 npt.assert_array_almost_equal(response_1, response_2, 0) 109 110 npt.assert_raises(ValueError, MultiShellDeconvModel, 111 gtab, np.ones((4, 3, 4))) 112 npt.assert_raises(ValueError, MultiShellDeconvModel, 113 gtab, np.ones((3, 3, 4)), iso=3) 114 115 116@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY") 117def test_MultiShellDeconvModel(): 118 gtab = get_3shell_gtab() 119 120 mevals = np.array([wm_response[0, :3], wm_response[0, :3]]) 121 angles = [(0, 0), (60, 0)] 122 123 S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles, 124 fractions=[30., 70.], snr=None) 125 S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0]) 126 S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0]) 127 128 sh_order = 8 129 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], 130 wm_response, 131 gm_response, 132 csf_response) 133 model = MultiShellDeconvModel(gtab, response) 134 vf = [0.325, 0.2, 0.475] 135 signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm])) 136 fit = model.fit(signal) 137 138 # Testing both ways to predict 139 S_pred_fit = fit.predict() 140 S_pred_model = model.predict(fit.all_shm_coeff) 141 142 npt.assert_array_almost_equal(S_pred_fit, S_pred_model, 0) 143 npt.assert_array_almost_equal(S_pred_fit, signal, 0) 144 145 146@pytest.mark.skipif(not mcsd.have_cvxpy, reason="Requires CVXPY") 147def test_MSDeconvFit(): 148 gtab = get_3shell_gtab() 149 150 mevals = np.array([wm_response[0, :3], wm_response[0, :3]]) 151 angles = [(0, 0), (60, 0)] 152 153 S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles, 154 fractions=[30., 70.], snr=None) 155 S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0]) 156 S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0]) 157 158 sh_order = 8 159 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], 160 wm_response, 161 gm_response, 162 csf_response) 163 model = MultiShellDeconvModel(gtab, response) 164 vf = [0.325, 0.2, 0.475] 165 signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm])) 166 fit = model.fit(signal) 167 168 # Testing volume fractions 169 npt.assert_array_almost_equal(fit.volume_fractions, vf, 1) 170 171 172def test_multi_shell_fiber_response(): 173 gtab = get_3shell_gtab() 174 sh_order = 8 175 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], 176 wm_response, 177 gm_response, 178 csf_response) 179 180 npt.assert_equal(response.response.shape, (4, 7)) 181 182 with warnings.catch_warnings(record=True) as w: 183 response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500], 184 wm_response, 185 gm_response, 186 csf_response) 187 # Test that the number of warnings raised is greater than 1, with 188 # deprecation warnings being raised from using legacy SH bases as well 189 # as a warning from multi_shell_fiber_response 190 npt.assert_(len(w) > 1) 191 # The last warning in list is the one from multi_shell_fiber_response 192 npt.assert_(issubclass(w[-1].category, UserWarning)) 193 npt.assert_("""No b0 given. Proceeding either way.""" in 194 str(w[-1].message)) 195 npt.assert_equal(response.response.shape, (3, 7)) 196 197 198def test_mask_for_response_msmt(): 199 gtab, data, masks_gt, _ = get_test_data() 200 201 wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data, 202 roi_center=None, 203 roi_radii=(1, 1, 0), 204 wm_fa_thr=0.7, 205 gm_fa_thr=0.3, 206 csf_fa_thr=0.15, 207 gm_md_thr=0.001, 208 csf_md_thr=0.0032) 209 210 # Verifies that masks are not empty: 211 masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask)) 212 npt.assert_equal(masks_sum != 0, True) 213 214 npt.assert_array_almost_equal(masks_gt[0], wm_mask) 215 npt.assert_array_almost_equal(masks_gt[1], gm_mask) 216 npt.assert_array_almost_equal(masks_gt[2], csf_mask) 217 218 219def test_mask_for_response_msmt_nvoxels(): 220 gtab, data, _, _ = get_test_data() 221 222 wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data, 223 roi_center=None, 224 roi_radii=(1, 1, 0), 225 wm_fa_thr=0.7, 226 gm_fa_thr=0.3, 227 csf_fa_thr=0.15, 228 gm_md_thr=0.001, 229 csf_md_thr=0.0032) 230 231 wm_nvoxels = np.sum(wm_mask) 232 gm_nvoxels = np.sum(gm_mask) 233 csf_nvoxels = np.sum(csf_mask) 234 npt.assert_equal(wm_nvoxels, 5) 235 npt.assert_equal(gm_nvoxels, 2) 236 npt.assert_equal(csf_nvoxels, 2) 237 238 with warnings.catch_warnings(record=True) as w: 239 wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data, 240 roi_center=None, 241 roi_radii=(1, 1, 0), 242 wm_fa_thr=1, 243 gm_fa_thr=0, 244 csf_fa_thr=0, 245 gm_md_thr=0, 246 csf_md_thr=0) 247 npt.assert_equal(len(w), 6) 248 npt.assert_(issubclass(w[0].category, UserWarning)) 249 npt.assert_("""Some b-values are higher than 1200.""" in 250 str(w[0].message)) 251 npt.assert_("No voxel with a FA higher than 1 were found" in 252 str(w[1].message)) 253 npt.assert_("No voxel with a FA lower than 0 were found" in 254 str(w[2].message)) 255 npt.assert_("No voxel with a MD lower than 0 were found" in 256 str(w[3].message)) 257 npt.assert_("No voxel with a FA lower than 0 were found" in 258 str(w[4].message)) 259 npt.assert_("No voxel with a MD lower than 0 were found" in 260 str(w[5].message)) 261 262 wm_nvoxels = np.sum(wm_mask) 263 gm_nvoxels = np.sum(gm_mask) 264 csf_nvoxels = np.sum(csf_mask) 265 npt.assert_equal(wm_nvoxels, 0) 266 npt.assert_equal(gm_nvoxels, 0) 267 npt.assert_equal(csf_nvoxels, 0) 268 269 270def test_response_from_mask_msmt(): 271 gtab, data, masks_gt, responses_gt = get_test_data() 272 273 response_wm, response_gm, response_csf \ 274 = response_from_mask_msmt(gtab, data, masks_gt[0], 275 masks_gt[1], masks_gt[2], tol=20) 276 277 # Verifying that csf's response is greater than gm's 278 npt.assert_equal(np.sum(response_csf[:, :3]) > np.sum(response_gm[:, :3]), 279 True) 280 # Verifying that csf and gm are described by spheres 281 npt.assert_almost_equal(response_csf[:, 1], response_csf[:, 2]) 282 npt.assert_allclose(response_csf[:, 0], response_csf[:, 1], rtol=1, atol=0) 283 npt.assert_almost_equal(response_gm[:, 1], response_gm[:, 2]) 284 npt.assert_allclose(response_gm[:, 0], response_gm[:, 1], rtol=1, atol=0) 285 # Verifying that wm is anisotropic in one direction 286 npt.assert_almost_equal(response_wm[:, 1], response_wm[:, 2]) 287 npt.assert_equal(response_wm[:, 0] > 2.5 * response_wm[:, 1], True) 288 289 # Verifying with ground truth for the first bvalue 290 npt.assert_array_almost_equal(response_wm[0], responses_gt[0], 1) 291 npt.assert_array_almost_equal(response_gm[0], responses_gt[1], 1) 292 npt.assert_array_almost_equal(response_csf[0], responses_gt[2], 1) 293 294 295def test_auto_response_msmt(): 296 gtab, data, _, _ = get_test_data() 297 298 with warnings.catch_warnings(record=True) as w: 299 response_auto_wm, response_auto_gm, response_auto_csf = \ 300 auto_response_msmt(gtab, data, tol=20, 301 roi_center=None, roi_radii=(1, 1, 0), 302 wm_fa_thr=0.7, gm_fa_thr=0.3, csf_fa_thr=0.15, 303 gm_md_thr=0.001, csf_md_thr=0.0032) 304 305 npt.assert_(issubclass(w[0].category, UserWarning)) 306 npt.assert_("""Some b-values are higher than 1200. 307 The DTI fit might be affected. It is advised to use 308 mask_for_response_msmt with bvalues lower than 1200, followed by 309 response_from_mask_msmt with all bvalues to overcome this.""" 310 in str(w[0].message)) 311 312 mask_wm, mask_gm, mask_csf = mask_for_response_msmt(gtab, data, 313 roi_center=None, 314 roi_radii=(1, 1, 0), 315 wm_fa_thr=0.7, 316 gm_fa_thr=0.3, 317 csf_fa_thr=0.15, 318 gm_md_thr=0.001, 319 csf_md_thr=0.0032) 320 321 response_from_mask_wm, response_from_mask_gm, response_from_mask_csf = \ 322 response_from_mask_msmt(gtab, data, 323 mask_wm, mask_gm, mask_csf, 324 tol=20) 325 326 npt.assert_array_equal(response_auto_wm, response_from_mask_wm) 327 npt.assert_array_equal(response_auto_gm, response_from_mask_gm) 328 npt.assert_array_equal(response_auto_csf, response_from_mask_csf) 329