1# cython: boundscheck=False
2# cython: cdivision=True
3# cython: initializedcheck=False
4# cython: wraparound=False
5
6cimport cython
7cimport numpy as np
8
9cdef extern from "dpy_math.h" nogil:
10    int dpy_rint(double)
11
12from dipy.core.interpolation cimport trilinear_interpolate4d_c
13
14import numpy as np
15
16cdef class StoppingCriterion:
17    cpdef StreamlineStatus check_point(self, double[::1] point):
18        if point.shape[0] != 3:
19            raise ValueError("Point has wrong shape")
20
21        return self.check_point_c(&point[0])
22
23    cdef StreamlineStatus check_point_c(self, double* point):
24         pass
25
26
27cdef class BinaryStoppingCriterion(StoppingCriterion):
28    """
29    cdef:
30        unsigned char[:, :, :] mask
31    """
32
33    def __cinit__(self, mask):
34        self.interp_out_view = self.interp_out_double
35        self.mask = (mask > 0).astype('uint8')
36
37    cdef StreamlineStatus check_point_c(self, double* point):
38        cdef:
39            unsigned char result
40            int err
41            int voxel[3]
42
43        voxel[0] = int(dpy_rint(point[0]))
44        voxel[1] = int(dpy_rint(point[1]))
45        voxel[2] = int(dpy_rint(point[2]))
46
47        if (voxel[0] < 0 or voxel[0] >= self.mask.shape[0]
48                or voxel[1] < 0 or voxel[1] >= self.mask.shape[1]
49                or voxel[2] < 0 or voxel[2] >= self.mask.shape[2]):
50            return OUTSIDEIMAGE
51
52        result = self.mask[voxel[0], voxel[1], voxel[2]]
53
54        if result > 0:
55            return TRACKPOINT
56        else:
57            return ENDPOINT
58
59
60cdef class ThresholdStoppingCriterion(StoppingCriterion):
61    """
62    # Declarations from stopping_criterion.pxd bellow
63    cdef:
64        double threshold, interp_out_double[1]
65        double[:]  interp_out_view = interp_out_view
66        double[:, :, :] metric_map
67    """
68
69    def __cinit__(self, metric_map, double threshold):
70        self.interp_out_view = self.interp_out_double
71        self.metric_map = np.asarray(metric_map, 'float64')
72        self.threshold = threshold
73
74    cdef StreamlineStatus check_point_c(self, double* point):
75        cdef:
76            double result
77            int err
78
79        err = trilinear_interpolate4d_c(
80            self.metric_map[..., None],
81            point,
82            self.interp_out_view)
83        if err == -1:
84            return OUTSIDEIMAGE
85        elif err != 0:
86            # This should never happen
87            raise RuntimeError(
88                "Unexpected interpolation error (code:%i)" % err)
89
90        result = self.interp_out_view[0]
91
92        if result > self.threshold:
93            return TRACKPOINT
94        else:
95            return ENDPOINT
96
97
98cdef class AnatomicalStoppingCriterion(StoppingCriterion):
99    r"""
100    Abstract class that takes as input included and excluded tissue maps.
101    The 'include_map' defines when the streamline reached a 'valid' stopping
102    region (e.g. gray matter partial volume estimation (PVE) map) and the
103    'exclude_map' defines when the streamline reached an 'invalid' stopping
104    region (e.g. corticospinal fluid PVE map). The background of the anatomical
105    image should be added to the 'include_map' to keep streamlines exiting the
106    brain (e.g. through the brain stem).
107
108    cdef:
109        double interp_out_double[1]
110        double[:]  interp_out_view = interp_out_view
111        double[:, :, :] include_map, exclude_map
112
113    """
114    def __cinit__(self, include_map, exclude_map, *args, **kw):
115        self.interp_out_view = self.interp_out_double
116        self.include_map = np.asarray(include_map, 'float64')
117        self.exclude_map = np.asarray(exclude_map, 'float64')
118
119    @classmethod
120    def from_pve(klass, wm_map, gm_map, csf_map, **kw):
121        """AnatomicalStoppingCriterion from partial volume fraction (PVE)
122        maps.
123
124        Parameters
125        ----------
126        wm_map : array
127            The partial volume fraction of white matter at each voxel.
128        gm_map : array
129            The partial volume fraction of gray matter at each voxel.
130        csf_map : array
131            The partial volume fraction of corticospinal fluid at each
132            voxel.
133
134        """
135        # include map = gray matter + image background
136        include_map = np.copy(gm_map)
137        include_map[(wm_map + gm_map + csf_map) == 0] = 1
138        # exclude map = csf
139        exclude_map = np.copy(csf_map)
140        return klass(include_map, exclude_map, **kw)
141
142    cpdef double get_exclude(self, double[::1] point):
143        if point.shape[0] != 3:
144            raise ValueError("Point has wrong shape")
145
146        return self.get_exclude_c(&point[0])
147
148    cdef get_exclude_c(self, double* point):
149        exclude_err = trilinear_interpolate4d_c(self.exclude_map[..., None],
150                                                point, self.interp_out_view)
151        if exclude_err != 0:
152            return 0
153        return self.interp_out_view[0]
154
155    cpdef double get_include(self, double[::1] point):
156        if point.shape[0] != 3:
157            raise ValueError("Point has wrong shape")
158
159        return self.get_include_c(&point[0])
160
161    cdef get_include_c(self, double* point):
162        exclude_err = trilinear_interpolate4d_c(self.include_map[..., None],
163                                                point, self.interp_out_view)
164        if exclude_err != 0:
165            return 0
166        return self.interp_out_view[0]
167
168
169cdef class ActStoppingCriterion(AnatomicalStoppingCriterion):
170    r"""
171    Anatomically-Constrained Tractography (ACT) stopping criterion from [1]_.
172    This implements the use of partial volume fraction (PVE) maps to
173    determine when the tracking stops. The proposed ([1]_) method that
174    cuts streamlines going through subcortical gray matter regions is
175    not implemented here. The backtracking technique for
176    streamlines reaching INVALIDPOINT is not implemented either.
177    cdef:
178        double interp_out_double[1]
179        double[:]  interp_out_view = interp_out_view
180        double[:, :, :] include_map, exclude_map
181    References
182    ----------
183    .. [1] Smith, R. E., Tournier, J.-D., Calamante, F., & Connelly, A.
184    "Anatomically-constrained tractography: Improved diffusion MRI
185    streamlines tractography through effective use of anatomical
186    information." NeuroImage, 63(3), 1924-1938, 2012.
187    """
188
189    def __cinit__(self, include_map, exclude_map):
190        self.interp_out_view = self.interp_out_double
191        self.include_map = np.asarray(include_map, 'float64')
192        self.exclude_map = np.asarray(exclude_map, 'float64')
193
194    cdef StreamlineStatus check_point_c(self, double* point):
195        cdef:
196            double include_result, exclude_result
197            int include_err, exclude_err
198
199        include_err = trilinear_interpolate4d_c(
200            self.include_map[..., None],
201            point,
202            self.interp_out_view)
203        include_result = self.interp_out_view[0]
204
205        exclude_err = trilinear_interpolate4d_c(
206            self.exclude_map[..., None],
207            point,
208            self.interp_out_view)
209        exclude_result = self.interp_out_view[0]
210
211        if include_err == -1 or exclude_err == -1:
212            return OUTSIDEIMAGE
213        elif include_err != 0:
214            # This should never happen
215            raise RuntimeError("Unexpected interpolation error " +
216                               "(include_map - code:%i)" % include_err)
217        elif exclude_err != 0:
218            # This should never happen
219            raise RuntimeError("Unexpected interpolation error " +
220                               "(exclude_map - code:%i)" % exclude_err)
221
222        if include_result > 0.5:
223            return ENDPOINT
224        elif exclude_result > 0.5:
225            return INVALIDPOINT
226        else:
227            return TRACKPOINT
228
229
230cdef class CmcStoppingCriterion(AnatomicalStoppingCriterion):
231    r"""
232    Continuous map criterion (CMC) stopping criterion from [1]_.
233    This implements the use of partial volume fraction (PVE) maps to
234    determine when the tracking stops.
235
236    cdef:
237        double interp_out_double[1]
238        double[:]  interp_out_view = interp_out_view
239        double[:, :, :] include_map, exclude_map
240        double step_size
241        double average_voxel_size
242        double correction_factor
243
244    References
245    ----------
246    .. [1] Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M.
247    "Towards quantitative connectivity analysis: reducing tractography biases."
248    NeuroImage, 98, 266-278, 2014.
249    """
250
251    def __cinit__(self, include_map, exclude_map, step_size, average_voxel_size):
252        self.step_size = step_size
253        self.average_voxel_size = average_voxel_size
254        self.correction_factor = step_size / average_voxel_size
255
256    cdef StreamlineStatus check_point_c(self, double* point):
257        cdef:
258            double include_result, exclude_result
259            int include_err, exclude_err
260
261        include_err = trilinear_interpolate4d_c(self.include_map[..., None],
262                                                point, self.interp_out_view)
263        include_result = self.interp_out_view[0]
264
265        exclude_err = trilinear_interpolate4d_c(self.exclude_map[..., None],
266                                                point, self.interp_out_view)
267        exclude_result = self.interp_out_view[0]
268
269        if include_err == -1 or exclude_err == -1:
270            return OUTSIDEIMAGE
271        elif include_err == -2 or exclude_err == -2:
272            raise ValueError("Point has wrong shape")
273        elif include_err != 0:
274            # This should never happen
275            raise RuntimeError("Unexpected interpolation error " +
276                               "(include_map - code:%i)" % include_err)
277        elif exclude_err != 0:
278            # This should never happen
279            raise RuntimeError("Unexpected interpolation error " +
280                               "(exclude_map - code:%i)" % exclude_err)
281
282        # test if the tracking continues
283        if include_result + exclude_result <= 0:
284            return TRACKPOINT
285        num = max(0, (1 - include_result - exclude_result))
286        den = num + include_result + exclude_result
287        p = (num / den) ** self.correction_factor
288        if np.random.random() < p:
289            return TRACKPOINT
290
291        # test if the tracking stopped in the include tissue map
292        p = (include_result / (include_result + exclude_result))
293        if np.random.random() < p:
294            return ENDPOINT
295
296        # the tracking stopped in the exclude tissue map
297        return INVALIDPOINT
298