1# $Id$
2#
3#  Copyright (C) 2001, 2003  greg Landrum and Rational Discovery LLC
4#   All Rights Reserved
5#
6""" Defines the class _QuantTreeNode_, used to represent decision trees with automatic
7 quantization bounds
8
9  _QuantTreeNode_ is derived from _DecTree.DecTreeNode_
10
11"""
12from rdkit.ML.DecTree import DecTree, Tree
13
14
15class QuantTreeNode(DecTree.DecTreeNode):
16  """
17
18  """
19
20  def __init__(self, *args, **kwargs):
21    DecTree.DecTreeNode.__init__(self, *args, **kwargs)
22    self.qBounds = []
23    self.nBounds = 0
24
25  def ClassifyExample(self, example, appendExamples=0):
26    """ Recursively classify an example by running it through the tree
27
28      **Arguments**
29
30        - example: the example to be classified
31
32        - appendExamples: if this is nonzero then this node (and all children)
33          will store the example
34
35      **Returns**
36
37        the classification of _example_
38
39      **NOTE:**
40        In the interest of speed, I don't use accessor functions
41        here.  So if you subclass DecTreeNode for your own trees, you'll
42        have to either include ClassifyExample or avoid changing the names
43        of the instance variables this needs.
44
45    """
46    if appendExamples:
47      self.examples.append(example)
48    if self.terminalNode:
49      return self.label
50    else:
51      val = example[self.label]
52      if not hasattr(self, 'nBounds'):
53        self.nBounds = len(self.qBounds)
54      if self.nBounds:
55        for i, bound in enumerate(self.qBounds):
56          if val < bound:
57            val = i
58            break
59        else:
60          val = i + 1
61      else:
62        val = int(val)
63      return self.children[val].ClassifyExample(example, appendExamples=appendExamples)
64
65  def SetQuantBounds(self, qBounds):
66    self.qBounds = qBounds[:]
67    self.nBounds = len(self.qBounds)
68
69  def GetQuantBounds(self):
70    return self.qBounds
71
72  def __cmp__(self, other):
73    return (self < other) * -1 or (other < self) * 1
74
75  def __lt__(self, other):
76    if str(type(self)) < str(type(other)):
77      return True
78    if self.qBounds < other.qBounds:
79      return True
80    if Tree.TreeNode.__lt__(self, other):
81      return True
82    return False
83
84  def __eq__(self, other):
85    return not self < other and not other < self
86
87  def __str__(self):
88    """ returns a string representation of the tree
89
90      **Note**
91
92        this works recursively
93
94    """
95    here = '%s%s %s\n' % ('  ' * self.level, self.name, str(self.qBounds))
96    for child in self.children:
97      here = here + str(child)
98    return here
99