1import numpy as np 2cimport numpy as cnp 3 4from warnings import warn 5 6from dipy.direction.peaks import peak_directions, default_sphere 7from dipy.direction.pmf cimport SimplePmfGen, SHCoeffPmfGen 8from dipy.reconst.shm import order_from_ncoef, sph_harm_lookup 9from dipy.tracking.direction_getter cimport DirectionGetter 10from dipy.utils.fast_numpy cimport copy_point, scalar_muliplication_point 11 12 13cdef int closest_peak(np.ndarray[np.float_t, ndim=2] peak_dirs, 14 double* direction, double cos_similarity): 15 """Update direction with the closest direction from peak_dirs. 16 17 All directions should be unit vectors. Antipodal symmetry is assumed, ie 18 direction x is the same as -x. 19 20 Parameters 21 ---------- 22 peak_dirs : array (N, 3) 23 N unit vectors. 24 direction : array (3,) or None 25 Previous direction. The new direction is saved here. 26 cos_similarity : float 27 `cos(max_angle)` where `max_angle` is the maximum allowed angle between 28 prev_step and the returned direction. 29 30 Returns 31 ------- 32 0 : if ``direction`` is updated 33 1 : if no new direction is founded 34 """ 35 cdef: 36 cnp.npy_intp _len=len(peak_dirs) 37 cnp.npy_intp i 38 int closest_peak_i=-1 39 double _dot 40 double closest_peak_dot=0 41 42 for i in range(_len): 43 _dot = (peak_dirs[i,0] * direction[0] 44 + peak_dirs[i,1] * direction[1] 45 + peak_dirs[i,2] * direction[2]) 46 47 if np.abs(_dot) > np.abs(closest_peak_dot): 48 closest_peak_dot = _dot 49 closest_peak_i = i 50 51 if closest_peak_i >= 0: 52 if closest_peak_dot >= cos_similarity: 53 copy_point(&peak_dirs[closest_peak_i, 0], direction) 54 return 0 55 if closest_peak_dot <= -cos_similarity: 56 copy_point(&peak_dirs[closest_peak_i, 0], direction) 57 scalar_muliplication_point(direction, -1) 58 return 0 59 return 1 60 61 62cdef class BasePmfDirectionGetter(DirectionGetter): 63 """A base class for dynamic direction getters""" 64 65 def __init__(self, pmf_gen, max_angle, sphere, pmf_threshold=.1, **kwargs): 66 self.sphere = sphere 67 self._pf_kwargs = kwargs 68 self.pmf_gen = pmf_gen 69 if pmf_threshold < 0: 70 raise ValueError("pmf threshold must be >= 0.") 71 self.pmf_threshold = pmf_threshold 72 self.cos_similarity = np.cos(np.deg2rad(max_angle)) 73 74 def _get_peak_directions(self, blob): 75 """Gets directions using parameters provided at init. 76 77 Blob can be any function defined on ``self.sphere``, i.e. an ODF. 78 """ 79 return peak_directions(blob, self.sphere, **self._pf_kwargs)[0] 80 81 cpdef np.ndarray[np.float_t, ndim=2] initial_direction(self, 82 double[::1] point): 83 """Returns best directions at seed location to start tracking. 84 85 Parameters 86 ---------- 87 point : ndarray, shape (3,) 88 The point in an image at which to lookup tracking directions. 89 90 Returns 91 ------- 92 directions : ndarray, shape (N, 3) 93 Possible tracking directions from point. ``N`` may be 0, all 94 directions should be unique. 95 96 """ 97 cdef double[:] pmf = self._get_pmf(&point[0]) 98 return self._get_peak_directions(pmf) 99 100 cdef _get_pmf(self, double* point): 101 cdef: 102 cnp.npy_intp _len, i 103 double[:] pmf 104 double absolute_pmf_threshold 105 106 pmf = self.pmf_gen.get_pmf_c(point) 107 _len = pmf.shape[0] 108 109 absolute_pmf_threshold = self.pmf_threshold*np.max(pmf) 110 for i in range(_len): 111 if pmf[i] < absolute_pmf_threshold: 112 pmf[i] = 0.0 113 return pmf 114 115 116cdef class PmfGenDirectionGetter(BasePmfDirectionGetter): 117 """A base class for direction getter using a pmf""" 118 119 @classmethod 120 def from_pmf(klass, pmf, max_angle, sphere=default_sphere, 121 pmf_threshold=0.1, **kwargs): 122 """Constructor for making a DirectionGetter from an array of Pmfs 123 124 Parameters 125 ---------- 126 pmf : array, 4d 127 The pmf to be used for tracking at each voxel. 128 max_angle : float, [0, 90] 129 The maximum allowed angle between incoming direction and new 130 direction. 131 sphere : Sphere 132 The set of directions to be used for tracking. 133 pmf_threshold : float [0., 1.] 134 Used to remove direction from the probability mass function for 135 selecting the tracking direction. 136 relative_peak_threshold : float in [0., 1.] 137 Used for extracting initial tracking directions. Passed to 138 peak_directions. 139 min_separation_angle : float in [0, 90] 140 Used for extracting initial tracking directions. Passed to 141 peak_directions. 142 143 See also 144 -------- 145 dipy.direction.peaks.peak_directions 146 147 """ 148 if pmf.ndim != 4: 149 raise ValueError("pmf should be a 4d array.") 150 if pmf.shape[3] != len(sphere.theta): 151 msg = ("The last dimension of pmf should match the number of " 152 "points in sphere.") 153 raise ValueError(msg) 154 155 pmf_gen = SimplePmfGen(np.asarray(pmf,dtype=float)) 156 return klass(pmf_gen, max_angle, sphere, pmf_threshold, **kwargs) 157 158 @classmethod 159 def from_shcoeff(klass, shcoeff, max_angle, sphere=default_sphere, 160 pmf_threshold=0.1, basis_type=None, **kwargs): 161 """Probabilistic direction getter from a distribution of directions 162 on the sphere 163 164 Parameters 165 ---------- 166 shcoeff : array 167 The distribution of tracking directions at each voxel represented 168 as a function on the sphere using the real spherical harmonic 169 basis. For example the FOD of the Constrained Spherical 170 Deconvolution model can be used this way. This distribution will 171 be discretized using ``sphere`` and tracking directions will be 172 chosen from the vertices of ``sphere`` based on the distribution. 173 max_angle : float, [0, 90] 174 The maximum allowed angle between incoming direction and new 175 direction. 176 sphere : Sphere 177 The set of directions to be used for tracking. 178 pmf_threshold : float [0., 1.] 179 Used to remove direction from the probability mass function for 180 selecting the tracking direction. 181 basis_type : name of basis 182 The basis that ``shcoeff`` are associated with. 183 ``dipy.reconst.shm.real_sh_descoteaux`` is used by default. 184 relative_peak_threshold : float in [0., 1.] 185 Used for extracting initial tracking directions. Passed to 186 peak_directions. 187 min_separation_angle : float in [0, 90] 188 Used for extracting initial tracking directions. Passed to 189 peak_directions. 190 191 See also 192 -------- 193 dipy.direction.peaks.peak_directions 194 195 """ 196 pmf_gen = SHCoeffPmfGen(np.asarray(shcoeff,dtype=float), sphere, 197 basis_type) 198 return klass(pmf_gen, max_angle, sphere, pmf_threshold, **kwargs) 199 200 201cdef class ClosestPeakDirectionGetter(PmfGenDirectionGetter): 202 """A direction getter that returns the closest odf peak to previous tracking 203 direction. 204 """ 205 206 cdef int get_direction_c(self, double* point, double* direction): 207 """ 208 Returns 209 ------- 210 0 : if ``direction`` is updated 211 1 : if no new direction is founded 212 """ 213 cdef: 214 double[:] pmf 215 np.ndarray[np.float_t, ndim=2] peaks 216 217 pmf = self._get_pmf(point) 218 219 peaks = self._get_peak_directions(pmf) 220 if len(peaks) == 0: 221 return 1 222 return closest_peak(peaks, direction, self.cos_similarity) 223