1import numpy as np
2cimport numpy as cnp
3
4from warnings import warn
5
6from dipy.direction.peaks import peak_directions, default_sphere
7from dipy.direction.pmf cimport SimplePmfGen, SHCoeffPmfGen
8from dipy.reconst.shm import order_from_ncoef, sph_harm_lookup
9from dipy.tracking.direction_getter cimport DirectionGetter
10from dipy.utils.fast_numpy cimport copy_point, scalar_muliplication_point
11
12
13cdef int closest_peak(np.ndarray[np.float_t, ndim=2] peak_dirs,
14                      double* direction, double cos_similarity):
15    """Update direction with the closest direction from peak_dirs.
16
17    All directions should be unit vectors. Antipodal symmetry is assumed, ie
18    direction x is the same as -x.
19
20    Parameters
21    ----------
22    peak_dirs : array (N, 3)
23        N unit vectors.
24    direction : array (3,) or None
25        Previous direction. The new direction is saved here.
26    cos_similarity : float
27        `cos(max_angle)` where `max_angle` is the maximum allowed angle between
28        prev_step and the returned direction.
29
30    Returns
31    -------
32    0 : if ``direction`` is updated
33    1 : if no new direction is founded
34    """
35    cdef:
36        cnp.npy_intp _len=len(peak_dirs)
37        cnp.npy_intp i
38        int closest_peak_i=-1
39        double _dot
40        double closest_peak_dot=0
41
42    for i in range(_len):
43        _dot = (peak_dirs[i,0] * direction[0]
44                + peak_dirs[i,1] * direction[1]
45                + peak_dirs[i,2] * direction[2])
46
47        if np.abs(_dot) > np.abs(closest_peak_dot):
48            closest_peak_dot = _dot
49            closest_peak_i = i
50
51    if closest_peak_i >= 0:
52        if closest_peak_dot >= cos_similarity:
53            copy_point(&peak_dirs[closest_peak_i, 0], direction)
54            return 0
55        if closest_peak_dot <= -cos_similarity:
56            copy_point(&peak_dirs[closest_peak_i, 0], direction)
57            scalar_muliplication_point(direction, -1)
58            return 0
59    return 1
60
61
62cdef class BasePmfDirectionGetter(DirectionGetter):
63    """A base class for dynamic direction getters"""
64
65    def __init__(self, pmf_gen, max_angle, sphere, pmf_threshold=.1, **kwargs):
66        self.sphere = sphere
67        self._pf_kwargs = kwargs
68        self.pmf_gen = pmf_gen
69        if pmf_threshold < 0:
70            raise ValueError("pmf threshold must be >= 0.")
71        self.pmf_threshold = pmf_threshold
72        self.cos_similarity = np.cos(np.deg2rad(max_angle))
73
74    def _get_peak_directions(self, blob):
75        """Gets directions using parameters provided at init.
76
77        Blob can be any function defined on ``self.sphere``, i.e. an ODF.
78        """
79        return peak_directions(blob, self.sphere, **self._pf_kwargs)[0]
80
81    cpdef np.ndarray[np.float_t, ndim=2] initial_direction(self,
82                                                           double[::1] point):
83        """Returns best directions at seed location to start tracking.
84
85        Parameters
86        ----------
87        point : ndarray, shape (3,)
88            The point in an image at which to lookup tracking directions.
89
90        Returns
91        -------
92        directions : ndarray, shape (N, 3)
93            Possible tracking directions from point. ``N`` may be 0, all
94            directions should be unique.
95
96        """
97        cdef double[:] pmf = self._get_pmf(&point[0])
98        return self._get_peak_directions(pmf)
99
100    cdef _get_pmf(self, double* point):
101        cdef:
102            cnp.npy_intp _len, i
103            double[:] pmf
104            double absolute_pmf_threshold
105
106        pmf = self.pmf_gen.get_pmf_c(point)
107        _len = pmf.shape[0]
108
109        absolute_pmf_threshold = self.pmf_threshold*np.max(pmf)
110        for i in range(_len):
111            if pmf[i] < absolute_pmf_threshold:
112                pmf[i] = 0.0
113        return pmf
114
115
116cdef class PmfGenDirectionGetter(BasePmfDirectionGetter):
117    """A base class for direction getter using a pmf"""
118
119    @classmethod
120    def from_pmf(klass, pmf, max_angle, sphere=default_sphere,
121                 pmf_threshold=0.1, **kwargs):
122        """Constructor for making a DirectionGetter from an array of Pmfs
123
124        Parameters
125        ----------
126        pmf : array, 4d
127            The pmf to be used for tracking at each voxel.
128        max_angle : float, [0, 90]
129            The maximum allowed angle between incoming direction and new
130            direction.
131        sphere : Sphere
132            The set of directions to be used for tracking.
133        pmf_threshold : float [0., 1.]
134            Used to remove direction from the probability mass function for
135            selecting the tracking direction.
136        relative_peak_threshold : float in [0., 1.]
137            Used for extracting initial tracking directions. Passed to
138            peak_directions.
139        min_separation_angle : float in [0, 90]
140            Used for extracting initial tracking directions. Passed to
141            peak_directions.
142
143        See also
144        --------
145        dipy.direction.peaks.peak_directions
146
147        """
148        if pmf.ndim != 4:
149            raise ValueError("pmf should be a 4d array.")
150        if pmf.shape[3] != len(sphere.theta):
151            msg = ("The last dimension of pmf should match the number of "
152                   "points in sphere.")
153            raise ValueError(msg)
154
155        pmf_gen = SimplePmfGen(np.asarray(pmf,dtype=float))
156        return klass(pmf_gen, max_angle, sphere, pmf_threshold, **kwargs)
157
158    @classmethod
159    def from_shcoeff(klass, shcoeff, max_angle, sphere=default_sphere,
160                     pmf_threshold=0.1, basis_type=None, **kwargs):
161        """Probabilistic direction getter from a distribution of directions
162        on the sphere
163
164        Parameters
165        ----------
166        shcoeff : array
167            The distribution of tracking directions at each voxel represented
168            as a function on the sphere using the real spherical harmonic
169            basis. For example the FOD of the Constrained Spherical
170            Deconvolution model can be used this way. This distribution will
171            be discretized using ``sphere`` and tracking directions will be
172            chosen from the vertices of ``sphere`` based on the distribution.
173        max_angle : float, [0, 90]
174            The maximum allowed angle between incoming direction and new
175            direction.
176        sphere : Sphere
177            The set of directions to be used for tracking.
178        pmf_threshold : float [0., 1.]
179            Used to remove direction from the probability mass function for
180            selecting the tracking direction.
181        basis_type : name of basis
182            The basis that ``shcoeff`` are associated with.
183            ``dipy.reconst.shm.real_sh_descoteaux`` is used by default.
184        relative_peak_threshold : float in [0., 1.]
185            Used for extracting initial tracking directions. Passed to
186            peak_directions.
187        min_separation_angle : float in [0, 90]
188            Used for extracting initial tracking directions. Passed to
189            peak_directions.
190
191        See also
192        --------
193        dipy.direction.peaks.peak_directions
194
195        """
196        pmf_gen = SHCoeffPmfGen(np.asarray(shcoeff,dtype=float), sphere,
197                                basis_type)
198        return klass(pmf_gen, max_angle, sphere, pmf_threshold, **kwargs)
199
200
201cdef class ClosestPeakDirectionGetter(PmfGenDirectionGetter):
202    """A direction getter that returns the closest odf peak to previous tracking
203    direction.
204    """
205
206    cdef int get_direction_c(self, double* point, double* direction):
207        """
208        Returns
209        -------
210        0 : if ``direction`` is updated
211        1 : if no new direction is founded
212        """
213        cdef:
214            double[:] pmf
215            np.ndarray[np.float_t, ndim=2] peaks
216
217        pmf = self._get_pmf(point)
218
219        peaks = self._get_peak_directions(pmf)
220        if len(peaks) == 0:
221            return 1
222        return closest_peak(peaks, direction, self.cos_similarity)
223