1r"""Rules for classification and regression trees. 2 3Tree visualisations usually need to show the rules of nodes, these classes make 4merging these rules simple (otherwise you have repeating rules e.g. `age < 3` 5and `age < 2` which can be merged into `age < 2`. 6 7Subclasses of the `Rule` class should provide a nice interface to merge rules 8together through the `merge_with` method. Of course, this should not be forced 9where it doesn't make sense e.g. merging a discrete rule (e.g. 10:math:`x \in \{red, blue, green\}`) and a continuous rule (e.g. 11:math:`x \leq 5`). 12 13""" 14import warnings 15 16 17class Rule: 18 """The base Rule class for tree rules.""" 19 20 def merge_with(self, rule): 21 """Merge the current rule with the given rule. 22 23 Parameters 24 ---------- 25 rule : Rule 26 27 Returns 28 ------- 29 Rule 30 31 """ 32 raise NotImplementedError() 33 34 @property 35 def description(self): 36 return str(self) 37 38 39class DiscreteRule(Rule): 40 """Discrete rule class for handling Indicator rules. 41 42 Parameters 43 ---------- 44 attr_name : str 45 equals : bool 46 Should indicate whether or not the rule equals the value or not. 47 value : object 48 49 Examples 50 -------- 51 >>> print(DiscreteRule('age', True, 30)) 52 age = 30 53 54 >>> print(DiscreteRule('name', False, 'John')) 55 name ≠ John 56 57 Notes 58 ----- 59 .. note:: Merging discrete rules is currently not implemented, the new rule 60 is simply returned and a warning is issued. 61 62 """ 63 64 def __init__(self, attr_name, equals, value): 65 self.attr_name = attr_name 66 self.equals = equals 67 self.value = value 68 69 def merge_with(self, rule): 70 # It does not make sense to merge discrete rules, since they can only 71 # be eq or not eq. 72 warnings.warn('Merged two discrete rules `%s` and `%s`' % (self, rule)) 73 return rule 74 75 @property 76 def description(self): 77 return '{} {}'.format('=' if self.equals else '≠', self.value) 78 79 def __str__(self): 80 return '{} {} {}'.format( 81 self.attr_name, '=' if self.equals else '≠', self.value) 82 83 def __repr__(self): 84 return "DiscreteRule(attr_name='%s', equals=%s, value=%s)" % ( 85 self.attr_name, self.equals, self.value) 86 87 88class ContinuousRule(Rule): 89 """Continuous rule class for handling numeric rules. 90 91 Parameters 92 ---------- 93 attr_name : str 94 greater : bool 95 Should indicate whether the variable must be greater than the value. 96 value : int 97 inclusive : bool, optional 98 Should the variable range include the value or not 99 (LT <> LTE | GT <> GTE). Default is False. 100 101 Examples 102 -------- 103 >>> print(ContinuousRule('age', False, 30, inclusive=True)) 104 age ≤ 30.000 105 106 >>> print(ContinuousRule('age', True, 30)) 107 age > 30.000 108 109 Notes 110 ----- 111 .. note:: Continuous rules can currently only be merged with other 112 continuous rules. 113 114 """ 115 116 def __init__(self, attr_name, greater, value, inclusive=False): 117 self.attr_name = attr_name 118 self.greater = greater 119 self.value = value 120 self.inclusive = inclusive 121 122 def merge_with(self, rule): 123 if not isinstance(rule, ContinuousRule): 124 raise NotImplementedError('Continuous rules can currently only be ' 125 'merged with other continuous rules') 126 # Handle when both have same sign 127 if self.greater == rule.greater: 128 # When both are GT 129 if self.greater is True: 130 larger = max(self.value, rule.value) 131 return ContinuousRule(self.attr_name, self.greater, larger) 132 # When both are LT 133 else: 134 smaller = min(self.value, rule.value) 135 return ContinuousRule(self.attr_name, self.greater, smaller) 136 # When they have different signs we need to return an interval rule 137 else: 138 lt_rule, gt_rule = (rule, self) if self.greater else (self, rule) 139 return IntervalRule(self.attr_name, gt_rule, lt_rule) 140 141 @property 142 def description(self): 143 return '%s %.3f' % ('>' if self.greater else '≤', self.value) 144 145 def __str__(self): 146 return '%s %s %.3f' % ( 147 self.attr_name, '>' if self.greater else '≤', self.value) 148 149 def __repr__(self): 150 return "ContinuousRule(attr_name='%s', greater=%s, value=%s, " \ 151 "inclusive=%s)" % (self.attr_name, self.greater, self.value, 152 self.inclusive) 153 154 155class IntervalRule(Rule): 156 """Interval rule class for ranges of continuous values. 157 158 Parameters 159 ---------- 160 attr_name : str 161 left_rule : ContinuousRule 162 The smaller (left) part of the interval. 163 right_rule : ContinuousRule 164 The larger (right) part of the interval. 165 166 Examples 167 -------- 168 >>> print(IntervalRule('Rule', 169 >>> ContinuousRule('Rule', True, 1, inclusive=True), 170 >>> ContinuousRule('Rule', False, 3))) 171 Rule ∈ [1.000, 3.000) 172 173 Notes 174 ----- 175 .. note:: Currently, only cases which appear in classification and 176 regression trees are implemented. An interval can not be made up of two 177 parts (e.g. (-∞, -1) ∪ (1, ∞)). 178 179 """ 180 181 def __init__(self, attr_name, left_rule, right_rule): 182 if not isinstance(left_rule, ContinuousRule): 183 raise AttributeError( 184 'The left rule must be an instance of the `ContinuousRule` ' 185 'class.') 186 if not isinstance(right_rule, ContinuousRule): 187 raise AttributeError( 188 'The right rule must be an instance of the `ContinuousRule` ' 189 'class.') 190 191 self.attr_name = attr_name 192 self.left_rule = left_rule 193 self.right_rule = right_rule 194 195 def merge_with(self, rule): 196 if isinstance(rule, ContinuousRule): 197 if rule.greater: 198 return IntervalRule( 199 self.attr_name, self.left_rule.merge_with(rule), 200 self.right_rule) 201 else: 202 return IntervalRule( 203 self.attr_name, self.left_rule, 204 self.right_rule.merge_with(rule)) 205 206 elif isinstance(rule, IntervalRule): 207 return IntervalRule( 208 self.attr_name, 209 self.left_rule.merge_with(rule.left_rule), 210 self.right_rule.merge_with(rule.right_rule)) 211 212 @property 213 def description(self): 214 return '∈ %s%.3f, %.3f%s' % ( 215 '[' if self.left_rule.inclusive else '(', 216 self.left_rule.value, 217 self.right_rule.value, 218 ']' if self.right_rule.inclusive else ')' 219 ) 220 221 def __str__(self): 222 return '%s ∈ %s%.3f, %.3f%s' % ( 223 self.attr_name, 224 '[' if self.left_rule.inclusive else '(', 225 self.left_rule.value, 226 self.right_rule.value, 227 ']' if self.right_rule.inclusive else ')' 228 ) 229 230 def __repr__(self): 231 return "IntervalRule(attr_name='%s', left_rule=%s, right_rule=%s)" % ( 232 self.attr_name, repr(self.left_rule), repr(self.right_rule)) 233