1# $Id$ 2# 3# Copyright (C) 2001-2008 greg Landrum 4# 5# @@ All Rights Reserved @@ 6# This file is part of the RDKit. 7# The contents are covered by the terms of the BSD license 8# which is included in the file license.txt, found at the root 9# of the RDKit source tree. 10# 11""" contains the Cluster class for representing hierarchical cluster trees 12 13""" 14 15 16def cmp(t1, t2): 17 return (t1 < t2) * -1 or (t1 > t2) * 1 18 19 20CMPTOL = 1e-6 21 22 23class Cluster(object): 24 """a class for storing clusters/data 25 26 **General Remarks** 27 28 - It is assumed that the bottom of any cluster hierarchy tree is composed of 29 the individual data points which were clustered. 30 31 - Clusters objects store the following pieces of data, most are 32 accessible via standard Setters/Getters: 33 34 - Children: *Not Settable*, the list of children. You can add children 35 with the _AddChild()_ and _AddChildren()_ methods. 36 37 **Note** this can be of arbitrary length, 38 but the current algorithms I have only produce trees with two children 39 per cluster 40 41 - Metric: the metric for this cluster (i.e. how far apart its children are) 42 43 - Index: the order in which this cluster was generated 44 45 - Points: *Not Settable*, the list of original points in this cluster 46 (calculated recursively from the children) 47 48 - PointsPositions: *Not Settable*, the list of positions of the original 49 points in this cluster (calculated recursively from the children) 50 51 - Position: the location of the cluster **Note** for a cluster this 52 probably means the location of the average of all the Points which are 53 its children. 54 55 - Data: a data field. This is used with the original points to store their 56 data value (i.e. the value we're using to classify) 57 58 - Name: the name of this cluster 59 60 """ 61 62 def __init__(self, metric=0.0, children=None, position=None, index=-1, name=None, data=None): 63 """Constructor 64 65 **Arguments** 66 67 see the class documentation for the meanings of these arguments 68 69 *my wrists are tired* 70 71 """ 72 if children is None: 73 children = [] 74 if position is None: 75 position = [] 76 self.metric = metric 77 self.children = children 78 self._UpdateLength() 79 self.pos = position 80 self.index = index 81 self.name = name 82 self._points = None 83 self._pointsPositions = None 84 self.data = data 85 86 def SetMetric(self, metric): 87 self.metric = metric 88 89 def GetMetric(self): 90 return self.metric 91 92 def SetIndex(self, index): 93 self.index = index 94 95 def GetIndex(self): 96 return self.index 97 98 def SetPosition(self, pos): 99 self.pos = pos 100 101 def GetPosition(self): 102 return self.pos 103 104 def GetPointsPositions(self): 105 if self._pointsPositions is not None: 106 return self._pointsPositions 107 else: 108 self._GenPoints() 109 return self._pointsPositions 110 111 def GetPoints(self): 112 if self._points is not None: 113 return self._points 114 else: 115 self._GenPoints() 116 return self._points 117 118 def FindSubtree(self, index): 119 """ finds and returns the subtree with a particular index 120 """ 121 res = None 122 if index == self.index: 123 res = self 124 else: 125 for child in self.children: 126 res = child.FindSubtree(index) 127 if res: 128 break 129 return res 130 131 def _GenPoints(self): 132 """ Generates the _Points_ and _PointsPositions_ lists 133 134 *intended for internal use* 135 136 """ 137 if len(self) == 1: 138 self._points = [self] 139 self._pointsPositions = [self.GetPosition()] 140 return self._points 141 else: 142 res = [] 143 children = self.GetChildren() 144 children.sort(key=lambda x: len(x), reverse=True) 145 for child in children: 146 res += child.GetPoints() 147 self._points = res 148 self._pointsPositions = [x.GetPosition() for x in res] 149 150 def AddChild(self, child): 151 """Adds a child to our list 152 153 **Arguments** 154 155 - child: a Cluster 156 157 """ 158 self.children.append(child) 159 self._GenPoints() 160 self._UpdateLength() 161 162 def AddChildren(self, children): 163 """Adds a bunch of children to our list 164 165 **Arguments** 166 167 - children: a list of Clusters 168 169 """ 170 self.children += children 171 self._GenPoints() 172 self._UpdateLength() 173 174 def RemoveChild(self, child): 175 """Removes a child from our list 176 177 **Arguments** 178 179 - child: a Cluster 180 181 """ 182 self.children.remove(child) 183 self._UpdateLength() 184 185 def GetChildren(self): 186 self.children.sort(key=lambda x: x.GetMetric()) 187 return self.children 188 189 def SetData(self, data): 190 self.data = data 191 192 def GetData(self): 193 return self.data 194 195 def SetName(self, name): 196 self.name = name 197 198 def GetName(self): 199 if self.name is None: 200 return 'Cluster(%d)' % (self.GetIndex()) 201 else: 202 return self.name 203 204 def Print(self, level=0, showData=0, offset='\t'): 205 if not showData or self.GetData() is None: 206 print('%s%s%s Metric: %f' % (' ' * level, self.GetName(), offset, self.GetMetric())) 207 else: 208 print('%s%s%s Data: %f\t Metric: %f' % 209 (' ' * level, self.GetName(), offset, self.GetData(), self.GetMetric())) 210 211 for child in self.GetChildren(): 212 child.Print(level=level + 1, showData=showData, offset=offset) 213 214 def Compare(self, other, ignoreExtras=1): 215 """ not as choosy as self==other 216 217 """ 218 tv1, tv2 = str(type(self)), str(type(other)) 219 tv = cmp(tv1, tv2) 220 if tv: 221 return tv 222 tv1, tv2 = len(self), len(other) 223 tv = cmp(tv1, tv2) 224 if tv: 225 return tv 226 227 if not ignoreExtras: 228 m1, m2 = self.GetMetric(), other.GetMetric() 229 if abs(m1 - m2) > CMPTOL: 230 return cmp(m1, m2) 231 232 if cmp(self.GetName(), other.GetName()): 233 return cmp(self.GetName(), other.GetName()) 234 235 sP = self.GetPosition() 236 oP = other.GetPosition() 237 try: 238 r = cmp(len(sP), len(oP)) 239 except Exception: 240 pass 241 else: 242 if r: 243 return r 244 245 try: 246 r = cmp(sP, oP) 247 except Exception: 248 r = sum(sP - oP) 249 if r: 250 return r 251 252 c1, c2 = self.GetChildren(), other.GetChildren() 253 if cmp(len(c1), len(c2)): 254 return cmp(len(c1), len(c2)) 255 for i in range(len(c1)): 256 t = c1[i].Compare(c2[i], ignoreExtras=ignoreExtras) 257 if t: 258 return t 259 260 return 0 261 262 def _UpdateLength(self): 263 """ updates our length 264 265 *intended for internal use* 266 267 """ 268 self._len = sum(len(c) for c in self.children) + 1 269 270 def IsTerminal(self): 271 return self._len <= 1 272 273 def __len__(self): 274 """ allows _len(cluster)_ to work 275 276 """ 277 return self._len 278 279 def __cmp__(self, other): 280 """ allows _cluster1 == cluster2_ to work 281 282 """ 283 if cmp(type(self), type(other)): 284 return cmp(type(self), type(other)) 285 286 m1, m2 = self.GetMetric(), other.GetMetric() 287 if abs(m1 - m2) > CMPTOL: 288 return cmp(m1, m2) 289 290 if cmp(self.GetName(), other.GetName()): 291 return cmp(self.GetName(), other.GetName()) 292 293 c1, c2 = self.GetChildren(), other.GetChildren() 294 return cmp(c1, c2) 295 296 297if __name__ == '__main__': # pragma: nocover 298 from rdkit.ML.Cluster import ClusterUtils 299 root = Cluster(index=1, metric=1000) 300 c1 = Cluster(index=10, metric=100) 301 c1.AddChild(Cluster(index=30, metric=10)) 302 c1.AddChild(Cluster(index=31, metric=10)) 303 c1.AddChild(Cluster(index=32, metric=10)) 304 305 c2 = Cluster(index=11, metric=100) 306 c2.AddChild(Cluster(index=40, metric=10)) 307 c2.AddChild(Cluster(index=41, metric=10)) 308 309 root.AddChild(c1) 310 root.AddChild(c2) 311 312 nodes = ClusterUtils.GetNodeList(root) 313 314 indices = [x.GetIndex() for x in nodes] 315 print('XXX:', indices) 316