1import itertools
2from typing import List
3from typing import Set
4from typing import Tuple
5from typing import TYPE_CHECKING
6from typing import Union
7
8import numpy
9
10
11if TYPE_CHECKING:
12    import sklearn.tree
13
14
15class _FanovaTree(object):
16    def __init__(self, tree: "sklearn.tree._tree.Tree", search_spaces: numpy.ndarray) -> None:
17        assert search_spaces.shape[0] == tree.n_features
18        assert search_spaces.shape[1] == 2
19
20        self._tree = tree
21        self._search_spaces = search_spaces
22
23        statistics = self._precompute_statistics()
24        split_midpoints, split_sizes = self._precompute_split_midpoints_and_sizes()
25        subtree_active_features = self._precompute_subtree_active_features()
26
27        self._statistics = statistics
28        self._split_midpoints = split_midpoints
29        self._split_sizes = split_sizes
30        self._subtree_active_features = subtree_active_features
31        self._variance = None  # Computed lazily and requires `self._statistics`.
32
33    @property
34    def variance(self) -> float:
35        if self._variance is None:
36            leaf_node_indices = numpy.array(
37                [
38                    node_index
39                    for node_index in range(self._n_nodes)
40                    if self._is_node_leaf(node_index)
41                ]
42            )
43            statistics = self._statistics[leaf_node_indices]
44            values = statistics[:, 0]
45            weights = statistics[:, 1]
46            average_values = numpy.average(values, weights=weights)
47            variance = numpy.average((values - average_values) ** 2, weights=weights)
48
49            self._variance = variance
50
51        assert self._variance is not None
52        return self._variance
53
54    def get_marginal_variance(self, features: numpy.ndarray) -> float:
55        assert features.size > 0
56
57        # For each midpoint along the given dimensions, traverse this tree to compute the
58        # marginal predictions.
59        midpoints = [self._split_midpoints[f] for f in features]
60        sizes = [self._split_sizes[f] for f in features]
61
62        product_midpoints = itertools.product(*midpoints)
63        product_sizes = itertools.product(*sizes)
64
65        sample = numpy.full(self._n_features, fill_value=numpy.nan, dtype=numpy.float64)
66
67        values: Union[List[float], numpy.ndarray] = []
68        weights: Union[List[float], numpy.ndarray] = []
69
70        for midpoints, sizes in zip(product_midpoints, product_sizes):
71            sample[features] = numpy.array(midpoints)
72
73            value, weight = self._get_marginalized_statistics(sample)
74            weight *= float(numpy.prod(sizes))
75
76            values = numpy.append(values, value)
77            weights = numpy.append(weights, weight)
78
79        weights = numpy.asarray(weights)
80        values = numpy.asarray(values)
81        average_values = numpy.average(values, weights=weights)
82        variance = numpy.average((values - average_values) ** 2, weights=weights)
83
84        assert variance >= 0.0
85        return variance
86
87    def _get_marginalized_statistics(self, feature_vector: numpy.ndarray) -> Tuple[float, float]:
88        assert feature_vector.size == self._n_features
89
90        marginalized_features = numpy.isnan(feature_vector)
91        active_features = ~marginalized_features
92
93        # Reduce search space cardinalities to 1 for non-active features.
94        search_spaces = self._search_spaces.copy()
95        search_spaces[marginalized_features] = [0.0, 1.0]
96
97        # Start from the root and traverse towards the leafs.
98        active_nodes = [0]
99        active_search_spaces = [search_spaces]
100
101        node_indices = []
102        active_features_cardinalities = []
103
104        while len(active_nodes) > 0:
105            node_index = active_nodes.pop()
106            search_spaces = active_search_spaces.pop()
107
108            feature = self._get_node_split_feature(node_index)
109            if feature >= 0:  # Not leaf. Avoid unnecessary call to `_is_node_leaf`.
110                # If node splits on an active feature, push the child node that we end up in.
111                response = feature_vector[feature]
112                if not numpy.isnan(response):
113                    if response <= self._get_node_split_threshold(node_index):
114                        next_node_index = self._get_node_left_child(node_index)
115                        next_subspace = self._get_node_left_child_subspaces(
116                            node_index, search_spaces
117                        )
118                    else:
119                        next_node_index = self._get_node_right_child(node_index)
120                        next_subspace = self._get_node_right_child_subspaces(
121                            node_index, search_spaces
122                        )
123
124                    active_nodes.append(next_node_index)
125                    active_search_spaces.append(next_subspace)
126                    continue
127
128                # If subtree starting from node splits on an active feature, push both child nodes.
129                if (active_features & self._subtree_active_features[node_index]).any():
130                    for child_node_index in self._get_node_children(node_index):
131                        active_nodes.append(child_node_index)
132                        active_search_spaces.append(search_spaces)
133                    continue
134
135            # If node is a leaf or the subtree does not split on any of the active features.
136            node_indices.append(node_index)
137            active_features_cardinalities.append(_get_cardinality(search_spaces))
138
139        statistics = self._statistics[node_indices]
140        values = statistics[:, 0]
141        weights = statistics[:, 1]
142        weights = weights / active_features_cardinalities
143
144        value = numpy.average(values, weights=weights)
145        weight = weights.sum()
146
147        return value, weight
148
149    def _precompute_statistics(self) -> numpy.ndarray:
150        n_nodes = self._n_nodes
151
152        # Holds for each node, its weighted average value and the sum of weights.
153        statistics = numpy.empty((n_nodes, 2), dtype=numpy.float64)
154
155        subspaces = numpy.array([None for _ in range(n_nodes)])
156        subspaces[0] = self._search_spaces
157
158        # Compute marginals for leaf nodes.
159        for node_index in range(n_nodes):
160            subspace = subspaces[node_index]
161
162            if self._is_node_leaf(node_index):
163                value = self._get_node_value(node_index)
164                weight = _get_cardinality(subspace)
165                statistics[node_index] = [value, weight]
166            else:
167                for child_node_index, child_subspace in zip(
168                    self._get_node_children(node_index),
169                    self._get_node_children_subspaces(node_index, subspace),
170                ):
171                    assert subspaces[child_node_index] is None
172                    subspaces[child_node_index] = child_subspace
173
174        # Compute marginals for internal nodes.
175        for node_index in reversed(range(n_nodes)):
176            if not self._is_node_leaf(node_index):
177                child_values = []
178                child_weights = []
179                for child_node_index in self._get_node_children(node_index):
180                    child_values.append(statistics[child_node_index, 0])
181                    child_weights.append(statistics[child_node_index, 1])
182                value = numpy.average(child_values, weights=child_weights)
183                weight = float(numpy.sum(child_weights))
184                statistics[node_index] = [value, weight]
185
186        return statistics
187
188    def _precompute_split_midpoints_and_sizes(
189        self,
190    ) -> Tuple[List[numpy.ndarray], List[numpy.ndarray]]:
191        midpoints = []
192        sizes = []
193
194        search_spaces = self._search_spaces
195        for feature, feature_split_values in enumerate(self._compute_features_split_values()):
196            feature_split_values = numpy.concatenate(
197                (
198                    numpy.atleast_1d(search_spaces[feature, 0]),
199                    feature_split_values,
200                    numpy.atleast_1d(search_spaces[feature, 1]),
201                )
202            )
203            midpoint = 0.5 * (feature_split_values[1:] + feature_split_values[:-1])
204            size = feature_split_values[1:] - feature_split_values[:-1]
205
206            midpoints.append(midpoint)
207            sizes.append(size)
208
209        return midpoints, sizes
210
211    def _compute_features_split_values(self) -> List[numpy.ndarray]:
212        all_split_values: List[Set[float]] = [set() for _ in range(self._n_features)]
213
214        for node_index in range(self._n_nodes):
215            feature = self._get_node_split_feature(node_index)
216            if feature >= 0:  # Not leaf. Avoid unnecessary call to `_is_node_leaf`.
217                threshold = self._get_node_split_threshold(node_index)
218                all_split_values[feature].add(threshold)
219
220        sorted_all_split_values: List[numpy.ndarray] = []
221
222        for split_values in all_split_values:
223            split_values_array = numpy.array(list(split_values), dtype=numpy.float64)
224            split_values_array.sort()
225            sorted_all_split_values.append(split_values_array)
226
227        return sorted_all_split_values
228
229    def _precompute_subtree_active_features(self) -> numpy.ndarray:
230        subtree_active_features = numpy.full((self._n_nodes, self._n_features), fill_value=False)
231
232        for node_index in reversed(range(self._n_nodes)):
233            feature = self._get_node_split_feature(node_index)
234            if feature >= 0:  # Not leaf. Avoid unnecessary call to `_is_node_leaf`.
235                subtree_active_features[node_index, feature] = True
236                for child_node_index in self._get_node_children(node_index):
237                    subtree_active_features[node_index] |= subtree_active_features[
238                        child_node_index
239                    ]
240
241        return subtree_active_features
242
243    @property
244    def _n_features(self) -> int:
245        return len(self._search_spaces)
246
247    @property
248    def _n_nodes(self) -> int:
249        return self._tree.node_count
250
251    def _is_node_leaf(self, node_index: int) -> bool:
252        return self._tree.feature[node_index] < 0
253
254    def _get_node_left_child(self, node_index: int) -> int:
255        return self._tree.children_left[node_index]
256
257    def _get_node_right_child(self, node_index: int) -> int:
258        return self._tree.children_right[node_index]
259
260    def _get_node_children(self, node_index: int) -> Tuple[int, int]:
261        return self._get_node_left_child(node_index), self._get_node_right_child(node_index)
262
263    def _get_node_value(self, node_index: int) -> float:
264        return self._tree.value[node_index]
265
266    def _get_node_split_threshold(self, node_index: int) -> float:
267        return self._tree.threshold[node_index]
268
269    def _get_node_split_feature(self, node_index: int) -> int:
270        return self._tree.feature[node_index]
271
272    def _get_node_left_child_subspaces(
273        self, node_index: int, search_spaces: numpy.ndarray
274    ) -> numpy.ndarray:
275        return _get_subspaces(
276            search_spaces,
277            search_spaces_column=1,
278            feature=self._get_node_split_feature(node_index),
279            threshold=self._get_node_split_threshold(node_index),
280        )
281
282    def _get_node_right_child_subspaces(
283        self, node_index: int, search_spaces: numpy.ndarray
284    ) -> numpy.ndarray:
285        return _get_subspaces(
286            search_spaces,
287            search_spaces_column=0,
288            feature=self._get_node_split_feature(node_index),
289            threshold=self._get_node_split_threshold(node_index),
290        )
291
292    def _get_node_children_subspaces(
293        self, node_index: int, search_spaces: numpy.ndarray
294    ) -> Tuple[numpy.ndarray, numpy.ndarray]:
295        return (
296            self._get_node_left_child_subspaces(node_index, search_spaces),
297            self._get_node_right_child_subspaces(node_index, search_spaces),
298        )
299
300
301def _get_cardinality(search_spaces: numpy.ndarray) -> float:
302    return numpy.prod(search_spaces[:, 1] - search_spaces[:, 0])
303
304
305def _get_subspaces(
306    search_spaces: numpy.ndarray, *, search_spaces_column: int, feature: int, threshold: float
307) -> numpy.ndarray:
308    search_spaces_subspace = numpy.copy(search_spaces)
309    search_spaces_subspace[feature, search_spaces_column] = threshold
310    return search_spaces_subspace
311