1import itertools 2from typing import List 3from typing import Set 4from typing import Tuple 5from typing import TYPE_CHECKING 6from typing import Union 7 8import numpy 9 10 11if TYPE_CHECKING: 12 import sklearn.tree 13 14 15class _FanovaTree(object): 16 def __init__(self, tree: "sklearn.tree._tree.Tree", search_spaces: numpy.ndarray) -> None: 17 assert search_spaces.shape[0] == tree.n_features 18 assert search_spaces.shape[1] == 2 19 20 self._tree = tree 21 self._search_spaces = search_spaces 22 23 statistics = self._precompute_statistics() 24 split_midpoints, split_sizes = self._precompute_split_midpoints_and_sizes() 25 subtree_active_features = self._precompute_subtree_active_features() 26 27 self._statistics = statistics 28 self._split_midpoints = split_midpoints 29 self._split_sizes = split_sizes 30 self._subtree_active_features = subtree_active_features 31 self._variance = None # Computed lazily and requires `self._statistics`. 32 33 @property 34 def variance(self) -> float: 35 if self._variance is None: 36 leaf_node_indices = numpy.array( 37 [ 38 node_index 39 for node_index in range(self._n_nodes) 40 if self._is_node_leaf(node_index) 41 ] 42 ) 43 statistics = self._statistics[leaf_node_indices] 44 values = statistics[:, 0] 45 weights = statistics[:, 1] 46 average_values = numpy.average(values, weights=weights) 47 variance = numpy.average((values - average_values) ** 2, weights=weights) 48 49 self._variance = variance 50 51 assert self._variance is not None 52 return self._variance 53 54 def get_marginal_variance(self, features: numpy.ndarray) -> float: 55 assert features.size > 0 56 57 # For each midpoint along the given dimensions, traverse this tree to compute the 58 # marginal predictions. 59 midpoints = [self._split_midpoints[f] for f in features] 60 sizes = [self._split_sizes[f] for f in features] 61 62 product_midpoints = itertools.product(*midpoints) 63 product_sizes = itertools.product(*sizes) 64 65 sample = numpy.full(self._n_features, fill_value=numpy.nan, dtype=numpy.float64) 66 67 values: Union[List[float], numpy.ndarray] = [] 68 weights: Union[List[float], numpy.ndarray] = [] 69 70 for midpoints, sizes in zip(product_midpoints, product_sizes): 71 sample[features] = numpy.array(midpoints) 72 73 value, weight = self._get_marginalized_statistics(sample) 74 weight *= float(numpy.prod(sizes)) 75 76 values = numpy.append(values, value) 77 weights = numpy.append(weights, weight) 78 79 weights = numpy.asarray(weights) 80 values = numpy.asarray(values) 81 average_values = numpy.average(values, weights=weights) 82 variance = numpy.average((values - average_values) ** 2, weights=weights) 83 84 assert variance >= 0.0 85 return variance 86 87 def _get_marginalized_statistics(self, feature_vector: numpy.ndarray) -> Tuple[float, float]: 88 assert feature_vector.size == self._n_features 89 90 marginalized_features = numpy.isnan(feature_vector) 91 active_features = ~marginalized_features 92 93 # Reduce search space cardinalities to 1 for non-active features. 94 search_spaces = self._search_spaces.copy() 95 search_spaces[marginalized_features] = [0.0, 1.0] 96 97 # Start from the root and traverse towards the leafs. 98 active_nodes = [0] 99 active_search_spaces = [search_spaces] 100 101 node_indices = [] 102 active_features_cardinalities = [] 103 104 while len(active_nodes) > 0: 105 node_index = active_nodes.pop() 106 search_spaces = active_search_spaces.pop() 107 108 feature = self._get_node_split_feature(node_index) 109 if feature >= 0: # Not leaf. Avoid unnecessary call to `_is_node_leaf`. 110 # If node splits on an active feature, push the child node that we end up in. 111 response = feature_vector[feature] 112 if not numpy.isnan(response): 113 if response <= self._get_node_split_threshold(node_index): 114 next_node_index = self._get_node_left_child(node_index) 115 next_subspace = self._get_node_left_child_subspaces( 116 node_index, search_spaces 117 ) 118 else: 119 next_node_index = self._get_node_right_child(node_index) 120 next_subspace = self._get_node_right_child_subspaces( 121 node_index, search_spaces 122 ) 123 124 active_nodes.append(next_node_index) 125 active_search_spaces.append(next_subspace) 126 continue 127 128 # If subtree starting from node splits on an active feature, push both child nodes. 129 if (active_features & self._subtree_active_features[node_index]).any(): 130 for child_node_index in self._get_node_children(node_index): 131 active_nodes.append(child_node_index) 132 active_search_spaces.append(search_spaces) 133 continue 134 135 # If node is a leaf or the subtree does not split on any of the active features. 136 node_indices.append(node_index) 137 active_features_cardinalities.append(_get_cardinality(search_spaces)) 138 139 statistics = self._statistics[node_indices] 140 values = statistics[:, 0] 141 weights = statistics[:, 1] 142 weights = weights / active_features_cardinalities 143 144 value = numpy.average(values, weights=weights) 145 weight = weights.sum() 146 147 return value, weight 148 149 def _precompute_statistics(self) -> numpy.ndarray: 150 n_nodes = self._n_nodes 151 152 # Holds for each node, its weighted average value and the sum of weights. 153 statistics = numpy.empty((n_nodes, 2), dtype=numpy.float64) 154 155 subspaces = numpy.array([None for _ in range(n_nodes)]) 156 subspaces[0] = self._search_spaces 157 158 # Compute marginals for leaf nodes. 159 for node_index in range(n_nodes): 160 subspace = subspaces[node_index] 161 162 if self._is_node_leaf(node_index): 163 value = self._get_node_value(node_index) 164 weight = _get_cardinality(subspace) 165 statistics[node_index] = [value, weight] 166 else: 167 for child_node_index, child_subspace in zip( 168 self._get_node_children(node_index), 169 self._get_node_children_subspaces(node_index, subspace), 170 ): 171 assert subspaces[child_node_index] is None 172 subspaces[child_node_index] = child_subspace 173 174 # Compute marginals for internal nodes. 175 for node_index in reversed(range(n_nodes)): 176 if not self._is_node_leaf(node_index): 177 child_values = [] 178 child_weights = [] 179 for child_node_index in self._get_node_children(node_index): 180 child_values.append(statistics[child_node_index, 0]) 181 child_weights.append(statistics[child_node_index, 1]) 182 value = numpy.average(child_values, weights=child_weights) 183 weight = float(numpy.sum(child_weights)) 184 statistics[node_index] = [value, weight] 185 186 return statistics 187 188 def _precompute_split_midpoints_and_sizes( 189 self, 190 ) -> Tuple[List[numpy.ndarray], List[numpy.ndarray]]: 191 midpoints = [] 192 sizes = [] 193 194 search_spaces = self._search_spaces 195 for feature, feature_split_values in enumerate(self._compute_features_split_values()): 196 feature_split_values = numpy.concatenate( 197 ( 198 numpy.atleast_1d(search_spaces[feature, 0]), 199 feature_split_values, 200 numpy.atleast_1d(search_spaces[feature, 1]), 201 ) 202 ) 203 midpoint = 0.5 * (feature_split_values[1:] + feature_split_values[:-1]) 204 size = feature_split_values[1:] - feature_split_values[:-1] 205 206 midpoints.append(midpoint) 207 sizes.append(size) 208 209 return midpoints, sizes 210 211 def _compute_features_split_values(self) -> List[numpy.ndarray]: 212 all_split_values: List[Set[float]] = [set() for _ in range(self._n_features)] 213 214 for node_index in range(self._n_nodes): 215 feature = self._get_node_split_feature(node_index) 216 if feature >= 0: # Not leaf. Avoid unnecessary call to `_is_node_leaf`. 217 threshold = self._get_node_split_threshold(node_index) 218 all_split_values[feature].add(threshold) 219 220 sorted_all_split_values: List[numpy.ndarray] = [] 221 222 for split_values in all_split_values: 223 split_values_array = numpy.array(list(split_values), dtype=numpy.float64) 224 split_values_array.sort() 225 sorted_all_split_values.append(split_values_array) 226 227 return sorted_all_split_values 228 229 def _precompute_subtree_active_features(self) -> numpy.ndarray: 230 subtree_active_features = numpy.full((self._n_nodes, self._n_features), fill_value=False) 231 232 for node_index in reversed(range(self._n_nodes)): 233 feature = self._get_node_split_feature(node_index) 234 if feature >= 0: # Not leaf. Avoid unnecessary call to `_is_node_leaf`. 235 subtree_active_features[node_index, feature] = True 236 for child_node_index in self._get_node_children(node_index): 237 subtree_active_features[node_index] |= subtree_active_features[ 238 child_node_index 239 ] 240 241 return subtree_active_features 242 243 @property 244 def _n_features(self) -> int: 245 return len(self._search_spaces) 246 247 @property 248 def _n_nodes(self) -> int: 249 return self._tree.node_count 250 251 def _is_node_leaf(self, node_index: int) -> bool: 252 return self._tree.feature[node_index] < 0 253 254 def _get_node_left_child(self, node_index: int) -> int: 255 return self._tree.children_left[node_index] 256 257 def _get_node_right_child(self, node_index: int) -> int: 258 return self._tree.children_right[node_index] 259 260 def _get_node_children(self, node_index: int) -> Tuple[int, int]: 261 return self._get_node_left_child(node_index), self._get_node_right_child(node_index) 262 263 def _get_node_value(self, node_index: int) -> float: 264 return self._tree.value[node_index] 265 266 def _get_node_split_threshold(self, node_index: int) -> float: 267 return self._tree.threshold[node_index] 268 269 def _get_node_split_feature(self, node_index: int) -> int: 270 return self._tree.feature[node_index] 271 272 def _get_node_left_child_subspaces( 273 self, node_index: int, search_spaces: numpy.ndarray 274 ) -> numpy.ndarray: 275 return _get_subspaces( 276 search_spaces, 277 search_spaces_column=1, 278 feature=self._get_node_split_feature(node_index), 279 threshold=self._get_node_split_threshold(node_index), 280 ) 281 282 def _get_node_right_child_subspaces( 283 self, node_index: int, search_spaces: numpy.ndarray 284 ) -> numpy.ndarray: 285 return _get_subspaces( 286 search_spaces, 287 search_spaces_column=0, 288 feature=self._get_node_split_feature(node_index), 289 threshold=self._get_node_split_threshold(node_index), 290 ) 291 292 def _get_node_children_subspaces( 293 self, node_index: int, search_spaces: numpy.ndarray 294 ) -> Tuple[numpy.ndarray, numpy.ndarray]: 295 return ( 296 self._get_node_left_child_subspaces(node_index, search_spaces), 297 self._get_node_right_child_subspaces(node_index, search_spaces), 298 ) 299 300 301def _get_cardinality(search_spaces: numpy.ndarray) -> float: 302 return numpy.prod(search_spaces[:, 1] - search_spaces[:, 0]) 303 304 305def _get_subspaces( 306 search_spaces: numpy.ndarray, *, search_spaces_column: int, feature: int, threshold: float 307) -> numpy.ndarray: 308 search_spaces_subspace = numpy.copy(search_spaces) 309 search_spaces_subspace[feature, search_spaces_column] = threshold 310 return search_spaces_subspace 311