1# cython: boundscheck=False
2# cython: initializedcheck=False
3# cython: wraparound=False
4
5import numpy as np
6cimport numpy as cnp
7
8from dipy.core.geometry import cart2sphere
9from dipy.reconst import shm
10
11from dipy.core.interpolation cimport trilinear_interpolate4d_c
12
13
14cdef class PmfGen:
15
16    def __init__(self,
17                 double[:, :, :, :] data):
18        self.data = np.asarray(data,  dtype=float)
19
20    cpdef double[:] get_pmf(self, double[::1] point):
21        return self.get_pmf_c(&point[0])
22
23    cdef double[:] get_pmf_c(self, double* point):
24        pass
25
26    cdef void __clear_pmf(self):
27        cdef:
28            cnp.npy_intp len_pmf = self.pmf.shape[0]
29            cnp.npy_intp i
30
31        for i in range(len_pmf):
32            self.pmf[i] = 0.0
33
34
35cdef class SimplePmfGen(PmfGen):
36
37    def __init__(self,
38                 double[:, :, :, :] pmf_array):
39        PmfGen.__init__(self, pmf_array)
40        self.pmf = np.empty(pmf_array.shape[3])
41        if np.min(pmf_array) < 0:
42            raise ValueError("pmf should not have negative values.")
43
44    cdef double[:] get_pmf_c(self, double* point):
45        if trilinear_interpolate4d_c(self.data, point, self.pmf) != 0:
46            PmfGen.__clear_pmf(self)
47        return self.pmf
48
49
50cdef class SHCoeffPmfGen(PmfGen):
51
52    def __init__(self,
53                 double[:, :, :, :] shcoeff_array,
54                 object sphere,
55                 object basis_type):
56        cdef:
57            int sh_order
58
59        PmfGen.__init__(self, shcoeff_array)
60
61        self.sphere = sphere
62        sh_order = shm.order_from_ncoef(shcoeff_array.shape[3])
63        try:
64            basis = shm.sph_harm_lookup[basis_type]
65        except KeyError:
66            raise ValueError("%s is not a known basis type." % basis_type)
67        self.B, _, _ = basis(sh_order, sphere.theta, sphere.phi)
68        self.coeff = np.empty(shcoeff_array.shape[3])
69        self.pmf = np.empty(self.B.shape[0])
70
71    cdef double[:] get_pmf_c(self, double* point):
72        cdef:
73            cnp.npy_intp i, j
74            cnp.npy_intp len_pmf = self.pmf.shape[0]
75            cnp.npy_intp len_B = self.B.shape[1]
76            double _sum
77
78        if trilinear_interpolate4d_c(self.data, point, self.coeff) != 0:
79            PmfGen.__clear_pmf(self)
80        else:
81            for i in range(len_pmf):
82                _sum = 0
83                for j in range(len_B):
84                    _sum += self.B[i, j] * self.coeff[j]
85                self.pmf[i] = _sum
86        return self.pmf
87
88
89cdef class BootPmfGen(PmfGen):
90
91    def __init__(self,
92                 double[:, :, :, :] dwi_array,
93                 object model,
94                 object sphere,
95                 int sh_order=0,
96                 double tol=1e-2):
97        cdef:
98            double b_range
99            np.ndarray x, y, z, r
100            double[:] theta, phi
101            double[:, :] B
102
103        PmfGen.__init__(self, dwi_array)
104        self.sh_order = sh_order
105        if self.sh_order == 0:
106            if hasattr(model, "sh_order"):
107                self.sh_order = model.sh_order
108            else:
109                self.sh_order = 4 #  DEFAULT Value
110
111        self.dwi_mask = model.gtab.b0s_mask == 0
112        x, y, z = model.gtab.gradients[self.dwi_mask].T
113        r, theta, phi = shm.cart2sphere(x, y, z)
114        b_range = (r.max() - r.min()) / r.min()
115        if b_range > tol:
116            raise ValueError("BootPmfGen only supports single shell data")
117        B, _, _ = shm.real_sh_descoteaux(self.sh_order, theta, phi)
118        self.H = shm.hat(B)
119        self.R = shm.lcr_matrix(self.H)
120        self.vox_data = np.empty(dwi_array.shape[3])
121
122        self.model = model
123        self.sphere = sphere
124        self.pmf = np.empty(len(sphere.theta))
125
126
127    cdef double[:] get_pmf_c(self, double* point):
128        """Produces an ODF from a SH bootstrap sample"""
129        if trilinear_interpolate4d_c(self.data, point, self.vox_data) != 0:
130            self.__clear_pmf()
131        else:
132            self.vox_data[self.dwi_mask] = shm.bootstrap_data_voxel(
133                self.vox_data[self.dwi_mask], self.H, self.R)
134            self.pmf = self.model.fit(self.vox_data).odf(self.sphere)
135        return self.pmf
136
137
138    cpdef double[:] get_pmf_no_boot(self, double[::1] point):
139        return self.get_pmf_no_boot_c(&point[0])
140
141
142    cdef double[:] get_pmf_no_boot_c(self, double* point):
143        if trilinear_interpolate4d_c(self.data, point, self.vox_data) != 0:
144            self.__clear_pmf()
145        else:
146            self.pmf = self.model.fit(self.vox_data).odf(self.sphere)
147        return self.pmf
148