1""" 2=========================================================== 3Hierarchical clustering: structured vs unstructured ward 4=========================================================== 5 6Example builds a swiss roll dataset and runs 7hierarchical clustering on their position. 8 9For more information, see :ref:`hierarchical_clustering`. 10 11In a first step, the hierarchical clustering is performed without connectivity 12constraints on the structure and is solely based on distance, whereas in 13a second step the clustering is restricted to the k-Nearest Neighbors 14graph: it's a hierarchical clustering with structure prior. 15 16Some of the clusters learned without connectivity constraints do not 17respect the structure of the swiss roll and extend across different folds of 18the manifolds. On the opposite, when opposing connectivity constraints, 19the clusters form a nice parcellation of the swiss roll. 20 21""" 22 23# Authors : Vincent Michel, 2010 24# Alexandre Gramfort, 2010 25# Gael Varoquaux, 2010 26# License: BSD 3 clause 27 28import time as time 29import numpy as np 30import matplotlib.pyplot as plt 31import mpl_toolkits.mplot3d.axes3d as p3 32from sklearn.cluster import AgglomerativeClustering 33from sklearn.datasets import make_swiss_roll 34 35# ############################################################################# 36# Generate data (swiss roll dataset) 37n_samples = 1500 38noise = 0.05 39X, _ = make_swiss_roll(n_samples, noise=noise) 40# Make it thinner 41X[:, 1] *= 0.5 42 43# ############################################################################# 44# Compute clustering 45print("Compute unstructured hierarchical clustering...") 46st = time.time() 47ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X) 48elapsed_time = time.time() - st 49label = ward.labels_ 50print("Elapsed time: %.2fs" % elapsed_time) 51print("Number of points: %i" % label.size) 52 53# ############################################################################# 54# Plot result 55fig = plt.figure() 56ax = p3.Axes3D(fig) 57ax.view_init(7, -80) 58for l in np.unique(label): 59 ax.scatter( 60 X[label == l, 0], 61 X[label == l, 1], 62 X[label == l, 2], 63 color=plt.cm.jet(float(l) / np.max(label + 1)), 64 s=20, 65 edgecolor="k", 66 ) 67plt.title("Without connectivity constraints (time %.2fs)" % elapsed_time) 68 69 70# ############################################################################# 71# Define the structure A of the data. Here a 10 nearest neighbors 72from sklearn.neighbors import kneighbors_graph 73 74connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False) 75 76# ############################################################################# 77# Compute clustering 78print("Compute structured hierarchical clustering...") 79st = time.time() 80ward = AgglomerativeClustering( 81 n_clusters=6, connectivity=connectivity, linkage="ward" 82).fit(X) 83elapsed_time = time.time() - st 84label = ward.labels_ 85print("Elapsed time: %.2fs" % elapsed_time) 86print("Number of points: %i" % label.size) 87 88# ############################################################################# 89# Plot result 90fig = plt.figure() 91ax = p3.Axes3D(fig) 92ax.view_init(7, -80) 93for l in np.unique(label): 94 ax.scatter( 95 X[label == l, 0], 96 X[label == l, 1], 97 X[label == l, 2], 98 color=plt.cm.jet(float(l) / np.max(label + 1)), 99 s=20, 100 edgecolor="k", 101 ) 102plt.title("With connectivity constraints (time %.2fs)" % elapsed_time) 103 104plt.show() 105