1""" 2==================================================================== 3Comparison of the K-Means and MiniBatchKMeans clustering algorithms 4==================================================================== 5 6We want to compare the performance of the MiniBatchKMeans and KMeans: 7the MiniBatchKMeans is faster, but gives slightly different results (see 8:ref:`mini_batch_kmeans`). 9 10We will cluster a set of data, first with KMeans and then with 11MiniBatchKMeans, and plot the results. 12We will also plot the points that are labelled differently between the two 13algorithms. 14 15""" 16 17import time 18 19import numpy as np 20import matplotlib.pyplot as plt 21 22from sklearn.cluster import MiniBatchKMeans, KMeans 23from sklearn.metrics.pairwise import pairwise_distances_argmin 24from sklearn.datasets import make_blobs 25 26# ############################################################################# 27# Generate sample data 28np.random.seed(0) 29 30batch_size = 45 31centers = [[1, 1], [-1, -1], [1, -1]] 32n_clusters = len(centers) 33X, labels_true = make_blobs(n_samples=3000, centers=centers, cluster_std=0.7) 34 35# ############################################################################# 36# Compute clustering with Means 37 38k_means = KMeans(init="k-means++", n_clusters=3, n_init=10) 39t0 = time.time() 40k_means.fit(X) 41t_batch = time.time() - t0 42 43# ############################################################################# 44# Compute clustering with MiniBatchKMeans 45 46mbk = MiniBatchKMeans( 47 init="k-means++", 48 n_clusters=3, 49 batch_size=batch_size, 50 n_init=10, 51 max_no_improvement=10, 52 verbose=0, 53) 54t0 = time.time() 55mbk.fit(X) 56t_mini_batch = time.time() - t0 57 58# ############################################################################# 59# Plot result 60 61fig = plt.figure(figsize=(8, 3)) 62fig.subplots_adjust(left=0.02, right=0.98, bottom=0.05, top=0.9) 63colors = ["#4EACC5", "#FF9C34", "#4E9A06"] 64 65# We want to have the same colors for the same cluster from the 66# MiniBatchKMeans and the KMeans algorithm. Let's pair the cluster centers per 67# closest one. 68k_means_cluster_centers = k_means.cluster_centers_ 69order = pairwise_distances_argmin(k_means.cluster_centers_, mbk.cluster_centers_) 70mbk_means_cluster_centers = mbk.cluster_centers_[order] 71 72k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers) 73mbk_means_labels = pairwise_distances_argmin(X, mbk_means_cluster_centers) 74 75# KMeans 76ax = fig.add_subplot(1, 3, 1) 77for k, col in zip(range(n_clusters), colors): 78 my_members = k_means_labels == k 79 cluster_center = k_means_cluster_centers[k] 80 ax.plot(X[my_members, 0], X[my_members, 1], "w", markerfacecolor=col, marker=".") 81 ax.plot( 82 cluster_center[0], 83 cluster_center[1], 84 "o", 85 markerfacecolor=col, 86 markeredgecolor="k", 87 markersize=6, 88 ) 89ax.set_title("KMeans") 90ax.set_xticks(()) 91ax.set_yticks(()) 92plt.text(-3.5, 1.8, "train time: %.2fs\ninertia: %f" % (t_batch, k_means.inertia_)) 93 94# MiniBatchKMeans 95ax = fig.add_subplot(1, 3, 2) 96for k, col in zip(range(n_clusters), colors): 97 my_members = mbk_means_labels == k 98 cluster_center = mbk_means_cluster_centers[k] 99 ax.plot(X[my_members, 0], X[my_members, 1], "w", markerfacecolor=col, marker=".") 100 ax.plot( 101 cluster_center[0], 102 cluster_center[1], 103 "o", 104 markerfacecolor=col, 105 markeredgecolor="k", 106 markersize=6, 107 ) 108ax.set_title("MiniBatchKMeans") 109ax.set_xticks(()) 110ax.set_yticks(()) 111plt.text(-3.5, 1.8, "train time: %.2fs\ninertia: %f" % (t_mini_batch, mbk.inertia_)) 112 113# Initialise the different array to all False 114different = mbk_means_labels == 4 115ax = fig.add_subplot(1, 3, 3) 116 117for k in range(n_clusters): 118 different += (k_means_labels == k) != (mbk_means_labels == k) 119 120identic = np.logical_not(different) 121ax.plot(X[identic, 0], X[identic, 1], "w", markerfacecolor="#bbbbbb", marker=".") 122ax.plot(X[different, 0], X[different, 1], "w", markerfacecolor="m", marker=".") 123ax.set_title("Difference") 124ax.set_xticks(()) 125ax.set_yticks(()) 126 127plt.show() 128