1# $Id$
2#
3# Copyright (C) 2001-2008  greg Landrum
4#
5#   @@ All Rights Reserved @@
6#  This file is part of the RDKit.
7#  The contents are covered by the terms of the BSD license
8#  which is included in the file license.txt, found at the root
9#  of the RDKit source tree.
10#
11""" contains the Cluster class for representing hierarchical cluster trees
12
13"""
14
15
16def cmp(t1, t2):
17    return (t1 < t2) * -1 or (t1 > t2) * 1
18
19
20CMPTOL = 1e-6
21
22
23class Cluster(object):
24    """a class for storing clusters/data
25
26      **General Remarks**
27
28       - It is assumed that the bottom of any cluster hierarchy tree is composed of
29         the individual data points which were clustered.
30
31       - Clusters objects store the following pieces of data, most are
32          accessible via standard Setters/Getters:
33
34         - Children: *Not Settable*, the list of children.  You can add children
35           with the _AddChild()_ and _AddChildren()_ methods.
36
37           **Note** this can be of arbitrary length,
38           but the current algorithms I have only produce trees with two children
39           per cluster
40
41         - Metric: the metric for this cluster (i.e. how far apart its children are)
42
43         - Index: the order in which this cluster was generated
44
45         - Points: *Not Settable*, the list of original points in this cluster
46              (calculated recursively from the children)
47
48         - PointsPositions: *Not Settable*, the list of positions of the original
49            points in this cluster (calculated recursively from the children)
50
51         - Position: the location of the cluster **Note** for a cluster this
52           probably means the location of the average of all the Points which are
53           its children.
54
55         - Data: a data field.  This is used with the original points to store their
56           data value (i.e. the value we're using to classify)
57
58         - Name: the name of this cluster
59
60    """
61
62    def __init__(self, metric=0.0, children=None, position=None, index=-1, name=None, data=None):
63        """Constructor
64
65          **Arguments**
66
67            see the class documentation for the meanings of these arguments
68
69            *my wrists are tired*
70
71        """
72        if children is None:
73            children = []
74        if position is None:
75            position = []
76        self.metric = metric
77        self.children = children
78        self._UpdateLength()
79        self.pos = position
80        self.index = index
81        self.name = name
82        self._points = None
83        self._pointsPositions = None
84        self.data = data
85
86    def SetMetric(self, metric):
87        self.metric = metric
88
89    def GetMetric(self):
90        return self.metric
91
92    def SetIndex(self, index):
93        self.index = index
94
95    def GetIndex(self):
96        return self.index
97
98    def SetPosition(self, pos):
99        self.pos = pos
100
101    def GetPosition(self):
102        return self.pos
103
104    def GetPointsPositions(self):
105        if self._pointsPositions is not None:
106            return self._pointsPositions
107        else:
108            self._GenPoints()
109            return self._pointsPositions
110
111    def GetPoints(self):
112        if self._points is not None:
113            return self._points
114        else:
115            self._GenPoints()
116            return self._points
117
118    def FindSubtree(self, index):
119        """ finds and returns the subtree with a particular index
120        """
121        res = None
122        if index == self.index:
123            res = self
124        else:
125            for child in self.children:
126                res = child.FindSubtree(index)
127                if res:
128                    break
129        return res
130
131    def _GenPoints(self):
132        """ Generates the _Points_ and _PointsPositions_ lists
133
134         *intended for internal use*
135
136        """
137        if len(self) == 1:
138            self._points = [self]
139            self._pointsPositions = [self.GetPosition()]
140            return self._points
141        else:
142            res = []
143            children = self.GetChildren()
144            children.sort(key=lambda x: len(x), reverse=True)
145            for child in children:
146                res += child.GetPoints()
147            self._points = res
148            self._pointsPositions = [x.GetPosition() for x in res]
149
150    def AddChild(self, child):
151        """Adds a child to our list
152
153          **Arguments**
154
155            - child: a Cluster
156
157        """
158        self.children.append(child)
159        self._GenPoints()
160        self._UpdateLength()
161
162    def AddChildren(self, children):
163        """Adds a bunch of children to our list
164
165          **Arguments**
166
167            - children: a list of Clusters
168
169        """
170        self.children += children
171        self._GenPoints()
172        self._UpdateLength()
173
174    def RemoveChild(self, child):
175        """Removes a child from our list
176
177          **Arguments**
178
179            - child: a Cluster
180
181        """
182        self.children.remove(child)
183        self._UpdateLength()
184
185    def GetChildren(self):
186        self.children.sort(key=lambda x: x.GetMetric())
187        return self.children
188
189    def SetData(self, data):
190        self.data = data
191
192    def GetData(self):
193        return self.data
194
195    def SetName(self, name):
196        self.name = name
197
198    def GetName(self):
199        if self.name is None:
200            return 'Cluster(%d)' % (self.GetIndex())
201        else:
202            return self.name
203
204    def Print(self, level=0, showData=0, offset='\t'):
205        if not showData or self.GetData() is None:
206            print('%s%s%s Metric: %f' % ('  ' * level, self.GetName(), offset, self.GetMetric()))
207        else:
208            print('%s%s%s Data: %f\t Metric: %f' %
209                  ('  ' * level, self.GetName(), offset, self.GetData(), self.GetMetric()))
210
211        for child in self.GetChildren():
212            child.Print(level=level + 1, showData=showData, offset=offset)
213
214    def Compare(self, other, ignoreExtras=1):
215        """ not as choosy as self==other
216
217        """
218        tv1, tv2 = str(type(self)), str(type(other))
219        tv = cmp(tv1, tv2)
220        if tv:
221            return tv
222        tv1, tv2 = len(self), len(other)
223        tv = cmp(tv1, tv2)
224        if tv:
225            return tv
226
227        if not ignoreExtras:
228            m1, m2 = self.GetMetric(), other.GetMetric()
229            if abs(m1 - m2) > CMPTOL:
230                return cmp(m1, m2)
231
232            if cmp(self.GetName(), other.GetName()):
233                return cmp(self.GetName(), other.GetName())
234
235            sP = self.GetPosition()
236            oP = other.GetPosition()
237            try:
238                r = cmp(len(sP), len(oP))
239            except Exception:
240                pass
241            else:
242                if r:
243                    return r
244
245            try:
246                r = cmp(sP, oP)
247            except Exception:
248                r = sum(sP - oP)
249            if r:
250                return r
251
252        c1, c2 = self.GetChildren(), other.GetChildren()
253        if cmp(len(c1), len(c2)):
254            return cmp(len(c1), len(c2))
255        for i in range(len(c1)):
256            t = c1[i].Compare(c2[i], ignoreExtras=ignoreExtras)
257            if t:
258                return t
259
260        return 0
261
262    def _UpdateLength(self):
263        """ updates our length
264
265         *intended for internal use*
266
267        """
268        self._len = sum(len(c) for c in self.children) + 1
269
270    def IsTerminal(self):
271        return self._len <= 1
272
273    def __len__(self):
274        """ allows _len(cluster)_ to work
275
276        """
277        return self._len
278
279    def __cmp__(self, other):
280        """ allows _cluster1 == cluster2_ to work
281
282        """
283        if cmp(type(self), type(other)):
284            return cmp(type(self), type(other))
285
286        m1, m2 = self.GetMetric(), other.GetMetric()
287        if abs(m1 - m2) > CMPTOL:
288            return cmp(m1, m2)
289
290        if cmp(self.GetName(), other.GetName()):
291            return cmp(self.GetName(), other.GetName())
292
293        c1, c2 = self.GetChildren(), other.GetChildren()
294        return cmp(c1, c2)
295
296
297if __name__ == '__main__':  # pragma: nocover
298    from rdkit.ML.Cluster import ClusterUtils
299    root = Cluster(index=1, metric=1000)
300    c1 = Cluster(index=10, metric=100)
301    c1.AddChild(Cluster(index=30, metric=10))
302    c1.AddChild(Cluster(index=31, metric=10))
303    c1.AddChild(Cluster(index=32, metric=10))
304
305    c2 = Cluster(index=11, metric=100)
306    c2.AddChild(Cluster(index=40, metric=10))
307    c2.AddChild(Cluster(index=41, metric=10))
308
309    root.AddChild(c1)
310    root.AddChild(c2)
311
312    nodes = ClusterUtils.GetNodeList(root)
313
314    indices = [x.GetIndex() for x in nodes]
315    print('XXX:', indices)
316