1"""Hierarchical Agglomerative Clustering
2
3These routines perform some hierarchical agglomerative clustering of some
4input data.
5
6Authors : Vincent Michel, Bertrand Thirion, Alexandre Gramfort,
7          Gael Varoquaux
8License: BSD 3 clause
9"""
10import warnings
11from heapq import heapify, heappop, heappush, heappushpop
12
13import numpy as np
14from scipy import sparse
15from scipy.sparse.csgraph import connected_components
16
17from ..base import BaseEstimator, ClusterMixin
18from ..metrics.pairwise import paired_distances
19from ..metrics import DistanceMetric
20from ..metrics._dist_metrics import METRIC_MAPPING
21from ..utils import check_array
22from ..utils._fast_dict import IntFloatDict
23from ..utils.fixes import _astype_copy_false
24from ..utils.graph import _fix_connected_components
25from ..utils.validation import check_memory
26
27# mypy error: Module 'sklearn.cluster' has no attribute '_hierarchical_fast'
28from . import _hierarchical_fast as _hierarchical  # type: ignore
29from ._feature_agglomeration import AgglomerationTransform
30
31###############################################################################
32# For non fully-connected graphs
33
34
35def _fix_connectivity(X, connectivity, affinity):
36    """
37    Fixes the connectivity matrix.
38
39    The different steps are:
40
41    - copies it
42    - makes it symmetric
43    - converts it to LIL if necessary
44    - completes it if necessary.
45
46    Parameters
47    ----------
48    X : array-like of shape (n_samples, n_features)
49        Feature matrix representing `n_samples` samples to be clustered.
50
51    connectivity : sparse matrix, default=None
52        Connectivity matrix. Defines for each sample the neighboring samples
53        following a given structure of the data. The matrix is assumed to
54        be symmetric and only the upper triangular half is used.
55        Default is `None`, i.e, the Ward algorithm is unstructured.
56
57    affinity : {"euclidean", "precomputed"}, default="euclidean"
58        Which affinity to use. At the moment `precomputed` and
59        ``euclidean`` are supported. `euclidean` uses the
60        negative squared Euclidean distance between points.
61
62    Returns
63    -------
64    connectivity : sparse matrix
65        The fixed connectivity matrix.
66
67    n_connected_components : int
68        The number of connected components in the graph.
69    """
70    n_samples = X.shape[0]
71    if connectivity.shape[0] != n_samples or connectivity.shape[1] != n_samples:
72        raise ValueError(
73            "Wrong shape for connectivity matrix: %s when X is %s"
74            % (connectivity.shape, X.shape)
75        )
76
77    # Make the connectivity matrix symmetric:
78    connectivity = connectivity + connectivity.T
79
80    # Convert connectivity matrix to LIL
81    if not sparse.isspmatrix_lil(connectivity):
82        if not sparse.isspmatrix(connectivity):
83            connectivity = sparse.lil_matrix(connectivity)
84        else:
85            connectivity = connectivity.tolil()
86
87    # Compute the number of nodes
88    n_connected_components, labels = connected_components(connectivity)
89
90    if n_connected_components > 1:
91        warnings.warn(
92            "the number of connected components of the "
93            "connectivity matrix is %d > 1. Completing it to avoid "
94            "stopping the tree early." % n_connected_components,
95            stacklevel=2,
96        )
97        # XXX: Can we do without completing the matrix?
98        connectivity = _fix_connected_components(
99            X=X,
100            graph=connectivity,
101            n_connected_components=n_connected_components,
102            component_labels=labels,
103            metric=affinity,
104            mode="connectivity",
105        )
106
107    return connectivity, n_connected_components
108
109
110def _single_linkage_tree(
111    connectivity,
112    n_samples,
113    n_nodes,
114    n_clusters,
115    n_connected_components,
116    return_distance,
117):
118    """
119    Perform single linkage clustering on sparse data via the minimum
120    spanning tree from scipy.sparse.csgraph, then using union-find to label.
121    The parent array is then generated by walking through the tree.
122    """
123    from scipy.sparse.csgraph import minimum_spanning_tree
124
125    # explicitly cast connectivity to ensure safety
126    connectivity = connectivity.astype("float64", **_astype_copy_false(connectivity))
127
128    # Ensure zero distances aren't ignored by setting them to "epsilon"
129    epsilon_value = np.finfo(dtype=connectivity.data.dtype).eps
130    connectivity.data[connectivity.data == 0] = epsilon_value
131
132    # Use scipy.sparse.csgraph to generate a minimum spanning tree
133    mst = minimum_spanning_tree(connectivity.tocsr())
134
135    # Convert the graph to scipy.cluster.hierarchy array format
136    mst = mst.tocoo()
137
138    # Undo the epsilon values
139    mst.data[mst.data == epsilon_value] = 0
140
141    mst_array = np.vstack([mst.row, mst.col, mst.data]).T
142
143    # Sort edges of the min_spanning_tree by weight
144    mst_array = mst_array[np.argsort(mst_array.T[2], kind="mergesort"), :]
145
146    # Convert edge list into standard hierarchical clustering format
147    single_linkage_tree = _hierarchical._single_linkage_label(mst_array)
148    children_ = single_linkage_tree[:, :2].astype(int)
149
150    # Compute parents
151    parent = np.arange(n_nodes, dtype=np.intp)
152    for i, (left, right) in enumerate(children_, n_samples):
153        if n_clusters is not None and i >= n_nodes:
154            break
155        if left < n_nodes:
156            parent[left] = i
157        if right < n_nodes:
158            parent[right] = i
159
160    if return_distance:
161        distances = single_linkage_tree[:, 2]
162        return children_, n_connected_components, n_samples, parent, distances
163    return children_, n_connected_components, n_samples, parent
164
165
166###############################################################################
167# Hierarchical tree building functions
168
169
170def ward_tree(X, *, connectivity=None, n_clusters=None, return_distance=False):
171    """Ward clustering based on a Feature matrix.
172
173    Recursively merges the pair of clusters that minimally increases
174    within-cluster variance.
175
176    The inertia matrix uses a Heapq-based representation.
177
178    This is the structured version, that takes into account some topological
179    structure between samples.
180
181    Read more in the :ref:`User Guide <hierarchical_clustering>`.
182
183    Parameters
184    ----------
185    X : array-like of shape (n_samples, n_features)
186        Feature matrix representing `n_samples` samples to be clustered.
187
188    connectivity : sparse matrix, default=None
189        Connectivity matrix. Defines for each sample the neighboring samples
190        following a given structure of the data. The matrix is assumed to
191        be symmetric and only the upper triangular half is used.
192        Default is None, i.e, the Ward algorithm is unstructured.
193
194    n_clusters : int, default=None
195        `n_clusters` should be less than `n_samples`.  Stop early the
196        construction of the tree at `n_clusters.` This is useful to decrease
197        computation time if the number of clusters is not small compared to the
198        number of samples. In this case, the complete tree is not computed, thus
199        the 'children' output is of limited use, and the 'parents' output should
200        rather be used. This option is valid only when specifying a connectivity
201        matrix.
202
203    return_distance : bool, default=False
204        If `True`, return the distance between the clusters.
205
206    Returns
207    -------
208    children : ndarray of shape (n_nodes-1, 2)
209        The children of each non-leaf node. Values less than `n_samples`
210        correspond to leaves of the tree which are the original samples.
211        A node `i` greater than or equal to `n_samples` is a non-leaf
212        node and has children `children_[i - n_samples]`. Alternatively
213        at the i-th iteration, children[i][0] and children[i][1]
214        are merged to form node `n_samples + i`.
215
216    n_connected_components : int
217        The number of connected components in the graph.
218
219    n_leaves : int
220        The number of leaves in the tree.
221
222    parents : ndarray of shape (n_nodes,) or None
223        The parent of each node. Only returned when a connectivity matrix
224        is specified, elsewhere 'None' is returned.
225
226    distances : ndarray of shape (n_nodes-1,)
227        Only returned if `return_distance` is set to `True` (for compatibility).
228        The distances between the centers of the nodes. `distances[i]`
229        corresponds to a weighted Euclidean distance between
230        the nodes `children[i, 1]` and `children[i, 2]`. If the nodes refer to
231        leaves of the tree, then `distances[i]` is their unweighted Euclidean
232        distance. Distances are updated in the following way
233        (from scipy.hierarchy.linkage):
234
235        The new entry :math:`d(u,v)` is computed as follows,
236
237        .. math::
238
239           d(u,v) = \\sqrt{\\frac{|v|+|s|}
240                               {T}d(v,s)^2
241                        + \\frac{|v|+|t|}
242                               {T}d(v,t)^2
243                        - \\frac{|v|}
244                               {T}d(s,t)^2}
245
246        where :math:`u` is the newly joined cluster consisting of
247        clusters :math:`s` and :math:`t`, :math:`v` is an unused
248        cluster in the forest, :math:`T=|v|+|s|+|t|`, and
249        :math:`|*|` is the cardinality of its argument. This is also
250        known as the incremental algorithm.
251    """
252    X = np.asarray(X)
253    if X.ndim == 1:
254        X = np.reshape(X, (-1, 1))
255    n_samples, n_features = X.shape
256
257    if connectivity is None:
258        from scipy.cluster import hierarchy  # imports PIL
259
260        if n_clusters is not None:
261            warnings.warn(
262                "Partial build of the tree is implemented "
263                "only for structured clustering (i.e. with "
264                "explicit connectivity). The algorithm "
265                "will build the full tree and only "
266                "retain the lower branches required "
267                "for the specified number of clusters",
268                stacklevel=2,
269            )
270        X = np.require(X, requirements="W")
271        out = hierarchy.ward(X)
272        children_ = out[:, :2].astype(np.intp)
273
274        if return_distance:
275            distances = out[:, 2]
276            return children_, 1, n_samples, None, distances
277        else:
278            return children_, 1, n_samples, None
279
280    connectivity, n_connected_components = _fix_connectivity(
281        X, connectivity, affinity="euclidean"
282    )
283    if n_clusters is None:
284        n_nodes = 2 * n_samples - 1
285    else:
286        if n_clusters > n_samples:
287            raise ValueError(
288                "Cannot provide more clusters than samples. "
289                "%i n_clusters was asked, and there are %i "
290                "samples." % (n_clusters, n_samples)
291            )
292        n_nodes = 2 * n_samples - n_clusters
293
294    # create inertia matrix
295    coord_row = []
296    coord_col = []
297    A = []
298    for ind, row in enumerate(connectivity.rows):
299        A.append(row)
300        # We keep only the upper triangular for the moments
301        # Generator expressions are faster than arrays on the following
302        row = [i for i in row if i < ind]
303        coord_row.extend(
304            len(row)
305            * [
306                ind,
307            ]
308        )
309        coord_col.extend(row)
310
311    coord_row = np.array(coord_row, dtype=np.intp, order="C")
312    coord_col = np.array(coord_col, dtype=np.intp, order="C")
313
314    # build moments as a list
315    moments_1 = np.zeros(n_nodes, order="C")
316    moments_1[:n_samples] = 1
317    moments_2 = np.zeros((n_nodes, n_features), order="C")
318    moments_2[:n_samples] = X
319    inertia = np.empty(len(coord_row), dtype=np.float64, order="C")
320    _hierarchical.compute_ward_dist(moments_1, moments_2, coord_row, coord_col, inertia)
321    inertia = list(zip(inertia, coord_row, coord_col))
322    heapify(inertia)
323
324    # prepare the main fields
325    parent = np.arange(n_nodes, dtype=np.intp)
326    used_node = np.ones(n_nodes, dtype=bool)
327    children = []
328    if return_distance:
329        distances = np.empty(n_nodes - n_samples)
330
331    not_visited = np.empty(n_nodes, dtype=np.int8, order="C")
332
333    # recursive merge loop
334    for k in range(n_samples, n_nodes):
335        # identify the merge
336        while True:
337            inert, i, j = heappop(inertia)
338            if used_node[i] and used_node[j]:
339                break
340        parent[i], parent[j] = k, k
341        children.append((i, j))
342        used_node[i] = used_node[j] = False
343        if return_distance:  # store inertia value
344            distances[k - n_samples] = inert
345
346        # update the moments
347        moments_1[k] = moments_1[i] + moments_1[j]
348        moments_2[k] = moments_2[i] + moments_2[j]
349
350        # update the structure matrix A and the inertia matrix
351        coord_col = []
352        not_visited.fill(1)
353        not_visited[k] = 0
354        _hierarchical._get_parents(A[i], coord_col, parent, not_visited)
355        _hierarchical._get_parents(A[j], coord_col, parent, not_visited)
356        # List comprehension is faster than a for loop
357        [A[col].append(k) for col in coord_col]
358        A.append(coord_col)
359        coord_col = np.array(coord_col, dtype=np.intp, order="C")
360        coord_row = np.empty(coord_col.shape, dtype=np.intp, order="C")
361        coord_row.fill(k)
362        n_additions = len(coord_row)
363        ini = np.empty(n_additions, dtype=np.float64, order="C")
364
365        _hierarchical.compute_ward_dist(moments_1, moments_2, coord_row, coord_col, ini)
366
367        # List comprehension is faster than a for loop
368        [heappush(inertia, (ini[idx], k, coord_col[idx])) for idx in range(n_additions)]
369
370    # Separate leaves in children (empty lists up to now)
371    n_leaves = n_samples
372    # sort children to get consistent output with unstructured version
373    children = [c[::-1] for c in children]
374    children = np.array(children)  # return numpy array for efficient caching
375
376    if return_distance:
377        # 2 is scaling factor to compare w/ unstructured version
378        distances = np.sqrt(2.0 * distances)
379        return children, n_connected_components, n_leaves, parent, distances
380    else:
381        return children, n_connected_components, n_leaves, parent
382
383
384# single average and complete linkage
385def linkage_tree(
386    X,
387    connectivity=None,
388    n_clusters=None,
389    linkage="complete",
390    affinity="euclidean",
391    return_distance=False,
392):
393    """Linkage agglomerative clustering based on a Feature matrix.
394
395    The inertia matrix uses a Heapq-based representation.
396
397    This is the structured version, that takes into account some topological
398    structure between samples.
399
400    Read more in the :ref:`User Guide <hierarchical_clustering>`.
401
402    Parameters
403    ----------
404    X : array-like of shape (n_samples, n_features)
405        Feature matrix representing `n_samples` samples to be clustered.
406
407    connectivity : sparse matrix, default=None
408        Connectivity matrix. Defines for each sample the neighboring samples
409        following a given structure of the data. The matrix is assumed to
410        be symmetric and only the upper triangular half is used.
411        Default is `None`, i.e, the Ward algorithm is unstructured.
412
413    n_clusters : int, default=None
414        Stop early the construction of the tree at `n_clusters`. This is
415        useful to decrease computation time if the number of clusters is
416        not small compared to the number of samples. In this case, the
417        complete tree is not computed, thus the 'children' output is of
418        limited use, and the 'parents' output should rather be used.
419        This option is valid only when specifying a connectivity matrix.
420
421    linkage : {"average", "complete", "single"}, default="complete"
422        Which linkage criteria to use. The linkage criterion determines which
423        distance to use between sets of observation.
424            - "average" uses the average of the distances of each observation of
425              the two sets.
426            - "complete" or maximum linkage uses the maximum distances between
427              all observations of the two sets.
428            - "single" uses the minimum of the distances between all
429              observations of the two sets.
430
431    affinity : str or callable, default='euclidean'
432        Which metric to use. Can be 'euclidean', 'manhattan', or any
433        distance known to paired distance (see metric.pairwise).
434
435    return_distance : bool, default=False
436        Whether or not to return the distances between the clusters.
437
438    Returns
439    -------
440    children : ndarray of shape (n_nodes-1, 2)
441        The children of each non-leaf node. Values less than `n_samples`
442        correspond to leaves of the tree which are the original samples.
443        A node `i` greater than or equal to `n_samples` is a non-leaf
444        node and has children `children_[i - n_samples]`. Alternatively
445        at the i-th iteration, children[i][0] and children[i][1]
446        are merged to form node `n_samples + i`.
447
448    n_connected_components : int
449        The number of connected components in the graph.
450
451    n_leaves : int
452        The number of leaves in the tree.
453
454    parents : ndarray of shape (n_nodes, ) or None
455        The parent of each node. Only returned when a connectivity matrix
456        is specified, elsewhere 'None' is returned.
457
458    distances : ndarray of shape (n_nodes-1,)
459        Returned when `return_distance` is set to `True`.
460
461        distances[i] refers to the distance between children[i][0] and
462        children[i][1] when they are merged.
463
464    See Also
465    --------
466    ward_tree : Hierarchical clustering with ward linkage.
467    """
468    X = np.asarray(X)
469    if X.ndim == 1:
470        X = np.reshape(X, (-1, 1))
471    n_samples, n_features = X.shape
472
473    linkage_choices = {
474        "complete": _hierarchical.max_merge,
475        "average": _hierarchical.average_merge,
476        "single": None,
477    }  # Single linkage is handled differently
478    try:
479        join_func = linkage_choices[linkage]
480    except KeyError as e:
481        raise ValueError(
482            "Unknown linkage option, linkage should be one of %s, but %s was given"
483            % (linkage_choices.keys(), linkage)
484        ) from e
485
486    if affinity == "cosine" and np.any(~np.any(X, axis=1)):
487        raise ValueError("Cosine affinity cannot be used when X contains zero vectors")
488
489    if connectivity is None:
490        from scipy.cluster import hierarchy  # imports PIL
491
492        if n_clusters is not None:
493            warnings.warn(
494                "Partial build of the tree is implemented "
495                "only for structured clustering (i.e. with "
496                "explicit connectivity). The algorithm "
497                "will build the full tree and only "
498                "retain the lower branches required "
499                "for the specified number of clusters",
500                stacklevel=2,
501            )
502
503        if affinity == "precomputed":
504            # for the linkage function of hierarchy to work on precomputed
505            # data, provide as first argument an ndarray of the shape returned
506            # by sklearn.metrics.pairwise_distances.
507            if X.shape[0] != X.shape[1]:
508                raise ValueError(
509                    f"Distance matrix should be square, got matrix of shape {X.shape}"
510                )
511            i, j = np.triu_indices(X.shape[0], k=1)
512            X = X[i, j]
513        elif affinity == "l2":
514            # Translate to something understood by scipy
515            affinity = "euclidean"
516        elif affinity in ("l1", "manhattan"):
517            affinity = "cityblock"
518        elif callable(affinity):
519            X = affinity(X)
520            i, j = np.triu_indices(X.shape[0], k=1)
521            X = X[i, j]
522        if (
523            linkage == "single"
524            and affinity != "precomputed"
525            and not callable(affinity)
526            and affinity in METRIC_MAPPING
527        ):
528
529            # We need the fast cythonized metric from neighbors
530            dist_metric = DistanceMetric.get_metric(affinity)
531
532            # The Cython routines used require contiguous arrays
533            X = np.ascontiguousarray(X, dtype=np.double)
534
535            mst = _hierarchical.mst_linkage_core(X, dist_metric)
536            # Sort edges of the min_spanning_tree by weight
537            mst = mst[np.argsort(mst.T[2], kind="mergesort"), :]
538
539            # Convert edge list into standard hierarchical clustering format
540            out = _hierarchical.single_linkage_label(mst)
541        else:
542            out = hierarchy.linkage(X, method=linkage, metric=affinity)
543        children_ = out[:, :2].astype(int, copy=False)
544
545        if return_distance:
546            distances = out[:, 2]
547            return children_, 1, n_samples, None, distances
548        return children_, 1, n_samples, None
549
550    connectivity, n_connected_components = _fix_connectivity(
551        X, connectivity, affinity=affinity
552    )
553    connectivity = connectivity.tocoo()
554    # Put the diagonal to zero
555    diag_mask = connectivity.row != connectivity.col
556    connectivity.row = connectivity.row[diag_mask]
557    connectivity.col = connectivity.col[diag_mask]
558    connectivity.data = connectivity.data[diag_mask]
559    del diag_mask
560
561    if affinity == "precomputed":
562        distances = X[connectivity.row, connectivity.col].astype(
563            "float64", **_astype_copy_false(X)
564        )
565    else:
566        # FIXME We compute all the distances, while we could have only computed
567        # the "interesting" distances
568        distances = paired_distances(
569            X[connectivity.row], X[connectivity.col], metric=affinity
570        )
571    connectivity.data = distances
572
573    if n_clusters is None:
574        n_nodes = 2 * n_samples - 1
575    else:
576        assert n_clusters <= n_samples
577        n_nodes = 2 * n_samples - n_clusters
578
579    if linkage == "single":
580        return _single_linkage_tree(
581            connectivity,
582            n_samples,
583            n_nodes,
584            n_clusters,
585            n_connected_components,
586            return_distance,
587        )
588
589    if return_distance:
590        distances = np.empty(n_nodes - n_samples)
591    # create inertia heap and connection matrix
592    A = np.empty(n_nodes, dtype=object)
593    inertia = list()
594
595    # LIL seems to the best format to access the rows quickly,
596    # without the numpy overhead of slicing CSR indices and data.
597    connectivity = connectivity.tolil()
598    # We are storing the graph in a list of IntFloatDict
599    for ind, (data, row) in enumerate(zip(connectivity.data, connectivity.rows)):
600        A[ind] = IntFloatDict(
601            np.asarray(row, dtype=np.intp), np.asarray(data, dtype=np.float64)
602        )
603        # We keep only the upper triangular for the heap
604        # Generator expressions are faster than arrays on the following
605        inertia.extend(
606            _hierarchical.WeightedEdge(d, ind, r) for r, d in zip(row, data) if r < ind
607        )
608    del connectivity
609
610    heapify(inertia)
611
612    # prepare the main fields
613    parent = np.arange(n_nodes, dtype=np.intp)
614    used_node = np.ones(n_nodes, dtype=np.intp)
615    children = []
616
617    # recursive merge loop
618    for k in range(n_samples, n_nodes):
619        # identify the merge
620        while True:
621            edge = heappop(inertia)
622            if used_node[edge.a] and used_node[edge.b]:
623                break
624        i = edge.a
625        j = edge.b
626
627        if return_distance:
628            # store distances
629            distances[k - n_samples] = edge.weight
630
631        parent[i] = parent[j] = k
632        children.append((i, j))
633        # Keep track of the number of elements per cluster
634        n_i = used_node[i]
635        n_j = used_node[j]
636        used_node[k] = n_i + n_j
637        used_node[i] = used_node[j] = False
638
639        # update the structure matrix A and the inertia matrix
640        # a clever 'min', or 'max' operation between A[i] and A[j]
641        coord_col = join_func(A[i], A[j], used_node, n_i, n_j)
642        for col, d in coord_col:
643            A[col].append(k, d)
644            # Here we use the information from coord_col (containing the
645            # distances) to update the heap
646            heappush(inertia, _hierarchical.WeightedEdge(d, k, col))
647        A[k] = coord_col
648        # Clear A[i] and A[j] to save memory
649        A[i] = A[j] = 0
650
651    # Separate leaves in children (empty lists up to now)
652    n_leaves = n_samples
653
654    # # return numpy array for efficient caching
655    children = np.array(children)[:, ::-1]
656
657    if return_distance:
658        return children, n_connected_components, n_leaves, parent, distances
659    return children, n_connected_components, n_leaves, parent
660
661
662# Matching names to tree-building strategies
663def _complete_linkage(*args, **kwargs):
664    kwargs["linkage"] = "complete"
665    return linkage_tree(*args, **kwargs)
666
667
668def _average_linkage(*args, **kwargs):
669    kwargs["linkage"] = "average"
670    return linkage_tree(*args, **kwargs)
671
672
673def _single_linkage(*args, **kwargs):
674    kwargs["linkage"] = "single"
675    return linkage_tree(*args, **kwargs)
676
677
678_TREE_BUILDERS = dict(
679    ward=ward_tree,
680    complete=_complete_linkage,
681    average=_average_linkage,
682    single=_single_linkage,
683)
684
685###############################################################################
686# Functions for cutting hierarchical clustering tree
687
688
689def _hc_cut(n_clusters, children, n_leaves):
690    """Function cutting the ward tree for a given number of clusters.
691
692    Parameters
693    ----------
694    n_clusters : int or ndarray
695        The number of clusters to form.
696
697    children : ndarray of shape (n_nodes-1, 2)
698        The children of each non-leaf node. Values less than `n_samples`
699        correspond to leaves of the tree which are the original samples.
700        A node `i` greater than or equal to `n_samples` is a non-leaf
701        node and has children `children_[i - n_samples]`. Alternatively
702        at the i-th iteration, children[i][0] and children[i][1]
703        are merged to form node `n_samples + i`.
704
705    n_leaves : int
706        Number of leaves of the tree.
707
708    Returns
709    -------
710    labels : array [n_samples]
711        Cluster labels for each point.
712    """
713    if n_clusters > n_leaves:
714        raise ValueError(
715            "Cannot extract more clusters than samples: "
716            "%s clusters where given for a tree with %s leaves."
717            % (n_clusters, n_leaves)
718        )
719    # In this function, we store nodes as a heap to avoid recomputing
720    # the max of the nodes: the first element is always the smallest
721    # We use negated indices as heaps work on smallest elements, and we
722    # are interested in largest elements
723    # children[-1] is the root of the tree
724    nodes = [-(max(children[-1]) + 1)]
725    for _ in range(n_clusters - 1):
726        # As we have a heap, nodes[0] is the smallest element
727        these_children = children[-nodes[0] - n_leaves]
728        # Insert the 2 children and remove the largest node
729        heappush(nodes, -these_children[0])
730        heappushpop(nodes, -these_children[1])
731    label = np.zeros(n_leaves, dtype=np.intp)
732    for i, node in enumerate(nodes):
733        label[_hierarchical._hc_get_descendent(-node, children, n_leaves)] = i
734    return label
735
736
737###############################################################################
738
739
740class AgglomerativeClustering(ClusterMixin, BaseEstimator):
741    """
742    Agglomerative Clustering.
743
744    Recursively merges pair of clusters of sample data; uses linkage distance.
745
746    Read more in the :ref:`User Guide <hierarchical_clustering>`.
747
748    Parameters
749    ----------
750    n_clusters : int or None, default=2
751        The number of clusters to find. It must be ``None`` if
752        ``distance_threshold`` is not ``None``.
753
754    affinity : str or callable, default='euclidean'
755        Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
756        "manhattan", "cosine", or "precomputed".
757        If linkage is "ward", only "euclidean" is accepted.
758        If "precomputed", a distance matrix (instead of a similarity matrix)
759        is needed as input for the fit method.
760
761    memory : str or object with the joblib.Memory interface, default=None
762        Used to cache the output of the computation of the tree.
763        By default, no caching is done. If a string is given, it is the
764        path to the caching directory.
765
766    connectivity : array-like or callable, default=None
767        Connectivity matrix. Defines for each sample the neighboring
768        samples following a given structure of the data.
769        This can be a connectivity matrix itself or a callable that transforms
770        the data into a connectivity matrix, such as derived from
771        `kneighbors_graph`. Default is ``None``, i.e, the
772        hierarchical clustering algorithm is unstructured.
773
774    compute_full_tree : 'auto' or bool, default='auto'
775        Stop early the construction of the tree at ``n_clusters``. This is
776        useful to decrease computation time if the number of clusters is not
777        small compared to the number of samples. This option is useful only
778        when specifying a connectivity matrix. Note also that when varying the
779        number of clusters and using caching, it may be advantageous to compute
780        the full tree. It must be ``True`` if ``distance_threshold`` is not
781        ``None``. By default `compute_full_tree` is "auto", which is equivalent
782        to `True` when `distance_threshold` is not `None` or that `n_clusters`
783        is inferior to the maximum between 100 or `0.02 * n_samples`.
784        Otherwise, "auto" is equivalent to `False`.
785
786    linkage : {'ward', 'complete', 'average', 'single'}, default='ward'
787        Which linkage criterion to use. The linkage criterion determines which
788        distance to use between sets of observation. The algorithm will merge
789        the pairs of cluster that minimize this criterion.
790
791        - 'ward' minimizes the variance of the clusters being merged.
792        - 'average' uses the average of the distances of each observation of
793          the two sets.
794        - 'complete' or 'maximum' linkage uses the maximum distances between
795          all observations of the two sets.
796        - 'single' uses the minimum of the distances between all observations
797          of the two sets.
798
799        .. versionadded:: 0.20
800            Added the 'single' option
801
802    distance_threshold : float, default=None
803        The linkage distance threshold above which, clusters will not be
804        merged. If not ``None``, ``n_clusters`` must be ``None`` and
805        ``compute_full_tree`` must be ``True``.
806
807        .. versionadded:: 0.21
808
809    compute_distances : bool, default=False
810        Computes distances between clusters even if `distance_threshold` is not
811        used. This can be used to make dendrogram visualization, but introduces
812        a computational and memory overhead.
813
814        .. versionadded:: 0.24
815
816    Attributes
817    ----------
818    n_clusters_ : int
819        The number of clusters found by the algorithm. If
820        ``distance_threshold=None``, it will be equal to the given
821        ``n_clusters``.
822
823    labels_ : ndarray of shape (n_samples)
824        Cluster labels for each point.
825
826    n_leaves_ : int
827        Number of leaves in the hierarchical tree.
828
829    n_connected_components_ : int
830        The estimated number of connected components in the graph.
831
832        .. versionadded:: 0.21
833            ``n_connected_components_`` was added to replace ``n_components_``.
834
835    n_features_in_ : int
836        Number of features seen during :term:`fit`.
837
838        .. versionadded:: 0.24
839
840    feature_names_in_ : ndarray of shape (`n_features_in_`,)
841        Names of features seen during :term:`fit`. Defined only when `X`
842        has feature names that are all strings.
843
844        .. versionadded:: 1.0
845
846    children_ : array-like of shape (n_samples-1, 2)
847        The children of each non-leaf node. Values less than `n_samples`
848        correspond to leaves of the tree which are the original samples.
849        A node `i` greater than or equal to `n_samples` is a non-leaf
850        node and has children `children_[i - n_samples]`. Alternatively
851        at the i-th iteration, children[i][0] and children[i][1]
852        are merged to form node `n_samples + i`.
853
854    distances_ : array-like of shape (n_nodes-1,)
855        Distances between nodes in the corresponding place in `children_`.
856        Only computed if `distance_threshold` is used or `compute_distances`
857        is set to `True`.
858
859    See Also
860    --------
861    FeatureAgglomeration : Agglomerative clustering but for features instead of
862        samples.
863    ward_tree : Hierarchical clustering with ward linkage.
864
865    Examples
866    --------
867    >>> from sklearn.cluster import AgglomerativeClustering
868    >>> import numpy as np
869    >>> X = np.array([[1, 2], [1, 4], [1, 0],
870    ...               [4, 2], [4, 4], [4, 0]])
871    >>> clustering = AgglomerativeClustering().fit(X)
872    >>> clustering
873    AgglomerativeClustering()
874    >>> clustering.labels_
875    array([1, 1, 1, 0, 0, 0])
876    """
877
878    def __init__(
879        self,
880        n_clusters=2,
881        *,
882        affinity="euclidean",
883        memory=None,
884        connectivity=None,
885        compute_full_tree="auto",
886        linkage="ward",
887        distance_threshold=None,
888        compute_distances=False,
889    ):
890        self.n_clusters = n_clusters
891        self.distance_threshold = distance_threshold
892        self.memory = memory
893        self.connectivity = connectivity
894        self.compute_full_tree = compute_full_tree
895        self.linkage = linkage
896        self.affinity = affinity
897        self.compute_distances = compute_distances
898
899    def fit(self, X, y=None):
900        """Fit the hierarchical clustering from features, or distance matrix.
901
902        Parameters
903        ----------
904        X : array-like, shape (n_samples, n_features) or \
905                (n_samples, n_samples)
906            Training instances to cluster, or distances between instances if
907            ``affinity='precomputed'``.
908
909        y : Ignored
910            Not used, present here for API consistency by convention.
911
912        Returns
913        -------
914        self : object
915            Returns the fitted instance.
916        """
917        X = self._validate_data(X, ensure_min_samples=2, estimator=self)
918        return self._fit(X)
919
920    def _fit(self, X):
921        """Fit without validation
922
923        Parameters
924        ----------
925        X : ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
926            Training instances to cluster, or distances between instances if
927            ``affinity='precomputed'``.
928
929        Returns
930        -------
931        self : object
932            Returns the fitted instance.
933        """
934        memory = check_memory(self.memory)
935
936        if self.n_clusters is not None and self.n_clusters <= 0:
937            raise ValueError(
938                "n_clusters should be an integer greater than 0. %s was provided."
939                % str(self.n_clusters)
940            )
941
942        if not ((self.n_clusters is None) ^ (self.distance_threshold is None)):
943            raise ValueError(
944                "Exactly one of n_clusters and "
945                "distance_threshold has to be set, and the other "
946                "needs to be None."
947            )
948
949        if self.distance_threshold is not None and not self.compute_full_tree:
950            raise ValueError(
951                "compute_full_tree must be True if distance_threshold is set."
952            )
953
954        if self.linkage == "ward" and self.affinity != "euclidean":
955            raise ValueError(
956                "%s was provided as affinity. Ward can only "
957                "work with euclidean distances." % (self.affinity,)
958            )
959
960        if self.linkage not in _TREE_BUILDERS:
961            raise ValueError(
962                "Unknown linkage type %s. Valid options are %s"
963                % (self.linkage, _TREE_BUILDERS.keys())
964            )
965        tree_builder = _TREE_BUILDERS[self.linkage]
966
967        connectivity = self.connectivity
968        if self.connectivity is not None:
969            if callable(self.connectivity):
970                connectivity = self.connectivity(X)
971            connectivity = check_array(
972                connectivity, accept_sparse=["csr", "coo", "lil"]
973            )
974
975        n_samples = len(X)
976        compute_full_tree = self.compute_full_tree
977        if self.connectivity is None:
978            compute_full_tree = True
979        if compute_full_tree == "auto":
980            if self.distance_threshold is not None:
981                compute_full_tree = True
982            else:
983                # Early stopping is likely to give a speed up only for
984                # a large number of clusters. The actual threshold
985                # implemented here is heuristic
986                compute_full_tree = self.n_clusters < max(100, 0.02 * n_samples)
987        n_clusters = self.n_clusters
988        if compute_full_tree:
989            n_clusters = None
990
991        # Construct the tree
992        kwargs = {}
993        if self.linkage != "ward":
994            kwargs["linkage"] = self.linkage
995            kwargs["affinity"] = self.affinity
996
997        distance_threshold = self.distance_threshold
998
999        return_distance = (distance_threshold is not None) or self.compute_distances
1000
1001        out = memory.cache(tree_builder)(
1002            X,
1003            connectivity=connectivity,
1004            n_clusters=n_clusters,
1005            return_distance=return_distance,
1006            **kwargs,
1007        )
1008        (self.children_, self.n_connected_components_, self.n_leaves_, parents) = out[
1009            :4
1010        ]
1011
1012        if return_distance:
1013            self.distances_ = out[-1]
1014
1015        if self.distance_threshold is not None:  # distance_threshold is used
1016            self.n_clusters_ = (
1017                np.count_nonzero(self.distances_ >= distance_threshold) + 1
1018            )
1019        else:  # n_clusters is used
1020            self.n_clusters_ = self.n_clusters
1021
1022        # Cut the tree
1023        if compute_full_tree:
1024            self.labels_ = _hc_cut(self.n_clusters_, self.children_, self.n_leaves_)
1025        else:
1026            labels = _hierarchical.hc_get_heads(parents, copy=False)
1027            # copy to avoid holding a reference on the original array
1028            labels = np.copy(labels[:n_samples])
1029            # Reassign cluster numbers
1030            self.labels_ = np.searchsorted(np.unique(labels), labels)
1031        return self
1032
1033    def fit_predict(self, X, y=None):
1034        """Fit and return the result of each sample's clustering assignment.
1035
1036        In addition to fitting, this method also return the result of the
1037        clustering assignment for each sample in the training set.
1038
1039        Parameters
1040        ----------
1041        X : array-like of shape (n_samples, n_features) or \
1042                (n_samples, n_samples)
1043            Training instances to cluster, or distances between instances if
1044            ``affinity='precomputed'``.
1045
1046        y : Ignored
1047            Not used, present here for API consistency by convention.
1048
1049        Returns
1050        -------
1051        labels : ndarray of shape (n_samples,)
1052            Cluster labels.
1053        """
1054        return super().fit_predict(X, y)
1055
1056
1057class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
1058    """Agglomerate features.
1059
1060    Recursively merges pair of clusters of features.
1061
1062    Read more in the :ref:`User Guide <hierarchical_clustering>`.
1063
1064    Parameters
1065    ----------
1066    n_clusters : int, default=2
1067        The number of clusters to find. It must be ``None`` if
1068        ``distance_threshold`` is not ``None``.
1069
1070    affinity : str or callable, default='euclidean'
1071        Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
1072        "manhattan", "cosine", or 'precomputed'.
1073        If linkage is "ward", only "euclidean" is accepted.
1074
1075    memory : str or object with the joblib.Memory interface, default=None
1076        Used to cache the output of the computation of the tree.
1077        By default, no caching is done. If a string is given, it is the
1078        path to the caching directory.
1079
1080    connectivity : array-like or callable, default=None
1081        Connectivity matrix. Defines for each feature the neighboring
1082        features following a given structure of the data.
1083        This can be a connectivity matrix itself or a callable that transforms
1084        the data into a connectivity matrix, such as derived from
1085        `kneighbors_graph`. Default is `None`, i.e, the
1086        hierarchical clustering algorithm is unstructured.
1087
1088    compute_full_tree : 'auto' or bool, default='auto'
1089        Stop early the construction of the tree at `n_clusters`. This is useful
1090        to decrease computation time if the number of clusters is not small
1091        compared to the number of features. This option is useful only when
1092        specifying a connectivity matrix. Note also that when varying the
1093        number of clusters and using caching, it may be advantageous to compute
1094        the full tree. It must be ``True`` if ``distance_threshold`` is not
1095        ``None``. By default `compute_full_tree` is "auto", which is equivalent
1096        to `True` when `distance_threshold` is not `None` or that `n_clusters`
1097        is inferior to the maximum between 100 or `0.02 * n_samples`.
1098        Otherwise, "auto" is equivalent to `False`.
1099
1100    linkage : {"ward", "complete", "average", "single"}, default="ward"
1101        Which linkage criterion to use. The linkage criterion determines which
1102        distance to use between sets of features. The algorithm will merge
1103        the pairs of cluster that minimize this criterion.
1104
1105        - "ward" minimizes the variance of the clusters being merged.
1106        - "complete" or maximum linkage uses the maximum distances between
1107          all features of the two sets.
1108        - "average" uses the average of the distances of each feature of
1109          the two sets.
1110        - "single" uses the minimum of the distances between all features
1111          of the two sets.
1112
1113    pooling_func : callable, default=np.mean
1114        This combines the values of agglomerated features into a single
1115        value, and should accept an array of shape [M, N] and the keyword
1116        argument `axis=1`, and reduce it to an array of size [M].
1117
1118    distance_threshold : float, default=None
1119        The linkage distance threshold above which, clusters will not be
1120        merged. If not ``None``, ``n_clusters`` must be ``None`` and
1121        ``compute_full_tree`` must be ``True``.
1122
1123        .. versionadded:: 0.21
1124
1125    compute_distances : bool, default=False
1126        Computes distances between clusters even if `distance_threshold` is not
1127        used. This can be used to make dendrogram visualization, but introduces
1128        a computational and memory overhead.
1129
1130        .. versionadded:: 0.24
1131
1132    Attributes
1133    ----------
1134    n_clusters_ : int
1135        The number of clusters found by the algorithm. If
1136        ``distance_threshold=None``, it will be equal to the given
1137        ``n_clusters``.
1138
1139    labels_ : array-like of (n_features,)
1140        Cluster labels for each feature.
1141
1142    n_leaves_ : int
1143        Number of leaves in the hierarchical tree.
1144
1145    n_connected_components_ : int
1146        The estimated number of connected components in the graph.
1147
1148        .. versionadded:: 0.21
1149            ``n_connected_components_`` was added to replace ``n_components_``.
1150
1151    n_features_in_ : int
1152        Number of features seen during :term:`fit`.
1153
1154        .. versionadded:: 0.24
1155
1156    feature_names_in_ : ndarray of shape (`n_features_in_`,)
1157        Names of features seen during :term:`fit`. Defined only when `X`
1158        has feature names that are all strings.
1159
1160        .. versionadded:: 1.0
1161
1162    children_ : array-like of shape (n_nodes-1, 2)
1163        The children of each non-leaf node. Values less than `n_features`
1164        correspond to leaves of the tree which are the original samples.
1165        A node `i` greater than or equal to `n_features` is a non-leaf
1166        node and has children `children_[i - n_features]`. Alternatively
1167        at the i-th iteration, children[i][0] and children[i][1]
1168        are merged to form node `n_features + i`.
1169
1170    distances_ : array-like of shape (n_nodes-1,)
1171        Distances between nodes in the corresponding place in `children_`.
1172        Only computed if `distance_threshold` is used or `compute_distances`
1173        is set to `True`.
1174
1175    See Also
1176    --------
1177    AgglomerativeClustering : Agglomerative clustering samples instead of
1178        features.
1179    ward_tree : Hierarchical clustering with ward linkage.
1180
1181    Examples
1182    --------
1183    >>> import numpy as np
1184    >>> from sklearn import datasets, cluster
1185    >>> digits = datasets.load_digits()
1186    >>> images = digits.images
1187    >>> X = np.reshape(images, (len(images), -1))
1188    >>> agglo = cluster.FeatureAgglomeration(n_clusters=32)
1189    >>> agglo.fit(X)
1190    FeatureAgglomeration(n_clusters=32)
1191    >>> X_reduced = agglo.transform(X)
1192    >>> X_reduced.shape
1193    (1797, 32)
1194    """
1195
1196    def __init__(
1197        self,
1198        n_clusters=2,
1199        *,
1200        affinity="euclidean",
1201        memory=None,
1202        connectivity=None,
1203        compute_full_tree="auto",
1204        linkage="ward",
1205        pooling_func=np.mean,
1206        distance_threshold=None,
1207        compute_distances=False,
1208    ):
1209        super().__init__(
1210            n_clusters=n_clusters,
1211            memory=memory,
1212            connectivity=connectivity,
1213            compute_full_tree=compute_full_tree,
1214            linkage=linkage,
1215            affinity=affinity,
1216            distance_threshold=distance_threshold,
1217            compute_distances=compute_distances,
1218        )
1219        self.pooling_func = pooling_func
1220
1221    def fit(self, X, y=None):
1222        """Fit the hierarchical clustering on the data.
1223
1224        Parameters
1225        ----------
1226        X : array-like of shape (n_samples, n_features)
1227            The data.
1228
1229        y : Ignored
1230            Not used, present here for API consistency by convention.
1231
1232        Returns
1233        -------
1234        self : object
1235            Returns the transformer.
1236        """
1237        X = self._validate_data(X, ensure_min_features=2, estimator=self)
1238        super()._fit(X.T)
1239        return self
1240
1241    @property
1242    def fit_predict(self):
1243        """Fit and return the result of each sample's clustering assignment."""
1244        raise AttributeError
1245