1#cython: cdivision=True
2#cython: boundscheck=False
3#cython: nonecheck=False
4#cython: wraparound=False
5from libc.float cimport DBL_MAX
6
7import numpy as np
8cimport numpy as cnp
9
10from ..util import regular_grid
11from .._shared.fused_numerics cimport np_floats
12
13cnp.import_array()
14
15
16def _slic_cython(np_floats[:, :, :, ::1] image_zyx,
17                 cnp.uint8_t[:, :, ::1] mask,
18                 np_floats[:, ::1] segments,
19                 float step,
20                 Py_ssize_t max_num_iter,
21                 np_floats[::1] spacing,
22                 bint slic_zero,
23                 Py_ssize_t start_label=1,
24                 bint ignore_color=False):
25    """Helper function for SLIC segmentation.
26
27    Parameters
28    ----------
29    image_zyx : 4D array of np_floats, shape (Z, Y, X, C)
30        The input image.
31    mask : 3D array of bool, shape (Z, Y, X), optional
32        The input mask.
33    segments : 2D array of np_floats, shape (N, 3 + C)
34        The initial centroids obtained by SLIC as [Z, Y, X, C...].
35    step : np_floats
36        The size of the step between two seeds in voxels.
37    max_num_iter : int
38        The maximum number of k-means iterations.
39    spacing : 1D array of np_floats, shape (3,)
40        The voxel spacing along each image dimension. This parameter
41        controls the weights of the distances along z, y, and x during
42        k-means clustering.
43    slic_zero : bool
44        True to run SLIC-zero, False to run original SLIC.
45    start_label: int
46        The label indexing start value.
47    ignore_color : bool
48        True to update centroid positions without considering pixels
49        color.
50
51    Returns
52    -------
53    nearest_segments : 3D array of int, shape (Z, Y, X)
54        The label field/superpixels found by SLIC.
55
56    Notes
57    -----
58    The image is considered to be in (z, y, x) order, which can be
59    surprising. More commonly, the order (x, y, z) is used. However,
60    in 3D image analysis, 'z' is usually the "special" dimension, with,
61    for example, a different effective resolution than the other two
62    axes. Therefore, x and y are often processed together, or viewed as
63    a cut-plane through the volume. So, if the order was (x, y, z) and
64    we wanted to look at the 5th cut plane, we would write::
65
66        my_z_plane = img3d[:, :, 5]
67
68    but, assuming a C-contiguous array, this would grab a discontiguous
69    slice of memory, which is bad for performance. In contrast, if we
70    see the image as (z, y, x) ordered, we would do::
71
72        my_z_plane = img3d[5]
73
74    and get back a contiguous block of memory. This is better both for
75    performance and for readability.
76
77    """
78
79    if np_floats is cnp.float32_t:
80        dtype = np.float32
81    else:
82        dtype = np.float64
83
84    # initialize on grid
85    cdef Py_ssize_t depth, height, width
86    depth = image_zyx.shape[0]
87    height = image_zyx.shape[1]
88    width = image_zyx.shape[2]
89
90    cdef Py_ssize_t n_segments = segments.shape[0]
91    # number of features [X, Y, Z, ...]
92    cdef Py_ssize_t n_features = segments.shape[1]
93
94    # approximate grid size for desired n_segments
95    cdef Py_ssize_t step_z, step_y, step_x
96    slices = regular_grid((depth, height, width), n_segments)
97    step_z, step_y, step_x = [int(s.step if s.step is not None else 1)
98                              for s in slices]
99
100    # Add mask support
101    cdef bint use_mask = mask is not None
102    cdef Py_ssize_t mask_label = start_label - 1
103
104    cdef Py_ssize_t[:, :, ::1] nearest_segments \
105        = np.full((depth, height, width), mask_label, dtype=np.intp)
106    cdef np_floats[:, :, ::1] distance \
107        = np.empty((depth, height, width), dtype=dtype)
108    cdef Py_ssize_t[::1] n_segment_elems = np.empty(n_segments, dtype=np.intp)
109
110    cdef Py_ssize_t i, c, k, x, y, z, x_min, x_max, y_min, y_max, z_min, z_max
111    cdef bint change
112    cdef np_floats dist_center, cx, cy, cz, dx, dy, dz, t
113
114    cdef np_floats sz, sy, sx
115    sz = spacing[0]
116    sy = spacing[1]
117    sx = spacing[2]
118
119    # The colors are scaled before being passed to _slic_cython so
120    # max_color_sq can be initialised as all ones
121    cdef np_floats[::1] max_dist_color = np.ones(n_segments, dtype=dtype)
122    cdef np_floats dist_color
123
124    # The reference implementation (Achanta et al.) calls this invxywt
125    cdef np_floats spatial_weight = 1.0 / (step * step)
126
127    with nogil:
128        for i in range(max_num_iter):
129            change = False
130            distance[:, :, :] = DBL_MAX
131
132            # assign pixels to segments
133            for k in range(n_segments):
134
135                # segment coordinate centers
136                cz = segments[k, 0]
137                cy = segments[k, 1]
138                cx = segments[k, 2]
139
140                # compute windows
141                z_min = <Py_ssize_t>max(cz - 2 * step_z, 0)
142                z_max = <Py_ssize_t>min(cz + 2 * step_z + 1, depth)
143                y_min = <Py_ssize_t>max(cy - 2 * step_y, 0)
144                y_max = <Py_ssize_t>min(cy + 2 * step_y + 1, height)
145                x_min = <Py_ssize_t>max(cx - 2 * step_x, 0)
146                x_max = <Py_ssize_t>min(cx + 2 * step_x + 1, width)
147
148                for z in range(z_min, z_max):
149                    dz = sz * (cz - z)
150                    dz *= dz
151                    for y in range(y_min, y_max):
152                        dy = sy * (cy - y)
153                        dy *= dy
154                        for x in range(x_min, x_max):
155
156                            if use_mask and not mask[z, y, x]:
157                                continue
158
159                            dx = sx * (cx - x)
160                            dx *= dx
161                            dist_center = (dz + dy + dx) * spatial_weight
162
163                            if not ignore_color:
164                                dist_color = 0
165                                for c in range(3, n_features):
166                                    t = (image_zyx[z, y, x, c - 3]
167                                         - segments[k, c])
168                                    dist_color += t * t
169
170                                if slic_zero:
171                                    dist_color /= max_dist_color[k]
172                                dist_center += dist_color
173
174                            if distance[z, y, x] > dist_center:
175                                nearest_segments[z, y, x] = k + start_label
176                                distance[z, y, x] = dist_center
177                                change = True
178
179            # stop if no pixel changed its segment
180            if not change:
181                break
182
183            # recompute segment centers
184
185            # sum features for all segments
186            n_segment_elems[:] = 0
187            segments[:, :] = 0
188            for z in range(depth):
189                for y in range(height):
190                    for x in range(width):
191
192                        if use_mask:
193                            if not mask[z, y, x]:
194                                continue
195
196                            if nearest_segments[z, y, x] == mask_label:
197                                continue
198
199                        k = nearest_segments[z, y, x] - start_label
200                        n_segment_elems[k] += 1
201                        segments[k, 0] += z
202                        segments[k, 1] += y
203                        segments[k, 2] += x
204                        for c in range(3, n_features):
205                            segments[k, c] += image_zyx[z, y, x, c - 3]
206
207            # divide by number of elements per segment to obtain mean
208            for k in range(n_segments):
209                for c in range(n_features):
210                    segments[k, c] /= n_segment_elems[k]
211
212            # If in SLICO mode, update the color distance maxima
213            if slic_zero:
214                for z in range(depth):
215                    for y in range(height):
216                        for x in range(width):
217
218                            if use_mask:
219                                if not mask[z, y, x]:
220                                    continue
221
222                                if nearest_segments[z, y, x] == mask_label:
223                                    continue
224
225                            k = nearest_segments[z, y, x] - start_label
226                            dist_color = 0
227
228                            for c in range(3, n_features):
229                                t = image_zyx[z, y, x, c - 3] - segments[k, c]
230                                dist_color += t * t
231
232                            # The reference implementation seems to only change
233                            # the color if it increases from previous iteration
234                            if max_dist_color[k] < dist_color:
235                                max_dist_color[k] = dist_color
236
237    return np.asarray(nearest_segments)
238
239
240def _enforce_label_connectivity_cython(Py_ssize_t[:, :, ::1] segments,
241                                       Py_ssize_t min_size,
242                                       Py_ssize_t max_size,
243                                       Py_ssize_t start_label=1):
244    """ Helper function to remove small disconnected regions from the labels
245
246    Parameters
247    ----------
248    segments : 3D array of int, shape (Z, Y, X)
249        The label field/superpixels found by SLIC.
250    min_size : int
251        The minimum size of the segment
252    max_size : int
253        The maximum size of the segment. This is done for performance reasons,
254        to pre-allocate a sufficiently large array for the breadth first search
255    start_label : int
256        The label indexing start value.
257
258    Returns
259    -------
260    connected_segments : 3D array of int, shape (Z, Y, X)
261        A label field with connected labels starting at label=1
262    """
263
264    # get image dimensions
265    cdef Py_ssize_t depth, height, width
266    depth = segments.shape[0]
267    height = segments.shape[1]
268    width = segments.shape[2]
269
270    # neighborhood arrays
271    cdef Py_ssize_t[::1] ddx = np.array((1, -1, 0, 0, 0, 0), dtype=np.intp)
272    cdef Py_ssize_t[::1] ddy = np.array((0, 0, 1, -1, 0, 0), dtype=np.intp)
273    cdef Py_ssize_t[::1] ddz = np.array((0, 0, 0, 0, 1, -1), dtype=np.intp)
274
275    # new object with connected segments initialized to mask_label
276    cdef Py_ssize_t mask_label = start_label - 1
277
278    cdef Py_ssize_t[:, :, ::1] connected_segments \
279        = np.full_like(segments, mask_label, dtype=np.intp)
280
281    cdef Py_ssize_t current_new_label = start_label
282    cdef Py_ssize_t label = start_label
283
284    # variables for the breadth first search
285    cdef Py_ssize_t current_segment_size = 1
286    cdef Py_ssize_t bfs_visited = 0
287    cdef Py_ssize_t adjacent
288
289    cdef Py_ssize_t zz, yy, xx
290
291    cdef Py_ssize_t[:, ::1] coord_list = np.empty((max_size, 3), dtype=np.intp)
292
293    with nogil:
294        for z in range(depth):
295            for y in range(height):
296                for x in range(width):
297
298                    if segments[z, y, x] == mask_label:
299                        continue
300
301                    if connected_segments[z, y, x] > mask_label:
302                        continue
303
304                    # find the component size
305                    adjacent = 0
306                    label = segments[z, y, x]
307                    connected_segments[z, y, x] = current_new_label
308                    current_segment_size = 1
309                    bfs_visited = 0
310                    coord_list[bfs_visited, 0] = z
311                    coord_list[bfs_visited, 1] = y
312                    coord_list[bfs_visited, 2] = x
313
314                    #perform a breadth first search to find
315                    # the size of the connected component
316                    while bfs_visited < current_segment_size < max_size:
317                        for i in range(6):
318                            zz = coord_list[bfs_visited, 0] + ddz[i]
319                            yy = coord_list[bfs_visited, 1] + ddy[i]
320                            xx = coord_list[bfs_visited, 2] + ddx[i]
321                            if (0 <= xx < width and
322                                0 <= yy < height and
323                                0 <= zz < depth):
324                                if (segments[zz, yy, xx] == label and
325                                    connected_segments[zz, yy, xx] == mask_label):
326                                    connected_segments[zz, yy, xx] = \
327                                        current_new_label
328                                    coord_list[current_segment_size, 0] = zz
329                                    coord_list[current_segment_size, 1] = yy
330                                    coord_list[current_segment_size, 2] = xx
331                                    current_segment_size += 1
332                                    if current_segment_size >= max_size:
333                                        break
334                                elif (connected_segments[zz, yy, xx] > mask_label and
335                                      connected_segments[zz, yy, xx] != current_new_label):
336                                    adjacent = connected_segments[zz, yy, xx]
337                        bfs_visited += 1
338
339                    # change to an adjacent one, like in the original paper
340                    if current_segment_size < min_size:
341                        for i in range(current_segment_size):
342                            connected_segments[coord_list[i, 0],
343                                               coord_list[i, 1],
344                                               coord_list[i, 2]] = adjacent
345                    else:
346                        current_new_label += 1
347
348    return np.asarray(connected_segments)
349