1r"""Rules for classification and regression trees.
2
3Tree visualisations usually need to show the rules of nodes, these classes make
4merging these rules simple (otherwise you have repeating rules e.g. `age < 3`
5and `age < 2` which can be merged into `age < 2`.
6
7Subclasses of the `Rule` class should provide a nice interface to merge rules
8together through the `merge_with` method. Of course, this should not be forced
9where it doesn't make sense e.g. merging a discrete rule (e.g.
10:math:`x \in \{red, blue, green\}`) and a continuous rule (e.g.
11:math:`x \leq 5`).
12
13"""
14import warnings
15
16
17class Rule:
18    """The base Rule class for tree rules."""
19
20    def merge_with(self, rule):
21        """Merge the current rule with the given rule.
22
23        Parameters
24        ----------
25        rule : Rule
26
27        Returns
28        -------
29        Rule
30
31        """
32        raise NotImplementedError()
33
34    @property
35    def description(self):
36        return str(self)
37
38
39class DiscreteRule(Rule):
40    """Discrete rule class for handling Indicator rules.
41
42    Parameters
43    ----------
44    attr_name : str
45    equals : bool
46        Should indicate whether or not the rule equals the value or not.
47    value : object
48
49    Examples
50    --------
51    >>> print(DiscreteRule('age', True, 30))
52    age = 30
53
54    >>> print(DiscreteRule('name', False, 'John'))
55    name ≠ John
56
57    Notes
58    -----
59    .. note:: Merging discrete rules is currently not implemented, the new rule
60        is simply returned and a warning is issued.
61
62    """
63
64    def __init__(self, attr_name, equals, value):
65        self.attr_name = attr_name
66        self.equals = equals
67        self.value = value
68
69    def merge_with(self, rule):
70        # It does not make sense to merge discrete rules, since they can only
71        # be eq or not eq.
72        warnings.warn('Merged two discrete rules `%s` and `%s`' % (self, rule))
73        return rule
74
75    @property
76    def description(self):
77        return '{} {}'.format('=' if self.equals else '≠', self.value)
78
79    def __str__(self):
80        return '{} {} {}'.format(
81            self.attr_name, '=' if self.equals else '≠', self.value)
82
83    def __repr__(self):
84        return "DiscreteRule(attr_name='%s', equals=%s, value=%s)" % (
85            self.attr_name, self.equals, self.value)
86
87
88class ContinuousRule(Rule):
89    """Continuous rule class for handling numeric rules.
90
91    Parameters
92    ----------
93    attr_name : str
94    greater : bool
95        Should indicate whether the variable must be greater than the value.
96    value : int
97    inclusive : bool, optional
98        Should the variable range include the value or not
99        (LT <> LTE | GT <> GTE). Default is False.
100
101    Examples
102    --------
103    >>> print(ContinuousRule('age', False, 30, inclusive=True))
104    age ≤ 30.000
105
106    >>> print(ContinuousRule('age', True, 30))
107    age > 30.000
108
109    Notes
110    -----
111    .. note:: Continuous rules can currently only be merged with other
112        continuous rules.
113
114    """
115
116    def __init__(self, attr_name, greater, value, inclusive=False):
117        self.attr_name = attr_name
118        self.greater = greater
119        self.value = value
120        self.inclusive = inclusive
121
122    def merge_with(self, rule):
123        if not isinstance(rule, ContinuousRule):
124            raise NotImplementedError('Continuous rules can currently only be '
125                                      'merged with other continuous rules')
126        # Handle when both have same sign
127        if self.greater == rule.greater:
128            # When both are GT
129            if self.greater is True:
130                larger = max(self.value, rule.value)
131                return ContinuousRule(self.attr_name, self.greater, larger)
132            # When both are LT
133            else:
134                smaller = min(self.value, rule.value)
135                return ContinuousRule(self.attr_name, self.greater, smaller)
136        # When they have different signs we need to return an interval rule
137        else:
138            lt_rule, gt_rule = (rule, self) if self.greater else (self, rule)
139            return IntervalRule(self.attr_name, gt_rule, lt_rule)
140
141    @property
142    def description(self):
143        return '%s %.3f' % ('>' if self.greater else '≤', self.value)
144
145    def __str__(self):
146        return '%s %s %.3f' % (
147            self.attr_name, '>' if self.greater else '≤', self.value)
148
149    def __repr__(self):
150        return "ContinuousRule(attr_name='%s', greater=%s, value=%s, " \
151               "inclusive=%s)" % (self.attr_name, self.greater, self.value,
152                                  self.inclusive)
153
154
155class IntervalRule(Rule):
156    """Interval rule class for ranges of continuous values.
157
158    Parameters
159    ----------
160    attr_name : str
161    left_rule : ContinuousRule
162        The smaller (left) part of the interval.
163    right_rule : ContinuousRule
164        The larger (right) part of the interval.
165
166    Examples
167    --------
168    >>> print(IntervalRule('Rule',
169    >>>                    ContinuousRule('Rule', True, 1, inclusive=True),
170    >>>                    ContinuousRule('Rule', False, 3)))
171    Rule ∈ [1.000, 3.000)
172
173    Notes
174    -----
175    .. note:: Currently, only cases which appear in classification and
176        regression trees are implemented. An interval can not be made up of two
177        parts (e.g. (-∞, -1) ∪ (1, ∞)).
178
179    """
180
181    def __init__(self, attr_name, left_rule, right_rule):
182        if not isinstance(left_rule, ContinuousRule):
183            raise AttributeError(
184                'The left rule must be an instance of the `ContinuousRule` '
185                'class.')
186        if not isinstance(right_rule, ContinuousRule):
187            raise AttributeError(
188                'The right rule must be an instance of the `ContinuousRule` '
189                'class.')
190
191        self.attr_name = attr_name
192        self.left_rule = left_rule
193        self.right_rule = right_rule
194
195    def merge_with(self, rule):
196        if isinstance(rule, ContinuousRule):
197            if rule.greater:
198                return IntervalRule(
199                    self.attr_name, self.left_rule.merge_with(rule),
200                    self.right_rule)
201            else:
202                return IntervalRule(
203                    self.attr_name, self.left_rule,
204                    self.right_rule.merge_with(rule))
205
206        elif isinstance(rule, IntervalRule):
207            return IntervalRule(
208                self.attr_name,
209                self.left_rule.merge_with(rule.left_rule),
210                self.right_rule.merge_with(rule.right_rule))
211
212    @property
213    def description(self):
214        return '∈ %s%.3f, %.3f%s' % (
215            '[' if self.left_rule.inclusive else '(',
216            self.left_rule.value,
217            self.right_rule.value,
218            ']' if self.right_rule.inclusive else ')'
219        )
220
221    def __str__(self):
222        return '%s ∈ %s%.3f, %.3f%s' % (
223            self.attr_name,
224            '[' if self.left_rule.inclusive else '(',
225            self.left_rule.value,
226            self.right_rule.value,
227            ']' if self.right_rule.inclusive else ')'
228        )
229
230    def __repr__(self):
231        return "IntervalRule(attr_name='%s', left_rule=%s, right_rule=%s)" % (
232            self.attr_name, repr(self.left_rule), repr(self.right_rule))
233