1""" 2Tests for the birch clustering algorithm. 3""" 4 5from scipy import sparse 6import numpy as np 7import pytest 8 9from sklearn.cluster.tests.common import generate_clustered_data 10from sklearn.cluster import Birch 11from sklearn.cluster import AgglomerativeClustering 12from sklearn.datasets import make_blobs 13from sklearn.exceptions import ConvergenceWarning 14from sklearn.linear_model import ElasticNet 15from sklearn.metrics import pairwise_distances_argmin, v_measure_score 16 17from sklearn.utils._testing import assert_almost_equal 18from sklearn.utils._testing import assert_array_equal 19from sklearn.utils._testing import assert_array_almost_equal 20 21 22def test_n_samples_leaves_roots(): 23 # Sanity check for the number of samples in leaves and roots 24 X, y = make_blobs(n_samples=10) 25 brc = Birch() 26 brc.fit(X) 27 n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_]) 28 n_samples_leaves = sum( 29 [sc.n_samples_ for leaf in brc._get_leaves() for sc in leaf.subclusters_] 30 ) 31 assert n_samples_leaves == X.shape[0] 32 assert n_samples_root == X.shape[0] 33 34 35def test_partial_fit(): 36 # Test that fit is equivalent to calling partial_fit multiple times 37 X, y = make_blobs(n_samples=100) 38 brc = Birch(n_clusters=3) 39 brc.fit(X) 40 brc_partial = Birch(n_clusters=None) 41 brc_partial.partial_fit(X[:50]) 42 brc_partial.partial_fit(X[50:]) 43 assert_array_almost_equal(brc_partial.subcluster_centers_, brc.subcluster_centers_) 44 45 # Test that same global labels are obtained after calling partial_fit 46 # with None 47 brc_partial.set_params(n_clusters=3) 48 brc_partial.partial_fit(None) 49 assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_) 50 51 52def test_birch_predict(): 53 # Test the predict method predicts the nearest centroid. 54 rng = np.random.RandomState(0) 55 X = generate_clustered_data(n_clusters=3, n_features=3, n_samples_per_cluster=10) 56 57 # n_samples * n_samples_per_cluster 58 shuffle_indices = np.arange(30) 59 rng.shuffle(shuffle_indices) 60 X_shuffle = X[shuffle_indices, :] 61 brc = Birch(n_clusters=4, threshold=1.0) 62 brc.fit(X_shuffle) 63 centroids = brc.subcluster_centers_ 64 assert_array_equal(brc.labels_, brc.predict(X_shuffle)) 65 nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids) 66 assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0) 67 68 69def test_n_clusters(): 70 # Test that n_clusters param works properly 71 X, y = make_blobs(n_samples=100, centers=10) 72 brc1 = Birch(n_clusters=10) 73 brc1.fit(X) 74 assert len(brc1.subcluster_centers_) > 10 75 assert len(np.unique(brc1.labels_)) == 10 76 77 # Test that n_clusters = Agglomerative Clustering gives 78 # the same results. 79 gc = AgglomerativeClustering(n_clusters=10) 80 brc2 = Birch(n_clusters=gc) 81 brc2.fit(X) 82 assert_array_equal(brc1.subcluster_labels_, brc2.subcluster_labels_) 83 assert_array_equal(brc1.labels_, brc2.labels_) 84 85 # Test that the wrong global clustering step raises an Error. 86 clf = ElasticNet() 87 brc3 = Birch(n_clusters=clf) 88 with pytest.raises(ValueError): 89 brc3.fit(X) 90 91 # Test that a small number of clusters raises a warning. 92 brc4 = Birch(threshold=10000.0) 93 with pytest.warns(ConvergenceWarning): 94 brc4.fit(X) 95 96 97def test_sparse_X(): 98 # Test that sparse and dense data give same results 99 X, y = make_blobs(n_samples=100, centers=10) 100 brc = Birch(n_clusters=10) 101 brc.fit(X) 102 103 csr = sparse.csr_matrix(X) 104 brc_sparse = Birch(n_clusters=10) 105 brc_sparse.fit(csr) 106 107 assert_array_equal(brc.labels_, brc_sparse.labels_) 108 assert_array_almost_equal(brc.subcluster_centers_, brc_sparse.subcluster_centers_) 109 110 111def test_partial_fit_second_call_error_checks(): 112 # second partial fit calls will error when n_features is not consistent 113 # with the first call 114 X, y = make_blobs(n_samples=100) 115 brc = Birch(n_clusters=3) 116 brc.partial_fit(X, y) 117 118 msg = "X has 1 features, but Birch is expecting 2 features" 119 with pytest.raises(ValueError, match=msg): 120 brc.partial_fit(X[:, [0]], y) 121 122 123def check_branching_factor(node, branching_factor): 124 subclusters = node.subclusters_ 125 assert branching_factor >= len(subclusters) 126 for cluster in subclusters: 127 if cluster.child_: 128 check_branching_factor(cluster.child_, branching_factor) 129 130 131def test_branching_factor(): 132 # Test that nodes have at max branching_factor number of subclusters 133 X, y = make_blobs() 134 branching_factor = 9 135 136 # Purposefully set a low threshold to maximize the subclusters. 137 brc = Birch(n_clusters=None, branching_factor=branching_factor, threshold=0.01) 138 brc.fit(X) 139 check_branching_factor(brc.root_, branching_factor) 140 brc = Birch(n_clusters=3, branching_factor=branching_factor, threshold=0.01) 141 brc.fit(X) 142 check_branching_factor(brc.root_, branching_factor) 143 144 # Raises error when branching_factor is set to one. 145 brc = Birch(n_clusters=None, branching_factor=1, threshold=0.01) 146 with pytest.raises(ValueError): 147 brc.fit(X) 148 149 150def check_threshold(birch_instance, threshold): 151 """Use the leaf linked list for traversal""" 152 current_leaf = birch_instance.dummy_leaf_.next_leaf_ 153 while current_leaf: 154 subclusters = current_leaf.subclusters_ 155 for sc in subclusters: 156 assert threshold >= sc.radius 157 current_leaf = current_leaf.next_leaf_ 158 159 160def test_threshold(): 161 # Test that the leaf subclusters have a threshold lesser than radius 162 X, y = make_blobs(n_samples=80, centers=4) 163 brc = Birch(threshold=0.5, n_clusters=None) 164 brc.fit(X) 165 check_threshold(brc, 0.5) 166 167 brc = Birch(threshold=5.0, n_clusters=None) 168 brc.fit(X) 169 check_threshold(brc, 5.0) 170 171 172def test_birch_n_clusters_long_int(): 173 # Check that birch supports n_clusters with np.int64 dtype, for instance 174 # coming from np.arange. #16484 175 X, _ = make_blobs(random_state=0) 176 n_clusters = np.int64(5) 177 Birch(n_clusters=n_clusters).fit(X) 178 179 180# TODO: Remove in 1.2 181@pytest.mark.parametrize("attribute", ["fit_", "partial_fit_"]) 182def test_birch_fit_attributes_deprecated(attribute): 183 """Test that fit_ and partial_fit_ attributes are deprecated.""" 184 msg = f"`{attribute}` is deprecated in 1.0 and will be removed in 1.2" 185 X, y = make_blobs(n_samples=10) 186 brc = Birch().fit(X, y) 187 188 with pytest.warns(FutureWarning, match=msg): 189 getattr(brc, attribute) 190