1#!/usr/bin/python
2""" Classes and functions for fitting the mean signal diffusion kurtosis
3model """
4
5import numpy as np
6import scipy.optimize as opt
7
8from dipy.core.gradients import (check_multi_b, unique_bvals_magnitude,
9                                 round_bvals)
10from dipy.reconst.base import ReconstModel
11from dipy.reconst.dti import MIN_POSITIVE_SIGNAL
12from dipy.core.ndindex import ndindex
13from dipy.core.onetime import auto_attr
14
15
16def mean_signal_bvalue(data, gtab, bmag=None):
17    """
18    Computes the average signal across different diffusion directions
19    for each unique b-value
20
21    Parameters
22    ----------
23    data : ndarray ([X, Y, Z, ...], g)
24        ndarray containing the data signals in its last dimension.
25    gtab : a GradientTable class instance
26        The gradient table containing diffusion acquisition parameters.
27    bmag : The order of magnitude that the bvalues have to differ to be
28        considered an unique b-value. Default: derive this value from the
29        maximal b-value provided: $bmag=log_{10}(max(bvals)) - 1$.
30
31    Returns
32    -------
33    msignal : ndarray ([X, Y, Z, ..., nub])
34        Mean signal along all gradient directions for each unique b-value
35        Note that the last dimension contains the signal means and nub is the
36        number of unique b-values.
37    ng : ndarray(nub)
38        Number of gradient directions used to compute the mean signal for
39        all unique b-values
40
41    Notes
42    -----
43    This function assumes that directions are evenly sampled on the sphere or
44    on the hemisphere
45    """
46    bvals = gtab.bvals.copy()
47
48    # Compute unique and rounded bvals
49    ub, rb = unique_bvals_magnitude(bvals, bmag=bmag, rbvals=True)
50
51    # Initialize msignal and ng
52    nub = ub.size
53    ng = np.zeros(nub)
54    msignal = np.zeros(data.shape[:-1] + (nub,))
55    for bi in range(ub.size):
56        msignal[..., bi] = np.mean(data[..., rb == ub[bi]], axis=-1)
57        ng[bi] = np.sum(rb == ub[bi])
58    return msignal, ng
59
60
61def msk_from_awf(f):
62    """
63    Computes mean signal kurtosis from axonal water fraction estimates of the
64    SMT2 model
65
66    Parameters
67    ----------
68    f : ndarray ([X, Y, Z, ...])
69        ndarray containing the axonal volume fraction estimate.
70
71    Returns
72    -------
73    msk : ndarray(nub)
74        Mean signal kurtosis (msk)
75
76    Notes
77    -----
78    Computes mean signal kurtosis using equations 17 of [1]_
79
80    References
81    ----------
82    .. [1] Neto Henriques R, Jespersen SN, Shemesh N (2019). Microscopic
83           anisotropy misestimation in spherical‐mean single diffusion
84           encoding MRI. Magnetic Resonance in Medicine (In press).
85           doi: 10.1002/mrm.27606
86    """
87    msk_num = 216*f - 504 * f**2 + 504 * f**3 - 180 * f**4
88    msk_den = 135 - 360*f + 420 * f**2 - 240 * f**3 + 60 * f**4
89    msk = msk_num / msk_den
90
91    return msk
92
93
94def _msk_from_awf_error(f, msk):
95    """ Helper function that calculates the error of a predicted mean signal
96    kurtosis from the axonal water fraction of SMT2 model and a measured
97    mean signal kurtosis
98
99    Parameters
100    ----------
101    f : float
102        Axonal volume fraction estimate.
103    msk : float
104        Measured mean signal kurtosis.
105
106    Return
107    ------
108    error : float
109       Error computed by subtracting msk with fun(f), where fun is the function
110       described in equation 17 of [1]_
111
112    Notes
113    -----
114    This function corresponds to the differential of equations 17 of [1]_
115    """
116    return msk_from_awf(f) - msk
117
118
119def _diff_msk_from_awf(f, msk):
120    """
121    Helper function that calculates differential of function msk_from_awf
122
123    Parameters
124    ----------
125    f : ndarray ([X, Y, Z, ...])
126        ndarray containing the axonal volume fraction estimate.
127
128    Returns
129    -------
130    dkdf : ndarray(nub)
131        Mean signal kurtosis differential
132    msk : float
133        Measured mean signal kurtosis.
134
135    Notes
136    -----
137    This function corresponds to the differential of equations 17 of [1]_.
138    This function is applicable to both _msk_from_awf and _msk_from_awf_error.
139
140    References
141    ----------
142    .. [1] Neto Henriques R, Jespersen SN, Shemesh N (2019). Microscopic
143           anisotropy misestimation in spherical‐mean single diffusion
144           encoding MRI. Magnetic Resonance in Medicine (In press).
145           doi: 10.1002/mrm.27606
146    """
147    F = 216*f - 504 * f**2 + 504 * f**3 - 180 * f**4  # Numerator
148    G = 135 - 360*f + 420 * f**2 - 240 * f**3 + 60 * f**4  # Denominator
149
150    dF = 216 - 1008 * f + 1512 * f**2 - 720 * f**3  # Num. differential
151    dG = -360 + 840 * f - 720 * f**2 + 240 * f**3  # Den. differential
152
153    return (G * dF - F * dG) / (G ** 2)
154
155
156def awf_from_msk(msk, mask=None):
157    """
158    Computes the axonal water fraction from the mean signal kurtosis
159    assuming the 2-compartmental spherical mean technique model [1]_, [2]_
160
161    Parameters
162    ----------
163    msk : ndarray ([X, Y, Z, ...])
164        Mean signal kurtosis (msk)
165    mask : ndarray, optional
166        A boolean array used to mark the coordinates in the data that should be
167        analyzed that has the same shape of the msdki parameters
168
169    Returns
170    -------
171    smt2f : ndarray ([X, Y, Z, ...])
172        ndarray containing the axonal volume fraction estimate.
173
174    Notes
175    -----
176    Computes the axonal water fraction from the mean signal kurtosis
177    MSK using equation 17 of [1]_
178
179    References
180    ----------
181    .. [1] Neto Henriques R, Jespersen SN, Shemesh N (2019). Microscopic
182           anisotropy misestimation in spherical‐mean single diffusion
183           encoding MRI. Magnetic Resonance in Medicine (In press).
184           doi: 10.1002/mrm.27606
185    .. [2] Kaden E, Kelm ND, Carson RP, et al. (2016) Multi‐compartment
186           microscopic diffusion imaging. Neuroimage 139:346–359.
187    """
188    awf = np.zeros(msk.shape)
189
190    # Prepare mask
191    if mask is None:
192        mask = np.ones(msk.shape, dtype=bool)
193    else:
194        if mask.shape != msk.shape:
195            raise ValueError("Mask is not the same shape as data.")
196        mask = np.array(mask, dtype=bool, copy=False)
197
198    # looping voxels
199    index = ndindex(mask.shape)
200    for v in index:
201        # Skip if out of mask
202        if not mask[v]:
203            continue
204
205        if msk[v] > 2.4:
206            awf[v] = 1
207        elif msk[v] < 0:
208            awf[v] = 0
209        else:
210            if np.isnan(msk[v]):
211                awf[v] = np.nan
212            else:
213                mski = msk[v]
214                fini = mski / 2.4  # Initial guess based on linear assumption
215                awf[v] = opt.fsolve(_msk_from_awf_error, fini, args=(mski,),
216                                    fprime=_diff_msk_from_awf, col_deriv=True)
217
218    return awf
219
220
221def msdki_prediction(msdki_params, gtab, S0=1.0):
222    """
223    Predict the mean signal given the parameters of the mean signal DKI, an
224    GradientTable object and S0 signal.
225
226    Parameters
227    ----------
228    params : ndarray ([X, Y, Z, ...], 2)
229        Array containing the mean signal diffusivity and mean signal kurtosis
230        in its last axis
231    gtab : a GradientTable class instance
232        The gradient table for this prediction
233    S0 : float or ndarray (optional)
234        The non diffusion-weighted signal in every voxel, or across all
235        voxels. Default: 1
236
237    Notes
238    -----
239    The predicted signal is given by:
240        $MS(b) = S_0 * exp(-bD + 1/6 b^{2} D^{2} K)$, where $D$ and $K$ are the
241        mean signal diffusivity and mean signal kurtosis.
242
243    References
244    ----------
245    .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
246           Analysis and their Application to the Healthy Ageing Brain (Doctoral
247           thesis). Downing College, University of Cambridge.
248           https://doi.org/10.17863/CAM.29356
249    """
250    A = design_matrix(round_bvals(gtab.bvals))
251
252    params = msdki_params.copy()
253    params[..., 1] = params[..., 1] * params[..., 0] ** 2
254
255    if isinstance(S0, float) or isinstance(S0, int):
256        pred_sig = S0 * np.exp(np.dot(params, A[:, :2].T))
257    elif S0.size == 1:
258        pred_sig = S0 * np.exp(np.dot(params, A[:, :2].T))
259    else:
260        nv = gtab.bvals.size
261        S0r = np.zeros(S0.shape + gtab.bvals.shape)
262        for vi in range(nv):
263            S0r[..., vi] = S0
264        pred_sig = S0r * np.exp(np.dot(params, A[:, :2].T))
265
266    return pred_sig
267
268
269class MeanDiffusionKurtosisModel(ReconstModel):
270    """ Mean signal Diffusion Kurtosis Model
271    """
272
273    def __init__(self, gtab, bmag=None, return_S0_hat=False, *args, **kwargs):
274        """ Mean Signal Diffusion Kurtosis Model [1]_.
275
276        Parameters
277        ----------
278        gtab : GradientTable class instance
279
280        bmag : int
281            The order of magnitude that the bvalues have to differ to be
282            considered an unique b-value. Default: derive this value from the
283            maximal b-value provided: $bmag=log_{10}(max(bvals)) - 1$.
284
285        return_S0_hat : bool
286            If True, also return S0 values for the fit.
287
288        args, kwargs : arguments and keyword arguments passed to the
289        fit_method. See msdki.wls_fit_msdki for details
290
291        References
292        ----------
293        .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
294               Analysis and their Application to the Healthy Ageing Brain
295               (Doctoral thesis). Downing College, University of Cambridge.
296               https://doi.org/10.17863/CAM.29356
297        """
298        ReconstModel.__init__(self, gtab)
299
300        self.return_S0_hat = return_S0_hat
301        self.ubvals = unique_bvals_magnitude(gtab.bvals, bmag=bmag)
302        self.design_matrix = design_matrix(self.ubvals)
303        self.bmag = bmag
304        self.args = args
305        self.kwargs = kwargs
306        self.min_signal = self.kwargs.pop('min_signal', MIN_POSITIVE_SIGNAL)
307        if self.min_signal is not None and self.min_signal <= 0:
308            e_s = "The `min_signal` key-word argument needs to be strictly"
309            e_s += " positive."
310            raise ValueError(e_s)
311
312        # Check if at least three b-values are given
313        enough_b = check_multi_b(self.gtab, 3, non_zero=False, bmag=bmag)
314        if not enough_b:
315            mes = "MSDKI requires at least 3 b-values (which can include b=0)"
316            raise ValueError(mes)
317
318    def fit(self, data, mask=None):
319        """ Fit method of the MSDKI model class
320
321        Parameters
322        ----------
323        data : ndarray ([X, Y, Z, ...], g)
324            ndarray containing the data signals in its last dimension.
325
326        mask : array
327            A boolean array used to mark the coordinates in the data that
328            should be analyzed that has the shape data.shape[:-1]
329        """
330        S0_params = None
331
332        # Compute mean signal for each unique b-value
333        mdata, ng = mean_signal_bvalue(data, self.gtab, bmag=self.bmag)
334
335        # Remove mdata zeros
336        mdata = np.maximum(mdata, self.min_signal)
337
338        params = wls_fit_msdki(self.design_matrix, mdata, ng, mask=mask,
339                               return_S0_hat=self.return_S0_hat, *self.args,
340                               **self.kwargs)
341        if self.return_S0_hat:
342            params, S0_params = params
343
344        return MeanDiffusionKurtosisFit(self, params, model_S0=S0_params)
345
346    def predict(self, msdki_params, S0=1.):
347        """
348        Predict a signal for this MeanDiffusionKurtosisModel class instance
349        given parameters.
350
351        Parameters
352        ----------
353        msdki_params : ndarray
354            The parameters of the mean signal diffusion kurtosis model
355        S0 : float or ndarray
356            The non diffusion-weighted signal in every voxel, or across all
357            voxels. Default: 1
358
359        Returns
360        --------
361        S : (..., N) ndarray
362            Simulated mean signal based on the mean signal diffusion kurtosis
363            model
364
365        Notes
366        -----
367        The predicted signal is given by:
368            $MS(b) = S_0 * exp(-bD + 1/6 b^{2} D^{2} K)$, where $D$ and $K$ are
369            the mean signal diffusivity and mean signal kurtosis.
370
371        References
372        ----------
373        .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
374               Analysis and their Application to the Healthy Ageing Brain
375               (Doctoral thesis). Downing College, University of Cambridge.
376               https://doi.org/10.17863/CAM.29356
377        """
378        return msdki_prediction(msdki_params, self.gtab, S0)
379
380
381class MeanDiffusionKurtosisFit(object):
382
383    def __init__(self, model, model_params, model_S0=None):
384        """ Initialize a MeanDiffusionKurtosisFit class instance.
385        """
386        self.model = model
387        self.model_params = model_params
388        self.model_S0 = model_S0
389
390    def __getitem__(self, index):
391        model_params = self.model_params
392        model_S0 = self.model_S0
393        N = model_params.ndim
394        if type(index) is not tuple:
395            index = (index,)
396        elif len(index) >= model_params.ndim:
397            raise IndexError("IndexError: invalid index")
398        index = index + (slice(None),) * (N - len(index))
399        if model_S0 is not None:
400            model_S0 = model_S0[index[:-1]]
401        return MeanDiffusionKurtosisFit(self.model, model_params[index],
402                                        model_S0=model_S0)
403
404    @property
405    def S0_hat(self):
406        return self.model_S0
407
408    @auto_attr
409    def msd(self):
410        r"""
411        Mean signal diffusitivity (MSD) calculated from the mean signal
412        Diffusion Kurtosis Model.
413
414        Returns
415        ---------
416        msd : ndarray
417            Calculated signal mean diffusitivity.
418
419        References
420        ----------
421        .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
422               Analysis and their Application to the Healthy Ageing Brain
423               (Doctoral thesis). Downing College, University of Cambridge.
424               https://doi.org/10.17863/CAM.29356
425        """
426        return self.model_params[..., 0]
427
428    @auto_attr
429    def msk(self):
430        r"""
431        Mean signal kurtosis (MSK) calculated from the mean signal
432        Diffusion Kurtosis Model.
433
434        Returns
435        ---------
436        msk : ndarray
437            Calculated signal mean kurtosis.
438
439        References
440        ----------
441        .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
442               Analysis and their Application to the Healthy Ageing Brain
443               (Doctoral thesis). Downing College, University of Cambridge.
444               https://doi.org/10.17863/CAM.29356
445        """
446        return self.model_params[..., 1]
447
448    @auto_attr
449    def smt2f(self):
450        r"""
451        Computes the axonal water fraction from the mean signal kurtosis
452        assuming the 2-compartmental spherical mean technique model [1]_, [2]_
453
454        Returns
455        ---------
456        smt2f : ndarray
457            Axonal volume fraction calculated from MSK.
458
459        Notes
460        -----
461        Computes the axonal water fraction from the mean signal kurtosis
462        MSK using equation 17 of [1]_
463
464        References
465        ----------
466        .. [1] Neto Henriques R, Jespersen SN, Shemesh N (2019). Microscopic
467               anisotropy misestimation in spherical‐mean single diffusion
468               encoding MRI. Magnetic Resonance in Medicine (In press).
469               doi: 10.1002/mrm.27606
470        .. [2] Kaden E, Kelm ND, Carson RP, et al. (2016) Multi‐compartment
471               microscopic diffusion imaging. Neuroimage 139:346–359.
472        """
473        return awf_from_msk(self.msk)
474
475    @auto_attr
476    def smt2di(self):
477        r"""
478        Computes the intrisic diffusivity from the mean signal diffusional
479        kurtosis parameters assuming the 2-compartmental spherical mean
480        technique model [1]_, [2]_
481
482        Returns
483        ---------
484        smt2di : ndarray
485            Intrisic diffusivity computed by converting MSDKI to SMT2.
486
487        Notes
488        -----
489        Computes the intrinsic diffusivity using equation 16 of [1]_
490
491        References
492        ----------
493        .. [1] Neto Henriques R, Jespersen SN, Shemesh N (2019). Microscopic
494               anisotropy misestimation in spherical‐mean single diffusion
495               encoding MRI. Magnetic Resonance in Medicine (In press).
496               doi: 10.1002/mrm.27606
497        .. [2] Kaden E, Kelm ND, Carson RP, et al. (2016) Multi‐compartment
498               microscopic diffusion imaging. Neuroimage 139:346–359.
499        """
500        return 3 * self.msd / (1 + 2 * (1 - self.smt2f)**2)
501
502    @auto_attr
503    def smt2uFA(self):
504        r"""
505        Computes the microscopic fractional anisotropy from the mean signal
506        diffusional kurtosis parameters assuming the 2-compartmental spherical
507        mean technique model [1]_, [2]_
508
509        Returns
510        ---------
511        smt2uFA : ndarray
512            Microscopic fractional anisotropy computed by converting MSDKI to
513            SMT2.
514
515        Notes
516        -----
517        Computes the intrinsic diffusivity using equation 10 of [1]_
518
519        References
520        ----------
521        .. [1] Neto Henriques R, Jespersen SN, Shemesh N (2019). Microscopic
522               anisotropy misestimation in spherical‐mean single diffusion
523               encoding MRI. Magnetic Resonance in Medicine (In press).
524               doi: 10.1002/mrm.27606
525        .. [2] Kaden E, Kelm ND, Carson RP, et al. (2016) Multi‐compartment
526               microscopic diffusion imaging. Neuroimage 139:346–359.
527        """
528        fe = (1 - self.smt2f)
529        num = 3 * (1 - 2 * fe ** 2 + fe ** 3)
530        den = 3 + 2 * fe ** 3 + 4 * fe ** 4
531        return np.sqrt(num/den)
532
533    def predict(self, gtab, S0=1.):
534        r"""
535        Given a mean signal diffusion kurtosis model fit, predict the signal
536        on the vertices of a sphere
537
538        Parameters
539        ----------
540        gtab : a GradientTable class instance
541            This encodes the directions for which a prediction is made
542
543        S0 : float array
544           The mean non-diffusion weighted signal in each voxel. Default:
545           The fitted S0 value in all voxels if it was fitted. Otherwise 1 in
546           all voxels.
547
548        Returns
549        --------
550        S : (..., N) ndarray
551            Simulated mean signal based on the mean signal kurtosis model
552
553        Notes
554        -----
555        The predicted signal is given by:
556        $MS(b) = S_0 * exp(-bD + 1/6 b^{2} D^{2} K)$, where $D$ and $k$ are the
557        mean signal diffusivity and mean signal kurtosis.
558
559        References
560        ----------
561        .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
562               Analysis and their Application to the Healthy Ageing Brain
563               (Doctoral thesis). Downing College, University of Cambridge.
564               https://doi.org/10.17863/CAM.29356
565        """
566        return msdki_prediction(self.model_params, gtab, S0=S0)
567
568
569def wls_fit_msdki(design_matrix, msignal, ng, mask=None,
570                  min_signal=MIN_POSITIVE_SIGNAL, return_S0_hat=False):
571    r"""
572    Fits the mean signal diffusion kurtosis imaging based on a weighted
573    least square solution [1]_.
574
575    Parameters
576    ----------
577    design_matrix : array (nub, 3)
578        Design matrix holding the covariants used to solve for the regression
579        coefficients of the mean signal diffusion kurtosis model. Note that
580        nub is the number of unique b-values
581    msignal : ndarray ([X, Y, Z, ..., nub])
582        Mean signal along all gradient directions for each unique b-value
583        Note that the last dimension should contain the signal means and nub
584        is the number of unique b-values.
585    ng : ndarray(nub)
586        Number of gradient directions used to compute the mean signal for
587        all unique b-values
588    mask : array
589        A boolean array used to mark the coordinates in the data that
590        should be analyzed that has the shape data.shape[:-1]
591    min_signal : float, optional
592        Voxel with mean signal intensities lower than the min positive signal
593        are not processed. Default: 0.0001
594    return_S0_hat : bool
595        If True, also return S0 values for the fit.
596
597    Returns
598    -------
599    params : array (..., 2)
600        Containing the mean signal diffusivity and mean signal kurtosis
601
602    References
603    ----------
604    .. [1] Henriques, R.N., 2018. Advanced Methods for Diffusion MRI Data
605           Analysis and their Application to the Healthy Ageing Brain
606           (Doctoral thesis). Downing College, University of Cambridge.
607           https://doi.org/10.17863/CAM.29356
608    """
609    params = np.zeros(msignal.shape[:-1] + (3,))
610
611    # Prepare mask
612    if mask is None:
613        mask = np.ones(msignal.shape[:-1], dtype=bool)
614    else:
615        if mask.shape != msignal.shape[:-1]:
616            raise ValueError("Mask is not the same shape as data.")
617        mask = np.array(mask, dtype=bool, copy=False)
618
619    index = ndindex(mask.shape)
620    for v in index:
621        # Skip if out of mask
622        if not mask[v]:
623            continue
624        # Skip if no signal is present
625        if np.mean(msignal[v]) <= min_signal:
626            continue
627        # Define weights as diag(ng * yn**2)
628        W = np.diag(ng * msignal[v]**2)
629
630        # WLS fitting
631        BTW = np.dot(design_matrix.T, W)
632        inv_BT_W_B = np.linalg.pinv(np.dot(BTW, design_matrix))
633        invBTWB_BTW = np.dot(inv_BT_W_B, BTW)
634        p = np.dot(invBTWB_BTW, np.log(msignal[v]))
635
636        # Process parameters
637        p[1] = p[1] / (p[0]**2)
638        p[2] = np.exp(p[2])
639        params[v] = p
640
641    if return_S0_hat:
642        return params[..., :2], params[..., 2]
643    else:
644        return params[..., :2]
645
646
647def design_matrix(ubvals):
648    """  Constructs design matrix for the mean signal diffusion kurtosis model
649
650    Parameters
651    ----------
652    ubvals : array
653        Containing the unique b-values of the data.
654
655    Returns
656    -------
657    design_matrix : array (nb, 3)
658        Design matrix or B matrix for the mean signal diffusion kurtosis
659        model assuming that parameters are in the following order:
660        design_matrix[j, :] = (msd, msk, S0)
661    """
662    nb = ubvals.shape
663    B = np.zeros(nb + (3,))
664    B[:, 0] = -ubvals
665    B[:, 1] = 1.0/6.0 * ubvals**2
666    B[:, 2] = np.ones(nb)
667    return B
668