1"""Base class for ensemble-based estimators.""" 2 3# Authors: Gilles Louppe 4# License: BSD 3 clause 5 6from abc import ABCMeta, abstractmethod 7import numbers 8from typing import List 9 10import numpy as np 11 12from joblib import effective_n_jobs 13 14from ..base import clone 15from ..base import is_classifier, is_regressor 16from ..base import BaseEstimator 17from ..base import MetaEstimatorMixin 18from ..tree import DecisionTreeRegressor, ExtraTreeRegressor 19from ..utils import Bunch, _print_elapsed_time 20from ..utils import check_random_state 21from ..utils.metaestimators import _BaseComposition 22 23 24def _fit_single_estimator( 25 estimator, X, y, sample_weight=None, message_clsname=None, message=None 26): 27 """Private function used to fit an estimator within a job.""" 28 if sample_weight is not None: 29 try: 30 with _print_elapsed_time(message_clsname, message): 31 estimator.fit(X, y, sample_weight=sample_weight) 32 except TypeError as exc: 33 if "unexpected keyword argument 'sample_weight'" in str(exc): 34 raise TypeError( 35 "Underlying estimator {} does not support sample weights.".format( 36 estimator.__class__.__name__ 37 ) 38 ) from exc 39 raise 40 else: 41 with _print_elapsed_time(message_clsname, message): 42 estimator.fit(X, y) 43 return estimator 44 45 46def _set_random_states(estimator, random_state=None): 47 """Set fixed random_state parameters for an estimator. 48 49 Finds all parameters ending ``random_state`` and sets them to integers 50 derived from ``random_state``. 51 52 Parameters 53 ---------- 54 estimator : estimator supporting get/set_params 55 Estimator with potential randomness managed by random_state 56 parameters. 57 58 random_state : int, RandomState instance or None, default=None 59 Pseudo-random number generator to control the generation of the random 60 integers. Pass an int for reproducible output across multiple function 61 calls. 62 See :term:`Glossary <random_state>`. 63 64 Notes 65 ----- 66 This does not necessarily set *all* ``random_state`` attributes that 67 control an estimator's randomness, only those accessible through 68 ``estimator.get_params()``. ``random_state``s not controlled include 69 those belonging to: 70 71 * cross-validation splitters 72 * ``scipy.stats`` rvs 73 """ 74 random_state = check_random_state(random_state) 75 to_set = {} 76 for key in sorted(estimator.get_params(deep=True)): 77 if key == "random_state" or key.endswith("__random_state"): 78 to_set[key] = random_state.randint(np.iinfo(np.int32).max) 79 80 if to_set: 81 estimator.set_params(**to_set) 82 83 84class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): 85 """Base class for all ensemble classes. 86 87 Warning: This class should not be used directly. Use derived classes 88 instead. 89 90 Parameters 91 ---------- 92 base_estimator : object 93 The base estimator from which the ensemble is built. 94 95 n_estimators : int, default=10 96 The number of estimators in the ensemble. 97 98 estimator_params : list of str, default=tuple() 99 The list of attributes to use as parameters when instantiating a 100 new base estimator. If none are given, default parameters are used. 101 102 Attributes 103 ---------- 104 base_estimator_ : estimator 105 The base estimator from which the ensemble is grown. 106 107 estimators_ : list of estimators 108 The collection of fitted base estimators. 109 """ 110 111 # overwrite _required_parameters from MetaEstimatorMixin 112 _required_parameters: List[str] = [] 113 114 @abstractmethod 115 def __init__(self, base_estimator, *, n_estimators=10, estimator_params=tuple()): 116 # Set parameters 117 self.base_estimator = base_estimator 118 self.n_estimators = n_estimators 119 self.estimator_params = estimator_params 120 121 # Don't instantiate estimators now! Parameters of base_estimator might 122 # still change. Eg., when grid-searching with the nested object syntax. 123 # self.estimators_ needs to be filled by the derived classes in fit. 124 125 def _validate_estimator(self, default=None): 126 """Check the estimator and the n_estimator attribute. 127 128 Sets the base_estimator_` attributes. 129 """ 130 if not isinstance(self.n_estimators, numbers.Integral): 131 raise ValueError( 132 "n_estimators must be an integer, got {0}.".format( 133 type(self.n_estimators) 134 ) 135 ) 136 137 if self.n_estimators <= 0: 138 raise ValueError( 139 "n_estimators must be greater than zero, got {0}.".format( 140 self.n_estimators 141 ) 142 ) 143 144 if self.base_estimator is not None: 145 self.base_estimator_ = self.base_estimator 146 else: 147 self.base_estimator_ = default 148 149 if self.base_estimator_ is None: 150 raise ValueError("base_estimator cannot be None") 151 152 def _make_estimator(self, append=True, random_state=None): 153 """Make and configure a copy of the `base_estimator_` attribute. 154 155 Warning: This method should be used to properly instantiate new 156 sub-estimators. 157 """ 158 estimator = clone(self.base_estimator_) 159 estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) 160 161 # TODO: Remove in v1.2 162 # criterion "mse" and "mae" would cause warnings in every call to 163 # DecisionTreeRegressor.fit(..) 164 if isinstance(estimator, (DecisionTreeRegressor, ExtraTreeRegressor)): 165 if getattr(estimator, "criterion", None) == "mse": 166 estimator.set_params(criterion="squared_error") 167 elif getattr(estimator, "criterion", None) == "mae": 168 estimator.set_params(criterion="absolute_error") 169 170 if random_state is not None: 171 _set_random_states(estimator, random_state) 172 173 if append: 174 self.estimators_.append(estimator) 175 176 return estimator 177 178 def __len__(self): 179 """Return the number of estimators in the ensemble.""" 180 return len(self.estimators_) 181 182 def __getitem__(self, index): 183 """Return the index'th estimator in the ensemble.""" 184 return self.estimators_[index] 185 186 def __iter__(self): 187 """Return iterator over estimators in the ensemble.""" 188 return iter(self.estimators_) 189 190 191def _partition_estimators(n_estimators, n_jobs): 192 """Private function used to partition estimators between jobs.""" 193 # Compute the number of jobs 194 n_jobs = min(effective_n_jobs(n_jobs), n_estimators) 195 196 # Partition estimators between jobs 197 n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int) 198 n_estimators_per_job[: n_estimators % n_jobs] += 1 199 starts = np.cumsum(n_estimators_per_job) 200 201 return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist() 202 203 204class _BaseHeterogeneousEnsemble( 205 MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta 206): 207 """Base class for heterogeneous ensemble of learners. 208 209 Parameters 210 ---------- 211 estimators : list of (str, estimator) tuples 212 The ensemble of estimators to use in the ensemble. Each element of the 213 list is defined as a tuple of string (i.e. name of the estimator) and 214 an estimator instance. An estimator can be set to `'drop'` using 215 `set_params`. 216 217 Attributes 218 ---------- 219 estimators_ : list of estimators 220 The elements of the estimators parameter, having been fitted on the 221 training data. If an estimator has been set to `'drop'`, it will not 222 appear in `estimators_`. 223 """ 224 225 _required_parameters = ["estimators"] 226 227 @property 228 def named_estimators(self): 229 """Dictionary to access any fitted sub-estimators by name. 230 231 Returns 232 ------- 233 :class:`~sklearn.utils.Bunch` 234 """ 235 return Bunch(**dict(self.estimators)) 236 237 @abstractmethod 238 def __init__(self, estimators): 239 self.estimators = estimators 240 241 def _validate_estimators(self): 242 if self.estimators is None or len(self.estimators) == 0: 243 raise ValueError( 244 "Invalid 'estimators' attribute, 'estimators' should be a list" 245 " of (string, estimator) tuples." 246 ) 247 names, estimators = zip(*self.estimators) 248 # defined by MetaEstimatorMixin 249 self._validate_names(names) 250 251 has_estimator = any(est != "drop" for est in estimators) 252 if not has_estimator: 253 raise ValueError( 254 "All estimators are dropped. At least one is required " 255 "to be an estimator." 256 ) 257 258 is_estimator_type = is_classifier if is_classifier(self) else is_regressor 259 260 for est in estimators: 261 if est != "drop" and not is_estimator_type(est): 262 raise ValueError( 263 "The estimator {} should be a {}.".format( 264 est.__class__.__name__, is_estimator_type.__name__[3:] 265 ) 266 ) 267 268 return names, estimators 269 270 def set_params(self, **params): 271 """ 272 Set the parameters of an estimator from the ensemble. 273 274 Valid parameter keys can be listed with `get_params()`. Note that you 275 can directly set the parameters of the estimators contained in 276 `estimators`. 277 278 Parameters 279 ---------- 280 **params : keyword arguments 281 Specific parameters using e.g. 282 `set_params(parameter_name=new_value)`. In addition, to setting the 283 parameters of the estimator, the individual estimator of the 284 estimators can also be set, or can be removed by setting them to 285 'drop'. 286 287 Returns 288 ------- 289 self : object 290 Estimator instance. 291 """ 292 super()._set_params("estimators", **params) 293 return self 294 295 def get_params(self, deep=True): 296 """ 297 Get the parameters of an estimator from the ensemble. 298 299 Returns the parameters given in the constructor as well as the 300 estimators contained within the `estimators` parameter. 301 302 Parameters 303 ---------- 304 deep : bool, default=True 305 Setting it to True gets the various estimators and the parameters 306 of the estimators as well. 307 308 Returns 309 ------- 310 params : dict 311 Parameter and estimator names mapped to their values or parameter 312 names mapped to their values. 313 """ 314 return super()._get_params("estimators", deep=deep) 315