1import warnings
2
3from collections import namedtuple, deque, defaultdict
4from operator import attrgetter
5from itertools import count
6
7import heapq
8import numpy
9
10import scipy.cluster.hierarchy
11import scipy.spatial.distance
12
13from Orange.distance import Euclidean, PearsonR
14
15__all__ = ['HierarchicalClustering']
16
17_undef = object()  # 'no value' sentinel
18
19SINGLE = "single"
20AVERAGE = "average"
21COMPLETE = "complete"
22WEIGHTED = "weighted"
23WARD = "ward"
24
25
26def condensedform(X, mode="upper"):
27    X = numpy.asarray(X)
28    assert len(X.shape) == 2
29    assert X.shape[0] == X.shape[1]
30
31    N = X.shape[0]
32
33    if mode == "upper":
34        i, j = numpy.triu_indices(N, k=1)
35    elif mode == "lower":
36        i, j = numpy.tril_indices(N, k=-1)
37    else:
38        raise ValueError("invalid mode")
39    return X[i, j]
40
41
42def squareform(X, mode="upper"):
43    X = numpy.asarray(X)
44    k = X.shape[0]
45    N = int(numpy.ceil(numpy.sqrt(k * 2)))
46    assert N * (N - 1) // 2 == k
47    matrix = numpy.zeros((N, N), dtype=X.dtype)
48    if mode == "upper":
49        i, j = numpy.triu_indices(N, k=1)
50        matrix[i, j] = X
51        m, n = numpy.tril_indices(N, k=-1)
52        matrix[m, n] = matrix.T[m, n]
53    elif mode == "lower":
54        i, j = numpy.tril_indices(N, k=-1)
55        matrix[i, j] = X
56        m, n = numpy.triu_indices(N, k=1)
57        matrix[m, n] = matrix.T[m, n]
58    return matrix
59
60
61def data_clustering(data, distance=Euclidean,
62                    linkage=AVERAGE):
63    """
64    Return the hierarchical clustering of the dataset's rows.
65
66    :param Orange.data.Table data: Dataset to cluster.
67    :param Orange.distance.Distance distance: A distance measure.
68    :param str linkage:
69    """
70    matrix = distance(data)
71    return dist_matrix_clustering(matrix, linkage=linkage)
72
73
74def feature_clustering(data, distance=PearsonR,
75                       linkage=AVERAGE):
76    """
77    Return the hierarchical clustering of the dataset's columns.
78
79    :param Orange.data.Table data: Dataset to cluster.
80    :param Orange.distance.Distance distance: A distance measure.
81    :param str linkage:
82    """
83    matrix = distance(data, axis=0)
84    return dist_matrix_clustering(matrix, linkage=linkage)
85
86
87def dist_matrix_linkage(matrix, linkage=AVERAGE):
88    """
89    Return linkage using a precomputed distance matrix.
90
91    :param Orange.misc.DistMatrix matrix:
92    :param str linkage:
93    """
94    # Extract compressed upper triangular distance matrix.
95    distances = condensedform(matrix)
96    return scipy.cluster.hierarchy.linkage(distances, method=linkage)
97
98
99def dist_matrix_clustering(matrix, linkage=AVERAGE):
100    """
101    Return the hierarchical clustering using a precomputed distance matrix.
102
103    :param Orange.misc.DistMatrix matrix:
104    :param str linkage:
105    """
106    Z = dist_matrix_linkage(matrix, linkage=linkage)
107    return tree_from_linkage(Z)
108
109
110def sample_clustering(X, linkage=AVERAGE, metric="euclidean"):
111    assert len(X.shape) == 2
112    Z = scipy.cluster.hierarchy.linkage(X, method=linkage, metric=metric)
113    return tree_from_linkage(Z)
114
115
116class Tree(object):
117    __slots__ = ("__value", "__branches", "__hash")
118
119    def __init__(self, value, branches=()):
120        if not isinstance(branches, tuple):
121            raise TypeError()
122        self.__value = value
123        self.__branches = branches
124        # preemptively cache the hash value
125        self.__hash = hash((value, branches))
126
127    def __hash__(self):
128        return self.__hash
129
130    def __eq__(self, other):
131        return isinstance(other, Tree) and tuple(self) == tuple(other)
132
133    def __lt__(self, other):
134        if not isinstance(other, Tree):
135            return NotImplemented
136        return tuple(self) < tuple(other)
137
138    def __le__(self, other):
139        if not isinstance(other, Tree):
140            return NotImplemented
141        return tuple(self) <= tuple(other)
142
143    def __getnewargs__(self):
144        return tuple(self)
145
146    def __iter__(self):
147        return iter((self.__value, self.__branches))
148
149    def __repr__(self):
150        return ("{0.__name__}(value={1!r}, branches={2!r})"
151                .format(type(self), self.value, self.branches))
152
153    @property
154    def is_leaf(self):
155        return not bool(self.branches)
156
157    @property
158    def left(self):
159        return self.branches[0]
160
161    @property
162    def right(self):
163        return self.branches[-1]
164
165    value = property(attrgetter("_Tree__value"))
166    branches = property(attrgetter("_Tree__branches"))
167
168
169ClusterData = namedtuple("Cluster", ["range", "height"])
170SingletonData = namedtuple("Singleton", ["range", "height", "index"])
171
172
173class _Ranged:
174
175    @property
176    def first(self):
177        return self.range[0]
178
179    @property
180    def last(self):
181        return self.range[-1]
182
183
184class ClusterData(ClusterData, _Ranged):
185    __slots__ = ()
186
187
188class SingletonData(SingletonData, _Ranged):
189    __slots__ = ()
190
191
192def tree_from_linkage(linkage):
193    """
194    Return a Tree representation of a clustering encoded in a linkage matrix.
195
196    .. seealso:: scipy.cluster.hierarchy.linkage
197
198    """
199    scipy.cluster.hierarchy.is_valid_linkage(
200        linkage, throw=True, name="linkage")
201    T = {}
202    N, _ = linkage.shape
203    N = N + 1
204    order = []
205    for i, (c1, c2, d, _) in enumerate(linkage):
206        if c1 < N:
207            left = Tree(SingletonData(range=(len(order), len(order) + 1),
208                                      height=0.0, index=int(c1)),
209                        ())
210            order.append(c1)
211        else:
212            left = T[c1]
213
214        if c2 < N:
215            right = Tree(SingletonData(range=(len(order), len(order) + 1),
216                                       height=0.0, index=int(c2)),
217                         ())
218            order.append(c2)
219        else:
220            right = T[c2]
221
222        t = Tree(ClusterData(range=(left.value.first, right.value.last),
223                             height=d),
224                 (left, right))
225        T[N + i] = t
226
227    root = T[N + N - 2]
228    T = {}
229
230    leaf_idx = 0
231    for node in postorder(root):
232        if node.is_leaf:
233            T[node] = Tree(
234                node.value._replace(range=(leaf_idx, leaf_idx + 1)), ())
235            leaf_idx += 1
236        else:
237            left, right = T[node.left].value, T[node.right].value
238            assert left.first < right.first
239
240            t = Tree(
241                node.value._replace(range=(left.range[0], right.range[1])),
242                tuple(T[ch] for ch in node.branches)
243            )
244            assert t.value.range[0] <= t.value.range[-1]
245            assert left.first == t.value.first and right.last == t.value.last
246            assert t.value.first < right.first
247            assert t.value.last > left.last
248            T[node] = t
249
250    return T[root]
251
252
253def linkage_from_tree(tree: Tree) -> numpy.ndarray:
254    leafs = [n for n in preorder(tree) if n.is_leaf]
255
256    Z = numpy.zeros((len(leafs) - 1, 4), float)
257    i = 0
258    node_to_i = defaultdict(count(len(leafs)).__next__)
259    for node in postorder(tree):
260        if node.is_leaf:
261            node_to_i[node] = node.value.index
262        else:
263            assert len(node.branches) == 2
264            assert node.left in node_to_i
265            assert node.right in node_to_i
266            Z[i] = [node_to_i[node.left], node_to_i[node.right],
267                    node.value.height, 0]
268            _ni = node_to_i[node]
269            assert _ni == Z.shape[0] + i + 1
270            i += 1
271    assert i == Z.shape[0]
272    return Z
273
274
275def postorder(tree, branches=attrgetter("branches")):
276    stack = deque([tree])
277    visited = set()
278
279    while stack:
280        current = stack.popleft()
281        children = branches(current)
282        if children:
283            # yield the item on the way up
284            if current in visited:
285                yield current
286            else:
287                # stack = children + [current] + stack
288                stack.extendleft([current])
289                stack.extendleft(reversed(children))
290                visited.add(current)
291
292        else:
293            yield current
294            visited.add(current)
295
296
297def preorder(tree, branches=attrgetter("branches")):
298    stack = deque([tree])
299    while stack:
300        current = stack.popleft()
301        yield current
302        children = branches(current)
303        if children:
304            stack.extendleft(reversed(children))
305
306
307def leaves(tree, branches=attrgetter("branches")):
308    """
309    Return an iterator over the leaf nodes in a tree structure.
310    """
311    return (node for node in postorder(tree, branches)
312            if node.is_leaf)
313
314
315def prune(cluster, level=None, height=None, condition=None):
316    """
317    Prune the clustering instance ``cluster``.
318
319    :param Tree cluster: Cluster root node to prune.
320    :param int level: If not `None` prune all clusters deeper then `level`.
321    :param float height:
322        If not `None` prune all clusters with height lower then `height`.
323    :param function condition:
324        If not `None condition must be a `Tree -> bool` function
325        evaluating to `True` if the cluster should be pruned.
326
327    .. note::
328        At least one `level`, `height` or `condition` argument needs to
329        be supplied.
330
331    """
332    if not any(arg is not None for arg in [level, height, condition]):
333        raise ValueError("At least one pruning argument must be supplied")
334
335    level_check = height_check = condition_check = lambda cl: False
336
337    if level is not None:
338        cluster_depth = cluster_depths(cluster)
339        level_check = lambda cl: cluster_depth[cl] >= level
340
341    if height is not None:
342        height_check = lambda cl: cl.value.height <= height
343
344    if condition is not None:
345        condition_check = condition
346
347    def check_all(cl):
348        return level_check(cl) or height_check(cl) or condition_check(cl)
349
350    T = {}
351
352    for node in postorder(cluster):
353        if check_all(node):
354            if node.is_leaf:
355                T[node] = node
356            else:
357                T[node] = Tree(node.value, ())
358        else:
359            T[node] = Tree(node.value,
360                           tuple(T[ch] for ch in node.branches))
361    return T[cluster]
362
363
364def cluster_depths(cluster):
365    """
366    Return a dictionary mapping :class:`Tree` instances to their depth.
367
368    :param Tree cluster: Root cluster
369    :rtype: class:`dict`
370
371    """
372    depths = {}
373    depths[cluster] = 0
374    for cluster in preorder(cluster):
375        cl_depth = depths[cluster]
376        depths.update(dict.fromkeys(cluster.branches, cl_depth + 1))
377    return depths
378
379
380def top_clusters(tree, k):
381    """
382    Return `k` topmost clusters from hierarchical clustering.
383
384    :param Tree root: Root cluster.
385    :param int k: Number of top clusters.
386
387    :rtype: list of :class:`Tree` instances
388    """
389    def item(node):
390        return ((node.is_leaf, -node.value.height), node)
391
392    heap = [item(tree)]
393
394    while len(heap) < k:
395        _, cl = heap[0]  # peek
396        if cl.is_leaf:
397            assert all(n.is_leaf for _, n in heap)
398            break
399        key, cl = heapq.heappop(heap)
400        left, right = cl.left, cl.right
401        heapq.heappush(heap, item(left))
402        heapq.heappush(heap, item(right))
403
404    return [n for _, n in heap]
405
406
407def optimal_leaf_ordering(
408        tree: Tree, distances: numpy.ndarray, progress_callback=_undef
409) -> Tree:
410    """
411    Order the leaves in the clustering tree.
412
413    :param Tree tree:
414        Binary hierarchical clustering tree.
415    :param numpy.ndarray distances:
416        A (N, N) numpy.ndarray of distances that were used to compute
417        the clustering.
418
419    .. seealso:: scipy.cluster.hierarchy.optimal_leaf_ordering
420    """
421    if progress_callback is not _undef:
422        warnings.warn(
423            "'progress_callback' parameter is deprecated and ignored. "
424            "Passing it will raise an error in the future.",
425            FutureWarning, stacklevel=2
426        )
427    Z = linkage_from_tree(tree)
428    y = condensedform(numpy.asarray(distances))
429    Zopt = scipy.cluster.hierarchy.optimal_leaf_ordering(Z, y)
430    return tree_from_linkage(Zopt)
431
432
433class HierarchicalClustering:
434    def __init__(self, n_clusters=2, linkage=AVERAGE):
435        self.n_clusters = n_clusters
436        self.linkage = linkage
437
438    def fit(self, X):
439        self.tree = dist_matrix_clustering(X, linkage=self.linkage)
440        cut = top_clusters(self.tree, self.n_clusters)
441        labels = numpy.zeros(self.tree.value.last)
442
443        for i, cl in enumerate(cut):
444            indices = [leaf.value.index for leaf in leaves(cl)]
445            labels[indices] = i
446
447        self.labels = labels
448
449    def fit_predict(self, X, y=None):
450        self.fit(X)
451        return self.labels
452