1# cython: boundscheck=False 2# cython: initializedcheck=False 3# cython: wraparound=False 4 5import numpy as np 6cimport numpy as cnp 7 8from dipy.core.geometry import cart2sphere 9from dipy.reconst import shm 10 11from dipy.core.interpolation cimport trilinear_interpolate4d_c 12 13 14cdef class PmfGen: 15 16 def __init__(self, 17 double[:, :, :, :] data): 18 self.data = np.asarray(data, dtype=float) 19 20 cpdef double[:] get_pmf(self, double[::1] point): 21 return self.get_pmf_c(&point[0]) 22 23 cdef double[:] get_pmf_c(self, double* point): 24 pass 25 26 cdef void __clear_pmf(self): 27 cdef: 28 cnp.npy_intp len_pmf = self.pmf.shape[0] 29 cnp.npy_intp i 30 31 for i in range(len_pmf): 32 self.pmf[i] = 0.0 33 34 35cdef class SimplePmfGen(PmfGen): 36 37 def __init__(self, 38 double[:, :, :, :] pmf_array): 39 PmfGen.__init__(self, pmf_array) 40 self.pmf = np.empty(pmf_array.shape[3]) 41 if np.min(pmf_array) < 0: 42 raise ValueError("pmf should not have negative values.") 43 44 cdef double[:] get_pmf_c(self, double* point): 45 if trilinear_interpolate4d_c(self.data, point, self.pmf) != 0: 46 PmfGen.__clear_pmf(self) 47 return self.pmf 48 49 50cdef class SHCoeffPmfGen(PmfGen): 51 52 def __init__(self, 53 double[:, :, :, :] shcoeff_array, 54 object sphere, 55 object basis_type): 56 cdef: 57 int sh_order 58 59 PmfGen.__init__(self, shcoeff_array) 60 61 self.sphere = sphere 62 sh_order = shm.order_from_ncoef(shcoeff_array.shape[3]) 63 try: 64 basis = shm.sph_harm_lookup[basis_type] 65 except KeyError: 66 raise ValueError("%s is not a known basis type." % basis_type) 67 self.B, _, _ = basis(sh_order, sphere.theta, sphere.phi) 68 self.coeff = np.empty(shcoeff_array.shape[3]) 69 self.pmf = np.empty(self.B.shape[0]) 70 71 cdef double[:] get_pmf_c(self, double* point): 72 cdef: 73 cnp.npy_intp i, j 74 cnp.npy_intp len_pmf = self.pmf.shape[0] 75 cnp.npy_intp len_B = self.B.shape[1] 76 double _sum 77 78 if trilinear_interpolate4d_c(self.data, point, self.coeff) != 0: 79 PmfGen.__clear_pmf(self) 80 else: 81 for i in range(len_pmf): 82 _sum = 0 83 for j in range(len_B): 84 _sum += self.B[i, j] * self.coeff[j] 85 self.pmf[i] = _sum 86 return self.pmf 87 88 89cdef class BootPmfGen(PmfGen): 90 91 def __init__(self, 92 double[:, :, :, :] dwi_array, 93 object model, 94 object sphere, 95 int sh_order=0, 96 double tol=1e-2): 97 cdef: 98 double b_range 99 np.ndarray x, y, z, r 100 double[:] theta, phi 101 double[:, :] B 102 103 PmfGen.__init__(self, dwi_array) 104 self.sh_order = sh_order 105 if self.sh_order == 0: 106 if hasattr(model, "sh_order"): 107 self.sh_order = model.sh_order 108 else: 109 self.sh_order = 4 # DEFAULT Value 110 111 self.dwi_mask = model.gtab.b0s_mask == 0 112 x, y, z = model.gtab.gradients[self.dwi_mask].T 113 r, theta, phi = shm.cart2sphere(x, y, z) 114 b_range = (r.max() - r.min()) / r.min() 115 if b_range > tol: 116 raise ValueError("BootPmfGen only supports single shell data") 117 B, _, _ = shm.real_sh_descoteaux(self.sh_order, theta, phi) 118 self.H = shm.hat(B) 119 self.R = shm.lcr_matrix(self.H) 120 self.vox_data = np.empty(dwi_array.shape[3]) 121 122 self.model = model 123 self.sphere = sphere 124 self.pmf = np.empty(len(sphere.theta)) 125 126 127 cdef double[:] get_pmf_c(self, double* point): 128 """Produces an ODF from a SH bootstrap sample""" 129 if trilinear_interpolate4d_c(self.data, point, self.vox_data) != 0: 130 self.__clear_pmf() 131 else: 132 self.vox_data[self.dwi_mask] = shm.bootstrap_data_voxel( 133 self.vox_data[self.dwi_mask], self.H, self.R) 134 self.pmf = self.model.fit(self.vox_data).odf(self.sphere) 135 return self.pmf 136 137 138 cpdef double[:] get_pmf_no_boot(self, double[::1] point): 139 return self.get_pmf_no_boot_c(&point[0]) 140 141 142 cdef double[:] get_pmf_no_boot_c(self, double* point): 143 if trilinear_interpolate4d_c(self.data, point, self.vox_data) != 0: 144 self.__clear_pmf() 145 else: 146 self.pmf = self.model.fit(self.vox_data).odf(self.sphere) 147 return self.pmf 148