1"""Nearest Neighbors graph functions"""
2
3# Author: Jake Vanderplas <vanderplas@astro.washington.edu>
4#         Tom Dupre la Tour
5#
6# License: BSD 3 clause (C) INRIA, University of Amsterdam
7from ._base import KNeighborsMixin, RadiusNeighborsMixin
8from ._base import NeighborsBase
9from ._unsupervised import NearestNeighbors
10from ..base import TransformerMixin
11from ..utils.validation import check_is_fitted
12
13
14def _check_params(X, metric, p, metric_params):
15    """Check the validity of the input parameters"""
16    params = zip(["metric", "p", "metric_params"], [metric, p, metric_params])
17    est_params = X.get_params()
18    for param_name, func_param in params:
19        if func_param != est_params[param_name]:
20            raise ValueError(
21                "Got %s for %s, while the estimator has %s for the same parameter."
22                % (func_param, param_name, est_params[param_name])
23            )
24
25
26def _query_include_self(X, include_self, mode):
27    """Return the query based on include_self param"""
28    if include_self == "auto":
29        include_self = mode == "connectivity"
30
31    # it does not include each sample as its own neighbors
32    if not include_self:
33        X = None
34
35    return X
36
37
38def kneighbors_graph(
39    X,
40    n_neighbors,
41    *,
42    mode="connectivity",
43    metric="minkowski",
44    p=2,
45    metric_params=None,
46    include_self=False,
47    n_jobs=None,
48):
49    """Computes the (weighted) graph of k-Neighbors for points in X
50
51    Read more in the :ref:`User Guide <unsupervised_neighbors>`.
52
53    Parameters
54    ----------
55    X : array-like of shape (n_samples, n_features) or BallTree
56        Sample data, in the form of a numpy array or a precomputed
57        :class:`BallTree`.
58
59    n_neighbors : int
60        Number of neighbors for each sample.
61
62    mode : {'connectivity', 'distance'}, default='connectivity'
63        Type of returned matrix: 'connectivity' will return the connectivity
64        matrix with ones and zeros, and 'distance' will return the distances
65        between neighbors according to the given metric.
66
67    metric : str, default='minkowski'
68        The distance metric to use for the tree. The default metric is
69        minkowski, and with p=2 is equivalent to the standard Euclidean
70        metric.
71        For a list of available metrics, see the documentation of
72        :class:`~sklearn.metrics.DistanceMetric`.
73
74    p : int, default=2
75        Power parameter for the Minkowski metric. When p = 1, this is
76        equivalent to using manhattan_distance (l1), and euclidean_distance
77        (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.
78
79    metric_params : dict, default=None
80        additional keyword arguments for the metric function.
81
82    include_self : bool or 'auto', default=False
83        Whether or not to mark each sample as the first nearest neighbor to
84        itself. If 'auto', then True is used for mode='connectivity' and False
85        for mode='distance'.
86
87    n_jobs : int, default=None
88        The number of parallel jobs to run for neighbors search.
89        ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
90        ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
91        for more details.
92
93    Returns
94    -------
95    A : sparse matrix of shape (n_samples, n_samples)
96        Graph where A[i, j] is assigned the weight of edge that
97        connects i to j. The matrix is of CSR format.
98
99    Examples
100    --------
101    >>> X = [[0], [3], [1]]
102    >>> from sklearn.neighbors import kneighbors_graph
103    >>> A = kneighbors_graph(X, 2, mode='connectivity', include_self=True)
104    >>> A.toarray()
105    array([[1., 0., 1.],
106           [0., 1., 1.],
107           [1., 0., 1.]])
108
109    See Also
110    --------
111    radius_neighbors_graph
112    """
113    if not isinstance(X, KNeighborsMixin):
114        X = NearestNeighbors(
115            n_neighbors=n_neighbors,
116            metric=metric,
117            p=p,
118            metric_params=metric_params,
119            n_jobs=n_jobs,
120        ).fit(X)
121    else:
122        _check_params(X, metric, p, metric_params)
123
124    query = _query_include_self(X._fit_X, include_self, mode)
125    return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode)
126
127
128def radius_neighbors_graph(
129    X,
130    radius,
131    *,
132    mode="connectivity",
133    metric="minkowski",
134    p=2,
135    metric_params=None,
136    include_self=False,
137    n_jobs=None,
138):
139    """Computes the (weighted) graph of Neighbors for points in X
140
141    Neighborhoods are restricted the points at a distance lower than
142    radius.
143
144    Read more in the :ref:`User Guide <unsupervised_neighbors>`.
145
146    Parameters
147    ----------
148    X : array-like of shape (n_samples, n_features) or BallTree
149        Sample data, in the form of a numpy array or a precomputed
150        :class:`BallTree`.
151
152    radius : float
153        Radius of neighborhoods.
154
155    mode : {'connectivity', 'distance'}, default='connectivity'
156        Type of returned matrix: 'connectivity' will return the connectivity
157        matrix with ones and zeros, and 'distance' will return the distances
158        between neighbors according to the given metric.
159
160    metric : str, default='minkowski'
161        The distance metric to use for the tree. The default metric is
162        minkowski, and with p=2 is equivalent to the standard Euclidean
163        metric.
164        For a list of available metrics, see the documentation of
165        :class:`~sklearn.metrics.DistanceMetric`.
166
167    p : int, default=2
168        Power parameter for the Minkowski metric. When p = 1, this is
169        equivalent to using manhattan_distance (l1), and euclidean_distance
170        (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.
171
172    metric_params : dict, default=None
173        additional keyword arguments for the metric function.
174
175    include_self : bool or 'auto', default=False
176        Whether or not to mark each sample as the first nearest neighbor to
177        itself. If 'auto', then True is used for mode='connectivity' and False
178        for mode='distance'.
179
180    n_jobs : int, default=None
181        The number of parallel jobs to run for neighbors search.
182        ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
183        ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
184        for more details.
185
186    Returns
187    -------
188    A : sparse matrix of shape (n_samples, n_samples)
189        Graph where A[i, j] is assigned the weight of edge that connects
190        i to j. The matrix is of CSR format.
191
192    Examples
193    --------
194    >>> X = [[0], [3], [1]]
195    >>> from sklearn.neighbors import radius_neighbors_graph
196    >>> A = radius_neighbors_graph(X, 1.5, mode='connectivity',
197    ...                            include_self=True)
198    >>> A.toarray()
199    array([[1., 0., 1.],
200           [0., 1., 0.],
201           [1., 0., 1.]])
202
203    See Also
204    --------
205    kneighbors_graph
206    """
207    if not isinstance(X, RadiusNeighborsMixin):
208        X = NearestNeighbors(
209            radius=radius,
210            metric=metric,
211            p=p,
212            metric_params=metric_params,
213            n_jobs=n_jobs,
214        ).fit(X)
215    else:
216        _check_params(X, metric, p, metric_params)
217
218    query = _query_include_self(X._fit_X, include_self, mode)
219    return X.radius_neighbors_graph(query, radius, mode)
220
221
222class KNeighborsTransformer(KNeighborsMixin, TransformerMixin, NeighborsBase):
223    """Transform X into a (weighted) graph of k nearest neighbors.
224
225    The transformed data is a sparse graph as returned by kneighbors_graph.
226
227    Read more in the :ref:`User Guide <neighbors_transformer>`.
228
229    .. versionadded:: 0.22
230
231    Parameters
232    ----------
233    mode : {'distance', 'connectivity'}, default='distance'
234        Type of returned matrix: 'connectivity' will return the connectivity
235        matrix with ones and zeros, and 'distance' will return the distances
236        between neighbors according to the given metric.
237
238    n_neighbors : int, default=5
239        Number of neighbors for each sample in the transformed sparse graph.
240        For compatibility reasons, as each sample is considered as its own
241        neighbor, one extra neighbor will be computed when mode == 'distance'.
242        In this case, the sparse graph contains (n_neighbors + 1) neighbors.
243
244    algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
245        Algorithm used to compute the nearest neighbors:
246
247        - 'ball_tree' will use :class:`BallTree`
248        - 'kd_tree' will use :class:`KDTree`
249        - 'brute' will use a brute-force search.
250        - 'auto' will attempt to decide the most appropriate algorithm
251          based on the values passed to :meth:`fit` method.
252
253        Note: fitting on sparse input will override the setting of
254        this parameter, using brute force.
255
256    leaf_size : int, default=30
257        Leaf size passed to BallTree or KDTree.  This can affect the
258        speed of the construction and query, as well as the memory
259        required to store the tree.  The optimal value depends on the
260        nature of the problem.
261
262    metric : str or callable, default='minkowski'
263        Metric to use for distance computation. Any metric from scikit-learn
264        or scipy.spatial.distance can be used.
265
266        If metric is a callable function, it is called on each
267        pair of instances (rows) and the resulting value recorded. The callable
268        should take two arrays as input and return one value indicating the
269        distance between them. This works for Scipy's metrics, but is less
270        efficient than passing the metric name as a string.
271
272        Distance matrices are not supported.
273
274        Valid values for metric are:
275
276        - from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2',
277          'manhattan']
278
279        - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
280          'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski',
281          'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao',
282          'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
283          'yule']
284
285        See the documentation for scipy.spatial.distance for details on these
286        metrics.
287
288    p : int, default=2
289        Parameter for the Minkowski metric from
290        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
291        equivalent to using manhattan_distance (l1), and euclidean_distance
292        (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.
293
294    metric_params : dict, default=None
295        Additional keyword arguments for the metric function.
296
297    n_jobs : int, default=1
298        The number of parallel jobs to run for neighbors search.
299        If ``-1``, then the number of jobs is set to the number of CPU cores.
300
301    Attributes
302    ----------
303    effective_metric_ : str or callable
304        The distance metric used. It will be same as the `metric` parameter
305        or a synonym of it, e.g. 'euclidean' if the `metric` parameter set to
306        'minkowski' and `p` parameter set to 2.
307
308    effective_metric_params_ : dict
309        Additional keyword arguments for the metric function. For most metrics
310        will be same with `metric_params` parameter, but may also contain the
311        `p` parameter value if the `effective_metric_` attribute is set to
312        'minkowski'.
313
314    n_features_in_ : int
315        Number of features seen during :term:`fit`.
316
317        .. versionadded:: 0.24
318
319    feature_names_in_ : ndarray of shape (`n_features_in_`,)
320        Names of features seen during :term:`fit`. Defined only when `X`
321        has feature names that are all strings.
322
323        .. versionadded:: 1.0
324
325    n_samples_fit_ : int
326        Number of samples in the fitted data.
327
328    See Also
329    --------
330    kneighbors_graph : Compute the weighted graph of k-neighbors for
331        points in X.
332    RadiusNeighborsTransformer : Transform X into a weighted graph of
333        neighbors nearer than a radius.
334
335    Examples
336    --------
337    >>> from sklearn.datasets import load_wine
338    >>> from sklearn.neighbors import KNeighborsTransformer
339    >>> X, _ = load_wine(return_X_y=True)
340    >>> X.shape
341    (178, 13)
342    >>> transformer = KNeighborsTransformer(n_neighbors=5, mode='distance')
343    >>> X_dist_graph = transformer.fit_transform(X)
344    >>> X_dist_graph.shape
345    (178, 178)
346    """
347
348    def __init__(
349        self,
350        *,
351        mode="distance",
352        n_neighbors=5,
353        algorithm="auto",
354        leaf_size=30,
355        metric="minkowski",
356        p=2,
357        metric_params=None,
358        n_jobs=1,
359    ):
360        super(KNeighborsTransformer, self).__init__(
361            n_neighbors=n_neighbors,
362            radius=None,
363            algorithm=algorithm,
364            leaf_size=leaf_size,
365            metric=metric,
366            p=p,
367            metric_params=metric_params,
368            n_jobs=n_jobs,
369        )
370        self.mode = mode
371
372    def fit(self, X, y=None):
373        """Fit the k-nearest neighbors transformer from the training dataset.
374
375        Parameters
376        ----------
377        X : {array-like, sparse matrix} of shape (n_samples, n_features) or \
378                (n_samples, n_samples) if metric='precomputed'
379            Training data.
380        y : Ignored
381            Not used, present for API consistency by convention.
382
383        Returns
384        -------
385        self : KNeighborsTransformer
386            The fitted k-nearest neighbors transformer.
387        """
388        return self._fit(X)
389
390    def transform(self, X):
391        """Compute the (weighted) graph of Neighbors for points in X.
392
393        Parameters
394        ----------
395        X : array-like of shape (n_samples_transform, n_features)
396            Sample data.
397
398        Returns
399        -------
400        Xt : sparse matrix of shape (n_samples_transform, n_samples_fit)
401            Xt[i, j] is assigned the weight of edge that connects i to j.
402            Only the neighbors have an explicit value.
403            The diagonal is always explicit.
404            The matrix is of CSR format.
405        """
406        check_is_fitted(self)
407        add_one = self.mode == "distance"
408        return self.kneighbors_graph(
409            X, mode=self.mode, n_neighbors=self.n_neighbors + add_one
410        )
411
412    def fit_transform(self, X, y=None):
413        """Fit to data, then transform it.
414
415        Fits transformer to X and y with optional parameters fit_params
416        and returns a transformed version of X.
417
418        Parameters
419        ----------
420        X : array-like of shape (n_samples, n_features)
421            Training set.
422
423        y : Ignored
424            Not used, present for API consistency by convention.
425
426        Returns
427        -------
428        Xt : sparse matrix of shape (n_samples, n_samples)
429            Xt[i, j] is assigned the weight of edge that connects i to j.
430            Only the neighbors have an explicit value.
431            The diagonal is always explicit.
432            The matrix is of CSR format.
433        """
434        return self.fit(X).transform(X)
435
436    def _more_tags(self):
437        return {
438            "_xfail_checks": {
439                "check_methods_sample_order_invariance": "check is not applicable."
440            }
441        }
442
443
444class RadiusNeighborsTransformer(RadiusNeighborsMixin, TransformerMixin, NeighborsBase):
445    """Transform X into a (weighted) graph of neighbors nearer than a radius.
446
447    The transformed data is a sparse graph as returned by
448    `radius_neighbors_graph`.
449
450    Read more in the :ref:`User Guide <neighbors_transformer>`.
451
452    .. versionadded:: 0.22
453
454    Parameters
455    ----------
456    mode : {'distance', 'connectivity'}, default='distance'
457        Type of returned matrix: 'connectivity' will return the connectivity
458        matrix with ones and zeros, and 'distance' will return the distances
459        between neighbors according to the given metric.
460
461    radius : float, default=1.0
462        Radius of neighborhood in the transformed sparse graph.
463
464    algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
465        Algorithm used to compute the nearest neighbors:
466
467        - 'ball_tree' will use :class:`BallTree`
468        - 'kd_tree' will use :class:`KDTree`
469        - 'brute' will use a brute-force search.
470        - 'auto' will attempt to decide the most appropriate algorithm
471          based on the values passed to :meth:`fit` method.
472
473        Note: fitting on sparse input will override the setting of
474        this parameter, using brute force.
475
476    leaf_size : int, default=30
477        Leaf size passed to BallTree or KDTree.  This can affect the
478        speed of the construction and query, as well as the memory
479        required to store the tree.  The optimal value depends on the
480        nature of the problem.
481
482    metric : str or callable, default='minkowski'
483        Metric to use for distance computation. Any metric from scikit-learn
484        or scipy.spatial.distance can be used.
485
486        If metric is a callable function, it is called on each
487        pair of instances (rows) and the resulting value recorded. The callable
488        should take two arrays as input and return one value indicating the
489        distance between them. This works for Scipy's metrics, but is less
490        efficient than passing the metric name as a string.
491
492        Distance matrices are not supported.
493
494        Valid values for metric are:
495
496        - from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2',
497          'manhattan']
498
499        - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
500          'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski',
501          'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao',
502          'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
503          'yule']
504
505        See the documentation for scipy.spatial.distance for details on these
506        metrics.
507
508    p : int, default=2
509        Parameter for the Minkowski metric from
510        sklearn.metrics.pairwise.pairwise_distances. When p = 1, this is
511        equivalent to using manhattan_distance (l1), and euclidean_distance
512        (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.
513
514    metric_params : dict, default=None
515        Additional keyword arguments for the metric function.
516
517    n_jobs : int, default=1
518        The number of parallel jobs to run for neighbors search.
519        If ``-1``, then the number of jobs is set to the number of CPU cores.
520
521    Attributes
522    ----------
523    effective_metric_ : str or callable
524        The distance metric used. It will be same as the `metric` parameter
525        or a synonym of it, e.g. 'euclidean' if the `metric` parameter set to
526        'minkowski' and `p` parameter set to 2.
527
528    effective_metric_params_ : dict
529        Additional keyword arguments for the metric function. For most metrics
530        will be same with `metric_params` parameter, but may also contain the
531        `p` parameter value if the `effective_metric_` attribute is set to
532        'minkowski'.
533
534    n_features_in_ : int
535        Number of features seen during :term:`fit`.
536
537        .. versionadded:: 0.24
538
539    feature_names_in_ : ndarray of shape (`n_features_in_`,)
540        Names of features seen during :term:`fit`. Defined only when `X`
541        has feature names that are all strings.
542
543        .. versionadded:: 1.0
544
545    n_samples_fit_ : int
546        Number of samples in the fitted data.
547
548    See Also
549    --------
550    kneighbors_graph : Compute the weighted graph of k-neighbors for
551        points in X.
552    KNeighborsTransformer : Transform X into a weighted graph of k
553        nearest neighbors.
554
555    Examples
556    --------
557    >>> import numpy as np
558    >>> from sklearn.datasets import load_wine
559    >>> from sklearn.cluster import DBSCAN
560    >>> from sklearn.neighbors import RadiusNeighborsTransformer
561    >>> from sklearn.pipeline import make_pipeline
562    >>> X, _ = load_wine(return_X_y=True)
563    >>> estimator = make_pipeline(
564    ...     RadiusNeighborsTransformer(radius=42.0, mode='distance'),
565    ...     DBSCAN(eps=25.0, metric='precomputed'))
566    >>> X_clustered = estimator.fit_predict(X)
567    >>> clusters, counts = np.unique(X_clustered, return_counts=True)
568    >>> print(counts)
569    [ 29  15 111  11  12]
570    """
571
572    def __init__(
573        self,
574        *,
575        mode="distance",
576        radius=1.0,
577        algorithm="auto",
578        leaf_size=30,
579        metric="minkowski",
580        p=2,
581        metric_params=None,
582        n_jobs=1,
583    ):
584        super(RadiusNeighborsTransformer, self).__init__(
585            n_neighbors=None,
586            radius=radius,
587            algorithm=algorithm,
588            leaf_size=leaf_size,
589            metric=metric,
590            p=p,
591            metric_params=metric_params,
592            n_jobs=n_jobs,
593        )
594        self.mode = mode
595
596    def fit(self, X, y=None):
597        """Fit the radius neighbors transformer from the training dataset.
598
599        Parameters
600        ----------
601        X : {array-like, sparse matrix} of shape (n_samples, n_features) or \
602                (n_samples, n_samples) if metric='precomputed'
603            Training data.
604
605        y : Ignored
606            Not used, present for API consistency by convention.
607
608        Returns
609        -------
610        self : RadiusNeighborsTransformer
611            The fitted radius neighbors transformer.
612        """
613        return self._fit(X)
614
615    def transform(self, X):
616        """Compute the (weighted) graph of Neighbors for points in X.
617
618        Parameters
619        ----------
620        X : array-like of shape (n_samples_transform, n_features)
621            Sample data.
622
623        Returns
624        -------
625        Xt : sparse matrix of shape (n_samples_transform, n_samples_fit)
626            Xt[i, j] is assigned the weight of edge that connects i to j.
627            Only the neighbors have an explicit value.
628            The diagonal is always explicit.
629            The matrix is of CSR format.
630        """
631        check_is_fitted(self)
632        return self.radius_neighbors_graph(X, mode=self.mode, sort_results=True)
633
634    def fit_transform(self, X, y=None):
635        """Fit to data, then transform it.
636
637        Fits transformer to X and y with optional parameters fit_params
638        and returns a transformed version of X.
639
640        Parameters
641        ----------
642        X : array-like of shape (n_samples, n_features)
643            Training set.
644
645        y : Ignored
646            Not used, present for API consistency by convention.
647
648        Returns
649        -------
650        Xt : sparse matrix of shape (n_samples, n_samples)
651            Xt[i, j] is assigned the weight of edge that connects i to j.
652            Only the neighbors have an explicit value.
653            The diagonal is always explicit.
654            The matrix is of CSR format.
655        """
656        return self.fit(X).transform(X)
657
658    def _more_tags(self):
659        return {
660            "_xfail_checks": {
661                "check_methods_sample_order_invariance": "check is not applicable."
662            }
663        }
664