1# cython: boundscheck=False
2# cython: nonecheck=False
3# cython: initializedcheck=False
4# Tree handling (condensing, finding stable clusters) for hdbscan
5# Authors: Leland McInnes
6# License: 3-clause BSD
7
8import numpy as np
9cimport numpy as np
10
11cdef np.double_t INFTY = np.inf
12
13
14cdef list bfs_from_hierarchy(np.ndarray[np.double_t, ndim=2] hierarchy,
15                             np.intp_t bfs_root):
16    """
17    Perform a breadth first search on a tree in scipy hclust format.
18    """
19
20    cdef list to_process
21    cdef np.intp_t max_node
22    cdef np.intp_t num_points
23    cdef np.intp_t dim
24
25    dim = hierarchy.shape[0]
26    max_node = 2 * dim
27    num_points = max_node - dim + 1
28
29    to_process = [bfs_root]
30    result = []
31
32    while to_process:
33        result.extend(to_process)
34        to_process = [x - num_points for x in
35                      to_process if x >= num_points]
36        if to_process:
37            to_process = hierarchy[to_process,
38                                   :2].flatten().astype(np.intp).tolist()
39
40    return result
41
42
43cpdef np.ndarray condense_tree(np.ndarray[np.double_t, ndim=2] hierarchy,
44                               np.intp_t min_cluster_size=10):
45    """Condense a tree according to a minimum cluster size. This is akin
46    to the runt pruning procedure of Stuetzle. The result is a much simpler
47    tree that is easier to visualize. We include extra information on the
48    lambda value at which individual points depart clusters for later
49    analysis and computation.
50
51    Parameters
52    ----------
53    hierarchy : ndarray (n_samples, 4)
54        A single linkage hierarchy in scipy.cluster.hierarchy format.
55
56    min_cluster_size : int, optional (default 10)
57        The minimum size of clusters to consider. Smaller "runt"
58        clusters are pruned from the tree.
59
60    Returns
61    -------
62    condensed_tree : numpy recarray
63        Effectively an edgelist with a parent, child, lambda_val
64        and child_size in each row providing a tree structure.
65    """
66
67    cdef np.intp_t root
68    cdef np.intp_t num_points
69    cdef np.intp_t next_label
70    cdef list node_list
71    cdef list result_list
72
73    cdef np.ndarray[np.intp_t, ndim=1] relabel
74    cdef np.ndarray[np.int_t, ndim=1] ignore
75    cdef np.ndarray[np.double_t, ndim=1] children
76
77    cdef np.intp_t node
78    cdef np.intp_t sub_node
79    cdef np.intp_t left
80    cdef np.intp_t right
81    cdef double lambda_value
82    cdef np.intp_t left_count
83    cdef np.intp_t right_count
84
85    root = 2 * hierarchy.shape[0]
86    num_points = root // 2 + 1
87    next_label = num_points + 1
88
89    node_list = bfs_from_hierarchy(hierarchy, root)
90
91    relabel = np.empty(root + 1, dtype=np.intp)
92    relabel[root] = num_points
93    result_list = []
94    ignore = np.zeros(len(node_list), dtype=np.int)
95
96    for node in node_list:
97        if ignore[node] or node < num_points:
98            continue
99
100        children = hierarchy[node - num_points]
101        left = <np.intp_t> children[0]
102        right = <np.intp_t> children[1]
103        if children[2] > 0.0:
104            lambda_value = 1.0 / children[2]
105        else:
106            lambda_value = INFTY
107
108        if left >= num_points:
109            left_count = <np.intp_t> hierarchy[left - num_points][3]
110        else:
111            left_count = 1
112
113        if right >= num_points:
114            right_count = <np.intp_t> hierarchy[right - num_points][3]
115        else:
116            right_count = 1
117
118        if left_count >= min_cluster_size and right_count >= min_cluster_size:
119            relabel[left] = next_label
120            next_label += 1
121            result_list.append((relabel[node], relabel[left], lambda_value,
122                                left_count))
123
124            relabel[right] = next_label
125            next_label += 1
126            result_list.append((relabel[node], relabel[right], lambda_value,
127                                right_count))
128
129        elif left_count < min_cluster_size and right_count < min_cluster_size:
130            for sub_node in bfs_from_hierarchy(hierarchy, left):
131                if sub_node < num_points:
132                    result_list.append((relabel[node], sub_node,
133                                        lambda_value, 1))
134                ignore[sub_node] = True
135
136            for sub_node in bfs_from_hierarchy(hierarchy, right):
137                if sub_node < num_points:
138                    result_list.append((relabel[node], sub_node,
139                                        lambda_value, 1))
140                ignore[sub_node] = True
141
142        elif left_count < min_cluster_size:
143            relabel[right] = relabel[node]
144            for sub_node in bfs_from_hierarchy(hierarchy, left):
145                if sub_node < num_points:
146                    result_list.append((relabel[node], sub_node,
147                                        lambda_value, 1))
148                ignore[sub_node] = True
149
150        else:
151            relabel[left] = relabel[node]
152            for sub_node in bfs_from_hierarchy(hierarchy, right):
153                if sub_node < num_points:
154                    result_list.append((relabel[node], sub_node,
155                                        lambda_value, 1))
156                ignore[sub_node] = True
157
158    return np.array(result_list, dtype=[('parent', np.intp),
159                                        ('child', np.intp),
160                                        ('lambda_val', float),
161                                        ('child_size', np.intp)])
162
163
164cpdef dict compute_stability(np.ndarray condensed_tree):
165
166    cdef np.ndarray[np.double_t, ndim=1] result_arr
167    cdef np.ndarray sorted_child_data
168    cdef np.ndarray[np.intp_t, ndim=1] sorted_children
169    cdef np.ndarray[np.double_t, ndim=1] sorted_lambdas
170
171    cdef np.ndarray[np.intp_t, ndim=1] parents
172    cdef np.ndarray[np.intp_t, ndim=1] sizes
173    cdef np.ndarray[np.double_t, ndim=1] lambdas
174
175    cdef np.intp_t child
176    cdef np.intp_t parent
177    cdef np.intp_t child_size
178    cdef np.intp_t result_index
179    cdef np.intp_t current_child
180    cdef np.float64_t lambda_
181    cdef np.float64_t min_lambda
182
183    cdef np.ndarray[np.double_t, ndim=1] births_arr
184    cdef np.double_t *births
185
186    cdef np.intp_t largest_child = condensed_tree['child'].max()
187    cdef np.intp_t smallest_cluster = condensed_tree['parent'].min()
188    cdef np.intp_t num_clusters = (condensed_tree['parent'].max() -
189                                   smallest_cluster + 1)
190
191    if largest_child < smallest_cluster:
192        largest_child = smallest_cluster
193
194    sorted_child_data = np.sort(condensed_tree[['child', 'lambda_val']],
195                                axis=0)
196    births_arr = np.nan * np.ones(largest_child + 1, dtype=np.double)
197    births = (<np.double_t *> births_arr.data)
198    sorted_children = sorted_child_data['child'].copy()
199    sorted_lambdas = sorted_child_data['lambda_val'].copy()
200
201    parents = condensed_tree['parent']
202    sizes = condensed_tree['child_size']
203    lambdas = condensed_tree['lambda_val']
204
205    current_child = -1
206    min_lambda = 0
207
208    for row in range(sorted_child_data.shape[0]):
209        child = <np.intp_t> sorted_children[row]
210        lambda_ = sorted_lambdas[row]
211
212        if child == current_child:
213            min_lambda = min(min_lambda, lambda_)
214        elif current_child != -1:
215            births[current_child] = min_lambda
216            current_child = child
217            min_lambda = lambda_
218        else:
219            # Initialize
220            current_child = child
221            min_lambda = lambda_
222
223    if current_child != -1:
224        births[current_child] = min_lambda
225    births[smallest_cluster] = 0.0
226
227    result_arr = np.zeros(num_clusters, dtype=np.double)
228
229    for i in range(condensed_tree.shape[0]):
230        parent = parents[i]
231        lambda_ = lambdas[i]
232        child_size = sizes[i]
233        result_index = parent - smallest_cluster
234
235        result_arr[result_index] += (lambda_ - births[parent]) * child_size
236
237    result_pre_dict = np.vstack((np.arange(smallest_cluster,
238                                           condensed_tree['parent'].max() + 1),
239                                 result_arr)).T
240
241    return dict(result_pre_dict)
242
243
244cdef list bfs_from_cluster_tree(np.ndarray tree, np.intp_t bfs_root):
245
246    cdef list result
247    cdef np.ndarray[np.intp_t, ndim=1] to_process
248
249    result = []
250    to_process = np.array([bfs_root], dtype=np.intp)
251
252    while to_process.shape[0] > 0:
253        result.extend(to_process.tolist())
254        to_process = tree['child'][np.in1d(tree['parent'], to_process)]
255
256    return result
257
258
259cdef max_lambdas(np.ndarray tree):
260
261    cdef np.ndarray sorted_parent_data
262    cdef np.ndarray[np.intp_t, ndim=1] sorted_parents
263    cdef np.ndarray[np.double_t, ndim=1] sorted_lambdas
264
265    cdef np.intp_t parent
266    cdef np.intp_t current_parent
267    cdef np.float64_t lambda_
268    cdef np.float64_t max_lambda
269
270    cdef np.ndarray[np.double_t, ndim=1] deaths_arr
271    cdef np.double_t *deaths
272
273    cdef np.intp_t largest_parent = tree['parent'].max()
274
275    sorted_parent_data = np.sort(tree[['parent', 'lambda_val']], axis=0)
276    deaths_arr = np.zeros(largest_parent + 1, dtype=np.double)
277    deaths = (<np.double_t *> deaths_arr.data)
278    sorted_parents = sorted_parent_data['parent']
279    sorted_lambdas = sorted_parent_data['lambda_val']
280
281    current_parent = -1
282    max_lambda = 0
283
284    for row in range(sorted_parent_data.shape[0]):
285        parent = <np.intp_t> sorted_parents[row]
286        lambda_ = sorted_lambdas[row]
287
288        if parent == current_parent:
289            max_lambda = max(max_lambda, lambda_)
290        elif current_parent != -1:
291            deaths[current_parent] = max_lambda
292            current_parent = parent
293            max_lambda = lambda_
294        else:
295            # Initialize
296            current_parent = parent
297            max_lambda = lambda_
298
299    deaths[current_parent] = max_lambda # value for last parent
300
301    return deaths_arr
302
303
304cdef class TreeUnionFind (object):
305
306    cdef np.ndarray _data_arr
307    cdef np.intp_t[:, ::1] _data
308    cdef np.ndarray is_component
309
310    def __init__(self, size):
311        self._data_arr = np.zeros((size, 2), dtype=np.intp)
312        self._data_arr.T[0] = np.arange(size)
313        self._data = (<np.intp_t[:size, :2:1]> (
314            <np.intp_t *> self._data_arr.data))
315        self.is_component = np.ones(size, dtype=np.bool)
316
317    cdef union_(self, np.intp_t x, np.intp_t y):
318        cdef np.intp_t x_root = self.find(x)
319        cdef np.intp_t y_root = self.find(y)
320
321        if self._data[x_root, 1] < self._data[y_root, 1]:
322            self._data[x_root, 0] = y_root
323        elif self._data[x_root, 1] > self._data[y_root, 1]:
324            self._data[y_root, 0] = x_root
325        else:
326            self._data[y_root, 0] = x_root
327            self._data[x_root, 1] += 1
328
329        return
330
331    cdef find(self, np.intp_t x):
332        if self._data[x, 0] != x:
333            self._data[x, 0] = self.find(self._data[x, 0])
334            self.is_component[x] = False
335        return self._data[x, 0]
336
337    cdef np.ndarray[np.intp_t, ndim=1] components(self):
338        return self.is_component.nonzero()[0]
339
340
341cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut(
342        np.ndarray linkage,
343        np.double_t cut,
344        np.intp_t min_cluster_size):
345    """Given a single linkage tree and a cut value, return the
346    vector of cluster labels at that cut value. This is useful
347    for Robust Single Linkage, and extracting DBSCAN results
348    from a single HDBSCAN run.
349
350    Parameters
351    ----------
352    linkage : ndarray (n_samples, 4)
353        The single linkage tree in scipy.cluster.hierarchy format.
354
355    cut : double
356        The cut value at which to find clusters.
357
358    min_cluster_size : int
359        The minimum cluster size; clusters below this size at
360        the cut will be considered noise.
361
362    Returns
363    -------
364    labels : ndarray (n_samples,)
365        The cluster labels for each point in the data set;
366        a label of -1 denotes a noise assignment.
367    """
368
369    cdef np.intp_t root
370    cdef np.intp_t num_points
371    cdef np.ndarray[np.intp_t, ndim=1] result_arr
372    cdef np.ndarray[np.intp_t, ndim=1] unique_labels
373    cdef np.ndarray[np.intp_t, ndim=1] cluster_size
374    cdef np.intp_t *result
375    cdef TreeUnionFind union_find
376    cdef np.intp_t n
377    cdef np.intp_t cluster
378    cdef np.intp_t cluster_id
379
380    root = 2 * linkage.shape[0]
381    num_points = root // 2 + 1
382
383    result_arr = np.empty(num_points, dtype=np.intp)
384    result = (<np.intp_t *> result_arr.data)
385
386    union_find = TreeUnionFind(<np.intp_t> root + 1)
387
388    cluster = num_points
389    for row in linkage:
390        if row[2] < cut:
391            union_find.union_(<np.intp_t> row[0], cluster)
392            union_find.union_(<np.intp_t> row[1], cluster)
393        cluster += 1
394
395    cluster_size = np.zeros(cluster, dtype=np.intp)
396    for n in range(num_points):
397        cluster = union_find.find(n)
398        cluster_size[cluster] += 1
399        result[n] = cluster
400
401    cluster_label_map = {-1: -1}
402    cluster_label = 0
403    unique_labels = np.unique(result_arr)
404
405    for cluster in unique_labels:
406        if cluster_size[cluster] < min_cluster_size:
407            cluster_label_map[cluster] = -1
408        else:
409            cluster_label_map[cluster] = cluster_label
410            cluster_label += 1
411
412    for n in range(num_points):
413        result[n] = cluster_label_map[result[n]]
414
415    return result_arr
416
417
418cdef np.ndarray[np.intp_t, ndim=1] do_labelling(
419        np.ndarray tree,
420        set clusters,
421        dict cluster_label_map,
422        np.intp_t allow_single_cluster,
423        np.double_t cluster_selection_epsilon,
424        np.intp_t match_reference_implementation):
425
426    cdef np.intp_t root_cluster
427    cdef np.ndarray[np.intp_t, ndim=1] result_arr
428    cdef np.ndarray[np.intp_t, ndim=1] parent_array
429    cdef np.ndarray[np.intp_t, ndim=1] child_array
430    cdef np.ndarray[np.double_t, ndim=1] lambda_array
431    cdef np.intp_t *result
432    cdef TreeUnionFind union_find
433    cdef np.intp_t parent
434    cdef np.intp_t child
435    cdef np.intp_t n
436    cdef np.intp_t cluster
437
438    child_array = tree['child']
439    parent_array = tree['parent']
440    lambda_array = tree['lambda_val']
441
442    root_cluster = parent_array.min()
443    result_arr = np.empty(root_cluster, dtype=np.intp)
444    result = (<np.intp_t *> result_arr.data)
445
446    union_find = TreeUnionFind(parent_array.max() + 1)
447
448    for n in range(tree.shape[0]):
449        child = child_array[n]
450        parent = parent_array[n]
451        if child not in clusters:
452            union_find.union_(parent, child)
453
454    for n in range(root_cluster):
455        cluster = union_find.find(n)
456        if cluster < root_cluster:
457            result[n] = -1
458        elif cluster == root_cluster:
459            if len(clusters) == 1 and allow_single_cluster:
460                if cluster_selection_epsilon != 0.0:
461                    if tree['lambda_val'][tree['child'] == n] >= 1 / cluster_selection_epsilon :
462                        result[n] = cluster_label_map[cluster]
463                    else:
464                        result[n] = -1
465                elif tree['lambda_val'][tree['child'] == n] >= \
466                     tree['lambda_val'][tree['parent'] == cluster].max():
467                    result[n] = cluster_label_map[cluster]
468                else:
469                    result[n] = -1
470            else:
471                result[n] = -1
472        else:
473            if match_reference_implementation:
474                point_lambda = lambda_array[child_array == n][0]
475                cluster_lambda = lambda_array[child_array == cluster][0]
476                if point_lambda > cluster_lambda:
477                    result[n] = cluster_label_map[cluster]
478                else:
479                    result[n] = -1
480            else:
481                result[n] = cluster_label_map[cluster]
482
483    return result_arr
484
485
486cdef get_probabilities(np.ndarray tree, dict cluster_map, np.ndarray labels):
487
488    cdef np.ndarray[np.double_t, ndim=1] result
489    cdef np.ndarray[np.double_t, ndim=1] deaths
490    cdef np.ndarray[np.double_t, ndim=1] lambda_array
491    cdef np.ndarray[np.intp_t, ndim=1] child_array
492    cdef np.ndarray[np.intp_t, ndim=1] parent_array
493    cdef np.intp_t root_cluster
494    cdef np.intp_t n
495    cdef np.intp_t point
496    cdef np.intp_t cluster_num
497    cdef np.intp_t cluster
498    cdef np.double_t max_lambda
499    cdef np.double_t lambda_
500
501    child_array = tree['child']
502    parent_array = tree['parent']
503    lambda_array = tree['lambda_val']
504
505    result = np.zeros(labels.shape[0])
506    deaths = max_lambdas(tree)
507    root_cluster = parent_array.min()
508
509    for n in range(tree.shape[0]):
510        point = child_array[n]
511        if point >= root_cluster:
512            continue
513
514        cluster_num = labels[point]
515
516        if cluster_num == -1:
517            continue
518
519        cluster = cluster_map[cluster_num]
520        max_lambda = deaths[cluster]
521        if max_lambda == 0.0 or not np.isfinite(lambda_array[n]):
522            result[point] = 1.0
523        else:
524            lambda_ = min(lambda_array[n], max_lambda)
525            result[point] = lambda_ / max_lambda
526
527    return result
528
529
530cpdef np.ndarray[np.double_t, ndim=1] outlier_scores(np.ndarray tree):
531    """Generate GLOSH outlier scores from a condensed tree.
532
533    Parameters
534    ----------
535    tree : numpy recarray
536        The condensed tree to generate GLOSH outlier scores from
537
538    Returns
539    -------
540    outlier_scores : ndarray (n_samples,)
541        Outlier scores for each sample point. The larger the score
542        the more outlying the point.
543    """
544
545    cdef np.ndarray[np.double_t, ndim=1] result
546    cdef np.ndarray[np.double_t, ndim=1] deaths
547    cdef np.ndarray[np.double_t, ndim=1] lambda_array
548    cdef np.ndarray[np.intp_t, ndim=1] child_array
549    cdef np.ndarray[np.intp_t, ndim=1] parent_array
550    cdef np.intp_t root_cluster
551    cdef np.intp_t point
552    cdef np.intp_t parent
553    cdef np.intp_t cluster
554    cdef np.double_t lambda_max
555
556    child_array = tree['child']
557    parent_array = tree['parent']
558    lambda_array = tree['lambda_val']
559
560    deaths = max_lambdas(tree)
561    root_cluster = parent_array.min()
562    result = np.zeros(root_cluster, dtype=np.double)
563
564    topological_sort_order = np.argsort(parent_array)
565    # topologically_sorted_tree = tree[topological_sort_order]
566
567    for n in topological_sort_order:
568        cluster = child_array[n]
569        if cluster < root_cluster:
570            break
571
572        parent = parent_array[n]
573        if deaths[cluster] > deaths[parent]:
574            deaths[parent] = deaths[cluster]
575
576    for n in range(tree.shape[0]):
577        point = child_array[n]
578        if point >= root_cluster:
579            continue
580
581        cluster = parent_array[n]
582        lambda_max = deaths[cluster]
583
584
585        if lambda_max == 0.0 or not np.isfinite(lambda_array[n]):
586            result[point] = 0.0
587        else:
588            result[point] = (lambda_max - lambda_array[n]) / lambda_max
589
590    return result
591
592
593cpdef np.ndarray get_stability_scores(np.ndarray labels, set clusters,
594                                      dict stability, np.double_t max_lambda):
595
596    cdef np.intp_t cluster_size
597    cdef np.intp_t n
598
599    result = np.empty(len(clusters), dtype=np.double)
600    for n, c in enumerate(sorted(list(clusters))):
601        cluster_size = np.sum(labels == n)
602        if np.isinf(max_lambda) or max_lambda == 0.0 or cluster_size == 0:
603            result[n] = 1.0
604        else:
605            result[n] = stability[c] / (cluster_size * max_lambda)
606
607    return result
608
609cpdef list recurse_leaf_dfs(np.ndarray cluster_tree, np.intp_t current_node):
610    children = cluster_tree[cluster_tree['parent'] == current_node]['child']
611    if len(children) == 0:
612        return [current_node,]
613    else:
614        return sum([recurse_leaf_dfs(cluster_tree, child) for child in children], [])
615
616
617cpdef list get_cluster_tree_leaves(np.ndarray cluster_tree):
618    if cluster_tree.shape[0] == 0:
619        return []
620    root = cluster_tree['parent'].min()
621    return recurse_leaf_dfs(cluster_tree, root)
622
623cpdef np.intp_t traverse_upwards(np.ndarray cluster_tree, np.double_t cluster_selection_epsilon, np.intp_t leaf, np.intp_t allow_single_cluster):
624
625    root = cluster_tree['parent'].min()
626    parent = cluster_tree[cluster_tree['child'] == leaf]['parent']
627    if parent == root:
628        if allow_single_cluster:
629            return parent
630        else:
631            return leaf #return node closest to root
632
633    parent_eps = 1/cluster_tree[cluster_tree['child'] == parent]['lambda_val']
634    if parent_eps > cluster_selection_epsilon:
635        return parent
636    else:
637        return traverse_upwards(cluster_tree, cluster_selection_epsilon, parent, allow_single_cluster)
638
639cpdef set epsilon_search(set leaves, np.ndarray cluster_tree, np.double_t cluster_selection_epsilon, np.intp_t allow_single_cluster):
640
641    selected_clusters = list()
642    processed = list()
643
644    for leaf in leaves:
645        eps = 1/cluster_tree['lambda_val'][cluster_tree['child'] == leaf][0]
646        if eps < cluster_selection_epsilon:
647            if leaf not in processed:
648                epsilon_child = traverse_upwards(cluster_tree, cluster_selection_epsilon, leaf, allow_single_cluster)
649                selected_clusters.append(epsilon_child)
650
651                for sub_node in bfs_from_cluster_tree(cluster_tree, epsilon_child):
652                    if sub_node != epsilon_child:
653                        processed.append(sub_node)
654        else:
655            selected_clusters.append(leaf)
656
657    return set(selected_clusters)
658
659cpdef tuple get_clusters(np.ndarray tree, dict stability,
660                         cluster_selection_method='eom',
661                         allow_single_cluster=False,
662                         match_reference_implementation=False,
663                         cluster_selection_epsilon=0.0):
664    """Given a tree and stability dict, produce the cluster labels
665    (and probabilities) for a flat clustering based on the chosen
666    cluster selection method.
667
668    Parameters
669    ----------
670    tree : numpy recarray
671        The condensed tree to extract flat clusters from
672
673    stability : dict
674        A dictionary mapping cluster_ids to stability values
675
676    cluster_selection_method : string, optional (default 'eom')
677        The method of selecting clusters. The default is the
678        Excess of Mass algorithm specified by 'eom'. The alternate
679        option is 'leaf'.
680
681    allow_single_cluster : boolean, optional (default False)
682        Whether to allow a single cluster to be selected by the
683        Excess of Mass algorithm.
684
685    match_reference_implementation : boolean, optional (default False)
686        Whether to match the reference implementation in how to handle
687        certain edge cases.
688
689    cluster_selection_epsilon: float, optional (default 0.0)
690        A distance threshold for cluster splits.
691
692    Returns
693    -------
694    labels : ndarray (n_samples,)
695        An integer array of cluster labels, with -1 denoting noise.
696
697    probabilities : ndarray (n_samples,)
698        The cluster membership strength of each sample.
699
700    stabilities : ndarray (n_clusters,)
701        The cluster coherence strengths of each cluster.
702    """
703    cdef list node_list
704    cdef np.ndarray cluster_tree
705    cdef np.ndarray child_selection
706    cdef dict is_cluster
707    cdef float subtree_stability
708    cdef np.intp_t node
709    cdef np.intp_t sub_node
710    cdef np.intp_t cluster
711    cdef np.intp_t num_points
712    cdef np.ndarray labels
713    cdef np.double_t max_lambda
714
715    # Assume clusters are ordered by numeric id equivalent to
716    # a topological sort of the tree; This is valid given the
717    # current implementation above, so don't change that ... or
718    # if you do, change this accordingly!
719    if allow_single_cluster:
720        node_list = sorted(stability.keys(), reverse=True)
721    else:
722        node_list = sorted(stability.keys(), reverse=True)[:-1]
723        # (exclude root)
724
725    cluster_tree = tree[tree['child_size'] > 1]
726    is_cluster = {cluster: True for cluster in node_list}
727    num_points = np.max(tree[tree['child_size'] == 1]['child']) + 1
728    max_lambda = np.max(tree['lambda_val'])
729
730    if cluster_selection_method == 'eom':
731        for node in node_list:
732            child_selection = (cluster_tree['parent'] == node)
733            subtree_stability = np.sum([
734                stability[child] for
735                child in cluster_tree['child'][child_selection]])
736            if subtree_stability > stability[node]:
737                is_cluster[node] = False
738                stability[node] = subtree_stability
739            else:
740                for sub_node in bfs_from_cluster_tree(cluster_tree, node):
741                    if sub_node != node:
742                        is_cluster[sub_node] = False
743
744        if cluster_selection_epsilon != 0.0 and cluster_tree.shape[0] > 0:
745            eom_clusters = [c for c in is_cluster if is_cluster[c]]
746            selected_clusters = []
747            # first check if eom_clusters only has root node, which skips epsilon check.
748            if (len(eom_clusters) == 1 and eom_clusters[0] == cluster_tree['parent'].min()):
749                if allow_single_cluster:
750                    selected_clusters = eom_clusters
751            else:
752                selected_clusters = epsilon_search(set(eom_clusters), cluster_tree, cluster_selection_epsilon, allow_single_cluster)
753            for c in is_cluster:
754                if c in selected_clusters:
755                    is_cluster[c] = True
756                else:
757                    is_cluster[c] = False
758
759
760    elif cluster_selection_method == 'leaf':
761        leaves = set(get_cluster_tree_leaves(cluster_tree))
762        if len(leaves) == 0:
763            for c in is_cluster:
764                is_cluster[c] = False
765            is_cluster[tree['parent'].min()] = True
766
767        if cluster_selection_epsilon != 0.0:
768            selected_clusters = epsilon_search(leaves, cluster_tree, cluster_selection_epsilon, allow_single_cluster)
769        else:
770            selected_clusters = leaves
771
772        for c in is_cluster:
773                if c in selected_clusters:
774                    is_cluster[c] = True
775                else:
776                    is_cluster[c] = False
777    else:
778        raise ValueError('Invalid Cluster Selection Method: %s\n'
779                         'Should be one of: "eom", "leaf"\n')
780
781    clusters = set([c for c in is_cluster if is_cluster[c]])
782    cluster_map = {c: n for n, c in enumerate(sorted(list(clusters)))}
783    reverse_cluster_map = {n: c for c, n in cluster_map.items()}
784
785    labels = do_labelling(tree, clusters, cluster_map,
786                          allow_single_cluster, cluster_selection_epsilon,
787                          match_reference_implementation)
788    probs = get_probabilities(tree, reverse_cluster_map, labels)
789    stabilities = get_stability_scores(labels, clusters, stability, max_lambda)
790
791    return (labels, probs, stabilities)
792