1# $Id$
2#
3# Copyright (C) 2001-2006  greg Landrum and Rational Discovery LLC
4#
5#   @@ All Rights Reserved @@
6#  This file is part of the RDKit.
7#  The contents are covered by the terms of the BSD license
8#  which is included in the file license.txt, found at the root
9#  of the RDKit source tree.
10#
11"""Cluster tree visualization using Sping
12
13"""
14
15try:
16  from rdkit.sping import pid
17  piddle = pid
18except ImportError:
19  from rdkit.piddle import piddle
20import numpy
21
22from . import ClusterUtils
23
24
25class VisOpts(object):
26  """ stores visualization options for cluster viewing
27
28    **Instance variables**
29
30      - x/yOffset: amount by which the drawing is offset from the edges of the canvas
31
32      - lineColor: default color for drawing the cluster tree
33
34      - lineWidth: the width of the lines used to draw the tree
35
36  """
37  xOffset = 20
38  yOffset = 20
39  lineColor = piddle.Color(0, 0, 0)
40  hideColor = piddle.Color(.8, .8, .8)
41  terminalColors = [piddle.Color(1, 0, 0), piddle.Color(0, 0, 1), piddle.Color(1, 1, 0),
42                    piddle.Color(0, .5, .5), piddle.Color(0, .8, 0), piddle.Color(.5, .5, .5),
43                    piddle.Color(.8, .3, .3), piddle.Color(.3, .3, .8), piddle.Color(.8, .8, .3),
44                    piddle.Color(.3, .8, .8)]
45  lineWidth = 2
46  hideWidth = 1.1
47  nodeRad = 15
48  nodeColor = piddle.Color(1., .4, .4)
49  highlightColor = piddle.Color(1., 1., .4)
50  highlightRad = 10
51
52
53def _scaleMetric(val, power=2, min=1e-4):
54  val = float(val)
55  nval = pow(val, power)
56  if nval < min:
57    return 0.0
58  else:
59    return numpy.log(nval / min)
60
61
62class ClusterRenderer(object):
63
64  def __init__(self, canvas, size, ptColors=[], lineWidth=None, showIndices=0, showNodes=1,
65               stopAtCentroids=0, logScale=0, tooClose=-1):
66    self.canvas = canvas
67    self.size = size
68    self.ptColors = ptColors
69    self.lineWidth = lineWidth
70    self.showIndices = showIndices
71    self.showNodes = showNodes
72    self.stopAtCentroids = stopAtCentroids
73    self.logScale = logScale
74    self.tooClose = tooClose
75
76  def _AssignPointLocations(self, cluster, terminalOffset=4):
77    self.pts = cluster.GetPoints()
78    self.nPts = len(self.pts)
79    self.xSpace = float(self.size[0] - 2 * VisOpts.xOffset) / float(self.nPts - 1)
80    ySize = self.size[1]
81    for i in range(self.nPts):
82      pt = self.pts[i]
83      if self.logScale > 0:
84        v = _scaleMetric(pt.GetMetric(), self.logScale)
85      else:
86        v = float(pt.GetMetric())
87      pt._drawPos = (VisOpts.xOffset + i * self.xSpace,
88                     ySize - (v * self.ySpace + VisOpts.yOffset) + terminalOffset)
89
90  def _AssignClusterLocations(self, cluster):
91    # first get the search order (top down)
92    toDo = [cluster]
93    examine = cluster.GetChildren()[:]
94    while len(examine):
95      node = examine.pop(0)
96      children = node.GetChildren()
97      if len(children):
98        toDo.append(node)
99        for child in children:
100          if not child.IsTerminal():
101            examine.append(child)
102    # and reverse it (to run from bottom up)
103    toDo.reverse()
104    for node in toDo:
105      if self.logScale > 0:
106        v = _scaleMetric(node.GetMetric(), self.logScale)
107      else:
108        v = float(node.GetMetric())
109      # average our children's x positions
110      childLocs = [x._drawPos[0] for x in node.GetChildren()]
111      if len(childLocs):
112        xp = sum(childLocs) / float(len(childLocs))
113        yp = self.size[1] - (v * self.ySpace + VisOpts.yOffset)
114        node._drawPos = (xp, yp)
115
116  def _DrawToLimit(self, cluster):
117    """
118      we assume that _drawPos settings have been done already
119    """
120    if self.lineWidth is None:
121      lineWidth = VisOpts.lineWidth
122    else:
123      lineWidth = self.lineWidth
124
125    examine = [cluster]
126    while len(examine):
127      node = examine.pop(0)
128      xp, yp = node._drawPos
129      children = node.GetChildren()
130      if abs(children[1]._drawPos[0] - children[0]._drawPos[0]) > self.tooClose:
131        # draw the horizontal line connecting things
132        drawColor = VisOpts.lineColor
133        self.canvas.drawLine(children[0]._drawPos[0], yp, children[-1]._drawPos[0], yp, drawColor,
134                             lineWidth)
135        # and draw the lines down to the children
136        for child in children:
137          if self.ptColors and child.GetData() is not None:
138            drawColor = self.ptColors[child.GetData()]
139          else:
140            drawColor = VisOpts.lineColor
141          cxp, cyp = child._drawPos
142          self.canvas.drawLine(cxp, yp, cxp, cyp, drawColor, lineWidth)
143          if not child.IsTerminal():
144            examine.append(child)
145          else:
146            if self.showIndices and not self.stopAtCentroids:
147              try:
148                txt = str(child.GetName())
149              except Exception:
150                txt = str(child.GetIndex())
151              self.canvas.drawString(txt, cxp - self.canvas.stringWidth(txt) / 2, cyp)
152
153      else:
154        # draw a "hidden" line to the bottom
155        self.canvas.drawLine(xp, yp, xp, self.size[1] - VisOpts.yOffset, VisOpts.hideColor,
156                             lineWidth)
157
158  def DrawTree(self, cluster, minHeight=2.0):
159    if self.logScale > 0:
160      v = _scaleMetric(cluster.GetMetric(), self.logScale)
161    else:
162      v = float(cluster.GetMetric())
163    if v <= 0:
164      v = minHeight
165    self.ySpace = float(self.size[1] - 2 * VisOpts.yOffset) / v
166
167    self._AssignPointLocations(cluster)
168    self._AssignClusterLocations(cluster)
169    if not self.stopAtCentroids:
170      self._DrawToLimit(cluster)
171    else:
172      raise NotImplementedError('stopAtCentroids drawing not yet implemented')
173
174
175def DrawClusterTree(cluster, canvas, size, ptColors=[], lineWidth=None, showIndices=0, showNodes=1,
176                    stopAtCentroids=0, logScale=0, tooClose=-1):
177  """ handles the work of drawing a cluster tree on a Sping canvas
178
179    **Arguments**
180
181      - cluster: the cluster tree to be drawn
182
183      - canvas:  the Sping canvas on which to draw
184
185      - size: the size of _canvas_
186
187      - ptColors: if this is specified, the _colors_ will be used to color
188        the terminal nodes of the cluster tree.  (color == _pid.Color_)
189
190      - lineWidth: if specified, it will be used for the widths of the lines
191        used to draw the tree
192
193   **Notes**
194
195     - _Canvas_ is neither _save_d nor _flush_ed at the end of this
196
197     - if _ptColors_ is the wrong length for the number of possible terminal
198       node types, this will throw an IndexError
199
200     - terminal node types are determined using their _GetData()_ methods
201
202  """
203  renderer = ClusterRenderer(canvas, size, ptColors, lineWidth, showIndices, showNodes,
204                             stopAtCentroids, logScale, tooClose)
205  renderer.DrawTree(cluster)
206
207
208def _DrawClusterTree(cluster, canvas, size, ptColors=[], lineWidth=None, showIndices=0, showNodes=1,
209                     stopAtCentroids=0, logScale=0, tooClose=-1):
210  """ handles the work of drawing a cluster tree on a Sping canvas
211
212    **Arguments**
213
214      - cluster: the cluster tree to be drawn
215
216      - canvas:  the Sping canvas on which to draw
217
218      - size: the size of _canvas_
219
220      - ptColors: if this is specified, the _colors_ will be used to color
221        the terminal nodes of the cluster tree.  (color == _pid.Color_)
222
223      - lineWidth: if specified, it will be used for the widths of the lines
224        used to draw the tree
225
226   **Notes**
227
228     - _Canvas_ is neither _save_d nor _flush_ed at the end of this
229
230     - if _ptColors_ is the wrong length for the number of possible terminal
231       node types, this will throw an IndexError
232
233     - terminal node types are determined using their _GetData()_ methods
234
235  """
236  if lineWidth is None:
237    lineWidth = VisOpts.lineWidth
238  pts = cluster.GetPoints()
239  nPts = len(pts)
240  if nPts <= 1:
241    return
242  xSpace = float(size[0] - 2 * VisOpts.xOffset) / float(nPts - 1)
243  if logScale > 0:
244    v = _scaleMetric(cluster.GetMetric(), logScale)
245  else:
246    v = float(cluster.GetMetric())
247  ySpace = float(size[1] - 2 * VisOpts.yOffset) / v
248
249  for i in range(nPts):
250    pt = pts[i]
251    if logScale > 0:
252      v = _scaleMetric(pt.GetMetric(), logScale)
253    else:
254      v = float(pt.GetMetric())
255    pt._drawPos = (VisOpts.xOffset + i * xSpace, size[1] - (v * ySpace + VisOpts.yOffset))
256#     if not stopAtCentroids or not hasattr(pt, '_isCentroid'):
257#       allNodes.remove(pt)  # allNodes not defined
258
259  if not stopAtCentroids:
260    allNodes = ClusterUtils.GetNodeList(cluster)
261  else:
262    allNodes = ClusterUtils.GetNodesDownToCentroids(cluster)
263
264  while len(allNodes):
265    node = allNodes.pop(0)
266    children = node.GetChildren()
267    if len(children):
268      if logScale > 0:
269        v = _scaleMetric(node.GetMetric(), logScale)
270      else:
271        v = float(node.GetMetric())
272      yp = size[1] - (v * ySpace + VisOpts.yOffset)
273      childLocs = [x._drawPos[0] for x in children]
274      xp = sum(childLocs) / float(len(childLocs))
275      node._drawPos = (xp, yp)
276      if not stopAtCentroids or node._aboveCentroid > 0:
277        for child in children:
278          if ptColors != [] and child.GetData() is not None:
279            drawColor = ptColors[child.GetData()]
280          else:
281            drawColor = VisOpts.lineColor
282          if showNodes and hasattr(child, '_isCentroid'):
283            canvas.drawLine(child._drawPos[0], child._drawPos[1] - VisOpts.nodeRad / 2,
284                            child._drawPos[0], node._drawPos[1], drawColor, lineWidth)
285          else:
286            canvas.drawLine(child._drawPos[0], child._drawPos[1], child._drawPos[0],
287                            node._drawPos[1], drawColor, lineWidth)
288        canvas.drawLine(children[0]._drawPos[0], node._drawPos[1], children[-1]._drawPos[0],
289                        node._drawPos[1], VisOpts.lineColor, lineWidth)
290      else:
291        for child in children:
292          drawColor = VisOpts.hideColor
293          canvas.drawLine(child._drawPos[0], child._drawPos[1], child._drawPos[0], node._drawPos[1],
294                          drawColor, VisOpts.hideWidth)
295        canvas.drawLine(children[0]._drawPos[0], node._drawPos[1], children[-1]._drawPos[0],
296                        node._drawPos[1], VisOpts.hideColor, VisOpts.hideWidth)
297
298    if showIndices and (not stopAtCentroids or node._aboveCentroid >= 0):
299      txt = str(node.GetIndex())
300      if hasattr(node, '_isCentroid'):
301        txtColor = piddle.Color(1, .2, .2)
302      else:
303        txtColor = piddle.Color(0, 0, 0)
304
305      canvas.drawString(txt, node._drawPos[0] - canvas.stringWidth(txt) / 2,
306                        node._drawPos[1] + canvas.fontHeight() / 4, color=txtColor)
307
308    if showNodes and hasattr(node, '_isCentroid'):
309      rad = VisOpts.nodeRad
310      canvas.drawEllipse(node._drawPos[0] - rad / 2, node._drawPos[1] - rad / 2,
311                         node._drawPos[0] + rad / 2, node._drawPos[1] + rad / 2, piddle.transparent,
312                         fillColor=VisOpts.nodeColor)
313      txt = str(node._clustID)
314      canvas.drawString(txt, node._drawPos[0] - canvas.stringWidth(txt) / 2,
315                        node._drawPos[1] + canvas.fontHeight() / 4, color=piddle.Color(0, 0, 0))
316
317  if showIndices and not stopAtCentroids:
318    for pt in pts:
319      txt = str(pt.GetIndex())
320      canvas.drawString(
321        str(pt.GetIndex()), pt._drawPos[0] - canvas.stringWidth(txt) / 2, pt._drawPos[1])
322
323
324def ClusterToPDF(cluster, fileName, size=(300, 300), ptColors=[], lineWidth=None, showIndices=0,
325                 stopAtCentroids=0, logScale=0):
326  """ handles the work of drawing a cluster tree to an PDF file
327
328    **Arguments**
329
330      - cluster: the cluster tree to be drawn
331
332      - fileName: the name of the file to be created
333
334      - size: the size of output canvas
335
336      - ptColors: if this is specified, the _colors_ will be used to color
337        the terminal nodes of the cluster tree.  (color == _pid.Color_)
338
339      - lineWidth: if specified, it will be used for the widths of the lines
340        used to draw the tree
341
342   **Notes**
343
344     - if _ptColors_ is the wrong length for the number of possible terminal
345       node types, this will throw an IndexError
346
347     - terminal node types are determined using their _GetData()_ methods
348
349  """
350  try:
351    from rdkit.sping.PDF import pidPDF
352  except ImportError:
353    from rdkit.piddle import piddlePDF
354    pidPDF = piddlePDF
355
356  canvas = pidPDF.PDFCanvas(size, fileName)
357  if lineWidth is None:
358    lineWidth = VisOpts.lineWidth
359  DrawClusterTree(cluster, canvas, size, ptColors=ptColors, lineWidth=lineWidth,
360                  showIndices=showIndices, stopAtCentroids=stopAtCentroids, logScale=logScale)
361  if fileName:
362    canvas.save()
363  return canvas
364
365
366def ClusterToSVG(cluster, fileName, size=(300, 300), ptColors=[], lineWidth=None, showIndices=0,
367                 stopAtCentroids=0, logScale=0):
368  """ handles the work of drawing a cluster tree to an SVG file
369
370    **Arguments**
371
372      - cluster: the cluster tree to be drawn
373
374      - fileName: the name of the file to be created
375
376      - size: the size of output canvas
377
378      - ptColors: if this is specified, the _colors_ will be used to color
379        the terminal nodes of the cluster tree.  (color == _pid.Color_)
380
381      - lineWidth: if specified, it will be used for the widths of the lines
382        used to draw the tree
383
384   **Notes**
385
386     - if _ptColors_ is the wrong length for the number of possible terminal
387       node types, this will throw an IndexError
388
389     - terminal node types are determined using their _GetData()_ methods
390
391  """
392  try:
393    from rdkit.sping.SVG import pidSVG
394  except ImportError:
395    from rdkit.piddle.piddleSVG import piddleSVG
396    pidSVG = piddleSVG
397
398  canvas = pidSVG.SVGCanvas(size, fileName)
399
400  if lineWidth is None:
401    lineWidth = VisOpts.lineWidth
402  DrawClusterTree(cluster, canvas, size, ptColors=ptColors, lineWidth=lineWidth,
403                  showIndices=showIndices, stopAtCentroids=stopAtCentroids, logScale=logScale)
404  if fileName:
405    canvas.save()
406  return canvas
407
408
409def ClusterToImg(cluster, fileName, size=(300, 300), ptColors=[], lineWidth=None, showIndices=0,
410                 stopAtCentroids=0, logScale=0):
411  """ handles the work of drawing a cluster tree to an image file
412
413    **Arguments**
414
415      - cluster: the cluster tree to be drawn
416
417      - fileName: the name of the file to be created
418
419      - size: the size of output canvas
420
421      - ptColors: if this is specified, the _colors_ will be used to color
422        the terminal nodes of the cluster tree.  (color == _pid.Color_)
423
424      - lineWidth: if specified, it will be used for the widths of the lines
425        used to draw the tree
426
427   **Notes**
428
429     - The extension on  _fileName_ determines the type of image file created.
430       All formats supported by PIL can be used.
431
432     - if _ptColors_ is the wrong length for the number of possible terminal
433       node types, this will throw an IndexError
434
435     - terminal node types are determined using their _GetData()_ methods
436
437  """
438  try:
439    from rdkit.sping.PIL import pidPIL
440  except ImportError:
441    from rdkit.piddle import piddlePIL
442    pidPIL = piddlePIL
443  canvas = pidPIL.PILCanvas(size, fileName)
444  if lineWidth is None:
445    lineWidth = VisOpts.lineWidth
446  DrawClusterTree(cluster, canvas, size, ptColors=ptColors, lineWidth=lineWidth,
447                  showIndices=showIndices, stopAtCentroids=stopAtCentroids, logScale=logScale)
448  if fileName:
449    canvas.save()
450  return canvas
451