1"""
2====================================================================
3Comparison of the K-Means and MiniBatchKMeans clustering algorithms
4====================================================================
5
6We want to compare the performance of the MiniBatchKMeans and KMeans:
7the MiniBatchKMeans is faster, but gives slightly different results (see
8:ref:`mini_batch_kmeans`).
9
10We will cluster a set of data, first with KMeans and then with
11MiniBatchKMeans, and plot the results.
12We will also plot the points that are labelled differently between the two
13algorithms.
14
15"""
16
17import time
18
19import numpy as np
20import matplotlib.pyplot as plt
21
22from sklearn.cluster import MiniBatchKMeans, KMeans
23from sklearn.metrics.pairwise import pairwise_distances_argmin
24from sklearn.datasets import make_blobs
25
26# #############################################################################
27# Generate sample data
28np.random.seed(0)
29
30batch_size = 45
31centers = [[1, 1], [-1, -1], [1, -1]]
32n_clusters = len(centers)
33X, labels_true = make_blobs(n_samples=3000, centers=centers, cluster_std=0.7)
34
35# #############################################################################
36# Compute clustering with Means
37
38k_means = KMeans(init="k-means++", n_clusters=3, n_init=10)
39t0 = time.time()
40k_means.fit(X)
41t_batch = time.time() - t0
42
43# #############################################################################
44# Compute clustering with MiniBatchKMeans
45
46mbk = MiniBatchKMeans(
47    init="k-means++",
48    n_clusters=3,
49    batch_size=batch_size,
50    n_init=10,
51    max_no_improvement=10,
52    verbose=0,
53)
54t0 = time.time()
55mbk.fit(X)
56t_mini_batch = time.time() - t0
57
58# #############################################################################
59# Plot result
60
61fig = plt.figure(figsize=(8, 3))
62fig.subplots_adjust(left=0.02, right=0.98, bottom=0.05, top=0.9)
63colors = ["#4EACC5", "#FF9C34", "#4E9A06"]
64
65# We want to have the same colors for the same cluster from the
66# MiniBatchKMeans and the KMeans algorithm. Let's pair the cluster centers per
67# closest one.
68k_means_cluster_centers = k_means.cluster_centers_
69order = pairwise_distances_argmin(k_means.cluster_centers_, mbk.cluster_centers_)
70mbk_means_cluster_centers = mbk.cluster_centers_[order]
71
72k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers)
73mbk_means_labels = pairwise_distances_argmin(X, mbk_means_cluster_centers)
74
75# KMeans
76ax = fig.add_subplot(1, 3, 1)
77for k, col in zip(range(n_clusters), colors):
78    my_members = k_means_labels == k
79    cluster_center = k_means_cluster_centers[k]
80    ax.plot(X[my_members, 0], X[my_members, 1], "w", markerfacecolor=col, marker=".")
81    ax.plot(
82        cluster_center[0],
83        cluster_center[1],
84        "o",
85        markerfacecolor=col,
86        markeredgecolor="k",
87        markersize=6,
88    )
89ax.set_title("KMeans")
90ax.set_xticks(())
91ax.set_yticks(())
92plt.text(-3.5, 1.8, "train time: %.2fs\ninertia: %f" % (t_batch, k_means.inertia_))
93
94# MiniBatchKMeans
95ax = fig.add_subplot(1, 3, 2)
96for k, col in zip(range(n_clusters), colors):
97    my_members = mbk_means_labels == k
98    cluster_center = mbk_means_cluster_centers[k]
99    ax.plot(X[my_members, 0], X[my_members, 1], "w", markerfacecolor=col, marker=".")
100    ax.plot(
101        cluster_center[0],
102        cluster_center[1],
103        "o",
104        markerfacecolor=col,
105        markeredgecolor="k",
106        markersize=6,
107    )
108ax.set_title("MiniBatchKMeans")
109ax.set_xticks(())
110ax.set_yticks(())
111plt.text(-3.5, 1.8, "train time: %.2fs\ninertia: %f" % (t_mini_batch, mbk.inertia_))
112
113# Initialise the different array to all False
114different = mbk_means_labels == 4
115ax = fig.add_subplot(1, 3, 3)
116
117for k in range(n_clusters):
118    different += (k_means_labels == k) != (mbk_means_labels == k)
119
120identic = np.logical_not(different)
121ax.plot(X[identic, 0], X[identic, 1], "w", markerfacecolor="#bbbbbb", marker=".")
122ax.plot(X[different, 0], X[different, 1], "w", markerfacecolor="m", marker=".")
123ax.set_title("Difference")
124ax.set_xticks(())
125ax.set_yticks(())
126
127plt.show()
128