1"""
2===========================================================
3Hierarchical clustering: structured vs unstructured ward
4===========================================================
5
6Example builds a swiss roll dataset and runs
7hierarchical clustering on their position.
8
9For more information, see :ref:`hierarchical_clustering`.
10
11In a first step, the hierarchical clustering is performed without connectivity
12constraints on the structure and is solely based on distance, whereas in
13a second step the clustering is restricted to the k-Nearest Neighbors
14graph: it's a hierarchical clustering with structure prior.
15
16Some of the clusters learned without connectivity constraints do not
17respect the structure of the swiss roll and extend across different folds of
18the manifolds. On the opposite, when opposing connectivity constraints,
19the clusters form a nice parcellation of the swiss roll.
20
21"""
22
23# Authors : Vincent Michel, 2010
24#           Alexandre Gramfort, 2010
25#           Gael Varoquaux, 2010
26# License: BSD 3 clause
27
28import time as time
29import numpy as np
30import matplotlib.pyplot as plt
31import mpl_toolkits.mplot3d.axes3d as p3
32from sklearn.cluster import AgglomerativeClustering
33from sklearn.datasets import make_swiss_roll
34
35# #############################################################################
36# Generate data (swiss roll dataset)
37n_samples = 1500
38noise = 0.05
39X, _ = make_swiss_roll(n_samples, noise=noise)
40# Make it thinner
41X[:, 1] *= 0.5
42
43# #############################################################################
44# Compute clustering
45print("Compute unstructured hierarchical clustering...")
46st = time.time()
47ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X)
48elapsed_time = time.time() - st
49label = ward.labels_
50print("Elapsed time: %.2fs" % elapsed_time)
51print("Number of points: %i" % label.size)
52
53# #############################################################################
54# Plot result
55fig = plt.figure()
56ax = p3.Axes3D(fig)
57ax.view_init(7, -80)
58for l in np.unique(label):
59    ax.scatter(
60        X[label == l, 0],
61        X[label == l, 1],
62        X[label == l, 2],
63        color=plt.cm.jet(float(l) / np.max(label + 1)),
64        s=20,
65        edgecolor="k",
66    )
67plt.title("Without connectivity constraints (time %.2fs)" % elapsed_time)
68
69
70# #############################################################################
71# Define the structure A of the data. Here a 10 nearest neighbors
72from sklearn.neighbors import kneighbors_graph
73
74connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
75
76# #############################################################################
77# Compute clustering
78print("Compute structured hierarchical clustering...")
79st = time.time()
80ward = AgglomerativeClustering(
81    n_clusters=6, connectivity=connectivity, linkage="ward"
82).fit(X)
83elapsed_time = time.time() - st
84label = ward.labels_
85print("Elapsed time: %.2fs" % elapsed_time)
86print("Number of points: %i" % label.size)
87
88# #############################################################################
89# Plot result
90fig = plt.figure()
91ax = p3.Axes3D(fig)
92ax.view_init(7, -80)
93for l in np.unique(label):
94    ax.scatter(
95        X[label == l, 0],
96        X[label == l, 1],
97        X[label == l, 2],
98        color=plt.cm.jet(float(l) / np.max(label + 1)),
99        s=20,
100        edgecolor="k",
101    )
102plt.title("With connectivity constraints (time %.2fs)" % elapsed_time)
103
104plt.show()
105