1from collections.abc import Iterable
2from warnings import warn
3
4import numpy as np
5from numpy import random
6from scipy.cluster.vq import kmeans2
7from scipy.spatial.distance import pdist, squareform
8
9from .._shared import utils
10from .._shared.filters import gaussian
11from ..color import rgb2lab
12from ..util import img_as_float, regular_grid
13from ._slic import _enforce_label_connectivity_cython, _slic_cython
14
15
16def _get_mask_centroids(mask, n_centroids, multichannel):
17    """Find regularly spaced centroids on a mask.
18
19    Parameters
20    ----------
21    mask : 3D ndarray
22        The mask within which the centroids must be positioned.
23    n_centroids : int
24        The number of centroids to be returned.
25
26    Returns
27    -------
28    centroids : 2D ndarray
29        The coordinates of the centroids with shape (n_centroids, 3).
30    steps : 1D ndarray
31        The approximate distance between two seeds in all dimensions.
32
33    """
34
35    # Get tight ROI around the mask to optimize
36    coord = np.array(np.nonzero(mask), dtype=float).T
37    # Fix random seed to ensure repeatability
38    # Keep old-style RandomState here as expected results in tests depend on it
39    rnd = random.RandomState(123)
40
41    # select n_centroids randomly distributed points from within the mask
42    idx_full = np.arange(len(coord), dtype=int)
43    idx = np.sort(rnd.choice(idx_full,
44                             min(n_centroids, len(coord)),
45                             replace=False))
46
47    # To save time, when n_centroids << len(coords), use only a subset of the
48    # coordinates when calling k-means. Rather than the full set of coords,
49    # we will use a substantially larger subset than n_centroids. Here we
50    # somewhat arbitrarily choose dense_factor=10 to make the samples
51    # 10 times closer together along each axis than the n_centroids samples.
52    dense_factor = 10
53    ndim_spatial = mask.ndim - 1 if multichannel else mask.ndim
54    n_dense = int((dense_factor ** ndim_spatial) * n_centroids)
55    if len(coord) > n_dense:
56        # subset of points to use for the k-means calculation
57        # (much denser than idx, but less than the full set)
58        idx_dense = np.sort(rnd.choice(idx_full,
59                                       n_dense,
60                                       replace=False))
61    else:
62        idx_dense = Ellipsis
63    centroids, _ = kmeans2(coord[idx_dense], coord[idx], iter=5)
64
65    # Compute the minimum distance of each centroid to the others
66    dist = squareform(pdist(centroids))
67    np.fill_diagonal(dist, np.inf)
68    closest_pts = dist.argmin(-1)
69    steps = abs(centroids - centroids[closest_pts, :]).mean(0)
70
71    return centroids, steps
72
73
74def _get_grid_centroids(image, n_centroids):
75    """Find regularly spaced centroids on the image.
76
77    Parameters
78    ----------
79    image : 2D, 3D or 4D ndarray
80        Input image, which can be 2D or 3D, and grayscale or
81        multichannel.
82    n_centroids : int
83        The (approximate) number of centroids to be returned.
84
85    Returns
86    -------
87    centroids : 2D ndarray
88        The coordinates of the centroids with shape (~n_centroids, 3).
89    steps : 1D ndarray
90        The approximate distance between two seeds in all dimensions.
91
92    """
93    d, h, w = image.shape[:3]
94
95    grid_z, grid_y, grid_x = np.mgrid[:d, :h, :w]
96    slices = regular_grid(image.shape[:3], n_centroids)
97
98    centroids_z = grid_z[slices].ravel()[..., np.newaxis]
99    centroids_y = grid_y[slices].ravel()[..., np.newaxis]
100    centroids_x = grid_x[slices].ravel()[..., np.newaxis]
101
102    centroids = np.concatenate([centroids_z, centroids_y, centroids_x],
103                               axis=-1)
104
105    steps = np.asarray([float(s.step) if s.step is not None else 1.0
106                        for s in slices])
107    return centroids, steps
108
109
110@utils.channel_as_last_axis(multichannel_output=False)
111@utils.deprecate_multichannel_kwarg(multichannel_position=6)
112@utils.deprecate_kwarg({'max_iter': 'max_num_iter'}, removed_version="1.0",
113                       deprecated_version="0.19")
114def slic(image, n_segments=100, compactness=10., max_num_iter=10, sigma=0,
115         spacing=None, multichannel=True, convert2lab=None,
116         enforce_connectivity=True, min_size_factor=0.5, max_size_factor=3,
117         slic_zero=False, start_label=1, mask=None, *,
118         channel_axis=-1):
119    """Segments image using k-means clustering in Color-(x,y,z) space.
120
121    Parameters
122    ----------
123    image : 2D, 3D or 4D ndarray
124        Input image, which can be 2D or 3D, and grayscale or multichannel
125        (see `channel_axis` parameter).
126        Input image must either be NaN-free or the NaN's must be masked out
127    n_segments : int, optional
128        The (approximate) number of labels in the segmented output image.
129    compactness : float, optional
130        Balances color proximity and space proximity. Higher values give
131        more weight to space proximity, making superpixel shapes more
132        square/cubic. In SLICO mode, this is the initial compactness.
133        This parameter depends strongly on image contrast and on the
134        shapes of objects in the image. We recommend exploring possible
135        values on a log scale, e.g., 0.01, 0.1, 1, 10, 100, before
136        refining around a chosen value.
137    max_num_iter : int, optional
138        Maximum number of iterations of k-means.
139    sigma : float or array-like of floats, optional
140        Width of Gaussian smoothing kernel for pre-processing for each
141        dimension of the image. The same sigma is applied to each dimension in
142        case of a scalar value. Zero means no smoothing.
143        Note that `sigma` is automatically scaled if it is scalar and
144        if a manual voxel spacing is provided (see Notes section). If
145        sigma is array-like, its size must match ``image``'s number
146        of spatial dimensions.
147    spacing : array-like of floats, optional
148        The voxel spacing along each spatial dimension. By default,
149        `slic` assumes uniform spacing (same voxel resolution along
150        each spatial dimension).
151        This parameter controls the weights of the distances along the
152        spatial dimensions during k-means clustering.
153    multichannel : bool, optional
154        Whether the last axis of the image is to be interpreted as multiple
155        channels or another spatial dimension. This argument is deprecated:
156        specify `channel_axis` instead.
157    convert2lab : bool, optional
158        Whether the input should be converted to Lab colorspace prior to
159        segmentation. The input image *must* be RGB. Highly recommended.
160        This option defaults to ``True`` when ``channel_axis` is not None *and*
161        ``image.shape[-1] == 3``.
162    enforce_connectivity : bool, optional
163        Whether the generated segments are connected or not
164    min_size_factor : float, optional
165        Proportion of the minimum segment size to be removed with respect
166        to the supposed segment size ```depth*width*height/n_segments```
167    max_size_factor : float, optional
168        Proportion of the maximum connected segment size. A value of 3 works
169        in most of the cases.
170    slic_zero : bool, optional
171        Run SLIC-zero, the zero-parameter mode of SLIC. [2]_
172    start_label : int, optional
173        The labels' index start. Should be 0 or 1.
174
175        .. versionadded:: 0.17
176           ``start_label`` was introduced in 0.17
177    mask : ndarray, optional
178        If provided, superpixels are computed only where mask is True,
179        and seed points are homogeneously distributed over the mask
180        using a k-means clustering strategy. Mask number of dimensions
181        must be equal to image number of spatial dimensions.
182
183        .. versionadded:: 0.17
184           ``mask`` was introduced in 0.17
185    channel_axis : int or None, optional
186        If None, the image is assumed to be a grayscale (single channel) image.
187        Otherwise, this parameter indicates which axis of the array corresponds
188        to channels.
189
190        .. versionadded:: 0.19
191           ``channel_axis`` was added in 0.19.
192
193    Returns
194    -------
195    labels : 2D or 3D array
196        Integer mask indicating segment labels.
197
198    Raises
199    ------
200    ValueError
201        If ``convert2lab`` is set to ``True`` but the last array
202        dimension is not of length 3.
203    ValueError
204        If ``start_label`` is not 0 or 1.
205
206    Notes
207    -----
208    * If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
209      segmentation.
210
211    * If `sigma` is scalar and `spacing` is provided, the kernel width is
212      divided along each dimension by the spacing. For example, if ``sigma=1``
213      and ``spacing=[5, 1, 1]``, the effective `sigma` is ``[0.2, 1, 1]``. This
214      ensures sensible smoothing for anisotropic images.
215
216    * The image is rescaled to be in [0, 1] prior to processing.
217
218    * Images of shape (M, N, 3) are interpreted as 2D RGB images by default. To
219      interpret them as 3D with the last dimension having length 3, use
220      `channel_axis=None`.
221
222    * `start_label` is introduced to handle the issue [4]_. Label indexing
223      starts at 1 by default.
224
225    References
226    ----------
227    .. [1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi,
228        Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to
229        State-of-the-art Superpixel Methods, TPAMI, May 2012.
230        :DOI:`10.1109/TPAMI.2012.120`
231    .. [2] https://www.epfl.ch/labs/ivrl/research/slic-superpixels/#SLICO
232    .. [3] Irving, Benjamin. "maskSLIC: regional superpixel generation with
233           application to local pathology characterisation in medical images.",
234           2016, :arXiv:`1606.09518`
235    .. [4] https://github.com/scikit-image/scikit-image/issues/3722
236
237    Examples
238    --------
239    >>> from skimage.segmentation import slic
240    >>> from skimage.data import astronaut
241    >>> img = astronaut()
242    >>> segments = slic(img, n_segments=100, compactness=10)
243
244    Increasing the compactness parameter yields more square regions:
245
246    >>> segments = slic(img, n_segments=100, compactness=20)
247
248    """
249
250    image = img_as_float(image)
251    float_dtype = utils._supported_float_type(image.dtype)
252    # copy=True so subsequent in-place operations do not modify the
253    # function input
254    image = image.astype(float_dtype, copy=True)
255
256    # Rescale image to [0, 1] to make choice of compactness insensitive to
257    # input image scale.
258    image -= image.min()
259    imax = image.max()
260    if imax != 0:
261        image /= imax
262
263    use_mask = mask is not None
264    dtype = image.dtype
265
266    is_2d = False
267
268    multichannel = channel_axis is not None
269    if image.ndim == 2:
270        # 2D grayscale image
271        image = image[np.newaxis, ..., np.newaxis]
272        is_2d = True
273    elif image.ndim == 3 and multichannel:
274        # Make 2D multichannel image 3D with depth = 1
275        image = image[np.newaxis, ...]
276        is_2d = True
277    elif image.ndim == 3 and not multichannel:
278        # Add channel as single last dimension
279        image = image[..., np.newaxis]
280
281    if multichannel and (convert2lab or convert2lab is None):
282        if image.shape[channel_axis] != 3 and convert2lab:
283            raise ValueError("Lab colorspace conversion requires a RGB image.")
284        elif image.shape[channel_axis] == 3:
285            image = rgb2lab(image)
286
287    if start_label not in [0, 1]:
288        raise ValueError("start_label should be 0 or 1.")
289
290    # initialize cluster centroids for desired number of segments
291    update_centroids = False
292    if use_mask:
293        mask = np.ascontiguousarray(mask, dtype=bool).view('uint8')
294        if mask.ndim == 2:
295            mask = np.ascontiguousarray(mask[np.newaxis, ...])
296        if mask.shape != image.shape[:3]:
297            raise ValueError("image and mask should have the same shape.")
298        centroids, steps = _get_mask_centroids(mask, n_segments, multichannel)
299        update_centroids = True
300    else:
301        centroids, steps = _get_grid_centroids(image, n_segments)
302
303    if spacing is None:
304        spacing = np.ones(3, dtype=dtype)
305    elif isinstance(spacing, Iterable):
306        spacing = np.asarray(spacing, dtype=dtype)
307        if is_2d:
308            if spacing.size != 2:
309                if spacing.size == 3:
310                    warn("Input image is 2D: spacing number of "
311                         "elements must be 2. In the future, a ValueError "
312                         "will be raised.", FutureWarning, stacklevel=2)
313                else:
314                    raise ValueError(f"Input image is 2D, but spacing has "
315                                     f"{spacing.size} elements (expected 2).")
316            else:
317                spacing = np.insert(spacing, 0, 1)
318        elif spacing.size != 3:
319            raise ValueError(f"Input image is 3D, but spacing has "
320                             f"{spacing.size} elements (expected 3).")
321        spacing = np.ascontiguousarray(spacing, dtype=dtype)
322    else:
323        raise TypeError("spacing must be None or iterable.")
324
325    if np.isscalar(sigma):
326        sigma = np.array([sigma, sigma, sigma], dtype=dtype)
327        sigma /= spacing
328    elif isinstance(sigma, Iterable):
329        sigma = np.asarray(sigma, dtype=dtype)
330        if is_2d:
331            if sigma.size != 2:
332                if spacing.size == 3:
333                    warn("Input image is 2D: sigma number of "
334                         "elements must be 2. In the future, a ValueError "
335                         "will be raised.", FutureWarning, stacklevel=2)
336                else:
337                    raise ValueError(f"Input image is 2D, but sigma has "
338                                     f"{sigma.size} elements (expected 2).")
339            else:
340                sigma = np.insert(sigma, 0, 0)
341        elif sigma.size != 3:
342            raise ValueError(f"Input image is 3D, but sigma has "
343                             f"{sigma.size} elements (expected 3).")
344
345    if (sigma > 0).any():
346        # add zero smoothing for channel dimension
347        sigma = list(sigma) + [0]
348        image = gaussian(image, sigma, mode='reflect')
349
350    n_centroids = centroids.shape[0]
351    segments = np.ascontiguousarray(np.concatenate(
352        [centroids, np.zeros((n_centroids, image.shape[3]))],
353        axis=-1), dtype=dtype)
354
355    # Scaling of ratio in the same way as in the SLIC paper so the
356    # values have the same meaning
357    step = max(steps)
358    ratio = 1.0 / compactness
359
360    image = np.ascontiguousarray(image * ratio, dtype=dtype)
361
362    if update_centroids:
363        # Step 2 of the algorithm [3]_
364        _slic_cython(image, mask, segments, step, max_num_iter, spacing,
365                     slic_zero, ignore_color=True,
366                     start_label=start_label)
367
368    labels = _slic_cython(image, mask, segments, step, max_num_iter,
369                          spacing, slic_zero, ignore_color=False,
370                          start_label=start_label)
371
372    if enforce_connectivity:
373        if use_mask:
374            segment_size = mask.sum() / n_centroids
375        else:
376            segment_size = np.prod(image.shape[:3]) / n_centroids
377        min_size = int(min_size_factor * segment_size)
378        max_size = int(max_size_factor * segment_size)
379        labels = _enforce_label_connectivity_cython(
380            labels, min_size, max_size, start_label=start_label)
381
382    if is_2d:
383        labels = labels[0]
384
385    return labels
386