1# -*- coding: utf-8 -*-
2# Author: Leland McInnes <leland.mcinnes@gmail.com>
3#
4# License: BSD 3 clause
5
6import numpy as np
7
8from scipy.cluster.hierarchy import dendrogram
9from sklearn.manifold import TSNE
10from sklearn.decomposition import PCA
11from sklearn.metrics import pairwise_distances
12from warnings import warn
13from ._hdbscan_tree import compute_stability, labelling_at_cut, recurse_leaf_dfs
14
15CB_LEFT = 0
16CB_RIGHT = 1
17CB_BOTTOM = 2
18CB_TOP = 3
19
20
21def _bfs_from_cluster_tree(tree, bfs_root):
22    """
23    Perform a breadth first search on a tree in condensed tree format
24    """
25
26    result = []
27    to_process = [bfs_root]
28
29    while to_process:
30        result.extend(to_process)
31        to_process = tree['child'][np.in1d(tree['parent'], to_process)].tolist()
32
33    return result
34
35def _recurse_leaf_dfs(cluster_tree, current_node):
36    children = cluster_tree[cluster_tree['parent'] == current_node]['child']
37    if len(children) == 0:
38        return [current_node,]
39    else:
40        return sum([recurse_leaf_dfs(cluster_tree, child) for child in children], [])
41
42def _get_leaves(condensed_tree):
43    cluster_tree = condensed_tree[condensed_tree['child_size'] > 1]
44    if cluster_tree.shape[0] == 0:
45        # Return the only cluster, the root
46        return [condensed_tree['parent'].min()]
47
48    root = cluster_tree['parent'].min()
49    return _recurse_leaf_dfs(cluster_tree, root)
50
51class CondensedTree(object):
52    """The condensed tree structure, which provides a simplified or smoothed version
53    of the :class:`~hdbscan.plots.SingleLinkageTree`.
54
55    Parameters
56    ----------
57    condensed_tree_array : numpy recarray from :class:`~hdbscan.HDBSCAN`
58        The raw numpy rec array version of the condensed tree as produced
59        internally by hdbscan.
60
61    cluster_selection_method : string, optional (default 'eom')
62        The method of selecting clusters. One of 'eom' or 'leaf'
63
64    allow_single_cluster : Boolean, optional (default False)
65        Whether to allow the root cluster as the only selected cluster
66
67    """
68    def __init__(self, condensed_tree_array, cluster_selection_method='eom',
69                 allow_single_cluster=False):
70        self._raw_tree = condensed_tree_array
71        self.cluster_selection_method = cluster_selection_method
72        self.allow_single_cluster = allow_single_cluster
73
74    def get_plot_data(self,
75                      leaf_separation=1,
76                      log_size=False,
77                      max_rectangle_per_icicle=20):
78        """Generates data for use in plotting the 'icicle plot' or dendrogram
79        plot of the condensed tree generated by HDBSCAN.
80
81        Parameters
82        ----------
83        leaf_separation : float, optional
84                          How far apart to space the final leaves of the
85                          dendrogram. (default 1)
86
87        log_size : boolean, optional
88                   Use log scale for the 'size' of clusters (i.e. number of
89                   points in the cluster at a given lambda value).
90                   (default False)
91
92        max_rectangles_per_icicle : int, optional
93            To simplify the plot this method will only emit
94            ``max_rectangles_per_icicle`` bars per branch of the dendrogram.
95            This ensures that we don't suffer from massive overplotting in
96            cases with a lot of data points.
97
98        Returns
99        -------
100        plot_data : dict
101                    Data associated to bars in a bar plot:
102                        `bar_centers` x coordinate centers for bars
103                        `bar_tops` heights of bars in lambda scale
104                        `bar_bottoms` y coordinate of bottoms of bars
105                        `bar_widths` widths of the bars (in x coord scale)
106                        `bar_bounds` a 4-tuple of [left, right, bottom, top]
107                                     giving the bounds on a full set of
108                                     cluster bars
109                    Data associates with cluster splits:
110                        `line_xs` x coordinates for horizontal dendrogram lines
111                        `line_ys` y coordinates for horizontal dendrogram lines
112        """
113        leaves = _get_leaves(self._raw_tree)
114        last_leaf = self._raw_tree['parent'].max()
115        root = self._raw_tree['parent'].min()
116
117        # We want to get the x and y coordinates for the start of each cluster
118        # Initialize the leaves, since we know where they go, the iterate
119        # through everything from the leaves back, setting coords as we go
120        if isinstance(leaves, np.int64):
121            cluster_x_coords = {leaves: leaf_separation}
122        else:
123            cluster_x_coords = dict(zip(leaves, [leaf_separation * x
124                                                 for x in range(len(leaves))]))
125        cluster_y_coords = {root: 0.0}
126
127        for cluster in range(last_leaf, root - 1, -1):
128            split = self._raw_tree[['child', 'lambda_val']]
129            split = split[(self._raw_tree['parent'] == cluster) &
130                          (self._raw_tree['child_size'] > 1)]
131            if len(split['child']) > 1:
132                left_child, right_child = split['child']
133                cluster_x_coords[cluster] = np.mean([cluster_x_coords[left_child],
134                                                     cluster_x_coords[right_child]])
135                cluster_y_coords[left_child] = split['lambda_val'][0]
136                cluster_y_coords[right_child] = split['lambda_val'][1]
137
138        # We use bars to plot the 'icicles', so we need to generate centers, tops,
139        # bottoms and widths for each rectangle. We can go through each cluster
140        # and do this for each in turn.
141        bar_centers = []
142        bar_tops = []
143        bar_bottoms = []
144        bar_widths = []
145
146        cluster_bounds = {}
147
148        scaling = np.sum(self._raw_tree[self._raw_tree['parent'] == root]['child_size'])
149
150        if log_size:
151            scaling = np.log(scaling)
152
153        for c in range(last_leaf, root - 1, -1):
154
155            cluster_bounds[c] = [0, 0, 0, 0]
156
157            c_children = self._raw_tree[self._raw_tree['parent'] == c]
158            current_size = np.sum(c_children['child_size'])
159            current_lambda = cluster_y_coords[c]
160            cluster_max_size = current_size
161            cluster_max_lambda = c_children['lambda_val'].max()
162            cluster_min_size = np.sum(
163                c_children[c_children['lambda_val'] ==
164                           cluster_max_lambda]['child_size'])
165
166            if log_size:
167                current_size = np.log(current_size)
168                cluster_max_size = np.log(cluster_max_size)
169                cluster_min_size = np.log(cluster_min_size)
170
171            total_size_change = float(cluster_max_size - cluster_min_size)
172            step_size_change = total_size_change / max_rectangle_per_icicle
173
174            cluster_bounds[c][CB_LEFT] = cluster_x_coords[c] * scaling - (current_size / 2.0)
175            cluster_bounds[c][CB_RIGHT] = cluster_x_coords[c] * scaling + (current_size / 2.0)
176            cluster_bounds[c][CB_BOTTOM] = cluster_y_coords[c]
177            cluster_bounds[c][CB_TOP] = np.max(c_children['lambda_val'])
178
179            last_step_size = current_size
180            last_step_lambda = current_lambda
181
182            for i in np.argsort(c_children['lambda_val']):
183                row = c_children[i]
184                if row['lambda_val'] != current_lambda and \
185                        (last_step_size - current_size > step_size_change
186                        or row['lambda_val'] == cluster_max_lambda):
187                    bar_centers.append(cluster_x_coords[c] * scaling)
188                    bar_tops.append(row['lambda_val'] - last_step_lambda)
189                    bar_bottoms.append(last_step_lambda)
190                    bar_widths.append(last_step_size)
191                    last_step_size = current_size
192                    last_step_lambda = current_lambda
193                if log_size:
194                    exp_size = np.exp(current_size) - row['child_size']
195                    # Ensure we don't try to take log of zero
196                    if exp_size > 0.01:
197                        current_size = np.log(np.exp(current_size) - row['child_size'])
198                    else:
199                        current_size = 0.0
200                else:
201                    current_size -= row['child_size']
202                current_lambda = row['lambda_val']
203
204        # Finally we need the horizontal lines that occur at cluster splits.
205        line_xs = []
206        line_ys = []
207
208        for row in self._raw_tree[self._raw_tree['child_size'] > 1]:
209            parent = row['parent']
210            child = row['child']
211            child_size = row['child_size']
212            if log_size:
213                child_size = np.log(child_size)
214            sign = np.sign(cluster_x_coords[child] - cluster_x_coords[parent])
215            line_xs.append([
216                cluster_x_coords[parent] * scaling,
217                cluster_x_coords[child] * scaling + sign * (child_size / 2.0)
218            ])
219            line_ys.append([
220                cluster_y_coords[child],
221                cluster_y_coords[child]
222            ])
223
224        return {
225            'bar_centers': bar_centers,
226            'bar_tops': bar_tops,
227            'bar_bottoms': bar_bottoms,
228            'bar_widths': bar_widths,
229            'line_xs': line_xs,
230            'line_ys': line_ys,
231            'cluster_bounds': cluster_bounds
232        }
233
234    def _select_clusters(self):
235        if self.cluster_selection_method == 'eom':
236            stability = compute_stability(self._raw_tree)
237            if self.allow_single_cluster:
238                node_list = sorted(stability.keys(), reverse=True)
239            else:
240                node_list = sorted(stability.keys(), reverse=True)[:-1]
241            cluster_tree = self._raw_tree[self._raw_tree['child_size'] > 1]
242            is_cluster = {cluster: True for cluster in node_list}
243
244            for node in node_list:
245                child_selection = (cluster_tree['parent'] == node)
246                subtree_stability = np.sum([stability[child] for
247                                            child in cluster_tree['child'][child_selection]])
248
249                if subtree_stability > stability[node]:
250                    is_cluster[node] = False
251                    stability[node] = subtree_stability
252                else:
253                    for sub_node in _bfs_from_cluster_tree(cluster_tree, node):
254                        if sub_node != node:
255                            is_cluster[sub_node] = False
256
257            return sorted([cluster
258                           for cluster in is_cluster
259                           if is_cluster[cluster]])
260
261        elif self.cluster_selection_method == 'leaf':
262            return _get_leaves(self._raw_tree)
263        else:
264            raise ValueError('Invalid Cluster Selection Method: %s\n'
265                             'Should be one of: "eom", "leaf"\n')
266
267    def plot(self, leaf_separation=1, cmap='viridis', select_clusters=False,
268             label_clusters=False, selection_palette=None,
269             axis=None, colorbar=True, log_size=False,
270             max_rectangles_per_icicle=20):
271        """Use matplotlib to plot an 'icicle plot' dendrogram of the condensed tree.
272
273        Effectively this is a dendrogram where the width of each cluster bar is
274        equal to the number of points (or log of the number of points) in the cluster
275        at the given lambda value. Thus bars narrow as points progressively drop
276        out of clusters. The make the effect more apparent the bars are also colored
277        according the the number of points (or log of the number of points).
278
279        Parameters
280        ----------
281        leaf_separation : float, optional (default 1)
282                          How far apart to space the final leaves of the
283                          dendrogram.
284
285        cmap : string or matplotlib colormap, optional (default viridis)
286               The matplotlib colormap to use to color the cluster bars.
287
288
289        select_clusters : boolean, optional (default False)
290                          Whether to draw ovals highlighting which cluster
291                          bar represent the clusters that were selected by
292                          HDBSCAN as the final clusters.
293
294        label_clusters : boolean, optional (default False)
295                         If select_clusters is True then this determines
296                         whether to draw text labels on the clusters.
297
298        selection_palette : list of colors, optional (default None)
299                            If not None, and at least as long as
300                            the number of clusters, draw ovals
301                            in colors iterating through this palette.
302                            This can aid in cluster identification
303                            when plotting.
304
305        axis : matplotlib axis or None, optional (default None)
306               The matplotlib axis to render to. If None then a new axis
307               will be generated. The rendered axis will be returned.
308
309
310        colorbar : boolean, optional (default True)
311                   Whether to draw a matplotlib colorbar displaying the range
312                   of cluster sizes as per the colormap.
313
314        log_size : boolean, optional (default False)
315                   Use log scale for the 'size' of clusters (i.e. number of
316                   points in the cluster at a given lambda value).
317
318
319        max_rectangles_per_icicle : int, optional (default 20)
320            To simplify the plot this method will only emit
321            ``max_rectangles_per_icicle`` bars per branch of the dendrogram.
322            This ensures that we don't suffer from massive overplotting in
323            cases with a lot of data points.
324
325         Returns
326        -------
327        axis : matplotlib axis
328               The axis on which the 'icicle plot' has been rendered.
329        """
330        try:
331            import matplotlib.pyplot as plt
332        except ImportError:
333            raise ImportError(
334                'You must install the matplotlib library to plot the condensed tree.'
335                'Use get_plot_data to calculate the relevant data without plotting.')
336
337        plot_data = self.get_plot_data(leaf_separation=leaf_separation,
338                                       log_size=log_size,
339                                       max_rectangle_per_icicle=max_rectangles_per_icicle)
340
341        if cmap != 'none':
342            sm = plt.cm.ScalarMappable(cmap=cmap,
343                                       norm=plt.Normalize(0, max(plot_data['bar_widths'])))
344            sm.set_array(plot_data['bar_widths'])
345            bar_colors = [sm.to_rgba(x) for x in plot_data['bar_widths']]
346        else:
347            bar_colors = 'black'
348
349        if axis is None:
350            axis = plt.gca()
351
352        axis.bar(
353            plot_data['bar_centers'],
354            plot_data['bar_tops'],
355            bottom=plot_data['bar_bottoms'],
356            width=plot_data['bar_widths'],
357            color=bar_colors,
358            align='center',
359            linewidth=0
360        )
361
362        drawlines = []
363        for xs, ys in zip(plot_data['line_xs'], plot_data['line_ys']):
364            drawlines.append(xs)
365            drawlines.append(ys)
366        axis.plot(*drawlines, color='black', linewidth=1)
367        # for xs, ys in zip(plot_data['line_xs'], plot_data['line_ys']):
368        #     axis.plot(xs, ys, color='black', linewidth=1)
369
370        if select_clusters:
371            try:
372                from matplotlib.patches import Ellipse
373            except ImportError:
374                raise ImportError('You must have matplotlib.patches available to plot selected clusters.')
375
376            chosen_clusters = self._select_clusters()
377
378            # Extract the chosen cluster bounds. If enough duplicate data points exist in the
379            # data the lambda value might be infinite. This breaks labeling and highlighting
380            # the chosen clusters.
381            cluster_bounds = np.array([ plot_data['cluster_bounds'][c] for c in chosen_clusters ])
382            if not np.isfinite(cluster_bounds).all():
383                warn('Infinite lambda values encountered in chosen clusters.'
384                     ' This might be due to duplicates in the data.')
385
386            # Extract the plot range of the y-axis and set default center and height values for ellipses.
387            # Extremly dense clusters might result in near infinite lambda values. Setting max_height
388            # based on the percentile should alleviate the impact on plotting.
389            plot_range = np.hstack([plot_data['bar_tops'], plot_data['bar_bottoms']])
390            plot_range = plot_range[np.isfinite(plot_range)]
391            mean_y_center = np.mean([np.max(plot_range), np.min(plot_range)])
392            max_height = np.diff(np.percentile(plot_range, q=[10,90]))
393
394            for i, c in enumerate(chosen_clusters):
395                c_bounds = plot_data['cluster_bounds'][c]
396                width = (c_bounds[CB_RIGHT] - c_bounds[CB_LEFT])
397                height = (c_bounds[CB_TOP] - c_bounds[CB_BOTTOM])
398                center = (
399                    np.mean([c_bounds[CB_LEFT], c_bounds[CB_RIGHT]]),
400                    np.mean([c_bounds[CB_TOP], c_bounds[CB_BOTTOM]]),
401                )
402
403                # Set center and height to default values if necessary
404                if not np.isfinite(center[1]):
405                    center = (center[0], mean_y_center)
406                if not np.isfinite(height):
407                    height = max_height
408
409                # Ensure the ellipse is visible
410                min_height = 0.1*max_height
411                if height < min_height:
412                    height = min_height
413
414                if selection_palette is not None and \
415                        len(selection_palette) >= len(chosen_clusters):
416                    oval_color = selection_palette[i]
417                else:
418                    oval_color = 'r'
419
420                box = Ellipse(
421                    center,
422                    2.0 * width,
423                    1.2 * height,
424                    facecolor='none',
425                    edgecolor=oval_color,
426                    linewidth=2
427                )
428
429                if label_clusters:
430                    axis.annotate(str(i), xy=center,
431                                  xytext=(center[0] - 4.0 * width, center[1] + 0.65 * height),
432                                  horizontalalignment='left',
433                                  verticalalignment='bottom')
434
435                axis.add_artist(box)
436
437        if colorbar:
438            cb = plt.colorbar(sm, ax=axis)
439            if log_size:
440                cb.ax.set_ylabel('log(Number of points)')
441            else:
442                cb.ax.set_ylabel('Number of points')
443
444        axis.set_xticks([])
445        for side in ('right', 'top', 'bottom'):
446            axis.spines[side].set_visible(False)
447        axis.invert_yaxis()
448        axis.set_ylabel('$\lambda$ value')
449
450        return axis
451
452    def to_numpy(self):
453        """Return a numpy structured array representation of the condensed tree.
454        """
455        return self._raw_tree.copy()
456
457    def to_pandas(self):
458        """Return a pandas dataframe representation of the condensed tree.
459
460        Each row of the dataframe corresponds to an edge in the tree.
461        The columns of the dataframe are `parent`, `child`, `lambda_val`
462        and `child_size`.
463
464        The `parent` and `child` are the ids of the
465        parent and child nodes in the tree. Node ids less than the number
466        of points in the original dataset represent individual points, while
467        ids greater than the number of points are clusters.
468
469        The `lambda_val` value is the value (1/distance) at which the `child`
470        node leaves the cluster.
471
472        The `child_size` is the number of points in the `child` node.
473        """
474        try:
475            from pandas import DataFrame, Series
476        except ImportError:
477            raise ImportError('You must have pandas installed to export pandas DataFrames')
478
479        result = DataFrame(self._raw_tree)
480
481        return result
482
483    def to_networkx(self):
484        """Return a NetworkX DiGraph object representing the condensed tree.
485
486        Edge weights in the graph are the lamba values at which child nodes
487        'leave' the parent cluster.
488
489        Nodes have a `size` attribute attached giving the number of points
490        that are in the cluster (or 1 if it is a singleton point) at the
491        point of cluster creation (fewer points may be in the cluster at
492        larger lambda values).
493        """
494        try:
495            from networkx import DiGraph, set_node_attributes
496        except ImportError:
497            raise ImportError('You must have networkx installed to export networkx graphs')
498
499        result = DiGraph()
500        for row in self._raw_tree:
501            result.add_edge(row['parent'], row['child'], weight=row['lambda_val'])
502
503        set_node_attributes(result, dict(self._raw_tree[['child', 'child_size']]), 'size')
504
505        return result
506
507
508def _get_dendrogram_ordering(parent, linkage, root):
509
510    if parent < root:
511        return []
512
513    return _get_dendrogram_ordering(int(linkage[parent-root][0]), linkage, root) + \
514            _get_dendrogram_ordering(int(linkage[parent-root][1]), linkage, root) + [parent]
515
516def _calculate_linewidths(ordering, linkage, root):
517
518    linewidths = []
519
520    for x in ordering:
521        if linkage[x - root][0] >= root:
522            left_width = linkage[int(linkage[x - root][0]) - root][3]
523        else:
524            left_width = 1
525
526        if linkage[x - root][1] >= root:
527            right_width = linkage[int(linkage[x - root][1]) - root][3]
528        else:
529            right_width = 1
530
531        linewidths.append((left_width, right_width))
532
533    return linewidths
534
535
536class SingleLinkageTree(object):
537    """A single linkage format dendrogram tree, with plotting functionality
538    and networkX support.
539
540    Parameters
541    ----------
542    linkage : ndarray (n_samples, 4)
543        The numpy array that holds the tree structure. As output by
544        scipy.cluster.hierarchy, hdbscan, of fastcluster.
545
546    """
547    def __init__(self, linkage):
548        self._linkage = linkage
549
550    def plot(self, axis=None, truncate_mode=None, p=0, vary_line_width=True,
551             cmap='viridis', colorbar=True):
552        """Plot a dendrogram of the single linkage tree.
553
554        Parameters
555        ----------
556        truncate_mode : str, optional
557                        The dendrogram can be hard to read when the original
558                        observation matrix from which the linkage is derived
559                        is large. Truncation is used to condense the dendrogram.
560                        There are several modes:
561
562        ``None/'none'``
563                No truncation is performed (Default).
564
565        ``'lastp'``
566                The last p non-singleton formed in the linkage are the only
567                non-leaf nodes in the linkage; they correspond to rows
568                Z[n-p-2:end] in Z. All other non-singleton clusters are
569                contracted into leaf nodes.
570
571        ``'level'/'mtica'``
572                No more than p levels of the dendrogram tree are displayed.
573                This corresponds to Mathematica(TM) behavior.
574
575        p : int, optional
576            The ``p`` parameter for ``truncate_mode``.
577
578        vary_line_width : boolean, optional
579            Draw downward branches of the dendrogram with line thickness that
580            varies depending on the size of the cluster.
581
582        cmap : string or matplotlib colormap, optional
583               The matplotlib colormap to use to color the cluster bars.
584               A value of 'none' will result in black bars.
585               (default 'viridis')
586
587        colorbar : boolean, optional
588                   Whether to draw a matplotlib colorbar displaying the range
589                   of cluster sizes as per the colormap. (default True)
590
591        Returns
592        -------
593        axis : matplotlib axis
594               The axis on which the dendrogram plot has been rendered.
595
596        """
597        dendrogram_data = dendrogram(self._linkage, p=p, truncate_mode=truncate_mode, no_plot=True)
598        X = dendrogram_data['icoord']
599        Y = dendrogram_data['dcoord']
600
601        try:
602            import matplotlib.pyplot as plt
603        except ImportError:
604            raise ImportError('You must install the matplotlib library to plot the single linkage tree.')
605
606        if axis is None:
607            axis = plt.gca()
608
609        if vary_line_width:
610            dendrogram_ordering = _get_dendrogram_ordering(2 * len(self._linkage), self._linkage, len(self._linkage) + 1)
611            linewidths = _calculate_linewidths(dendrogram_ordering, self._linkage, len(self._linkage) + 1)
612        else:
613            linewidths = [(1.0, 1.0)] * len(Y)
614
615        if cmap != 'none':
616            color_array = np.log2(np.array(linewidths).flatten())
617            sm = plt.cm.ScalarMappable(cmap=cmap,
618                                       norm=plt.Normalize(0, color_array.max()))
619            sm.set_array(color_array)
620
621        for x, y, lw in zip(X, Y, linewidths):
622            left_x = x[:2]
623            right_x = x[2:]
624            left_y = y[:2]
625            right_y = y[2:]
626            horizontal_x = x[1:3]
627            horizontal_y = y[1:3]
628
629            if cmap != 'none':
630                axis.plot(left_x, left_y, color=sm.to_rgba(np.log2(lw[0])),
631                          linewidth=np.log2(1 + lw[0]),
632                          solid_joinstyle='miter', solid_capstyle='butt')
633                axis.plot(right_x, right_y, color=sm.to_rgba(np.log2(lw[1])),
634                          linewidth=np.log2(1 + lw[1]),
635                          solid_joinstyle='miter', solid_capstyle='butt')
636            else:
637                axis.plot(left_x, left_y, color='k',
638                          linewidth=np.log2(1 + lw[0]),
639                          solid_joinstyle='miter', solid_capstyle='butt')
640                axis.plot(right_x, right_y, color='k',
641                          linewidth=np.log2(1 + lw[1]),
642                          solid_joinstyle='miter', solid_capstyle='butt')
643
644            axis.plot(horizontal_x, horizontal_y, color='k', linewidth=1.0,
645                      solid_joinstyle='miter', solid_capstyle='butt')
646
647        if colorbar:
648            cb = plt.colorbar(sm, ax=axis)
649            cb.ax.set_ylabel('log(Number of points)')
650
651        axis.set_xticks([])
652        for side in ('right', 'top', 'bottom'):
653            axis.spines[side].set_visible(False)
654        axis.set_ylabel('distance')
655
656        return axis
657
658    def to_numpy(self):
659        """Return a numpy array representation of the single linkage tree.
660
661        This representation conforms to the scipy.cluster.hierarchy notion
662        of a single linkage tree, and can be used with all the associated
663        scipy tools. Please see the scipy documentation for more details
664        on the format.
665        """
666        return self._linkage.copy()
667
668
669    def to_pandas(self):
670        """Return a pandas dataframe representation of the single linkage tree.
671
672        Each row of the dataframe corresponds to an edge in the tree.
673        The columns of the dataframe are `parent`, `left_child`,
674        `right_child`, `distance` and `size`.
675
676        The `parent`, `left_child` and `right_child` are the ids of the
677        parent and child nodes in the tree. Node ids less than the number
678        of points in the original dataset represent individual points, while
679        ids greater than the number of points are clusters.
680
681        The `distance` value is the at which the child nodes merge to form
682        the parent node.
683
684        The `size` is the number of points in the `parent` node.
685        """
686        try:
687            from pandas import DataFrame, Series
688        except ImportError:
689            raise ImportError('You must have pandas installed to export pandas DataFrames')
690
691        max_node = 2 * self._linkage.shape[0]
692        num_points = max_node - (self._linkage.shape[0] - 1)
693
694        parent_array = np.arange(num_points, max_node + 1)
695
696        result = DataFrame({
697            'parent': parent_array,
698            'left_child': self._linkage.T[0],
699            'right_child': self._linkage.T[1],
700            'distance': self._linkage.T[2],
701            'size': self._linkage.T[3]
702        })[['parent', 'left_child', 'right_child', 'distance', 'size']]
703
704        return result
705
706    def to_networkx(self):
707        """Return a NetworkX DiGraph object representing the single linkage tree.
708
709        Edge weights in the graph are the distance values at which child nodes
710        merge to form the parent cluster.
711
712        Nodes have a `size` attribute attached giving the number of points
713        that are in the cluster.
714        """
715        try:
716            from networkx import DiGraph, set_node_attributes
717        except ImportError:
718            raise ImportError('You must have networkx installed to export networkx graphs')
719
720        max_node = 2 * self._linkage.shape[0]
721        num_points = max_node - (self._linkage.shape[0] - 1)
722
723        result = DiGraph()
724        for parent, row in enumerate(self._linkage, num_points):
725            result.add_edge(parent, row[0], weight=row[2])
726            result.add_edge(parent, row[1], weight=row[2])
727
728        size_dict = {parent: row[3] for parent, row in enumerate(self._linkage, num_points)}
729        set_node_attributes(result, size_dict, 'size')
730
731        return result
732
733    def get_clusters(self, cut_distance, min_cluster_size=5):
734        """Return a flat clustering from the single linkage hierarchy.
735
736        This represents the result of selecting a cut value for robust single linkage
737        clustering. The `min_cluster_size` allows the flat clustering to declare noise
738        points (and cluster smaller than `min_cluster_size`).
739
740        Parameters
741        ----------
742
743        cut_distance : float
744            The mutual reachability distance cut value to use to generate a flat clustering.
745
746        min_cluster_size : int, optional
747            Clusters smaller than this value with be called 'noise' and remain unclustered
748            in the resulting flat clustering.
749
750        Returns
751        -------
752
753        labels : array [n_samples]
754            An array of cluster labels, one per datapoint. Unclustered points are assigned
755            the label -1.
756        """
757        return labelling_at_cut(self._linkage, cut_distance, min_cluster_size)
758
759
760class MinimumSpanningTree(object):
761    def __init__(self, mst, data):
762        self._mst = mst
763        self._data = data
764
765    def plot(self, axis=None, node_size=40, node_color='k',
766             node_alpha=0.8, edge_alpha=0.5, edge_cmap='viridis_r',
767             edge_linewidth=2, vary_line_width=True, colorbar=True):
768        """Plot the minimum spanning tree (as projected into 2D by t-SNE if required).
769
770        Parameters
771        ----------
772
773        axis : matplotlib axis, optional
774               The axis to render the plot to
775
776        node_size : int, optional
777                The size of nodes in the plot (default 40).
778
779        node_color : matplotlib color spec, optional
780                The color to render nodes (default black).
781
782        node_alpha : float, optional
783                The alpha value (between 0 and 1) to render nodes with
784                (default 0.8).
785
786        edge_cmap : matplotlib colormap, optional
787                The colormap to color edges by (varying color by edge
788                    weight/distance). Can be a cmap object or a string
789                    recognised by matplotlib. (default `viridis_r`)
790
791        edge_alpha : float, optional
792                The alpha value (between 0 and 1) to render edges with
793                (default 0.5).
794
795        edge_linewidth : float, optional
796                The linewidth to use for rendering edges (default 2).
797
798        vary_line_width : bool, optional
799                Edge width is proportional to (log of) the inverse of the
800                mutual reachability distance. (default True)
801
802        colorbar : bool, optional
803                Whether to draw a colorbar. (default True)
804
805        Returns
806        -------
807
808        axis : matplotlib axis
809                The axis used the render the plot.
810        """
811        try:
812            import matplotlib.pyplot as plt
813            from matplotlib.collections import LineCollection
814        except ImportError:
815            raise ImportError('You must install the matplotlib library to plot the minimum spanning tree.')
816
817        if self._data.shape[0] > 32767:
818            warn('Too many data points for safe rendering of an minimal spanning tree!')
819            return None
820
821        if axis is None:
822            axis = plt.gca()
823
824        if self._data.shape[1] > 2:
825            # Get a 2D projection; if we have a lot of dimensions use PCA first
826            if self._data.shape[1] > 32:
827                # Use PCA to get down to 32 dimension
828                data_for_projection = PCA(n_components=32).fit_transform(self._data)
829            else:
830                data_for_projection = self._data.copy()
831
832            projection = TSNE().fit_transform(data_for_projection)
833        else:
834            projection = self._data.copy()
835
836        if vary_line_width:
837            line_width = edge_linewidth * (np.log(self._mst.T[2].max() / self._mst.T[2]) + 1.0)
838        else:
839            line_width = edge_linewidth
840
841        line_coords = projection[self._mst[:, :2].astype(int)]
842        line_collection = LineCollection(line_coords, linewidth=line_width,
843                                         cmap=edge_cmap, alpha=edge_alpha)
844        line_collection.set_array(self._mst[:, 2].T)
845
846        axis.add_artist(line_collection)
847        axis.scatter(projection.T[0], projection.T[1], c=node_color, alpha=node_alpha, s=node_size)
848        axis.set_xticks([])
849        axis.set_yticks([])
850
851        if colorbar:
852            cb = plt.colorbar(line_collection, ax=axis)
853            cb.ax.set_ylabel('Mutual reachability distance')
854
855        return axis
856
857    def to_numpy(self):
858        """Return a numpy array of weighted edges in the minimum spanning tree
859        """
860        return self._mst.copy()
861
862    def to_pandas(self):
863        """Return a Pandas dataframe of the minimum spanning tree.
864
865        Each row is an edge in the tree; the columns are `from`,
866        `to`, and `distance` giving the two vertices of the edge
867        which are indices into the dataset, and the distance
868        between those datapoints.
869        """
870        try:
871            from pandas import DataFrame
872        except ImportError:
873            raise ImportError('You must have pandas installed to export pandas DataFrames')
874
875        result = DataFrame({'from': self._mst.T[0].astype(int),
876                            'to': self._mst.T[1].astype(int),
877                            'distance': self._mst.T[2]})
878        return result
879
880    def to_networkx(self):
881        """Return a NetworkX Graph object representing the minimum spanning tree.
882
883        Edge weights in the graph are the distance between the nodes they connect.
884
885        Nodes have a `data` attribute attached giving the data vector of the
886        associated point.
887        """
888        try:
889            from networkx import Graph, set_node_attributes
890        except ImportError:
891            raise ImportError('You must have networkx installed to export networkx graphs')
892
893        result = Graph()
894        for row in self._mst:
895            result.add_edge(row[0], row[1], weight=row[2])
896
897        data_dict = {index: tuple(row) for index, row in enumerate(self._data)}
898        set_node_attributes(result, data_dict, 'data')
899
900        return result
901