1from distutils.version import LooseVersion 2import warnings 3 4import numpy as np 5import numbers 6from dipy.core import geometry as geo 7from dipy.core.gradients import (GradientTable, gradient_table, 8 unique_bvals_tolerance, get_bval_indices) 9from dipy.data import default_sphere 10from dipy.reconst import shm 11from dipy.reconst.csdeconv import response_from_mask_ssst 12from dipy.reconst.dti import (TensorModel, fractional_anisotropy, 13 mean_diffusivity) 14from dipy.reconst.multi_voxel import multi_voxel_fit 15from dipy.reconst.utils import _roi_in_volume, _mask_from_roi 16from dipy.sims.voxel import single_tensor 17 18from dipy.utils.optpkg import optional_package 19cvxpy, have_cvxpy, _ = optional_package("cvxpy") 20 21SH_CONST = .5 / np.sqrt(np.pi) 22 23 24def multi_tissue_basis(gtab, sh_order, iso_comp): 25 """ 26 Builds a basis for multi-shell multi-tissue CSD model. 27 28 Parameters 29 ---------- 30 gtab : GradientTable 31 sh_order : int 32 iso_comp: int 33 Number of tissue compartments for running the MSMT-CSD. Minimum 34 number of compartments required is 2. 35 36 Returns 37 ------- 38 B : ndarray 39 Matrix of the spherical harmonics model used to fit the data 40 m : int ``|m| <= n`` 41 The order of the harmonic. 42 n : int ``>= 0`` 43 The degree of the harmonic. 44 """ 45 if iso_comp < 2: 46 msg = ("Multi-tissue CSD requires at least 2 tissue compartments") 47 raise ValueError(msg) 48 r, theta, phi = geo.cart2sphere(*gtab.gradients.T) 49 m, n = shm.sph_harm_ind_list(sh_order) 50 B = shm.real_sh_descoteaux_from_index(m, n, theta[:, None], phi[:, None]) 51 B[np.ix_(gtab.b0s_mask, n > 0)] = 0. 52 53 iso = np.empty([B.shape[0], iso_comp]) 54 iso[:] = SH_CONST 55 56 B = np.concatenate([iso, B], axis=1) 57 return B, m, n 58 59 60class MultiShellResponse(object): 61 62 def __init__(self, response, sh_order, shells, S0=None): 63 """ Estimate Multi Shell response function for multiple tissues and 64 multiple shells. 65 66 The method `multi_shell_fiber_response` allows to create a multi-shell 67 fiber response with the right format, for a three compartments model. 68 It can be refered to in order to understand the inputs of this class. 69 70 Parameters 71 ---------- 72 response : ndarray 73 Multi-shell fiber response. The ordering of the responses should 74 follow the same logic as S0. 75 sh_order : int 76 Maximal spherical harmonics order. 77 shells : int 78 Number of shells in the data 79 S0 : array (3,) 80 Signal with no diffusion weighting for each tissue compartments, in 81 the same tissue order as `response`. This S0 can be used for 82 predicting from a fit model later on. 83 """ 84 self.S0 = S0 85 self.response = response 86 self.sh_order = sh_order 87 self.n = np.arange(0, sh_order + 1, 2) 88 self.m = np.zeros_like(self.n) 89 self.shells = shells 90 if self.iso < 1: 91 raise ValueError("sh_order and shape of response do not agree") 92 93 @property 94 def iso(self): 95 return self.response.shape[1] - (self.sh_order // 2) - 1 96 97 98def _inflate_response(response, gtab, n, delta): 99 """Used to inflate the response for the `multiplier_matrix` in the 100 `MultiShellDeconvModel`. 101 Parameters 102 ---------- 103 response : MultiShellResponse object 104 gtab : GradientTable 105 n : int ``>= 0`` 106 The degree of the harmonic. 107 delta : Delta generated from `_basic_delta` 108 """ 109 if any((n % 2) != 0) or (n.max() // 2) >= response.sh_order: 110 raise ValueError("Response and n do not match") 111 112 iso = response.iso 113 n_idx = np.empty(len(n) + iso, dtype=int) 114 n_idx[:iso] = np.arange(0, iso) 115 n_idx[iso:] = n // 2 + iso 116 diff = abs(response.shells[:, None] - gtab.bvals) 117 b_idx = np.argmin(diff, axis=0) 118 kernal = response.response / delta 119 120 return kernal[np.ix_(b_idx, n_idx)] 121 122 123def _basic_delta(iso, m, n, theta, phi): 124 """Simple delta function 125 Parameters 126 ---------- 127 iso: int 128 Number of tissue compartments for running the MSMT-CSD. Minimum 129 number of compartments required is 2. 130 Default: 2 131 m : int ``|m| <= n`` 132 The order of the harmonic. 133 n : int ``>= 0`` 134 The degree of the harmonic. 135 theta : array_like 136 inclination or polar angle 137 phi : array_like 138 azimuth angle 139 """ 140 wm_d = shm.gen_dirac(m, n, theta, phi) 141 iso_d = [SH_CONST] * iso 142 return np.concatenate([iso_d, wm_d]) 143 144 145class MultiShellDeconvModel(shm.SphHarmModel): 146 147 def __init__(self, gtab, response, reg_sphere=default_sphere, 148 sh_order=8, iso=2): 149 r""" 150 Multi-Shell Multi-Tissue Constrained Spherical Deconvolution 151 (MSMT-CSD) [1]_. This method extends the CSD model proposed in [2]_ by 152 the estimation of multiple response functions as a function of multiple 153 b-values and multiple tissue types. 154 155 Spherical deconvolution computes a fiber orientation distribution 156 (FOD), also called fiber ODF (fODF) [2]_. The fODF is derived from 157 different tissue types and thus overcomes the overestimation of WM in 158 GM and CSF areas. 159 160 The response function is based on the different tissue types 161 and is provided as input to the MultiShellDeconvModel. 162 It will be used as deconvolution kernel, as described in [2]_. 163 164 Parameters 165 ---------- 166 gtab : GradientTable 167 response : ndarray or MultiShellResponse object 168 Pre-computed multi-shell fiber response function in the form of a 169 MultiShellResponse object, or simple response function as a ndarray. 170 The later must be of shape (3, len(bvals)-1, 4), because it will be 171 converted into a MultiShellResponse object via the 172 `multi_shell_fiber_response` method (important note: the function 173 `unique_bvals_tolerance` is used here to select unique bvalues from 174 gtab as input). Each column (3,) has two elements. The first is the 175 eigen-values as a (3,) ndarray and the second is the signal value 176 for the response function without diffusion weighting (S0). Note 177 that in order to use more than three compartments, one must create 178 a MultiShellResponse object on the side. 179 reg_sphere : Sphere (optional) 180 sphere used to build the regularization B matrix. 181 Default: 'symmetric362'. 182 sh_order : int (optional) 183 maximal spherical harmonics order. Default: 8 184 iso: int (optional) 185 Number of tissue compartments for running the MSMT-CSD. Minimum 186 number of compartments required is 2. 187 Default: 2 188 189 References 190 ---------- 191 .. [1] Jeurissen, B., et al. NeuroImage 2014. Multi-tissue constrained 192 spherical deconvolution for improved analysis of multi-shell 193 diffusion MRI data 194 .. [2] Tournier, J.D., et al. NeuroImage 2007. Robust determination of 195 the fibre orientation distribution in diffusion MRI: 196 Non-negativity constrained super-resolved spherical 197 deconvolution 198 .. [3] Tournier, J.D, et al. Imaging Systems and Technology 199 2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions 200 """ 201 if not iso >= 2: 202 msg = ("Multi-tissue CSD requires at least 2 tissue compartments") 203 raise ValueError(msg) 204 205 super(MultiShellDeconvModel, self).__init__(gtab) 206 207 if not isinstance(response, MultiShellResponse): 208 bvals = unique_bvals_tolerance(gtab.bvals, tol=20) 209 if iso > 2: 210 msg = """Too many compartments for this kind of response 211 input. It must be two tissue compartments.""" 212 raise ValueError(msg) 213 if response.shape != (3, len(bvals)-1, 4): 214 msg = """Response must be of shape (3, len(bvals)-1, 4) or be a 215 MultiShellResponse object.""" 216 raise ValueError(msg) 217 response = multi_shell_fiber_response(sh_order, 218 bvals=bvals, 219 wm_rf=response[0], 220 gm_rf=response[1], 221 csf_rf=response[2]) 222 223 B, m, n = multi_tissue_basis(gtab, sh_order, iso) 224 225 delta = _basic_delta(response.iso, response.m, response.n, 0., 0.) 226 self.delta = delta 227 multiplier_matrix = _inflate_response(response, gtab, n, delta) 228 229 r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T) 230 odf_reg, _, _ = shm.real_sh_descoteaux(sh_order, theta, phi) 231 reg = np.zeros([i + iso for i in odf_reg.shape]) 232 reg[:iso, :iso] = np.eye(iso) 233 reg[iso:, iso:] = odf_reg 234 235 X = B * multiplier_matrix 236 237 self.fitter = QpFitter(X, reg) 238 self.sh_order = sh_order 239 self._X = X 240 self.sphere = reg_sphere 241 self.gtab = gtab 242 self.B_dwi = B 243 self.m = m 244 self.n = n 245 self.response = response 246 247 def predict(self, params, gtab=None, S0=None): 248 """Compute a signal prediction given spherical harmonic coefficients 249 for the provided GradientTable class instance. 250 251 Parameters 252 ---------- 253 params : ndarray 254 The spherical harmonic representation of the FOD from which to make 255 the signal prediction. 256 gtab : GradientTable 257 The gradients for which the signal will be predicted. Use the 258 model's gradient table by default. 259 S0 : ndarray or float 260 The non diffusion-weighted signal value. 261 """ 262 if gtab is None or gtab is self.gtab: 263 gtab = self.gtab 264 X = self._X 265 else: 266 iso = self.response.iso 267 B, m, n = multi_tissue_basis(gtab, self.sh_order, iso) 268 multiplier_matrix = _inflate_response(self.response, gtab, n, 269 self.delta) 270 X = B * multiplier_matrix 271 272 scaling = 1. 273 if S0 and S0 != 1.: # The S0=1. case comes from fit.predict(). 274 raise NotImplementedError 275 # This case is not implemented yet because it would require to have 276 # access to volume fractions (vf) from the fit. The following code 277 # gives an idea of how to use this with S0 and vf. It could also be 278 # calculated externally and used as scaling = S0. 279 # response_scaling = np.ndarray(params.shape[0:3]) 280 # response_scaling[...] = (vf[..., 0] * self.response.S0[0] 281 # + vf[..., 1] * self.response.S0[1] 282 # + vf[..., 2] * self.response.S0[2]) 283 # scaling = np.where(response_scaling > 1, S0 / response_scaling, 0) 284 # scaling = np.expand_dims(scaling, 3) 285 # scaling = np.repeat(scaling, len(gtab.bvals), axis=3) 286 287 pred_sig = scaling * np.dot(params, X.T) 288 return pred_sig 289 290 @multi_voxel_fit 291 def fit(self, data, verbose=True): 292 """Fits the model to diffusion data and returns the model fit. 293 294 Sometimes the solving process of some voxels can end in a SolverError 295 from cvxpy. This might be attributed to the response functions not 296 being tuned properly, as the solving process is very sensitive to it. 297 The method will fill the problematic voxels with a NaN value, so that 298 it is traceable. The user should check for the number of NaN values and 299 could then fill the problematic voxels with zeros, for example. 300 Running a fit again only on those problematic voxels can also work. 301 302 Parameters 303 ---------- 304 data : ndarray 305 The diffusion data to fit the model on. 306 verbose : bool (optional) 307 Whether to show warnings when a SolverError appears or not. 308 Default: True 309 """ 310 coeff = self.fitter(data) 311 if verbose: 312 if np.isnan(coeff[..., 0]): 313 msg = """Voxel could not be solved properly and ended up with a 314 SolverError. Proceeding to fill it with NaN values. 315 """ 316 warnings.warn(msg, UserWarning) 317 318 return MSDeconvFit(self, coeff, None) 319 320 321class MSDeconvFit(shm.SphHarmFit): 322 323 def __init__(self, model, coeff, mask): 324 """ 325 Abstract class which holds the fit result of MultiShellDeconvModel. 326 Inherits the SphHarmFit which fits the diffusion data to a spherical 327 harmonic model. 328 329 Parameters 330 ---------- 331 model: object 332 MultiShellDeconvModel 333 coeff : array 334 Spherical harmonic coefficients for the ODF. 335 mask: ndarray 336 Mask for fitting 337 """ 338 self._shm_coef = coeff 339 self.mask = mask 340 self.model = model 341 342 @property 343 def shm_coeff(self): 344 return self._shm_coef[..., self.model.response.iso:] 345 346 @property 347 def all_shm_coeff(self): 348 return self._shm_coef 349 350 @property 351 def volume_fractions(self): 352 tissue_classes = self.model.response.iso + 1 353 return self._shm_coef[..., :tissue_classes] / SH_CONST 354 355 356def solve_qp(P, Q, G, H): 357 r""" 358 Helper function to set up and solve the Quadratic Program (QP) in CVXPY. 359 A QP problem has the following form: 360 minimize 1/2 x' P x + Q' x 361 subject to G x <= H 362 363 Here the QP solver is based on CVXPY and uses OSQP. 364 365 Parameters 366 ---------- 367 P : ndarray 368 n x n matrix for the primal QP objective function. 369 Q : ndarray 370 n x 1 matrix for the primal QP objective function. 371 G : ndarray 372 m x n matrix for the inequality constraint. 373 H : ndarray 374 m x 1 matrix for the inequality constraint. 375 376 Returns 377 ------- 378 x : array 379 Optimal solution to the QP problem. 380 """ 381 x = cvxpy.Variable(Q.shape[0]) 382 P = cvxpy.Constant(P) 383 if LooseVersion(cvxpy.__version__) < LooseVersion('1.1'): 384 objective = cvxpy.Minimize(0.5 * cvxpy.quad_form(x, P) + Q * x) 385 constraints = [G * x <= H] 386 else: 387 objective = cvxpy.Minimize(0.5 * cvxpy.quad_form(x, P) + Q @ x) 388 constraints = [G @ x <= H] 389 390 # setting up the problem 391 prob = cvxpy.Problem(objective, constraints) 392 try: 393 prob.solve() 394 opt = np.array(x.value).reshape((Q.shape[0],)) 395 except cvxpy.error.SolverError: 396 opt = np.empty((Q.shape[0],)) 397 opt[:] = np.NaN 398 399 return opt 400 401 402class QpFitter(object): 403 404 def __init__(self, X, reg): 405 r""" 406 Makes use of the quadratic programming solver `solve_qp` to fit the 407 model. The initialization for the model is done using the warm-start by 408 default in `CVXPY`. 409 410 Parameters 411 ---------- 412 X : ndarray 413 Matrix to be fit by the QP solver calculated in 414 `MultiShellDeconvModel` 415 reg : ndarray 416 the regularization B matrix calculated in `MultiShellDeconvModel` 417 """ 418 self._P = P = np.dot(X.T, X) 419 self._X = X 420 421 self._reg = reg 422 self._P_mat = np.array(P) 423 self._reg_mat = np.array(-reg) 424 self._h_mat = np.array([0]) 425 426 def __call__(self, signal): 427 Q = np.dot(self._X.T, signal) 428 Q_mat = np.array(-Q) 429 fodf_sh = solve_qp(self._P_mat, Q_mat, self._reg_mat, self._h_mat) 430 return fodf_sh 431 432 433def multi_shell_fiber_response(sh_order, bvals, wm_rf, gm_rf, csf_rf, 434 sphere=None, tol=20): 435 """Fiber response function estimation for multi-shell data. 436 437 Parameters 438 ---------- 439 sh_order : int 440 Maximum spherical harmonics order. 441 bvals : ndarray 442 Array containing the b-values. Must be unique b-values, like outputed 443 by `dipy.core.gradients.unique_bvals_tolerance`. 444 wm_rf : (4, len(bvals)) ndarray 445 Response function of the WM tissue, for each bvals. 446 gm_rf : (4, len(bvals)) ndarray 447 Response function of the GM tissue, for each bvals. 448 csf_rf : (4, len(bvals)) ndarray 449 Response function of the CSF tissue, for each bvals. 450 sphere : `dipy.core.Sphere` instance, optional 451 Sphere where the signal will be evaluated. 452 453 Returns 454 ------- 455 MultiShellResponse 456 MultiShellResponse object. 457 """ 458 NUMPY_1_14_PLUS = LooseVersion(np.__version__) >= LooseVersion('1.14.0') 459 rcond_value = None if NUMPY_1_14_PLUS else -1 460 461 bvals = np.array(bvals, copy=True) 462 evecs = np.zeros((3, 3)) 463 z = np.array([0, 0, 1.]) 464 evecs[:, 0] = z 465 evecs[:2, 1:] = np.eye(2) 466 467 n = np.arange(0, sh_order + 1, 2) 468 m = np.zeros_like(n) 469 470 if sphere is None: 471 sphere = default_sphere 472 473 big_sphere = sphere.subdivide() 474 theta, phi = big_sphere.theta, big_sphere.phi 475 476 B = shm.real_sh_descoteaux_from_index(m, n, theta[:, None], phi[:, None]) 477 A = shm.real_sh_descoteaux_from_index(0, 0, 0, 0) 478 479 response = np.empty([len(bvals), len(n) + 2]) 480 481 if bvals[0] < tol: 482 gtab = GradientTable(big_sphere.vertices * 0) 483 wm_response = single_tensor(gtab, wm_rf[0, 3], wm_rf[0, :3], evecs, 484 snr=None) 485 response[0, 2:] = np.linalg.lstsq(B, wm_response, rcond=rcond_value)[0] 486 487 response[0, 1] = gm_rf[0, 3] / A 488 response[0, 0] = csf_rf[0, 3] / A 489 490 for i, bvalue in enumerate(bvals[1:]): 491 gtab = GradientTable(big_sphere.vertices * bvalue) 492 wm_response = single_tensor(gtab, wm_rf[i, 3], wm_rf[i, :3], evecs, 493 snr=None) 494 response[i+1, 2:] = np.linalg.lstsq(B, wm_response, 495 rcond=rcond_value)[0] 496 497 response[i+1, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A 498 response[i+1, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A 499 500 S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]] 501 502 else: 503 warnings.warn("""No b0 given. Proceeding either way.""", UserWarning) 504 for i, bvalue in enumerate(bvals): 505 gtab = GradientTable(big_sphere.vertices * bvalue) 506 wm_response = single_tensor(gtab, wm_rf[i, 3], wm_rf[i, :3], evecs, 507 snr=None) 508 response[i, 2:] = np.linalg.lstsq(B, wm_response, 509 rcond=rcond_value)[0] 510 511 response[i, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A 512 response[i, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A 513 514 S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]] 515 516 return MultiShellResponse(response, sh_order, bvals, S0=S0) 517 518 519def mask_for_response_msmt(gtab, data, roi_center=None, roi_radii=10, 520 wm_fa_thr=0.7, gm_fa_thr=0.2, csf_fa_thr=0.1, 521 gm_md_thr=0.0007, csf_md_thr=0.002): 522 """ Computation of masks for multi-shell multi-tissue (msmt) response 523 function using FA and MD. 524 525 Parameters 526 ---------- 527 gtab : GradientTable 528 data : ndarray 529 diffusion data (4D) 530 roi_center : array-like, (3,) 531 Center of ROI in data. If center is None, it is assumed that it is 532 the center of the volume with shape `data.shape[:3]`. 533 roi_radii : int or array-like, (3,) 534 radii of cuboid ROI 535 wm_fa_thr : float 536 FA threshold for WM. 537 gm_fa_thr : float 538 FA threshold for GM. 539 csf_fa_thr : float 540 FA threshold for CSF. 541 gm_md_thr : float 542 MD threshold for GM. 543 csf_md_thr : float 544 MD threshold for CSF. 545 546 Returns 547 ------- 548 mask_wm : ndarray 549 Mask of voxels within the ROI and with FA above the FA threshold 550 for WM. 551 mask_gm : ndarray 552 Mask of voxels within the ROI and with FA below the FA threshold 553 for GM and with MD below the MD threshold for GM. 554 mask_csf : ndarray 555 Mask of voxels within the ROI and with FA below the FA threshold 556 for CSF and with MD below the MD threshold for CSF. 557 558 Notes 559 ----- 560 In msmt-CSD there is an important pre-processing step: the estimation of 561 every tissue's response function. In order to do this, we look for voxels 562 corresponding to WM, GM and CSF. This function aims to accomplish that by 563 returning a mask of voxels within a ROI and who respect some threshold 564 constraints, for each tissue. More precisely, the WM mask must have a FA 565 value above a given threshold. The GM mask and CSF mask must have a FA 566 below given thresholds and a MD below other thresholds. To get the FA and 567 MD, we need to fit a Tensor model to the datasets. 568 """ 569 570 if len(data.shape) < 4: 571 msg = """Data must be 4D (3D image + directions). To use a 2D image, 572 please reshape it into a (N, N, 1, ndirs) array.""" 573 raise ValueError(msg) 574 575 if isinstance(roi_radii, numbers.Number): 576 roi_radii = (roi_radii, roi_radii, roi_radii) 577 578 if roi_center is None: 579 roi_center = np.array(data.shape[:3]) // 2 580 581 roi_radii = _roi_in_volume(data.shape, np.asarray(roi_center), 582 np.asarray(roi_radii)) 583 584 roi_mask = _mask_from_roi(data.shape[:3], roi_center, roi_radii) 585 586 list_bvals = unique_bvals_tolerance(gtab.bvals) 587 if not np.all(list_bvals <= 1200): 588 msg_bvals = """Some b-values are higher than 1200. 589 The DTI fit might be affected.""" 590 warnings.warn(msg_bvals, UserWarning) 591 592 ten = TensorModel(gtab) 593 tenfit = ten.fit(data, mask=roi_mask) 594 fa = fractional_anisotropy(tenfit.evals) 595 fa[np.isnan(fa)] = 0 596 md = mean_diffusivity(tenfit.evals) 597 md[np.isnan(md)] = 0 598 599 mask_wm = np.zeros(fa.shape, dtype=np.int64) 600 mask_wm[fa > wm_fa_thr] = 1 601 mask_wm *= roi_mask 602 603 md_mask_gm = np.zeros(md.shape, dtype=np.int64) 604 md_mask_gm[(md < gm_md_thr)] = 1 605 606 fa_mask_gm = np.zeros(fa.shape, dtype=np.int64) 607 fa_mask_gm[(fa < gm_fa_thr) & (fa > 0)] = 1 608 609 mask_gm = md_mask_gm * fa_mask_gm 610 mask_gm *= roi_mask 611 612 md_mask_csf = np.zeros(md.shape, dtype=np.int64) 613 md_mask_csf[(md < csf_md_thr) & (md > 0)] = 1 614 615 fa_mask_csf = np.zeros(fa.shape, dtype=np.int64) 616 fa_mask_csf[(fa < csf_fa_thr) & (fa > 0)] = 1 617 618 mask_csf = md_mask_csf * fa_mask_csf 619 mask_csf *= roi_mask 620 621 msg = """No voxel with a {0} than {1} were found. 622 Try a larger roi or a {2} threshold for {3}.""" 623 624 if np.sum(mask_wm) == 0: 625 msg_fa = msg.format('FA higher', str(wm_fa_thr), 'lower FA', 'WM') 626 warnings.warn(msg_fa, UserWarning) 627 628 if np.sum(mask_gm) == 0: 629 msg_fa = msg.format('FA lower', str(gm_fa_thr), 'higher FA', 'GM') 630 msg_md = msg.format('MD lower', str(gm_md_thr), 'higher MD', 'GM') 631 warnings.warn(msg_fa, UserWarning) 632 warnings.warn(msg_md, UserWarning) 633 634 if np.sum(mask_csf) == 0: 635 msg_fa = msg.format('FA lower', str(csf_fa_thr), 'higher FA', 'CSF') 636 msg_md = msg.format('MD lower', str(csf_md_thr), 'higher MD', 'CSF') 637 warnings.warn(msg_fa, UserWarning) 638 warnings.warn(msg_md, UserWarning) 639 640 return mask_wm, mask_gm, mask_csf 641 642 643def response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, tol=20): 644 """ Computation of multi-shell multi-tissue (msmt) response 645 functions from given tissues masks. 646 647 Parameters 648 ---------- 649 gtab : GradientTable 650 data : ndarray 651 diffusion data 652 mask_wm : ndarray 653 mask from where to compute the WM response function. 654 mask_gm : ndarray 655 mask from where to compute the GM response function. 656 mask_csf : ndarray 657 mask from where to compute the CSF response function. 658 tol : int 659 tolerance gap for b-values clustering. (Default = 20) 660 661 Returns 662 ------- 663 response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) 664 (`evals`, `S0`) for WM for each unique bvalues (except b0). 665 response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) 666 (`evals`, `S0`) for GM for each unique bvalues (except b0). 667 response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) 668 (`evals`, `S0`) for CSF for each unique bvalues (except b0). 669 670 Notes 671 ----- 672 In msmt-CSD there is an important pre-processing step: the estimation of 673 every tissue's response function. In order to do this, we look for voxels 674 corresponding to WM, GM and CSF. This information can be obtained by using 675 mcsd.mask_for_response_msmt() through masks of selected voxels. The present 676 function uses such masks to compute the msmt response functions. 677 678 For the responses, we base our approach on the function 679 csdeconv.response_from_mask_ssst(), with the added layers of multishell and 680 multi-tissue (see the ssst function for more information about the 681 computation of the ssst response function). This means that for each tissue 682 we use the previously found masks and loop on them. For each mask, we loop 683 on the b-values (clustered using the tolerance gap) to get many responses 684 and then average them to get one response per tissue. 685 """ 686 687 bvals = gtab.bvals 688 bvecs = gtab.bvecs 689 btens = gtab.btens 690 691 list_bvals = unique_bvals_tolerance(bvals, tol) 692 693 b0_indices = get_bval_indices(bvals, list_bvals[0], tol) 694 b0_map = np.mean(data[..., b0_indices], axis=-1)[..., np.newaxis] 695 696 masks = [mask_wm, mask_gm, mask_csf] 697 tissue_responses = [] 698 for mask in masks: 699 responses = [] 700 for bval in list_bvals[1:]: 701 indices = get_bval_indices(bvals, bval, tol) 702 703 bvecs_sub = np.concatenate([[bvecs[b0_indices[0]]], 704 bvecs[indices]]) 705 bvals_sub = np.concatenate([[0], bvals[indices]]) 706 if btens is not None: 707 btens_b0 = btens[b0_indices[0]].reshape((1, 3, 3)) 708 btens_sub = np.concatenate([btens_b0, btens[indices]]) 709 else: 710 btens_sub = None 711 712 data_conc = np.concatenate([b0_map, data[..., indices]], axis=3) 713 714 gtab = gradient_table(bvals_sub, bvecs_sub, btens=btens_sub) 715 response, _ = response_from_mask_ssst(gtab, data_conc, mask) 716 717 responses.append(list(np.concatenate([response[0], [response[1]]]))) 718 719 tissue_responses.append(list(responses)) 720 721 wm_response = np.asarray(tissue_responses[0]) 722 gm_response = np.asarray(tissue_responses[1]) 723 csf_response = np.asarray(tissue_responses[2]) 724 return wm_response, gm_response, csf_response 725 726 727def auto_response_msmt(gtab, data, tol=20, roi_center=None, roi_radii=10, 728 wm_fa_thr=0.7, gm_fa_thr=0.3, csf_fa_thr=0.15, 729 gm_md_thr=0.001, csf_md_thr=0.0032): 730 """ Automatic estimation of multi-shell multi-tissue (msmt) response 731 functions using FA and MD. 732 733 Parameters 734 ---------- 735 gtab : GradientTable 736 data : ndarray 737 diffusion data 738 roi_center : array-like, (3,) 739 Center of ROI in data. If center is None, it is assumed that it is 740 the center of the volume with shape `data.shape[:3]`. 741 roi_radii : int or array-like, (3,) 742 radii of cuboid ROI 743 wm_fa_thr : float 744 FA threshold for WM. 745 gm_fa_thr : float 746 FA threshold for GM. 747 csf_fa_thr : float 748 FA threshold for CSF. 749 gm_md_thr : float 750 MD threshold for GM. 751 csf_md_thr : float 752 MD threshold for CSF. 753 754 Returns 755 ------- 756 response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) 757 (`evals`, `S0`) for WM for each unique bvalues (except b0). 758 response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) 759 (`evals`, `S0`) for GM for each unique bvalues (except b0). 760 response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) 761 (`evals`, `S0`) for CSF for each unique bvalues (except b0). 762 763 Notes 764 ----- 765 In msmt-CSD there is an important pre-processing step: the estimation of 766 every tissue's response function. In order to do this, we look for voxels 767 corresponding to WM, GM and CSF. We get this information from 768 mcsd.mask_for_response_msmt(), which returns masks of selected voxels 769 (more details are available in the description of the function). 770 771 With the masks, we compute the response functions by using 772 mcsd.response_from_mask_msmt(), which returns the `response` for each 773 tissue (more details are available in the description of the function). 774 """ 775 776 list_bvals = unique_bvals_tolerance(gtab.bvals) 777 if not np.all(list_bvals <= 1200): 778 msg_bvals = """Some b-values are higher than 1200. 779 The DTI fit might be affected. It is advised to use 780 mask_for_response_msmt with bvalues lower than 1200, followed by 781 response_from_mask_msmt with all bvalues to overcome this.""" 782 warnings.warn(msg_bvals, UserWarning) 783 mask_wm, mask_gm, mask_csf = mask_for_response_msmt(gtab, data, 784 roi_center, 785 roi_radii, 786 wm_fa_thr, 787 gm_fa_thr, 788 csf_fa_thr, 789 gm_md_thr, 790 csf_md_thr) 791 response_wm, response_gm, response_csf = response_from_mask_msmt( 792 gtab, data, 793 mask_wm, mask_gm, 794 mask_csf, tol) 795 796 return response_wm, response_gm, response_csf 797