1""" Testing Mean Signal DKI (MSDKI) """
2
3import numpy as np
4import random
5from numpy.testing import (assert_array_almost_equal, assert_raises,
6                           assert_almost_equal, assert_)
7from dipy.sims.voxel import (single_tensor, multi_tensor_dki)
8from dipy.io.gradients import read_bvals_bvecs
9from dipy.core.gradients import (gradient_table, unique_bvals_magnitude,
10                                 round_bvals)
11from dipy.data import get_fnames
12import dipy.reconst.msdki as msdki
13from dipy.reconst.msdki import (msk_from_awf, awf_from_msk)
14
15fimg, fbvals, fbvecs = get_fnames('small_64D')
16bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
17bvals = round_bvals(bvals)
18gtab = gradient_table(bvals, bvecs)
19
20# 2 shells for techniques that requires multishell data
21bvals_3s = np.concatenate((bvals, bvals*1.5, bvals * 2), axis=0)
22bvecs_3s = np.concatenate((bvecs, bvecs, bvecs), axis=0)
23gtab_3s = gradient_table(bvals_3s, bvecs_3s)
24
25# Simulation 1. Spherical kurtosis tensor - MSK and MSD from the MSDKI model
26# should be equal to the MK and MD of the DKI tensor for cases of
27# spherical kurtosis tensors
28Di = 0.00099
29De = 0.00226
30mevals_sph = np.array([[Di, Di, Di], [De, De, De]])
31f = 0.5
32frac_sph = [f * 100, (1.0 - f) * 100]
33signal_sph, dt_sph, kt_sph = multi_tensor_dki(gtab_3s, mevals_sph, S0=100,
34                                              fractions=frac_sph,
35                                              snr=None)
36# Compute ground truth values
37MDgt = f * Di + (1 - f) * De
38MKgt = 3 * f * (1-f) * ((Di-De) / MDgt) ** 2
39params_single = np.array([MDgt, MKgt])
40msignal_sph = np.zeros(4)
41msignal_sph[0] = signal_sph[0]
42msignal_sph[1] = signal_sph[1]
43msignal_sph[2] = signal_sph[100]
44msignal_sph[3] = signal_sph[180]
45
46# Simulation 2. Multi-voxel simulations
47DWI = np.zeros((2, 2, 2, len(gtab_3s.bvals)))
48MDWI = np.zeros((2, 2, 2, 4))
49MDgt_multi = np.zeros((2, 2, 2))
50MKgt_multi = np.zeros((2, 2, 2))
51S0gt_multi = np.zeros((2, 2, 2))
52params_multi = np.zeros((2, 2, 2, 2))
53
54for i in range(2):
55    for j in range(2):
56        for k in range(1):  # Only one k to have some zero voxels
57            f = random.uniform(0.0, 1)
58            frac = [f * 100, (1.0 - f) * 100]
59            signal_i, dt_i, kt_i = multi_tensor_dki(gtab_3s, mevals_sph,
60                                                    S0=100, fractions=frac,
61                                                    snr=None)
62            DWI[i, j, k] = signal_i
63            md_i = f*Di + (1-f)*De
64            mk_i = 3 * f * (1-f) * ((Di-De) / md_i) ** 2
65            MDgt_multi[i, j, k] = md_i
66            MKgt_multi[i, j, k] = mk_i
67            S0gt_multi[i, j, k] = 100
68            params_multi[i, j, k, 0] = md_i
69            params_multi[i, j, k, 1] = mk_i
70            MDWI[i, j, k, 0] = signal_i[0]
71            MDWI[i, j, k, 1] = signal_i[1]
72            MDWI[i, j, k, 2] = signal_i[100]
73            MDWI[i, j, k, 3] = signal_i[180]
74
75
76def test_msdki_predict():
77    dkiM = msdki.MeanDiffusionKurtosisModel(gtab_3s)
78
79    # single voxel
80    pred = dkiM.predict(params_single, S0=100)
81    assert_array_almost_equal(pred, signal_sph)
82
83    # multi-voxel
84    pred = dkiM.predict(params_multi, S0=100)
85    assert_array_almost_equal(pred[:, :, 0, :], DWI[:, :, 0, :])
86
87    # check the function predict of the DiffusionKurtosisFit object
88    dkiF = dkiM.fit(signal_sph)
89    pred_single = dkiF.predict(gtab_3s, S0=100)
90    assert_array_almost_equal(pred_single, signal_sph)
91    dkiF = dkiM.fit(DWI)
92    pred_multi = dkiF.predict(gtab_3s, S0=100)
93    assert_array_almost_equal(pred_multi[:, :, 0, :], DWI[:, :, 0, :])
94
95    # No S0
96    dkiF = dkiM.fit(signal_sph)
97    pred_single = dkiF.predict(gtab_3s)
98    assert_array_almost_equal(100 * pred_single, signal_sph)
99    dkiF = dkiM.fit(DWI)
100    pred_multi = dkiF.predict(gtab_3s)
101    assert_array_almost_equal(100 * pred_multi[:, :, 0, :], DWI[:, :, 0, :])
102
103    # SO volume
104    dkiF = dkiM.fit(DWI)
105    pred_multi = dkiF.predict(gtab_3s, 100 * np.ones(DWI.shape[:-1]))
106    assert_array_almost_equal(pred_multi[:, :, 0, :], DWI[:, :, 0, :])
107
108
109def test_errors():
110    # first error raises if MeanDiffusionKurtosisModel is called for
111    # data will only one non-zero b-value
112    assert_raises(ValueError, msdki.MeanDiffusionKurtosisModel, gtab)
113
114    # second error raises if negative signal is given to MeanDiffusionKurtosis
115    # model
116    assert_raises(ValueError, msdki.MeanDiffusionKurtosisModel, gtab_3s,
117                  min_signal=-1)
118
119    # third error raises if wrong mask is given to fit
120    mask_wrong = np.ones((2, 3, 1))
121    msdki_model = msdki.MeanDiffusionKurtosisModel(gtab_3s)
122    assert_raises(ValueError, msdki_model.fit, DWI, mask=mask_wrong)
123
124    # fourth error raises if an given index point to more dimensions that data
125    # does not contain
126
127    # define auxiliary function for the assert raises
128    def aux_test_fun(ob, ind):
129        met = ob[ind].msk
130        return met
131
132    mdkiF = msdki_model.fit(DWI)
133    assert_raises(IndexError, aux_test_fun, mdkiF, (0, 0, 0, 0))
134    # checking if aux_test_fun runs fine
135    met = aux_test_fun(mdkiF, (0, 0, 0))
136    assert_array_almost_equal(MKgt_multi[0, 0, 0], met)
137
138    # Fifth error rises if wrong mask is given to awf_from_msk
139    assert_raises(ValueError, awf_from_msk, MKgt_multi, mask=mask_wrong)
140
141
142def test_design_matrix():
143    ub = unique_bvals_magnitude(bvals_3s)
144    D = msdki.design_matrix(ub)
145    Dgt = np.ones((4, 3))
146    Dgt[:, 0] = -ub
147    Dgt[:, 1] = 1.0/6 * ub ** 2
148    assert_array_almost_equal(D, Dgt)
149
150
151def test_msignal():
152    # Multi-voxel case
153    ms, ng = msdki.mean_signal_bvalue(DWI, gtab_3s)
154    assert_array_almost_equal(ms, MDWI)
155    assert_array_almost_equal(ng, np.array([3, 64, 64, 64]))
156
157    # Single-voxel case
158    ms, ng = msdki.mean_signal_bvalue(signal_sph, gtab_3s)
159    assert_array_almost_equal(ng, np.array([3, 64, 64, 64]))
160    assert_array_almost_equal(ms, msignal_sph)
161
162
163def test_msdki_statistics():
164    # tests if MD and MK are equal to expected values of a spherical
165    # tensors
166
167    # Multi-tensors
168    ub = unique_bvals_magnitude(bvals_3s)
169    design_matrix = msdki.design_matrix(ub)
170    msignal, ng = msdki.mean_signal_bvalue(DWI, gtab_3s, bmag=None)
171    params = msdki.wls_fit_msdki(design_matrix, msignal, ng)
172    assert_array_almost_equal(params[..., 1], MKgt_multi)
173    assert_array_almost_equal(params[..., 0], MDgt_multi)
174
175    mdkiM = msdki.MeanDiffusionKurtosisModel(gtab_3s)
176    mdkiF = mdkiM.fit(DWI)
177    mk = mdkiF.msk
178    md = mdkiF.msd
179    assert_array_almost_equal(MKgt_multi, mk)
180    assert_array_almost_equal(MDgt_multi, md)
181
182    # Single-tensors
183    mdkiF = mdkiM.fit(signal_sph)
184    mk = mdkiF.msk
185    md = mdkiF.msd
186    assert_array_almost_equal(MKgt, mk)
187    assert_array_almost_equal(MDgt, md)
188
189    # Test with given mask
190    mask = np.ones(DWI.shape[:-1])
191    v = (0, 0, 0)
192    mask[1, 1, 1] = 0
193    mdkiF = mdkiM.fit(DWI, mask=mask)
194    mk = mdkiF.msk
195    md = mdkiF.msd
196    assert_array_almost_equal(MKgt_multi, mk)
197    assert_array_almost_equal(MDgt_multi, md)
198    assert_array_almost_equal(MKgt_multi[v], mdkiF[v].msk)  # tuple case
199    assert_array_almost_equal(MDgt_multi[v], mdkiF[v].msd)  # tuple case
200    assert_array_almost_equal(MKgt_multi[0], mdkiF[0].msk)  # not tuple case
201    assert_array_almost_equal(MDgt_multi[0], mdkiF[0].msd)  # not tuple case
202
203    # Test returned S0
204    mdkiM = msdki.MeanDiffusionKurtosisModel(gtab_3s, return_S0_hat=True)
205    mdkiF = mdkiM.fit(DWI)
206    assert_array_almost_equal(S0gt_multi, mdkiF.S0_hat)
207    assert_array_almost_equal(MKgt_multi[v], mdkiF[v].msk)
208
209
210def test_kurtosis_to_smt2_convertion():
211    # 1. Check convertion of smt2 awf to kurtosis
212    # When awf = 0 kurtosis was to be 0
213    awf0 = 0
214    kexp0 = 0
215    kest0 = msk_from_awf(awf0)
216    assert_almost_equal(kest0, kexp0)
217
218    # When awf = 1 kurtosis was to be 2.4
219    awf1 = 1
220    kexp1 = 2.4
221    kest1 = msk_from_awf(awf1)
222    assert_almost_equal(kest1, kexp1)
223
224    # Check the invertion of msk_from_awf
225    awf_test_array = np.linspace(0, 1, 100)
226    k_exp = msk_from_awf(awf_test_array)
227    awf_from_k = awf_from_msk(k_exp)
228    assert_array_almost_equal(awf_from_k, awf_test_array)
229
230    # Check the awf_from_msk estimates when kurtosis is out of expected
231    # interval ranges - note that under SMT2 assumption MSK is never lower
232    # than 0 and never higher than 2.4. Since SMT2 assumptions are commonly not
233    # met kurtosis can be out of this expected range. So, if MSK is lower than
234    # 0, f is set to 0 (avoiding negative f). On the other hand, if MSK is
235    # higher than 2.4, f is set to the maxumum value of 1.
236    assert_array_almost_equal(awf_from_msk(np.array([-0.1, 2.5])),
237                              np.array([0., 1.]))
238
239    # if msk = np.nan, function outputs awf=np.nan
240    assert_(np.isnan(awf_from_msk(np.array(np.nan))))
241
242
243def test_smt2_metrics():
244    # Just checking if parameters can be retrived from MSDKI's fit class obj
245
246    # Based on the multi-voxel simulations above (computes gt for SMT2 params)
247    AWFgt = awf_from_msk(MKgt_multi)
248    DIgt = 3 * MDgt_multi / (1 + 2 * (1 - AWFgt) ** 2)
249    # General microscopic anisotropy estimation (Eq 8 Henriques et al MRM 2019)
250    RDe = DIgt * (1 - AWFgt)  # tortuosity assumption
251    VarD = 2/9 * (AWFgt * DIgt ** 2 + (1 - AWFgt) * (DIgt - RDe) ** 2)
252    MD = (AWFgt * DIgt + (1 - AWFgt) * (DIgt + 2 * RDe)) / 3
253    uFAgt = np.sqrt(3 / 2 * VarD[MD > 0] / (VarD[MD > 0] + MD[MD > 0] ** 2))
254
255    mdkiM = msdki.MeanDiffusionKurtosisModel(gtab_3s)
256    mdkiF = mdkiM.fit(DWI)
257    assert_array_almost_equal(mdkiF.smt2f, AWFgt)
258    assert_array_almost_equal(mdkiF.smt2di, DIgt)
259    assert_array_almost_equal(mdkiF.smt2uFA[MD > 0], uFAgt)
260
261    # Check if awf_from_msk when mask is given
262    mask = MKgt_multi > 0
263    AWF = awf_from_msk(MKgt_multi, mask)
264    assert_array_almost_equal(AWF, AWFgt)
265
266
267def test_smt2_specific_cases():
268    mdkiM = msdki.MeanDiffusionKurtosisModel(gtab_3s)
269
270    # Check smt2 is sepecific cases with knowm g.t:
271    # 1) Intrisic diffusion is equal MSD for single Gaussian isotropic
272    #     diffusion (i.e. awf=0)
273    sig_gaussian = single_tensor(gtab_3s, evals=np.array([2e-3, 2e-3, 2e-3]))
274    mdkiF = mdkiM.fit(sig_gaussian)
275    assert_almost_equal(mdkiF.msk, 0.0)
276    assert_almost_equal(mdkiF.msd, 2.0e-3)
277    assert_almost_equal(mdkiF.smt2f, 0)
278    assert_almost_equal(mdkiF.smt2di, 2.0e-3)
279
280    # 2) Intrisic diffusion is equal to MSD/3 for single powder-averaged stick
281    #    compartment
282    Da = 2.0e-3
283    mevals = np.zeros((64, 3))
284    mevals[:, 0] = Da
285    fracs = np.ones(64) * 100 / 64
286    signal_pa, dt_sph, kt_sph = multi_tensor_dki(gtab_3s, mevals,
287                                                 angles=bvecs[1:, :],
288                                                 fractions=fracs, snr=None)
289    mdkiF = mdkiM.fit(signal_pa)
290    # decimal is set to 1 because of finite number of directions for powder
291    # average calculation
292    assert_almost_equal(mdkiF.msk, 2.4, decimal=1)
293    assert_almost_equal(mdkiF.msd * 1000, Da/3 * 1000, decimal=1)
294    assert_almost_equal(mdkiF.smt2f, 1, decimal=1)
295    assert_almost_equal(mdkiF.smt2di, mdkiF.msd * 3, decimal=1)
296