1from distutils.version import LooseVersion
2import warnings
3
4import numpy as np
5import numbers
6from dipy.core import geometry as geo
7from dipy.core.gradients import (GradientTable, gradient_table,
8                                 unique_bvals_tolerance, get_bval_indices)
9from dipy.data import default_sphere
10from dipy.reconst import shm
11from dipy.reconst.csdeconv import response_from_mask_ssst
12from dipy.reconst.dti import (TensorModel, fractional_anisotropy,
13                              mean_diffusivity)
14from dipy.reconst.multi_voxel import multi_voxel_fit
15from dipy.reconst.utils import _roi_in_volume, _mask_from_roi
16from dipy.sims.voxel import single_tensor
17
18from dipy.utils.optpkg import optional_package
19cvxpy, have_cvxpy, _ = optional_package("cvxpy")
20
21SH_CONST = .5 / np.sqrt(np.pi)
22
23
24def multi_tissue_basis(gtab, sh_order, iso_comp):
25    """
26    Builds a basis for multi-shell multi-tissue CSD model.
27
28    Parameters
29    ----------
30    gtab : GradientTable
31    sh_order : int
32    iso_comp: int
33        Number of tissue compartments for running the MSMT-CSD. Minimum
34        number of compartments required is 2.
35
36    Returns
37    -------
38    B : ndarray
39        Matrix of the spherical harmonics model used to fit the data
40    m : int ``|m| <= n``
41        The order of the harmonic.
42    n : int ``>= 0``
43        The degree of the harmonic.
44    """
45    if iso_comp < 2:
46        msg = ("Multi-tissue CSD requires at least 2 tissue compartments")
47        raise ValueError(msg)
48    r, theta, phi = geo.cart2sphere(*gtab.gradients.T)
49    m, n = shm.sph_harm_ind_list(sh_order)
50    B = shm.real_sh_descoteaux_from_index(m, n, theta[:, None], phi[:, None])
51    B[np.ix_(gtab.b0s_mask, n > 0)] = 0.
52
53    iso = np.empty([B.shape[0], iso_comp])
54    iso[:] = SH_CONST
55
56    B = np.concatenate([iso, B], axis=1)
57    return B, m, n
58
59
60class MultiShellResponse(object):
61
62    def __init__(self, response, sh_order, shells, S0=None):
63        """ Estimate Multi Shell response function for multiple tissues and
64        multiple shells.
65
66        The method `multi_shell_fiber_response` allows to create a multi-shell
67        fiber response with the right format, for a three compartments model.
68        It can be refered to in order to understand the inputs of this class.
69
70        Parameters
71        ----------
72        response : ndarray
73            Multi-shell fiber response. The ordering of the responses should
74            follow the same logic as S0.
75        sh_order : int
76            Maximal spherical harmonics order.
77        shells : int
78            Number of shells in the data
79        S0 : array (3,)
80            Signal with no diffusion weighting for each tissue compartments, in
81            the same tissue order as `response`. This S0 can be used for
82            predicting from a fit model later on.
83        """
84        self.S0 = S0
85        self.response = response
86        self.sh_order = sh_order
87        self.n = np.arange(0, sh_order + 1, 2)
88        self.m = np.zeros_like(self.n)
89        self.shells = shells
90        if self.iso < 1:
91            raise ValueError("sh_order and shape of response do not agree")
92
93    @property
94    def iso(self):
95        return self.response.shape[1] - (self.sh_order // 2) - 1
96
97
98def _inflate_response(response, gtab, n, delta):
99    """Used to inflate the response for the `multiplier_matrix` in the
100    `MultiShellDeconvModel`.
101    Parameters
102    ----------
103    response : MultiShellResponse object
104    gtab : GradientTable
105    n : int ``>= 0``
106        The degree of the harmonic.
107    delta : Delta generated from `_basic_delta`
108    """
109    if any((n % 2) != 0) or (n.max() // 2) >= response.sh_order:
110        raise ValueError("Response and n do not match")
111
112    iso = response.iso
113    n_idx = np.empty(len(n) + iso, dtype=int)
114    n_idx[:iso] = np.arange(0, iso)
115    n_idx[iso:] = n // 2 + iso
116    diff = abs(response.shells[:, None] - gtab.bvals)
117    b_idx = np.argmin(diff, axis=0)
118    kernal = response.response / delta
119
120    return kernal[np.ix_(b_idx, n_idx)]
121
122
123def _basic_delta(iso, m, n, theta, phi):
124    """Simple delta function
125    Parameters
126    ----------
127    iso: int
128        Number of tissue compartments for running the MSMT-CSD. Minimum
129        number of compartments required is 2.
130        Default: 2
131    m : int ``|m| <= n``
132        The order of the harmonic.
133    n : int ``>= 0``
134        The degree of the harmonic.
135    theta : array_like
136       inclination or polar angle
137    phi : array_like
138       azimuth angle
139    """
140    wm_d = shm.gen_dirac(m, n, theta, phi)
141    iso_d = [SH_CONST] * iso
142    return np.concatenate([iso_d, wm_d])
143
144
145class MultiShellDeconvModel(shm.SphHarmModel):
146
147    def __init__(self, gtab, response, reg_sphere=default_sphere,
148                 sh_order=8, iso=2):
149        r"""
150        Multi-Shell Multi-Tissue Constrained Spherical Deconvolution
151        (MSMT-CSD) [1]_. This method extends the CSD model proposed in [2]_ by
152        the estimation of multiple response functions as a function of multiple
153        b-values and multiple tissue types.
154
155        Spherical deconvolution computes a fiber orientation distribution
156        (FOD), also called fiber ODF (fODF) [2]_. The fODF is derived from
157        different tissue types and thus overcomes the overestimation of WM in
158        GM and CSF areas.
159
160        The response function is based on the different tissue types
161        and is provided as input to the MultiShellDeconvModel.
162        It will be used as deconvolution kernel, as described in [2]_.
163
164        Parameters
165        ----------
166        gtab : GradientTable
167        response : ndarray or MultiShellResponse object
168            Pre-computed multi-shell fiber response function in the form of a
169            MultiShellResponse object, or simple response function as a ndarray.
170            The later must be of shape (3, len(bvals)-1, 4), because it will be
171            converted into a MultiShellResponse object via the
172            `multi_shell_fiber_response` method (important note: the function
173            `unique_bvals_tolerance` is used here to select unique bvalues from
174            gtab as input). Each column (3,) has two elements. The first is the
175            eigen-values as a (3,) ndarray and the second is the signal value
176            for the response function without diffusion weighting (S0). Note
177            that in order to use more than three compartments, one must create
178            a MultiShellResponse object on the side.
179        reg_sphere : Sphere (optional)
180            sphere used to build the regularization B matrix.
181            Default: 'symmetric362'.
182        sh_order : int (optional)
183            maximal spherical harmonics order. Default: 8
184        iso: int (optional)
185            Number of tissue compartments for running the MSMT-CSD. Minimum
186            number of compartments required is 2.
187            Default: 2
188
189        References
190        ----------
191        .. [1] Jeurissen, B., et al. NeuroImage 2014. Multi-tissue constrained
192               spherical deconvolution for improved analysis of multi-shell
193               diffusion MRI data
194        .. [2] Tournier, J.D., et al. NeuroImage 2007. Robust determination of
195               the fibre orientation distribution in diffusion MRI:
196               Non-negativity constrained super-resolved spherical
197               deconvolution
198        .. [3] Tournier, J.D, et al. Imaging Systems and Technology
199               2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions
200        """
201        if not iso >= 2:
202            msg = ("Multi-tissue CSD requires at least 2 tissue compartments")
203            raise ValueError(msg)
204
205        super(MultiShellDeconvModel, self).__init__(gtab)
206
207        if not isinstance(response, MultiShellResponse):
208            bvals = unique_bvals_tolerance(gtab.bvals, tol=20)
209            if iso > 2:
210                msg = """Too many compartments for this kind of response
211                input. It must be two tissue compartments."""
212                raise ValueError(msg)
213            if response.shape != (3, len(bvals)-1, 4):
214                msg = """Response must be of shape (3, len(bvals)-1, 4) or be a
215                MultiShellResponse object."""
216                raise ValueError(msg)
217            response = multi_shell_fiber_response(sh_order,
218                                                  bvals=bvals,
219                                                  wm_rf=response[0],
220                                                  gm_rf=response[1],
221                                                  csf_rf=response[2])
222
223        B, m, n = multi_tissue_basis(gtab, sh_order, iso)
224
225        delta = _basic_delta(response.iso, response.m, response.n, 0., 0.)
226        self.delta = delta
227        multiplier_matrix = _inflate_response(response, gtab, n, delta)
228
229        r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T)
230        odf_reg, _, _ = shm.real_sh_descoteaux(sh_order, theta, phi)
231        reg = np.zeros([i + iso for i in odf_reg.shape])
232        reg[:iso, :iso] = np.eye(iso)
233        reg[iso:, iso:] = odf_reg
234
235        X = B * multiplier_matrix
236
237        self.fitter = QpFitter(X, reg)
238        self.sh_order = sh_order
239        self._X = X
240        self.sphere = reg_sphere
241        self.gtab = gtab
242        self.B_dwi = B
243        self.m = m
244        self.n = n
245        self.response = response
246
247    def predict(self, params, gtab=None, S0=None):
248        """Compute a signal prediction given spherical harmonic coefficients
249        for the provided GradientTable class instance.
250
251        Parameters
252        ----------
253        params : ndarray
254            The spherical harmonic representation of the FOD from which to make
255            the signal prediction.
256        gtab : GradientTable
257            The gradients for which the signal will be predicted. Use the
258            model's gradient table by default.
259        S0 : ndarray or float
260            The non diffusion-weighted signal value.
261        """
262        if gtab is None or gtab is self.gtab:
263            gtab = self.gtab
264            X = self._X
265        else:
266            iso = self.response.iso
267            B, m, n = multi_tissue_basis(gtab, self.sh_order, iso)
268            multiplier_matrix = _inflate_response(self.response, gtab, n,
269                                                  self.delta)
270            X = B * multiplier_matrix
271
272        scaling = 1.
273        if S0 and S0 != 1.:     # The S0=1. case comes from fit.predict().
274            raise NotImplementedError
275            # This case is not implemented yet because it would require to have
276            # access to volume fractions (vf) from the fit. The following code
277            # gives an idea of how to use this with S0 and vf. It could also be
278            # calculated externally and used as scaling = S0.
279            # response_scaling = np.ndarray(params.shape[0:3])
280            # response_scaling[...] = (vf[..., 0] * self.response.S0[0]
281            #                          + vf[..., 1] * self.response.S0[1]
282            #                          + vf[..., 2] * self.response.S0[2])
283            # scaling = np.where(response_scaling > 1, S0 / response_scaling, 0)
284            # scaling = np.expand_dims(scaling, 3)
285            # scaling = np.repeat(scaling, len(gtab.bvals), axis=3)
286
287        pred_sig = scaling * np.dot(params, X.T)
288        return pred_sig
289
290    @multi_voxel_fit
291    def fit(self, data, verbose=True):
292        """Fits the model to diffusion data and returns the model fit.
293
294        Sometimes the solving process of some voxels can end in a SolverError
295        from cvxpy. This might be attributed to the response functions not
296        being tuned properly, as the solving process is very sensitive to it.
297        The method will fill the problematic voxels with a NaN value, so that
298        it is traceable. The user should check for the number of NaN values and
299        could then fill the problematic voxels with zeros, for example.
300        Running a fit again only on those problematic voxels can also work.
301
302        Parameters
303        ----------
304        data : ndarray
305            The diffusion data to fit the model on.
306        verbose : bool (optional)
307            Whether to show warnings when a SolverError appears or not.
308            Default: True
309        """
310        coeff = self.fitter(data)
311        if verbose:
312            if np.isnan(coeff[..., 0]):
313                msg = """Voxel could not be solved properly and ended up with a
314                SolverError. Proceeding to fill it with NaN values.
315                """
316                warnings.warn(msg, UserWarning)
317
318        return MSDeconvFit(self, coeff, None)
319
320
321class MSDeconvFit(shm.SphHarmFit):
322
323    def __init__(self, model, coeff, mask):
324        """
325        Abstract class which holds the fit result of MultiShellDeconvModel.
326        Inherits the SphHarmFit which fits the diffusion data to a spherical
327        harmonic model.
328
329        Parameters
330        ----------
331        model: object
332            MultiShellDeconvModel
333        coeff : array
334            Spherical harmonic coefficients for the ODF.
335        mask: ndarray
336            Mask for fitting
337        """
338        self._shm_coef = coeff
339        self.mask = mask
340        self.model = model
341
342    @property
343    def shm_coeff(self):
344        return self._shm_coef[..., self.model.response.iso:]
345
346    @property
347    def all_shm_coeff(self):
348        return self._shm_coef
349
350    @property
351    def volume_fractions(self):
352        tissue_classes = self.model.response.iso + 1
353        return self._shm_coef[..., :tissue_classes] / SH_CONST
354
355
356def solve_qp(P, Q, G, H):
357    r"""
358    Helper function to set up and solve the Quadratic Program (QP) in CVXPY.
359    A QP problem has the following form:
360    minimize      1/2 x' P x + Q' x
361    subject to    G x <= H
362
363    Here the QP solver is based on CVXPY and uses OSQP.
364
365    Parameters
366    ----------
367    P : ndarray
368        n x n matrix for the primal QP objective function.
369    Q : ndarray
370        n x 1 matrix for the primal QP objective function.
371    G : ndarray
372        m x n matrix for the inequality constraint.
373    H : ndarray
374        m x 1 matrix for the inequality constraint.
375
376    Returns
377    -------
378    x : array
379        Optimal solution to the QP problem.
380    """
381    x = cvxpy.Variable(Q.shape[0])
382    P = cvxpy.Constant(P)
383    if LooseVersion(cvxpy.__version__) < LooseVersion('1.1'):
384        objective = cvxpy.Minimize(0.5 * cvxpy.quad_form(x, P) + Q * x)
385        constraints = [G * x <= H]
386    else:
387        objective = cvxpy.Minimize(0.5 * cvxpy.quad_form(x, P) + Q @ x)
388        constraints = [G @ x <= H]
389
390    # setting up the problem
391    prob = cvxpy.Problem(objective, constraints)
392    try:
393        prob.solve()
394        opt = np.array(x.value).reshape((Q.shape[0],))
395    except cvxpy.error.SolverError:
396        opt = np.empty((Q.shape[0],))
397        opt[:] = np.NaN
398
399    return opt
400
401
402class QpFitter(object):
403
404    def __init__(self, X, reg):
405        r"""
406        Makes use of the quadratic programming solver `solve_qp` to fit the
407        model. The initialization for the model is done using the warm-start by
408        default in `CVXPY`.
409
410        Parameters
411        ----------
412        X : ndarray
413            Matrix to be fit by the QP solver calculated in
414            `MultiShellDeconvModel`
415        reg : ndarray
416            the regularization B matrix calculated in `MultiShellDeconvModel`
417        """
418        self._P = P = np.dot(X.T, X)
419        self._X = X
420
421        self._reg = reg
422        self._P_mat = np.array(P)
423        self._reg_mat = np.array(-reg)
424        self._h_mat = np.array([0])
425
426    def __call__(self, signal):
427        Q = np.dot(self._X.T, signal)
428        Q_mat = np.array(-Q)
429        fodf_sh = solve_qp(self._P_mat, Q_mat, self._reg_mat, self._h_mat)
430        return fodf_sh
431
432
433def multi_shell_fiber_response(sh_order, bvals, wm_rf, gm_rf, csf_rf,
434                               sphere=None, tol=20):
435    """Fiber response function estimation for multi-shell data.
436
437    Parameters
438    ----------
439    sh_order : int
440         Maximum spherical harmonics order.
441    bvals : ndarray
442        Array containing the b-values. Must be unique b-values, like outputed
443        by `dipy.core.gradients.unique_bvals_tolerance`.
444    wm_rf : (4, len(bvals)) ndarray
445        Response function of the WM tissue, for each bvals.
446    gm_rf : (4, len(bvals)) ndarray
447        Response function of the GM tissue, for each bvals.
448    csf_rf : (4, len(bvals)) ndarray
449        Response function of the CSF tissue, for each bvals.
450    sphere : `dipy.core.Sphere` instance, optional
451        Sphere where the signal will be evaluated.
452
453    Returns
454    -------
455    MultiShellResponse
456        MultiShellResponse object.
457    """
458    NUMPY_1_14_PLUS = LooseVersion(np.__version__) >= LooseVersion('1.14.0')
459    rcond_value = None if NUMPY_1_14_PLUS else -1
460
461    bvals = np.array(bvals, copy=True)
462    evecs = np.zeros((3, 3))
463    z = np.array([0, 0, 1.])
464    evecs[:, 0] = z
465    evecs[:2, 1:] = np.eye(2)
466
467    n = np.arange(0, sh_order + 1, 2)
468    m = np.zeros_like(n)
469
470    if sphere is None:
471        sphere = default_sphere
472
473    big_sphere = sphere.subdivide()
474    theta, phi = big_sphere.theta, big_sphere.phi
475
476    B = shm.real_sh_descoteaux_from_index(m, n, theta[:, None], phi[:, None])
477    A = shm.real_sh_descoteaux_from_index(0, 0, 0, 0)
478
479    response = np.empty([len(bvals), len(n) + 2])
480
481    if bvals[0] < tol:
482        gtab = GradientTable(big_sphere.vertices * 0)
483        wm_response = single_tensor(gtab, wm_rf[0, 3], wm_rf[0, :3], evecs,
484                                    snr=None)
485        response[0, 2:] = np.linalg.lstsq(B, wm_response, rcond=rcond_value)[0]
486
487        response[0, 1] = gm_rf[0, 3] / A
488        response[0, 0] = csf_rf[0, 3] / A
489
490        for i, bvalue in enumerate(bvals[1:]):
491            gtab = GradientTable(big_sphere.vertices * bvalue)
492            wm_response = single_tensor(gtab, wm_rf[i, 3], wm_rf[i, :3], evecs,
493                                        snr=None)
494            response[i+1, 2:] = np.linalg.lstsq(B, wm_response,
495                                                rcond=rcond_value)[0]
496
497            response[i+1, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A
498            response[i+1, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A
499
500        S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]]
501
502    else:
503        warnings.warn("""No b0 given. Proceeding either way.""", UserWarning)
504        for i, bvalue in enumerate(bvals):
505            gtab = GradientTable(big_sphere.vertices * bvalue)
506            wm_response = single_tensor(gtab, wm_rf[i, 3], wm_rf[i, :3], evecs,
507                                        snr=None)
508            response[i, 2:] = np.linalg.lstsq(B, wm_response,
509                                              rcond=rcond_value)[0]
510
511            response[i, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A
512            response[i, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A
513
514        S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]]
515
516    return MultiShellResponse(response, sh_order, bvals, S0=S0)
517
518
519def mask_for_response_msmt(gtab, data, roi_center=None, roi_radii=10,
520                           wm_fa_thr=0.7, gm_fa_thr=0.2, csf_fa_thr=0.1,
521                           gm_md_thr=0.0007, csf_md_thr=0.002):
522    """ Computation of masks for multi-shell multi-tissue (msmt) response
523        function using FA and MD.
524
525    Parameters
526    ----------
527    gtab : GradientTable
528    data : ndarray
529        diffusion data (4D)
530    roi_center : array-like, (3,)
531        Center of ROI in data. If center is None, it is assumed that it is
532        the center of the volume with shape `data.shape[:3]`.
533    roi_radii : int or array-like, (3,)
534        radii of cuboid ROI
535    wm_fa_thr : float
536        FA threshold for WM.
537    gm_fa_thr : float
538        FA threshold for GM.
539    csf_fa_thr : float
540        FA threshold for CSF.
541    gm_md_thr : float
542        MD threshold for GM.
543    csf_md_thr : float
544        MD threshold for CSF.
545
546    Returns
547    -------
548    mask_wm : ndarray
549        Mask of voxels within the ROI and with FA above the FA threshold
550        for WM.
551    mask_gm : ndarray
552        Mask of voxels within the ROI and with FA below the FA threshold
553        for GM and with MD below the MD threshold for GM.
554    mask_csf : ndarray
555        Mask of voxels within the ROI and with FA below the FA threshold
556        for CSF and with MD below the MD threshold for CSF.
557
558    Notes
559    -----
560    In msmt-CSD there is an important pre-processing step: the estimation of
561    every tissue's response function. In order to do this, we look for voxels
562    corresponding to WM, GM and CSF. This function aims to accomplish that by
563    returning a mask of voxels within a ROI and who respect some threshold
564    constraints, for each tissue. More precisely, the WM mask must have a FA
565    value above a given threshold. The GM mask and CSF mask must have a FA
566    below given thresholds and a MD below other thresholds. To get the FA and
567    MD, we need to fit a Tensor model to the datasets.
568    """
569
570    if len(data.shape) < 4:
571        msg = """Data must be 4D (3D image + directions). To use a 2D image,
572        please reshape it into a (N, N, 1, ndirs) array."""
573        raise ValueError(msg)
574
575    if isinstance(roi_radii, numbers.Number):
576        roi_radii = (roi_radii, roi_radii, roi_radii)
577
578    if roi_center is None:
579        roi_center = np.array(data.shape[:3]) // 2
580
581    roi_radii = _roi_in_volume(data.shape, np.asarray(roi_center),
582                               np.asarray(roi_radii))
583
584    roi_mask = _mask_from_roi(data.shape[:3], roi_center, roi_radii)
585
586    list_bvals = unique_bvals_tolerance(gtab.bvals)
587    if not np.all(list_bvals <= 1200):
588        msg_bvals = """Some b-values are higher than 1200.
589        The DTI fit might be affected."""
590        warnings.warn(msg_bvals, UserWarning)
591
592    ten = TensorModel(gtab)
593    tenfit = ten.fit(data, mask=roi_mask)
594    fa = fractional_anisotropy(tenfit.evals)
595    fa[np.isnan(fa)] = 0
596    md = mean_diffusivity(tenfit.evals)
597    md[np.isnan(md)] = 0
598
599    mask_wm = np.zeros(fa.shape, dtype=np.int64)
600    mask_wm[fa > wm_fa_thr] = 1
601    mask_wm *= roi_mask
602
603    md_mask_gm = np.zeros(md.shape, dtype=np.int64)
604    md_mask_gm[(md < gm_md_thr)] = 1
605
606    fa_mask_gm = np.zeros(fa.shape, dtype=np.int64)
607    fa_mask_gm[(fa < gm_fa_thr) & (fa > 0)] = 1
608
609    mask_gm = md_mask_gm * fa_mask_gm
610    mask_gm *= roi_mask
611
612    md_mask_csf = np.zeros(md.shape, dtype=np.int64)
613    md_mask_csf[(md < csf_md_thr) & (md > 0)] = 1
614
615    fa_mask_csf = np.zeros(fa.shape, dtype=np.int64)
616    fa_mask_csf[(fa < csf_fa_thr) & (fa > 0)] = 1
617
618    mask_csf = md_mask_csf * fa_mask_csf
619    mask_csf *= roi_mask
620
621    msg = """No voxel with a {0} than {1} were found.
622    Try a larger roi or a {2} threshold for {3}."""
623
624    if np.sum(mask_wm) == 0:
625        msg_fa = msg.format('FA higher', str(wm_fa_thr), 'lower FA', 'WM')
626        warnings.warn(msg_fa, UserWarning)
627
628    if np.sum(mask_gm) == 0:
629        msg_fa = msg.format('FA lower', str(gm_fa_thr), 'higher FA', 'GM')
630        msg_md = msg.format('MD lower', str(gm_md_thr), 'higher MD', 'GM')
631        warnings.warn(msg_fa, UserWarning)
632        warnings.warn(msg_md, UserWarning)
633
634    if np.sum(mask_csf) == 0:
635        msg_fa = msg.format('FA lower', str(csf_fa_thr), 'higher FA', 'CSF')
636        msg_md = msg.format('MD lower', str(csf_md_thr), 'higher MD', 'CSF')
637        warnings.warn(msg_fa, UserWarning)
638        warnings.warn(msg_md, UserWarning)
639
640    return mask_wm, mask_gm, mask_csf
641
642
643def response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, tol=20):
644    """ Computation of multi-shell multi-tissue (msmt) response
645        functions from given tissues masks.
646
647    Parameters
648    ----------
649    gtab : GradientTable
650    data : ndarray
651        diffusion data
652    mask_wm : ndarray
653        mask from where to compute the WM response function.
654    mask_gm : ndarray
655        mask from where to compute the GM response function.
656    mask_csf : ndarray
657        mask from where to compute the CSF response function.
658    tol : int
659        tolerance gap for b-values clustering. (Default = 20)
660
661    Returns
662    -------
663    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
664        (`evals`, `S0`) for WM for each unique bvalues (except b0).
665    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
666        (`evals`, `S0`) for GM for each unique bvalues (except b0).
667    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
668        (`evals`, `S0`) for CSF for each unique bvalues (except b0).
669
670    Notes
671    -----
672    In msmt-CSD there is an important pre-processing step: the estimation of
673    every tissue's response function. In order to do this, we look for voxels
674    corresponding to WM, GM and CSF. This information can be obtained by using
675    mcsd.mask_for_response_msmt() through masks of selected voxels. The present
676    function uses such masks to compute the msmt response functions.
677
678    For the responses, we base our approach on the function
679    csdeconv.response_from_mask_ssst(), with the added layers of multishell and
680    multi-tissue (see the ssst function for more information about the
681    computation of the ssst response function). This means that for each tissue
682    we use the previously found masks and loop on them. For each mask, we loop
683    on the b-values (clustered using the tolerance gap) to get many responses
684    and then average them to get one response per tissue.
685    """
686
687    bvals = gtab.bvals
688    bvecs = gtab.bvecs
689    btens = gtab.btens
690
691    list_bvals = unique_bvals_tolerance(bvals, tol)
692
693    b0_indices = get_bval_indices(bvals, list_bvals[0], tol)
694    b0_map = np.mean(data[..., b0_indices], axis=-1)[..., np.newaxis]
695
696    masks = [mask_wm, mask_gm, mask_csf]
697    tissue_responses = []
698    for mask in masks:
699        responses = []
700        for bval in list_bvals[1:]:
701            indices = get_bval_indices(bvals, bval, tol)
702
703            bvecs_sub = np.concatenate([[bvecs[b0_indices[0]]],
704                                       bvecs[indices]])
705            bvals_sub = np.concatenate([[0], bvals[indices]])
706            if btens is not None:
707                btens_b0 = btens[b0_indices[0]].reshape((1, 3, 3))
708                btens_sub = np.concatenate([btens_b0, btens[indices]])
709            else:
710                btens_sub = None
711
712            data_conc = np.concatenate([b0_map, data[..., indices]], axis=3)
713
714            gtab = gradient_table(bvals_sub, bvecs_sub, btens=btens_sub)
715            response, _ = response_from_mask_ssst(gtab, data_conc, mask)
716
717            responses.append(list(np.concatenate([response[0], [response[1]]])))
718
719        tissue_responses.append(list(responses))
720
721    wm_response = np.asarray(tissue_responses[0])
722    gm_response = np.asarray(tissue_responses[1])
723    csf_response = np.asarray(tissue_responses[2])
724    return wm_response, gm_response, csf_response
725
726
727def auto_response_msmt(gtab, data, tol=20, roi_center=None, roi_radii=10,
728                       wm_fa_thr=0.7, gm_fa_thr=0.3, csf_fa_thr=0.15,
729                       gm_md_thr=0.001, csf_md_thr=0.0032):
730    """ Automatic estimation of multi-shell multi-tissue (msmt) response
731        functions using FA and MD.
732
733    Parameters
734    ----------
735    gtab : GradientTable
736    data : ndarray
737        diffusion data
738    roi_center : array-like, (3,)
739        Center of ROI in data. If center is None, it is assumed that it is
740        the center of the volume with shape `data.shape[:3]`.
741    roi_radii : int or array-like, (3,)
742        radii of cuboid ROI
743    wm_fa_thr : float
744        FA threshold for WM.
745    gm_fa_thr : float
746        FA threshold for GM.
747    csf_fa_thr : float
748        FA threshold for CSF.
749    gm_md_thr : float
750        MD threshold for GM.
751    csf_md_thr : float
752        MD threshold for CSF.
753
754    Returns
755    -------
756    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
757        (`evals`, `S0`) for WM for each unique bvalues (except b0).
758    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
759        (`evals`, `S0`) for GM for each unique bvalues (except b0).
760    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
761        (`evals`, `S0`) for CSF for each unique bvalues (except b0).
762
763    Notes
764    -----
765    In msmt-CSD there is an important pre-processing step: the estimation of
766    every tissue's response function. In order to do this, we look for voxels
767    corresponding to WM, GM and CSF. We get this information from
768    mcsd.mask_for_response_msmt(), which returns masks of selected voxels
769    (more details are available in the description of the function).
770
771    With the masks, we compute the response functions by using
772    mcsd.response_from_mask_msmt(), which returns the `response` for each
773    tissue (more details are available in the description of the function).
774    """
775
776    list_bvals = unique_bvals_tolerance(gtab.bvals)
777    if not np.all(list_bvals <= 1200):
778        msg_bvals = """Some b-values are higher than 1200.
779        The DTI fit might be affected. It is advised to use
780        mask_for_response_msmt with bvalues lower than 1200, followed by
781        response_from_mask_msmt with all bvalues to overcome this."""
782        warnings.warn(msg_bvals, UserWarning)
783    mask_wm, mask_gm, mask_csf = mask_for_response_msmt(gtab, data,
784                                                        roi_center,
785                                                        roi_radii,
786                                                        wm_fa_thr,
787                                                        gm_fa_thr,
788                                                        csf_fa_thr,
789                                                        gm_md_thr,
790                                                        csf_md_thr)
791    response_wm, response_gm, response_csf = response_from_mask_msmt(
792                                                        gtab, data,
793                                                        mask_wm, mask_gm,
794                                                        mask_csf, tol)
795
796    return response_wm, response_gm, response_csf
797