1"""
2Tests for the birch clustering algorithm.
3"""
4
5from scipy import sparse
6import numpy as np
7import pytest
8
9from sklearn.cluster.tests.common import generate_clustered_data
10from sklearn.cluster import Birch
11from sklearn.cluster import AgglomerativeClustering
12from sklearn.datasets import make_blobs
13from sklearn.exceptions import ConvergenceWarning
14from sklearn.linear_model import ElasticNet
15from sklearn.metrics import pairwise_distances_argmin, v_measure_score
16
17from sklearn.utils._testing import assert_almost_equal
18from sklearn.utils._testing import assert_array_equal
19from sklearn.utils._testing import assert_array_almost_equal
20
21
22def test_n_samples_leaves_roots():
23    # Sanity check for the number of samples in leaves and roots
24    X, y = make_blobs(n_samples=10)
25    brc = Birch()
26    brc.fit(X)
27    n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_])
28    n_samples_leaves = sum(
29        [sc.n_samples_ for leaf in brc._get_leaves() for sc in leaf.subclusters_]
30    )
31    assert n_samples_leaves == X.shape[0]
32    assert n_samples_root == X.shape[0]
33
34
35def test_partial_fit():
36    # Test that fit is equivalent to calling partial_fit multiple times
37    X, y = make_blobs(n_samples=100)
38    brc = Birch(n_clusters=3)
39    brc.fit(X)
40    brc_partial = Birch(n_clusters=None)
41    brc_partial.partial_fit(X[:50])
42    brc_partial.partial_fit(X[50:])
43    assert_array_almost_equal(brc_partial.subcluster_centers_, brc.subcluster_centers_)
44
45    # Test that same global labels are obtained after calling partial_fit
46    # with None
47    brc_partial.set_params(n_clusters=3)
48    brc_partial.partial_fit(None)
49    assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_)
50
51
52def test_birch_predict():
53    # Test the predict method predicts the nearest centroid.
54    rng = np.random.RandomState(0)
55    X = generate_clustered_data(n_clusters=3, n_features=3, n_samples_per_cluster=10)
56
57    # n_samples * n_samples_per_cluster
58    shuffle_indices = np.arange(30)
59    rng.shuffle(shuffle_indices)
60    X_shuffle = X[shuffle_indices, :]
61    brc = Birch(n_clusters=4, threshold=1.0)
62    brc.fit(X_shuffle)
63    centroids = brc.subcluster_centers_
64    assert_array_equal(brc.labels_, brc.predict(X_shuffle))
65    nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids)
66    assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0)
67
68
69def test_n_clusters():
70    # Test that n_clusters param works properly
71    X, y = make_blobs(n_samples=100, centers=10)
72    brc1 = Birch(n_clusters=10)
73    brc1.fit(X)
74    assert len(brc1.subcluster_centers_) > 10
75    assert len(np.unique(brc1.labels_)) == 10
76
77    # Test that n_clusters = Agglomerative Clustering gives
78    # the same results.
79    gc = AgglomerativeClustering(n_clusters=10)
80    brc2 = Birch(n_clusters=gc)
81    brc2.fit(X)
82    assert_array_equal(brc1.subcluster_labels_, brc2.subcluster_labels_)
83    assert_array_equal(brc1.labels_, brc2.labels_)
84
85    # Test that the wrong global clustering step raises an Error.
86    clf = ElasticNet()
87    brc3 = Birch(n_clusters=clf)
88    with pytest.raises(ValueError):
89        brc3.fit(X)
90
91    # Test that a small number of clusters raises a warning.
92    brc4 = Birch(threshold=10000.0)
93    with pytest.warns(ConvergenceWarning):
94        brc4.fit(X)
95
96
97def test_sparse_X():
98    # Test that sparse and dense data give same results
99    X, y = make_blobs(n_samples=100, centers=10)
100    brc = Birch(n_clusters=10)
101    brc.fit(X)
102
103    csr = sparse.csr_matrix(X)
104    brc_sparse = Birch(n_clusters=10)
105    brc_sparse.fit(csr)
106
107    assert_array_equal(brc.labels_, brc_sparse.labels_)
108    assert_array_almost_equal(brc.subcluster_centers_, brc_sparse.subcluster_centers_)
109
110
111def test_partial_fit_second_call_error_checks():
112    # second partial fit calls will error when n_features is not consistent
113    # with the first call
114    X, y = make_blobs(n_samples=100)
115    brc = Birch(n_clusters=3)
116    brc.partial_fit(X, y)
117
118    msg = "X has 1 features, but Birch is expecting 2 features"
119    with pytest.raises(ValueError, match=msg):
120        brc.partial_fit(X[:, [0]], y)
121
122
123def check_branching_factor(node, branching_factor):
124    subclusters = node.subclusters_
125    assert branching_factor >= len(subclusters)
126    for cluster in subclusters:
127        if cluster.child_:
128            check_branching_factor(cluster.child_, branching_factor)
129
130
131def test_branching_factor():
132    # Test that nodes have at max branching_factor number of subclusters
133    X, y = make_blobs()
134    branching_factor = 9
135
136    # Purposefully set a low threshold to maximize the subclusters.
137    brc = Birch(n_clusters=None, branching_factor=branching_factor, threshold=0.01)
138    brc.fit(X)
139    check_branching_factor(brc.root_, branching_factor)
140    brc = Birch(n_clusters=3, branching_factor=branching_factor, threshold=0.01)
141    brc.fit(X)
142    check_branching_factor(brc.root_, branching_factor)
143
144    # Raises error when branching_factor is set to one.
145    brc = Birch(n_clusters=None, branching_factor=1, threshold=0.01)
146    with pytest.raises(ValueError):
147        brc.fit(X)
148
149
150def check_threshold(birch_instance, threshold):
151    """Use the leaf linked list for traversal"""
152    current_leaf = birch_instance.dummy_leaf_.next_leaf_
153    while current_leaf:
154        subclusters = current_leaf.subclusters_
155        for sc in subclusters:
156            assert threshold >= sc.radius
157        current_leaf = current_leaf.next_leaf_
158
159
160def test_threshold():
161    # Test that the leaf subclusters have a threshold lesser than radius
162    X, y = make_blobs(n_samples=80, centers=4)
163    brc = Birch(threshold=0.5, n_clusters=None)
164    brc.fit(X)
165    check_threshold(brc, 0.5)
166
167    brc = Birch(threshold=5.0, n_clusters=None)
168    brc.fit(X)
169    check_threshold(brc, 5.0)
170
171
172def test_birch_n_clusters_long_int():
173    # Check that birch supports n_clusters with np.int64 dtype, for instance
174    # coming from np.arange. #16484
175    X, _ = make_blobs(random_state=0)
176    n_clusters = np.int64(5)
177    Birch(n_clusters=n_clusters).fit(X)
178
179
180# TODO: Remove in 1.2
181@pytest.mark.parametrize("attribute", ["fit_", "partial_fit_"])
182def test_birch_fit_attributes_deprecated(attribute):
183    """Test that fit_ and partial_fit_ attributes are deprecated."""
184    msg = f"`{attribute}` is deprecated in 1.0 and will be removed in 1.2"
185    X, y = make_blobs(n_samples=10)
186    brc = Birch().fit(X, y)
187
188    with pytest.warns(FutureWarning, match=msg):
189        getattr(brc, attribute)
190