1"""Base class for ensemble-based estimators."""
2
3# Authors: Gilles Louppe
4# License: BSD 3 clause
5
6from abc import ABCMeta, abstractmethod
7import numbers
8from typing import List
9
10import numpy as np
11
12from joblib import effective_n_jobs
13
14from ..base import clone
15from ..base import is_classifier, is_regressor
16from ..base import BaseEstimator
17from ..base import MetaEstimatorMixin
18from ..tree import DecisionTreeRegressor, ExtraTreeRegressor
19from ..utils import Bunch, _print_elapsed_time
20from ..utils import check_random_state
21from ..utils.metaestimators import _BaseComposition
22
23
24def _fit_single_estimator(
25    estimator, X, y, sample_weight=None, message_clsname=None, message=None
26):
27    """Private function used to fit an estimator within a job."""
28    if sample_weight is not None:
29        try:
30            with _print_elapsed_time(message_clsname, message):
31                estimator.fit(X, y, sample_weight=sample_weight)
32        except TypeError as exc:
33            if "unexpected keyword argument 'sample_weight'" in str(exc):
34                raise TypeError(
35                    "Underlying estimator {} does not support sample weights.".format(
36                        estimator.__class__.__name__
37                    )
38                ) from exc
39            raise
40    else:
41        with _print_elapsed_time(message_clsname, message):
42            estimator.fit(X, y)
43    return estimator
44
45
46def _set_random_states(estimator, random_state=None):
47    """Set fixed random_state parameters for an estimator.
48
49    Finds all parameters ending ``random_state`` and sets them to integers
50    derived from ``random_state``.
51
52    Parameters
53    ----------
54    estimator : estimator supporting get/set_params
55        Estimator with potential randomness managed by random_state
56        parameters.
57
58    random_state : int, RandomState instance or None, default=None
59        Pseudo-random number generator to control the generation of the random
60        integers. Pass an int for reproducible output across multiple function
61        calls.
62        See :term:`Glossary <random_state>`.
63
64    Notes
65    -----
66    This does not necessarily set *all* ``random_state`` attributes that
67    control an estimator's randomness, only those accessible through
68    ``estimator.get_params()``.  ``random_state``s not controlled include
69    those belonging to:
70
71        * cross-validation splitters
72        * ``scipy.stats`` rvs
73    """
74    random_state = check_random_state(random_state)
75    to_set = {}
76    for key in sorted(estimator.get_params(deep=True)):
77        if key == "random_state" or key.endswith("__random_state"):
78            to_set[key] = random_state.randint(np.iinfo(np.int32).max)
79
80    if to_set:
81        estimator.set_params(**to_set)
82
83
84class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
85    """Base class for all ensemble classes.
86
87    Warning: This class should not be used directly. Use derived classes
88    instead.
89
90    Parameters
91    ----------
92    base_estimator : object
93        The base estimator from which the ensemble is built.
94
95    n_estimators : int, default=10
96        The number of estimators in the ensemble.
97
98    estimator_params : list of str, default=tuple()
99        The list of attributes to use as parameters when instantiating a
100        new base estimator. If none are given, default parameters are used.
101
102    Attributes
103    ----------
104    base_estimator_ : estimator
105        The base estimator from which the ensemble is grown.
106
107    estimators_ : list of estimators
108        The collection of fitted base estimators.
109    """
110
111    # overwrite _required_parameters from MetaEstimatorMixin
112    _required_parameters: List[str] = []
113
114    @abstractmethod
115    def __init__(self, base_estimator, *, n_estimators=10, estimator_params=tuple()):
116        # Set parameters
117        self.base_estimator = base_estimator
118        self.n_estimators = n_estimators
119        self.estimator_params = estimator_params
120
121        # Don't instantiate estimators now! Parameters of base_estimator might
122        # still change. Eg., when grid-searching with the nested object syntax.
123        # self.estimators_ needs to be filled by the derived classes in fit.
124
125    def _validate_estimator(self, default=None):
126        """Check the estimator and the n_estimator attribute.
127
128        Sets the base_estimator_` attributes.
129        """
130        if not isinstance(self.n_estimators, numbers.Integral):
131            raise ValueError(
132                "n_estimators must be an integer, got {0}.".format(
133                    type(self.n_estimators)
134                )
135            )
136
137        if self.n_estimators <= 0:
138            raise ValueError(
139                "n_estimators must be greater than zero, got {0}.".format(
140                    self.n_estimators
141                )
142            )
143
144        if self.base_estimator is not None:
145            self.base_estimator_ = self.base_estimator
146        else:
147            self.base_estimator_ = default
148
149        if self.base_estimator_ is None:
150            raise ValueError("base_estimator cannot be None")
151
152    def _make_estimator(self, append=True, random_state=None):
153        """Make and configure a copy of the `base_estimator_` attribute.
154
155        Warning: This method should be used to properly instantiate new
156        sub-estimators.
157        """
158        estimator = clone(self.base_estimator_)
159        estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
160
161        # TODO: Remove in v1.2
162        # criterion "mse" and "mae" would cause warnings in every call to
163        # DecisionTreeRegressor.fit(..)
164        if isinstance(estimator, (DecisionTreeRegressor, ExtraTreeRegressor)):
165            if getattr(estimator, "criterion", None) == "mse":
166                estimator.set_params(criterion="squared_error")
167            elif getattr(estimator, "criterion", None) == "mae":
168                estimator.set_params(criterion="absolute_error")
169
170        if random_state is not None:
171            _set_random_states(estimator, random_state)
172
173        if append:
174            self.estimators_.append(estimator)
175
176        return estimator
177
178    def __len__(self):
179        """Return the number of estimators in the ensemble."""
180        return len(self.estimators_)
181
182    def __getitem__(self, index):
183        """Return the index'th estimator in the ensemble."""
184        return self.estimators_[index]
185
186    def __iter__(self):
187        """Return iterator over estimators in the ensemble."""
188        return iter(self.estimators_)
189
190
191def _partition_estimators(n_estimators, n_jobs):
192    """Private function used to partition estimators between jobs."""
193    # Compute the number of jobs
194    n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
195
196    # Partition estimators between jobs
197    n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int)
198    n_estimators_per_job[: n_estimators % n_jobs] += 1
199    starts = np.cumsum(n_estimators_per_job)
200
201    return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
202
203
204class _BaseHeterogeneousEnsemble(
205    MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta
206):
207    """Base class for heterogeneous ensemble of learners.
208
209    Parameters
210    ----------
211    estimators : list of (str, estimator) tuples
212        The ensemble of estimators to use in the ensemble. Each element of the
213        list is defined as a tuple of string (i.e. name of the estimator) and
214        an estimator instance. An estimator can be set to `'drop'` using
215        `set_params`.
216
217    Attributes
218    ----------
219    estimators_ : list of estimators
220        The elements of the estimators parameter, having been fitted on the
221        training data. If an estimator has been set to `'drop'`, it will not
222        appear in `estimators_`.
223    """
224
225    _required_parameters = ["estimators"]
226
227    @property
228    def named_estimators(self):
229        """Dictionary to access any fitted sub-estimators by name.
230
231        Returns
232        -------
233        :class:`~sklearn.utils.Bunch`
234        """
235        return Bunch(**dict(self.estimators))
236
237    @abstractmethod
238    def __init__(self, estimators):
239        self.estimators = estimators
240
241    def _validate_estimators(self):
242        if self.estimators is None or len(self.estimators) == 0:
243            raise ValueError(
244                "Invalid 'estimators' attribute, 'estimators' should be a list"
245                " of (string, estimator) tuples."
246            )
247        names, estimators = zip(*self.estimators)
248        # defined by MetaEstimatorMixin
249        self._validate_names(names)
250
251        has_estimator = any(est != "drop" for est in estimators)
252        if not has_estimator:
253            raise ValueError(
254                "All estimators are dropped. At least one is required "
255                "to be an estimator."
256            )
257
258        is_estimator_type = is_classifier if is_classifier(self) else is_regressor
259
260        for est in estimators:
261            if est != "drop" and not is_estimator_type(est):
262                raise ValueError(
263                    "The estimator {} should be a {}.".format(
264                        est.__class__.__name__, is_estimator_type.__name__[3:]
265                    )
266                )
267
268        return names, estimators
269
270    def set_params(self, **params):
271        """
272        Set the parameters of an estimator from the ensemble.
273
274        Valid parameter keys can be listed with `get_params()`. Note that you
275        can directly set the parameters of the estimators contained in
276        `estimators`.
277
278        Parameters
279        ----------
280        **params : keyword arguments
281            Specific parameters using e.g.
282            `set_params(parameter_name=new_value)`. In addition, to setting the
283            parameters of the estimator, the individual estimator of the
284            estimators can also be set, or can be removed by setting them to
285            'drop'.
286
287        Returns
288        -------
289        self : object
290            Estimator instance.
291        """
292        super()._set_params("estimators", **params)
293        return self
294
295    def get_params(self, deep=True):
296        """
297        Get the parameters of an estimator from the ensemble.
298
299        Returns the parameters given in the constructor as well as the
300        estimators contained within the `estimators` parameter.
301
302        Parameters
303        ----------
304        deep : bool, default=True
305            Setting it to True gets the various estimators and the parameters
306            of the estimators as well.
307
308        Returns
309        -------
310        params : dict
311            Parameter and estimator names mapped to their values or parameter
312            names mapped to their values.
313        """
314        return super()._get_params("estimators", deep=deep)
315