1"""Gradient Boosted Regression Trees.
2
3This module contains methods for fitting gradient boosted regression trees for
4both classification and regression.
5
6The module structure is the following:
7
8- The ``BaseGradientBoosting`` base class implements a common ``fit`` method
9  for all the estimators in the module. Regression and classification
10  only differ in the concrete ``LossFunction`` used.
11
12- ``GradientBoostingClassifier`` implements gradient boosting for
13  classification problems.
14
15- ``GradientBoostingRegressor`` implements gradient boosting for
16  regression problems.
17"""
18
19# Authors: Peter Prettenhofer, Scott White, Gilles Louppe, Emanuele Olivetti,
20#          Arnaud Joly, Jacob Schreiber
21# License: BSD 3 clause
22
23from abc import ABCMeta
24from abc import abstractmethod
25import warnings
26
27from ._base import BaseEnsemble
28from ..base import ClassifierMixin
29from ..base import RegressorMixin
30from ..base import BaseEstimator
31from ..base import is_classifier
32from ..utils import deprecated
33
34from ._gradient_boosting import predict_stages
35from ._gradient_boosting import predict_stage
36from ._gradient_boosting import _random_sample_mask
37
38import numbers
39import numpy as np
40
41from scipy.sparse import csc_matrix
42from scipy.sparse import csr_matrix
43from scipy.sparse import issparse
44
45from time import time
46from ..model_selection import train_test_split
47from ..tree import DecisionTreeRegressor
48from ..tree._tree import DTYPE, DOUBLE
49from . import _gb_losses
50
51from ..utils import check_random_state
52from ..utils import check_array
53from ..utils import column_or_1d
54from ..utils.validation import check_is_fitted, _check_sample_weight
55from ..utils.multiclass import check_classification_targets
56from ..exceptions import NotFittedError
57
58
59class VerboseReporter:
60    """Reports verbose output to stdout.
61
62    Parameters
63    ----------
64    verbose : int
65        Verbosity level. If ``verbose==1`` output is printed once in a while
66        (when iteration mod verbose_mod is zero).; if larger than 1 then output
67        is printed for each update.
68    """
69
70    def __init__(self, verbose):
71        self.verbose = verbose
72
73    def init(self, est, begin_at_stage=0):
74        """Initialize reporter
75
76        Parameters
77        ----------
78        est : Estimator
79            The estimator
80
81        begin_at_stage : int, default=0
82            stage at which to begin reporting
83        """
84        # header fields and line format str
85        header_fields = ["Iter", "Train Loss"]
86        verbose_fmt = ["{iter:>10d}", "{train_score:>16.4f}"]
87        # do oob?
88        if est.subsample < 1:
89            header_fields.append("OOB Improve")
90            verbose_fmt.append("{oob_impr:>16.4f}")
91        header_fields.append("Remaining Time")
92        verbose_fmt.append("{remaining_time:>16s}")
93
94        # print the header line
95        print(("%10s " + "%16s " * (len(header_fields) - 1)) % tuple(header_fields))
96
97        self.verbose_fmt = " ".join(verbose_fmt)
98        # plot verbose info each time i % verbose_mod == 0
99        self.verbose_mod = 1
100        self.start_time = time()
101        self.begin_at_stage = begin_at_stage
102
103    def update(self, j, est):
104        """Update reporter with new iteration.
105
106        Parameters
107        ----------
108        j : int
109            The new iteration.
110        est : Estimator
111            The estimator.
112        """
113        do_oob = est.subsample < 1
114        # we need to take into account if we fit additional estimators.
115        i = j - self.begin_at_stage  # iteration relative to the start iter
116        if (i + 1) % self.verbose_mod == 0:
117            oob_impr = est.oob_improvement_[j] if do_oob else 0
118            remaining_time = (
119                (est.n_estimators - (j + 1)) * (time() - self.start_time) / float(i + 1)
120            )
121            if remaining_time > 60:
122                remaining_time = "{0:.2f}m".format(remaining_time / 60.0)
123            else:
124                remaining_time = "{0:.2f}s".format(remaining_time)
125            print(
126                self.verbose_fmt.format(
127                    iter=j + 1,
128                    train_score=est.train_score_[j],
129                    oob_impr=oob_impr,
130                    remaining_time=remaining_time,
131                )
132            )
133            if self.verbose == 1 and ((i + 1) // (self.verbose_mod * 10) > 0):
134                # adjust verbose frequency (powers of 10)
135                self.verbose_mod *= 10
136
137
138class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta):
139    """Abstract base class for Gradient Boosting."""
140
141    @abstractmethod
142    def __init__(
143        self,
144        *,
145        loss,
146        learning_rate,
147        n_estimators,
148        criterion,
149        min_samples_split,
150        min_samples_leaf,
151        min_weight_fraction_leaf,
152        max_depth,
153        min_impurity_decrease,
154        init,
155        subsample,
156        max_features,
157        ccp_alpha,
158        random_state,
159        alpha=0.9,
160        verbose=0,
161        max_leaf_nodes=None,
162        warm_start=False,
163        validation_fraction=0.1,
164        n_iter_no_change=None,
165        tol=1e-4,
166    ):
167
168        self.n_estimators = n_estimators
169        self.learning_rate = learning_rate
170        self.loss = loss
171        self.criterion = criterion
172        self.min_samples_split = min_samples_split
173        self.min_samples_leaf = min_samples_leaf
174        self.min_weight_fraction_leaf = min_weight_fraction_leaf
175        self.subsample = subsample
176        self.max_features = max_features
177        self.max_depth = max_depth
178        self.min_impurity_decrease = min_impurity_decrease
179        self.ccp_alpha = ccp_alpha
180        self.init = init
181        self.random_state = random_state
182        self.alpha = alpha
183        self.verbose = verbose
184        self.max_leaf_nodes = max_leaf_nodes
185        self.warm_start = warm_start
186        self.validation_fraction = validation_fraction
187        self.n_iter_no_change = n_iter_no_change
188        self.tol = tol
189
190    @abstractmethod
191    def _validate_y(self, y, sample_weight=None):
192        """Called by fit to validate y."""
193
194    def _fit_stage(
195        self,
196        i,
197        X,
198        y,
199        raw_predictions,
200        sample_weight,
201        sample_mask,
202        random_state,
203        X_csc=None,
204        X_csr=None,
205    ):
206        """Fit another stage of ``_n_classes`` trees to the boosting model."""
207
208        assert sample_mask.dtype == bool
209        loss = self.loss_
210        original_y = y
211
212        # Need to pass a copy of raw_predictions to negative_gradient()
213        # because raw_predictions is partially updated at the end of the loop
214        # in update_terminal_regions(), and gradients need to be evaluated at
215        # iteration i - 1.
216        raw_predictions_copy = raw_predictions.copy()
217
218        for k in range(loss.K):
219            if loss.is_multi_class:
220                y = np.array(original_y == k, dtype=np.float64)
221
222            residual = loss.negative_gradient(
223                y, raw_predictions_copy, k=k, sample_weight=sample_weight
224            )
225
226            # induce regression tree on residuals
227            tree = DecisionTreeRegressor(
228                criterion=self.criterion,
229                splitter="best",
230                max_depth=self.max_depth,
231                min_samples_split=self.min_samples_split,
232                min_samples_leaf=self.min_samples_leaf,
233                min_weight_fraction_leaf=self.min_weight_fraction_leaf,
234                min_impurity_decrease=self.min_impurity_decrease,
235                max_features=self.max_features,
236                max_leaf_nodes=self.max_leaf_nodes,
237                random_state=random_state,
238                ccp_alpha=self.ccp_alpha,
239            )
240
241            if self.subsample < 1.0:
242                # no inplace multiplication!
243                sample_weight = sample_weight * sample_mask.astype(np.float64)
244
245            X = X_csr if X_csr is not None else X
246            tree.fit(X, residual, sample_weight=sample_weight, check_input=False)
247
248            # update tree leaves
249            loss.update_terminal_regions(
250                tree.tree_,
251                X,
252                y,
253                residual,
254                raw_predictions,
255                sample_weight,
256                sample_mask,
257                learning_rate=self.learning_rate,
258                k=k,
259            )
260
261            # add tree to ensemble
262            self.estimators_[i, k] = tree
263
264        return raw_predictions
265
266    def _check_params(self):
267        """Check validity of parameters and raise ValueError if not valid."""
268        if self.n_estimators <= 0:
269            raise ValueError(
270                "n_estimators must be greater than 0 but was %r" % self.n_estimators
271            )
272
273        if self.learning_rate <= 0.0:
274            raise ValueError(
275                "learning_rate must be greater than 0 but was %r" % self.learning_rate
276            )
277
278        if (
279            self.loss not in self._SUPPORTED_LOSS
280            or self.loss not in _gb_losses.LOSS_FUNCTIONS
281        ):
282            raise ValueError("Loss '{0:s}' not supported. ".format(self.loss))
283
284        # TODO: Remove in v1.2
285        if self.loss == "ls":
286            warnings.warn(
287                "The loss 'ls' was deprecated in v1.0 and "
288                "will be removed in version 1.2. Use 'squared_error'"
289                " which is equivalent.",
290                FutureWarning,
291            )
292        elif self.loss == "lad":
293            warnings.warn(
294                "The loss 'lad' was deprecated in v1.0 and "
295                "will be removed in version 1.2. Use "
296                "'absolute_error' which is equivalent.",
297                FutureWarning,
298            )
299
300        if self.loss == "deviance":
301            loss_class = (
302                _gb_losses.MultinomialDeviance
303                if len(self.classes_) > 2
304                else _gb_losses.BinomialDeviance
305            )
306        else:
307            loss_class = _gb_losses.LOSS_FUNCTIONS[self.loss]
308
309        if is_classifier(self):
310            self.loss_ = loss_class(self.n_classes_)
311        elif self.loss in ("huber", "quantile"):
312            self.loss_ = loss_class(self.alpha)
313        else:
314            self.loss_ = loss_class()
315
316        if not (0.0 < self.subsample <= 1.0):
317            raise ValueError("subsample must be in (0,1] but was %r" % self.subsample)
318
319        if self.init is not None:
320            # init must be an estimator or 'zero'
321            if isinstance(self.init, BaseEstimator):
322                self.loss_.check_init_estimator(self.init)
323            elif not (isinstance(self.init, str) and self.init == "zero"):
324                raise ValueError(
325                    "The init parameter must be an estimator or 'zero'. "
326                    "Got init={}".format(self.init)
327                )
328
329        if not (0.0 < self.alpha < 1.0):
330            raise ValueError("alpha must be in (0.0, 1.0) but was %r" % self.alpha)
331
332        if isinstance(self.max_features, str):
333            if self.max_features == "auto":
334                if is_classifier(self):
335                    max_features = max(1, int(np.sqrt(self.n_features_in_)))
336                else:
337                    max_features = self.n_features_in_
338            elif self.max_features == "sqrt":
339                max_features = max(1, int(np.sqrt(self.n_features_in_)))
340            elif self.max_features == "log2":
341                max_features = max(1, int(np.log2(self.n_features_in_)))
342            else:
343                raise ValueError(
344                    "Invalid value for max_features: %r. "
345                    "Allowed string values are 'auto', 'sqrt' "
346                    "or 'log2'."
347                    % self.max_features
348                )
349        elif self.max_features is None:
350            max_features = self.n_features_in_
351        elif isinstance(self.max_features, numbers.Integral):
352            max_features = self.max_features
353        else:  # float
354            if 0.0 < self.max_features <= 1.0:
355                max_features = max(int(self.max_features * self.n_features_in_), 1)
356            else:
357                raise ValueError("max_features must be in (0, n_features]")
358
359        self.max_features_ = max_features
360
361        if not isinstance(self.n_iter_no_change, (numbers.Integral, type(None))):
362            raise ValueError(
363                "n_iter_no_change should either be None or an integer. %r was passed"
364                % self.n_iter_no_change
365            )
366
367    def _init_state(self):
368        """Initialize model state and allocate model state data structures."""
369
370        self.init_ = self.init
371        if self.init_ is None:
372            self.init_ = self.loss_.init_estimator()
373
374        self.estimators_ = np.empty((self.n_estimators, self.loss_.K), dtype=object)
375        self.train_score_ = np.zeros((self.n_estimators,), dtype=np.float64)
376        # do oob?
377        if self.subsample < 1.0:
378            self.oob_improvement_ = np.zeros((self.n_estimators), dtype=np.float64)
379
380    def _clear_state(self):
381        """Clear the state of the gradient boosting model."""
382        if hasattr(self, "estimators_"):
383            self.estimators_ = np.empty((0, 0), dtype=object)
384        if hasattr(self, "train_score_"):
385            del self.train_score_
386        if hasattr(self, "oob_improvement_"):
387            del self.oob_improvement_
388        if hasattr(self, "init_"):
389            del self.init_
390        if hasattr(self, "_rng"):
391            del self._rng
392
393    def _resize_state(self):
394        """Add additional ``n_estimators`` entries to all attributes."""
395        # self.n_estimators is the number of additional est to fit
396        total_n_estimators = self.n_estimators
397        if total_n_estimators < self.estimators_.shape[0]:
398            raise ValueError(
399                "resize with smaller n_estimators %d < %d"
400                % (total_n_estimators, self.estimators_[0])
401            )
402
403        self.estimators_ = np.resize(
404            self.estimators_, (total_n_estimators, self.loss_.K)
405        )
406        self.train_score_ = np.resize(self.train_score_, total_n_estimators)
407        if self.subsample < 1 or hasattr(self, "oob_improvement_"):
408            # if do oob resize arrays or create new if not available
409            if hasattr(self, "oob_improvement_"):
410                self.oob_improvement_ = np.resize(
411                    self.oob_improvement_, total_n_estimators
412                )
413            else:
414                self.oob_improvement_ = np.zeros(
415                    (total_n_estimators,), dtype=np.float64
416                )
417
418    def _is_initialized(self):
419        return len(getattr(self, "estimators_", [])) > 0
420
421    def _check_initialized(self):
422        """Check that the estimator is initialized, raising an error if not."""
423        check_is_fitted(self)
424
425    @abstractmethod
426    def _warn_mae_for_criterion(self):
427        pass
428
429    def fit(self, X, y, sample_weight=None, monitor=None):
430        """Fit the gradient boosting model.
431
432        Parameters
433        ----------
434        X : {array-like, sparse matrix} of shape (n_samples, n_features)
435            The input samples. Internally, it will be converted to
436            ``dtype=np.float32`` and if a sparse matrix is provided
437            to a sparse ``csr_matrix``.
438
439        y : array-like of shape (n_samples,)
440            Target values (strings or integers in classification, real numbers
441            in regression)
442            For classification, labels must correspond to classes.
443
444        sample_weight : array-like of shape (n_samples,), default=None
445            Sample weights. If None, then samples are equally weighted. Splits
446            that would create child nodes with net zero or negative weight are
447            ignored while searching for a split in each node. In the case of
448            classification, splits are also ignored if they would result in any
449            single class carrying a negative weight in either child node.
450
451        monitor : callable, default=None
452            The monitor is called after each iteration with the current
453            iteration, a reference to the estimator and the local variables of
454            ``_fit_stages`` as keyword arguments ``callable(i, self,
455            locals())``. If the callable returns ``True`` the fitting procedure
456            is stopped. The monitor can be used for various things such as
457            computing held-out estimates, early stopping, model introspect, and
458            snapshoting.
459
460        Returns
461        -------
462        self : object
463            Fitted estimator.
464        """
465        if self.criterion in ("absolute_error", "mae"):
466            # TODO: This should raise an error from 1.1
467            self._warn_mae_for_criterion()
468
469        if self.criterion == "mse":
470            # TODO: Remove in v1.2. By then it should raise an error.
471            warnings.warn(
472                "Criterion 'mse' was deprecated in v1.0 and will be "
473                "removed in version 1.2. Use `criterion='squared_error'` "
474                "which is equivalent.",
475                FutureWarning,
476            )
477
478        # if not warmstart - clear the estimator state
479        if not self.warm_start:
480            self._clear_state()
481
482        # Check input
483        # Since check_array converts both X and y to the same dtype, but the
484        # trees use different types for X and y, checking them separately.
485
486        X, y = self._validate_data(
487            X, y, accept_sparse=["csr", "csc", "coo"], dtype=DTYPE, multi_output=True
488        )
489
490        sample_weight_is_none = sample_weight is None
491
492        sample_weight = _check_sample_weight(sample_weight, X)
493
494        y = column_or_1d(y, warn=True)
495
496        if is_classifier(self):
497            y = self._validate_y(y, sample_weight)
498        else:
499            y = self._validate_y(y)
500
501        if self.n_iter_no_change is not None:
502            stratify = y if is_classifier(self) else None
503            X, X_val, y, y_val, sample_weight, sample_weight_val = train_test_split(
504                X,
505                y,
506                sample_weight,
507                random_state=self.random_state,
508                test_size=self.validation_fraction,
509                stratify=stratify,
510            )
511            if is_classifier(self):
512                if self._n_classes != np.unique(y).shape[0]:
513                    # We choose to error here. The problem is that the init
514                    # estimator would be trained on y, which has some missing
515                    # classes now, so its predictions would not have the
516                    # correct shape.
517                    raise ValueError(
518                        "The training data after the early stopping split "
519                        "is missing some classes. Try using another random "
520                        "seed."
521                    )
522        else:
523            X_val = y_val = sample_weight_val = None
524
525        self._check_params()
526
527        if not self._is_initialized():
528            # init state
529            self._init_state()
530
531            # fit initial model and initialize raw predictions
532            if self.init_ == "zero":
533                raw_predictions = np.zeros(
534                    shape=(X.shape[0], self.loss_.K), dtype=np.float64
535                )
536            else:
537                # XXX clean this once we have a support_sample_weight tag
538                if sample_weight_is_none:
539                    self.init_.fit(X, y)
540                else:
541                    msg = (
542                        "The initial estimator {} does not support sample "
543                        "weights.".format(self.init_.__class__.__name__)
544                    )
545                    try:
546                        self.init_.fit(X, y, sample_weight=sample_weight)
547                    except TypeError as e:
548                        # regular estimator without SW support
549                        raise ValueError(msg) from e
550                    except ValueError as e:
551                        if (
552                            "pass parameters to specific steps of "
553                            "your pipeline using the "
554                            "stepname__parameter"
555                            in str(e)
556                        ):  # pipeline
557                            raise ValueError(msg) from e
558                        else:  # regular estimator whose input checking failed
559                            raise
560
561                raw_predictions = self.loss_.get_init_raw_predictions(X, self.init_)
562
563            begin_at_stage = 0
564
565            # The rng state must be preserved if warm_start is True
566            self._rng = check_random_state(self.random_state)
567
568        else:
569            # add more estimators to fitted model
570            # invariant: warm_start = True
571            if self.n_estimators < self.estimators_.shape[0]:
572                raise ValueError(
573                    "n_estimators=%d must be larger or equal to "
574                    "estimators_.shape[0]=%d when "
575                    "warm_start==True" % (self.n_estimators, self.estimators_.shape[0])
576                )
577            begin_at_stage = self.estimators_.shape[0]
578            # The requirements of _decision_function (called in two lines
579            # below) are more constrained than fit. It accepts only CSR
580            # matrices.
581            X = check_array(X, dtype=DTYPE, order="C", accept_sparse="csr")
582            raw_predictions = self._raw_predict(X)
583            self._resize_state()
584
585        # fit the boosting stages
586        n_stages = self._fit_stages(
587            X,
588            y,
589            raw_predictions,
590            sample_weight,
591            self._rng,
592            X_val,
593            y_val,
594            sample_weight_val,
595            begin_at_stage,
596            monitor,
597        )
598
599        # change shape of arrays after fit (early-stopping or additional ests)
600        if n_stages != self.estimators_.shape[0]:
601            self.estimators_ = self.estimators_[:n_stages]
602            self.train_score_ = self.train_score_[:n_stages]
603            if hasattr(self, "oob_improvement_"):
604                self.oob_improvement_ = self.oob_improvement_[:n_stages]
605
606        self.n_estimators_ = n_stages
607        return self
608
609    def _fit_stages(
610        self,
611        X,
612        y,
613        raw_predictions,
614        sample_weight,
615        random_state,
616        X_val,
617        y_val,
618        sample_weight_val,
619        begin_at_stage=0,
620        monitor=None,
621    ):
622        """Iteratively fits the stages.
623
624        For each stage it computes the progress (OOB, train score)
625        and delegates to ``_fit_stage``.
626        Returns the number of stages fit; might differ from ``n_estimators``
627        due to early stopping.
628        """
629        n_samples = X.shape[0]
630        do_oob = self.subsample < 1.0
631        sample_mask = np.ones((n_samples,), dtype=bool)
632        n_inbag = max(1, int(self.subsample * n_samples))
633        loss_ = self.loss_
634
635        if self.verbose:
636            verbose_reporter = VerboseReporter(verbose=self.verbose)
637            verbose_reporter.init(self, begin_at_stage)
638
639        X_csc = csc_matrix(X) if issparse(X) else None
640        X_csr = csr_matrix(X) if issparse(X) else None
641
642        if self.n_iter_no_change is not None:
643            loss_history = np.full(self.n_iter_no_change, np.inf)
644            # We create a generator to get the predictions for X_val after
645            # the addition of each successive stage
646            y_val_pred_iter = self._staged_raw_predict(X_val, check_input=False)
647
648        # perform boosting iterations
649        i = begin_at_stage
650        for i in range(begin_at_stage, self.n_estimators):
651
652            # subsampling
653            if do_oob:
654                sample_mask = _random_sample_mask(n_samples, n_inbag, random_state)
655                # OOB score before adding this stage
656                old_oob_score = loss_(
657                    y[~sample_mask],
658                    raw_predictions[~sample_mask],
659                    sample_weight[~sample_mask],
660                )
661
662            # fit next stage of trees
663            raw_predictions = self._fit_stage(
664                i,
665                X,
666                y,
667                raw_predictions,
668                sample_weight,
669                sample_mask,
670                random_state,
671                X_csc,
672                X_csr,
673            )
674
675            # track deviance (= loss)
676            if do_oob:
677                self.train_score_[i] = loss_(
678                    y[sample_mask],
679                    raw_predictions[sample_mask],
680                    sample_weight[sample_mask],
681                )
682                self.oob_improvement_[i] = old_oob_score - loss_(
683                    y[~sample_mask],
684                    raw_predictions[~sample_mask],
685                    sample_weight[~sample_mask],
686                )
687            else:
688                # no need to fancy index w/ no subsampling
689                self.train_score_[i] = loss_(y, raw_predictions, sample_weight)
690
691            if self.verbose > 0:
692                verbose_reporter.update(i, self)
693
694            if monitor is not None:
695                early_stopping = monitor(i, self, locals())
696                if early_stopping:
697                    break
698
699            # We also provide an early stopping based on the score from
700            # validation set (X_val, y_val), if n_iter_no_change is set
701            if self.n_iter_no_change is not None:
702                # By calling next(y_val_pred_iter), we get the predictions
703                # for X_val after the addition of the current stage
704                validation_loss = loss_(y_val, next(y_val_pred_iter), sample_weight_val)
705
706                # Require validation_score to be better (less) than at least
707                # one of the last n_iter_no_change evaluations
708                if np.any(validation_loss + self.tol < loss_history):
709                    loss_history[i % len(loss_history)] = validation_loss
710                else:
711                    break
712
713        return i + 1
714
715    def _make_estimator(self, append=True):
716        # we don't need _make_estimator
717        raise NotImplementedError()
718
719    def _raw_predict_init(self, X):
720        """Check input and compute raw predictions of the init estimator."""
721        self._check_initialized()
722        X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True)
723        if self.init_ == "zero":
724            raw_predictions = np.zeros(
725                shape=(X.shape[0], self.loss_.K), dtype=np.float64
726            )
727        else:
728            raw_predictions = self.loss_.get_init_raw_predictions(X, self.init_).astype(
729                np.float64
730            )
731        return raw_predictions
732
733    def _raw_predict(self, X):
734        """Return the sum of the trees raw predictions (+ init estimator)."""
735        raw_predictions = self._raw_predict_init(X)
736        predict_stages(self.estimators_, X, self.learning_rate, raw_predictions)
737        return raw_predictions
738
739    def _staged_raw_predict(self, X, check_input=True):
740        """Compute raw predictions of ``X`` for each iteration.
741
742        This method allows monitoring (i.e. determine error on testing set)
743        after each stage.
744
745        Parameters
746        ----------
747        X : {array-like, sparse matrix} of shape (n_samples, n_features)
748            The input samples. Internally, it will be converted to
749            ``dtype=np.float32`` and if a sparse matrix is provided
750            to a sparse ``csr_matrix``.
751
752        check_input : bool, default=True
753            If False, the input arrays X will not be checked.
754
755        Returns
756        -------
757        raw_predictions : generator of ndarray of shape (n_samples, k)
758            The raw predictions of the input samples. The order of the
759            classes corresponds to that in the attribute :term:`classes_`.
760            Regression and binary classification are special cases with
761            ``k == 1``, otherwise ``k==n_classes``.
762        """
763        if check_input:
764            X = self._validate_data(
765                X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
766            )
767        raw_predictions = self._raw_predict_init(X)
768        for i in range(self.estimators_.shape[0]):
769            predict_stage(self.estimators_, i, X, self.learning_rate, raw_predictions)
770            yield raw_predictions.copy()
771
772    @property
773    def feature_importances_(self):
774        """The impurity-based feature importances.
775
776        The higher, the more important the feature.
777        The importance of a feature is computed as the (normalized)
778        total reduction of the criterion brought by that feature.  It is also
779        known as the Gini importance.
780
781        Warning: impurity-based feature importances can be misleading for
782        high cardinality features (many unique values). See
783        :func:`sklearn.inspection.permutation_importance` as an alternative.
784
785        Returns
786        -------
787        feature_importances_ : ndarray of shape (n_features,)
788            The values of this array sum to 1, unless all trees are single node
789            trees consisting of only the root node, in which case it will be an
790            array of zeros.
791        """
792        self._check_initialized()
793
794        relevant_trees = [
795            tree
796            for stage in self.estimators_
797            for tree in stage
798            if tree.tree_.node_count > 1
799        ]
800        if not relevant_trees:
801            # degenerate case where all trees have only one node
802            return np.zeros(shape=self.n_features_in_, dtype=np.float64)
803
804        relevant_feature_importances = [
805            tree.tree_.compute_feature_importances(normalize=False)
806            for tree in relevant_trees
807        ]
808        avg_feature_importances = np.mean(
809            relevant_feature_importances, axis=0, dtype=np.float64
810        )
811        return avg_feature_importances / np.sum(avg_feature_importances)
812
813    def _compute_partial_dependence_recursion(self, grid, target_features):
814        """Fast partial dependence computation.
815
816        Parameters
817        ----------
818        grid : ndarray of shape (n_samples, n_target_features)
819            The grid points on which the partial dependence should be
820            evaluated.
821        target_features : ndarray of shape (n_target_features,)
822            The set of target features for which the partial dependence
823            should be evaluated.
824
825        Returns
826        -------
827        averaged_predictions : ndarray of shape \
828                (n_trees_per_iteration, n_samples)
829            The value of the partial dependence function on each grid point.
830        """
831        if self.init is not None:
832            warnings.warn(
833                "Using recursion method with a non-constant init predictor "
834                "will lead to incorrect partial dependence values. "
835                "Got init=%s."
836                % self.init,
837                UserWarning,
838            )
839        grid = np.asarray(grid, dtype=DTYPE, order="C")
840        n_estimators, n_trees_per_stage = self.estimators_.shape
841        averaged_predictions = np.zeros(
842            (n_trees_per_stage, grid.shape[0]), dtype=np.float64, order="C"
843        )
844        for stage in range(n_estimators):
845            for k in range(n_trees_per_stage):
846                tree = self.estimators_[stage, k].tree_
847                tree.compute_partial_dependence(
848                    grid, target_features, averaged_predictions[k]
849                )
850        averaged_predictions *= self.learning_rate
851
852        return averaged_predictions
853
854    def apply(self, X):
855        """Apply trees in the ensemble to X, return leaf indices.
856
857        .. versionadded:: 0.17
858
859        Parameters
860        ----------
861        X : {array-like, sparse matrix} of shape (n_samples, n_features)
862            The input samples. Internally, its dtype will be converted to
863            ``dtype=np.float32``. If a sparse matrix is provided, it will
864            be converted to a sparse ``csr_matrix``.
865
866        Returns
867        -------
868        X_leaves : array-like of shape (n_samples, n_estimators, n_classes)
869            For each datapoint x in X and for each tree in the ensemble,
870            return the index of the leaf x ends up in each estimator.
871            In the case of binary classification n_classes is 1.
872        """
873
874        self._check_initialized()
875        X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True)
876
877        # n_classes will be equal to 1 in the binary classification or the
878        # regression case.
879        n_estimators, n_classes = self.estimators_.shape
880        leaves = np.zeros((X.shape[0], n_estimators, n_classes))
881
882        for i in range(n_estimators):
883            for j in range(n_classes):
884                estimator = self.estimators_[i, j]
885                leaves[:, i, j] = estimator.apply(X, check_input=False)
886
887        return leaves
888
889    # TODO: Remove in 1.2
890    # mypy error: Decorated property not supported
891    @deprecated(  # type: ignore
892        "Attribute `n_features_` was deprecated in version 1.0 and will be "
893        "removed in 1.2. Use `n_features_in_` instead."
894    )
895    @property
896    def n_features_(self):
897        return self.n_features_in_
898
899
900class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting):
901    """Gradient Boosting for classification.
902
903    GB builds an additive model in a
904    forward stage-wise fashion; it allows for the optimization of
905    arbitrary differentiable loss functions. In each stage ``n_classes_``
906    regression trees are fit on the negative gradient of the
907    binomial or multinomial deviance loss function. Binary classification
908    is a special case where only a single regression tree is induced.
909
910    Read more in the :ref:`User Guide <gradient_boosting>`.
911
912    Parameters
913    ----------
914    loss : {'deviance', 'exponential'}, default='deviance'
915        The loss function to be optimized. 'deviance' refers to
916        deviance (= logistic regression) for classification
917        with probabilistic outputs. For loss 'exponential' gradient
918        boosting recovers the AdaBoost algorithm.
919
920    learning_rate : float, default=0.1
921        Learning rate shrinks the contribution of each tree by `learning_rate`.
922        There is a trade-off between learning_rate and n_estimators.
923
924    n_estimators : int, default=100
925        The number of boosting stages to perform. Gradient boosting
926        is fairly robust to over-fitting so a large number usually
927        results in better performance.
928
929    subsample : float, default=1.0
930        The fraction of samples to be used for fitting the individual base
931        learners. If smaller than 1.0 this results in Stochastic Gradient
932        Boosting. `subsample` interacts with the parameter `n_estimators`.
933        Choosing `subsample < 1.0` leads to a reduction of variance
934        and an increase in bias.
935
936    criterion : {'friedman_mse', 'squared_error', 'mse', 'mae'}, \
937            default='friedman_mse'
938        The function to measure the quality of a split. Supported criteria
939        are 'friedman_mse' for the mean squared error with improvement
940        score by Friedman, 'squared_error' for mean squared error, and 'mae'
941        for the mean absolute error. The default value of 'friedman_mse' is
942        generally the best as it can provide a better approximation in some
943        cases.
944
945        .. versionadded:: 0.18
946
947        .. deprecated:: 0.24
948            `criterion='mae'` is deprecated and will be removed in version
949            1.1 (renaming of 0.26). Use `criterion='friedman_mse'` or
950            `'squared_error'` instead, as trees should use a squared error
951            criterion in Gradient Boosting.
952
953        .. deprecated:: 1.0
954            Criterion 'mse' was deprecated in v1.0 and will be removed in
955            version 1.2. Use `criterion='squared_error'` which is equivalent.
956
957    min_samples_split : int or float, default=2
958        The minimum number of samples required to split an internal node:
959
960        - If int, then consider `min_samples_split` as the minimum number.
961        - If float, then `min_samples_split` is a fraction and
962          `ceil(min_samples_split * n_samples)` are the minimum
963          number of samples for each split.
964
965        .. versionchanged:: 0.18
966           Added float values for fractions.
967
968    min_samples_leaf : int or float, default=1
969        The minimum number of samples required to be at a leaf node.
970        A split point at any depth will only be considered if it leaves at
971        least ``min_samples_leaf`` training samples in each of the left and
972        right branches.  This may have the effect of smoothing the model,
973        especially in regression.
974
975        - If int, then consider `min_samples_leaf` as the minimum number.
976        - If float, then `min_samples_leaf` is a fraction and
977          `ceil(min_samples_leaf * n_samples)` are the minimum
978          number of samples for each node.
979
980        .. versionchanged:: 0.18
981           Added float values for fractions.
982
983    min_weight_fraction_leaf : float, default=0.0
984        The minimum weighted fraction of the sum total of weights (of all
985        the input samples) required to be at a leaf node. Samples have
986        equal weight when sample_weight is not provided.
987
988    max_depth : int, default=3
989        The maximum depth of the individual regression estimators. The maximum
990        depth limits the number of nodes in the tree. Tune this parameter
991        for best performance; the best value depends on the interaction
992        of the input variables.
993
994    min_impurity_decrease : float, default=0.0
995        A node will be split if this split induces a decrease of the impurity
996        greater than or equal to this value.
997
998        The weighted impurity decrease equation is the following::
999
1000            N_t / N * (impurity - N_t_R / N_t * right_impurity
1001                                - N_t_L / N_t * left_impurity)
1002
1003        where ``N`` is the total number of samples, ``N_t`` is the number of
1004        samples at the current node, ``N_t_L`` is the number of samples in the
1005        left child, and ``N_t_R`` is the number of samples in the right child.
1006
1007        ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
1008        if ``sample_weight`` is passed.
1009
1010        .. versionadded:: 0.19
1011
1012    init : estimator or 'zero', default=None
1013        An estimator object that is used to compute the initial predictions.
1014        ``init`` has to provide :meth:`fit` and :meth:`predict_proba`. If
1015        'zero', the initial raw predictions are set to zero. By default, a
1016        ``DummyEstimator`` predicting the classes priors is used.
1017
1018    random_state : int, RandomState instance or None, default=None
1019        Controls the random seed given to each Tree estimator at each
1020        boosting iteration.
1021        In addition, it controls the random permutation of the features at
1022        each split (see Notes for more details).
1023        It also controls the random splitting of the training data to obtain a
1024        validation set if `n_iter_no_change` is not None.
1025        Pass an int for reproducible output across multiple function calls.
1026        See :term:`Glossary <random_state>`.
1027
1028    max_features : {'auto', 'sqrt', 'log2'}, int or float, default=None
1029        The number of features to consider when looking for the best split:
1030
1031        - If int, then consider `max_features` features at each split.
1032        - If float, then `max_features` is a fraction and
1033          `int(max_features * n_features)` features are considered at each
1034          split.
1035        - If 'auto', then `max_features=sqrt(n_features)`.
1036        - If 'sqrt', then `max_features=sqrt(n_features)`.
1037        - If 'log2', then `max_features=log2(n_features)`.
1038        - If None, then `max_features=n_features`.
1039
1040        Choosing `max_features < n_features` leads to a reduction of variance
1041        and an increase in bias.
1042
1043        Note: the search for a split does not stop until at least one
1044        valid partition of the node samples is found, even if it requires to
1045        effectively inspect more than ``max_features`` features.
1046
1047    verbose : int, default=0
1048        Enable verbose output. If 1 then it prints progress and performance
1049        once in a while (the more trees the lower the frequency). If greater
1050        than 1 then it prints progress and performance for every tree.
1051
1052    max_leaf_nodes : int, default=None
1053        Grow trees with ``max_leaf_nodes`` in best-first fashion.
1054        Best nodes are defined as relative reduction in impurity.
1055        If None then unlimited number of leaf nodes.
1056
1057    warm_start : bool, default=False
1058        When set to ``True``, reuse the solution of the previous call to fit
1059        and add more estimators to the ensemble, otherwise, just erase the
1060        previous solution. See :term:`the Glossary <warm_start>`.
1061
1062    validation_fraction : float, default=0.1
1063        The proportion of training data to set aside as validation set for
1064        early stopping. Must be between 0 and 1.
1065        Only used if ``n_iter_no_change`` is set to an integer.
1066
1067        .. versionadded:: 0.20
1068
1069    n_iter_no_change : int, default=None
1070        ``n_iter_no_change`` is used to decide if early stopping will be used
1071        to terminate training when validation score is not improving. By
1072        default it is set to None to disable early stopping. If set to a
1073        number, it will set aside ``validation_fraction`` size of the training
1074        data as validation and terminate training when validation score is not
1075        improving in all of the previous ``n_iter_no_change`` numbers of
1076        iterations. The split is stratified.
1077
1078        .. versionadded:: 0.20
1079
1080    tol : float, default=1e-4
1081        Tolerance for the early stopping. When the loss is not improving
1082        by at least tol for ``n_iter_no_change`` iterations (if set to a
1083        number), the training stops.
1084
1085        .. versionadded:: 0.20
1086
1087    ccp_alpha : non-negative float, default=0.0
1088        Complexity parameter used for Minimal Cost-Complexity Pruning. The
1089        subtree with the largest cost complexity that is smaller than
1090        ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
1091        :ref:`minimal_cost_complexity_pruning` for details.
1092
1093        .. versionadded:: 0.22
1094
1095    Attributes
1096    ----------
1097    n_estimators_ : int
1098        The number of estimators as selected by early stopping (if
1099        ``n_iter_no_change`` is specified). Otherwise it is set to
1100        ``n_estimators``.
1101
1102        .. versionadded:: 0.20
1103
1104    feature_importances_ : ndarray of shape (n_features,)
1105        The impurity-based feature importances.
1106        The higher, the more important the feature.
1107        The importance of a feature is computed as the (normalized)
1108        total reduction of the criterion brought by that feature.  It is also
1109        known as the Gini importance.
1110
1111        Warning: impurity-based feature importances can be misleading for
1112        high cardinality features (many unique values). See
1113        :func:`sklearn.inspection.permutation_importance` as an alternative.
1114
1115    oob_improvement_ : ndarray of shape (n_estimators,)
1116        The improvement in loss (= deviance) on the out-of-bag samples
1117        relative to the previous iteration.
1118        ``oob_improvement_[0]`` is the improvement in
1119        loss of the first stage over the ``init`` estimator.
1120        Only available if ``subsample < 1.0``
1121
1122    train_score_ : ndarray of shape (n_estimators,)
1123        The i-th score ``train_score_[i]`` is the deviance (= loss) of the
1124        model at iteration ``i`` on the in-bag sample.
1125        If ``subsample == 1`` this is the deviance on the training data.
1126
1127    loss_ : LossFunction
1128        The concrete ``LossFunction`` object.
1129
1130    init_ : estimator
1131        The estimator that provides the initial predictions.
1132        Set via the ``init`` argument or ``loss.init_estimator``.
1133
1134    estimators_ : ndarray of DecisionTreeRegressor of \
1135            shape (n_estimators, ``loss_.K``)
1136        The collection of fitted sub-estimators. ``loss_.K`` is 1 for binary
1137        classification, otherwise n_classes.
1138
1139    classes_ : ndarray of shape (n_classes,)
1140        The classes labels.
1141
1142    n_features_ : int
1143        The number of data features.
1144
1145        .. deprecated:: 1.0
1146            Attribute `n_features_` was deprecated in version 1.0 and will be
1147            removed in 1.2. Use `n_features_in_` instead.
1148
1149    n_features_in_ : int
1150        Number of features seen during :term:`fit`.
1151
1152        .. versionadded:: 0.24
1153
1154    feature_names_in_ : ndarray of shape (`n_features_in_`,)
1155        Names of features seen during :term:`fit`. Defined only when `X`
1156        has feature names that are all strings.
1157
1158        .. versionadded:: 1.0
1159
1160    n_classes_ : int
1161        The number of classes.
1162
1163    max_features_ : int
1164        The inferred value of max_features.
1165
1166    See Also
1167    --------
1168    HistGradientBoostingClassifier : Histogram-based Gradient Boosting
1169        Classification Tree.
1170    sklearn.tree.DecisionTreeClassifier : A decision tree classifier.
1171    RandomForestClassifier : A meta-estimator that fits a number of decision
1172        tree classifiers on various sub-samples of the dataset and uses
1173        averaging to improve the predictive accuracy and control over-fitting.
1174    AdaBoostClassifier : A meta-estimator that begins by fitting a classifier
1175        on the original dataset and then fits additional copies of the
1176        classifier on the same dataset where the weights of incorrectly
1177        classified instances are adjusted such that subsequent classifiers
1178        focus more on difficult cases.
1179
1180    Notes
1181    -----
1182    The features are always randomly permuted at each split. Therefore,
1183    the best found split may vary, even with the same training data and
1184    ``max_features=n_features``, if the improvement of the criterion is
1185    identical for several splits enumerated during the search of the best
1186    split. To obtain a deterministic behaviour during fitting,
1187    ``random_state`` has to be fixed.
1188
1189    References
1190    ----------
1191    J. Friedman, Greedy Function Approximation: A Gradient Boosting
1192    Machine, The Annals of Statistics, Vol. 29, No. 5, 2001.
1193
1194    J. Friedman, Stochastic Gradient Boosting, 1999
1195
1196    T. Hastie, R. Tibshirani and J. Friedman.
1197    Elements of Statistical Learning Ed. 2, Springer, 2009.
1198
1199    Examples
1200    --------
1201    The following example shows how to fit a gradient boosting classifier with
1202    100 decision stumps as weak learners.
1203
1204    >>> from sklearn.datasets import make_hastie_10_2
1205    >>> from sklearn.ensemble import GradientBoostingClassifier
1206
1207    >>> X, y = make_hastie_10_2(random_state=0)
1208    >>> X_train, X_test = X[:2000], X[2000:]
1209    >>> y_train, y_test = y[:2000], y[2000:]
1210
1211    >>> clf = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0,
1212    ...     max_depth=1, random_state=0).fit(X_train, y_train)
1213    >>> clf.score(X_test, y_test)
1214    0.913...
1215    """
1216
1217    _SUPPORTED_LOSS = ("deviance", "exponential")
1218
1219    def __init__(
1220        self,
1221        *,
1222        loss="deviance",
1223        learning_rate=0.1,
1224        n_estimators=100,
1225        subsample=1.0,
1226        criterion="friedman_mse",
1227        min_samples_split=2,
1228        min_samples_leaf=1,
1229        min_weight_fraction_leaf=0.0,
1230        max_depth=3,
1231        min_impurity_decrease=0.0,
1232        init=None,
1233        random_state=None,
1234        max_features=None,
1235        verbose=0,
1236        max_leaf_nodes=None,
1237        warm_start=False,
1238        validation_fraction=0.1,
1239        n_iter_no_change=None,
1240        tol=1e-4,
1241        ccp_alpha=0.0,
1242    ):
1243
1244        super().__init__(
1245            loss=loss,
1246            learning_rate=learning_rate,
1247            n_estimators=n_estimators,
1248            criterion=criterion,
1249            min_samples_split=min_samples_split,
1250            min_samples_leaf=min_samples_leaf,
1251            min_weight_fraction_leaf=min_weight_fraction_leaf,
1252            max_depth=max_depth,
1253            init=init,
1254            subsample=subsample,
1255            max_features=max_features,
1256            random_state=random_state,
1257            verbose=verbose,
1258            max_leaf_nodes=max_leaf_nodes,
1259            min_impurity_decrease=min_impurity_decrease,
1260            warm_start=warm_start,
1261            validation_fraction=validation_fraction,
1262            n_iter_no_change=n_iter_no_change,
1263            tol=tol,
1264            ccp_alpha=ccp_alpha,
1265        )
1266
1267    def _validate_y(self, y, sample_weight):
1268        check_classification_targets(y)
1269        self.classes_, y = np.unique(y, return_inverse=True)
1270        n_trim_classes = np.count_nonzero(np.bincount(y, sample_weight))
1271        if n_trim_classes < 2:
1272            raise ValueError(
1273                "y contains %d class after sample_weight "
1274                "trimmed classes with zero weights, while a "
1275                "minimum of 2 classes are required." % n_trim_classes
1276            )
1277        self._n_classes = len(self.classes_)
1278        # expose n_classes_ attribute
1279        self.n_classes_ = self._n_classes
1280        return y
1281
1282    def _warn_mae_for_criterion(self):
1283        # TODO: This should raise an error from 1.1
1284        warnings.warn(
1285            "criterion='mae' was deprecated in version 0.24 and "
1286            "will be removed in version 1.1 (renaming of 0.26). Use "
1287            "criterion='friedman_mse' or 'squared_error' instead, as"
1288            " trees should use a squared error criterion in Gradient"
1289            " Boosting.",
1290            FutureWarning,
1291        )
1292
1293    def decision_function(self, X):
1294        """Compute the decision function of ``X``.
1295
1296        Parameters
1297        ----------
1298        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1299            The input samples. Internally, it will be converted to
1300            ``dtype=np.float32`` and if a sparse matrix is provided
1301            to a sparse ``csr_matrix``.
1302
1303        Returns
1304        -------
1305        score : ndarray of shape (n_samples, n_classes) or (n_samples,)
1306            The decision function of the input samples, which corresponds to
1307            the raw values predicted from the trees of the ensemble . The
1308            order of the classes corresponds to that in the attribute
1309            :term:`classes_`. Regression and binary classification produce an
1310            array of shape (n_samples,).
1311        """
1312        X = self._validate_data(
1313            X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
1314        )
1315        raw_predictions = self._raw_predict(X)
1316        if raw_predictions.shape[1] == 1:
1317            return raw_predictions.ravel()
1318        return raw_predictions
1319
1320    def staged_decision_function(self, X):
1321        """Compute decision function of ``X`` for each iteration.
1322
1323        This method allows monitoring (i.e. determine error on testing set)
1324        after each stage.
1325
1326        Parameters
1327        ----------
1328        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1329            The input samples. Internally, it will be converted to
1330            ``dtype=np.float32`` and if a sparse matrix is provided
1331            to a sparse ``csr_matrix``.
1332
1333        Yields
1334        ------
1335        score : generator of ndarray of shape (n_samples, k)
1336            The decision function of the input samples, which corresponds to
1337            the raw values predicted from the trees of the ensemble . The
1338            classes corresponds to that in the attribute :term:`classes_`.
1339            Regression and binary classification are special cases with
1340            ``k == 1``, otherwise ``k==n_classes``.
1341        """
1342        yield from self._staged_raw_predict(X)
1343
1344    def predict(self, X):
1345        """Predict class for X.
1346
1347        Parameters
1348        ----------
1349        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1350            The input samples. Internally, it will be converted to
1351            ``dtype=np.float32`` and if a sparse matrix is provided
1352            to a sparse ``csr_matrix``.
1353
1354        Returns
1355        -------
1356        y : ndarray of shape (n_samples,)
1357            The predicted values.
1358        """
1359        raw_predictions = self.decision_function(X)
1360        encoded_labels = self.loss_._raw_prediction_to_decision(raw_predictions)
1361        return self.classes_.take(encoded_labels, axis=0)
1362
1363    def staged_predict(self, X):
1364        """Predict class at each stage for X.
1365
1366        This method allows monitoring (i.e. determine error on testing set)
1367        after each stage.
1368
1369        Parameters
1370        ----------
1371        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1372            The input samples. Internally, it will be converted to
1373            ``dtype=np.float32`` and if a sparse matrix is provided
1374            to a sparse ``csr_matrix``.
1375
1376        Yields
1377        -------
1378        y : generator of ndarray of shape (n_samples,)
1379            The predicted value of the input samples.
1380        """
1381        for raw_predictions in self._staged_raw_predict(X):
1382            encoded_labels = self.loss_._raw_prediction_to_decision(raw_predictions)
1383            yield self.classes_.take(encoded_labels, axis=0)
1384
1385    def predict_proba(self, X):
1386        """Predict class probabilities for X.
1387
1388        Parameters
1389        ----------
1390        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1391            The input samples. Internally, it will be converted to
1392            ``dtype=np.float32`` and if a sparse matrix is provided
1393            to a sparse ``csr_matrix``.
1394
1395        Returns
1396        -------
1397        p : ndarray of shape (n_samples, n_classes)
1398            The class probabilities of the input samples. The order of the
1399            classes corresponds to that in the attribute :term:`classes_`.
1400
1401        Raises
1402        ------
1403        AttributeError
1404            If the ``loss`` does not support probabilities.
1405        """
1406        raw_predictions = self.decision_function(X)
1407        try:
1408            return self.loss_._raw_prediction_to_proba(raw_predictions)
1409        except NotFittedError:
1410            raise
1411        except AttributeError as e:
1412            raise AttributeError(
1413                "loss=%r does not support predict_proba" % self.loss
1414            ) from e
1415
1416    def predict_log_proba(self, X):
1417        """Predict class log-probabilities for X.
1418
1419        Parameters
1420        ----------
1421        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1422            The input samples. Internally, it will be converted to
1423            ``dtype=np.float32`` and if a sparse matrix is provided
1424            to a sparse ``csr_matrix``.
1425
1426        Returns
1427        -------
1428        p : ndarray of shape (n_samples, n_classes)
1429            The class log-probabilities of the input samples. The order of the
1430            classes corresponds to that in the attribute :term:`classes_`.
1431
1432        Raises
1433        ------
1434        AttributeError
1435            If the ``loss`` does not support probabilities.
1436        """
1437        proba = self.predict_proba(X)
1438        return np.log(proba)
1439
1440    def staged_predict_proba(self, X):
1441        """Predict class probabilities at each stage for X.
1442
1443        This method allows monitoring (i.e. determine error on testing set)
1444        after each stage.
1445
1446        Parameters
1447        ----------
1448        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1449            The input samples. Internally, it will be converted to
1450            ``dtype=np.float32`` and if a sparse matrix is provided
1451            to a sparse ``csr_matrix``.
1452
1453        Yields
1454        ------
1455        y : generator of ndarray of shape (n_samples,)
1456            The predicted value of the input samples.
1457        """
1458        try:
1459            for raw_predictions in self._staged_raw_predict(X):
1460                yield self.loss_._raw_prediction_to_proba(raw_predictions)
1461        except NotFittedError:
1462            raise
1463        except AttributeError as e:
1464            raise AttributeError(
1465                "loss=%r does not support predict_proba" % self.loss
1466            ) from e
1467
1468
1469class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting):
1470    """Gradient Boosting for regression.
1471
1472    GB builds an additive model in a forward stage-wise fashion;
1473    it allows for the optimization of arbitrary differentiable loss functions.
1474    In each stage a regression tree is fit on the negative gradient of the
1475    given loss function.
1476
1477    Read more in the :ref:`User Guide <gradient_boosting>`.
1478
1479    Parameters
1480    ----------
1481    loss : {'squared_error', 'absolute_error', 'huber', 'quantile'}, \
1482            default='squared_error'
1483        Loss function to be optimized. 'squared_error' refers to the squared
1484        error for regression. 'absolute_error' refers to the absolute error of
1485        regression and is a robust loss function. 'huber' is a
1486        combination of the two. 'quantile' allows quantile regression (use
1487        `alpha` to specify the quantile).
1488
1489        .. deprecated:: 1.0
1490            The loss 'ls' was deprecated in v1.0 and will be removed in
1491            version 1.2. Use `loss='squared_error'` which is equivalent.
1492
1493        .. deprecated:: 1.0
1494            The loss 'lad' was deprecated in v1.0 and will be removed in
1495            version 1.2. Use `loss='absolute_error'` which is equivalent.
1496
1497    learning_rate : float, default=0.1
1498        Learning rate shrinks the contribution of each tree by `learning_rate`.
1499        There is a trade-off between learning_rate and n_estimators.
1500
1501    n_estimators : int, default=100
1502        The number of boosting stages to perform. Gradient boosting
1503        is fairly robust to over-fitting so a large number usually
1504        results in better performance.
1505
1506    subsample : float, default=1.0
1507        The fraction of samples to be used for fitting the individual base
1508        learners. If smaller than 1.0 this results in Stochastic Gradient
1509        Boosting. `subsample` interacts with the parameter `n_estimators`.
1510        Choosing `subsample < 1.0` leads to a reduction of variance
1511        and an increase in bias.
1512
1513    criterion : {'friedman_mse', 'squared_error', 'mse', 'mae'}, \
1514            default='friedman_mse'
1515        The function to measure the quality of a split. Supported criteria
1516        are "friedman_mse" for the mean squared error with improvement
1517        score by Friedman, "squared_error" for mean squared error, and "mae"
1518        for the mean absolute error. The default value of "friedman_mse" is
1519        generally the best as it can provide a better approximation in some
1520        cases.
1521
1522        .. versionadded:: 0.18
1523
1524        .. deprecated:: 0.24
1525            `criterion='mae'` is deprecated and will be removed in version
1526            1.1 (renaming of 0.26). The correct way of minimizing the absolute
1527            error is to use `loss='absolute_error'` instead.
1528
1529        .. deprecated:: 1.0
1530            Criterion 'mse' was deprecated in v1.0 and will be removed in
1531            version 1.2. Use `criterion='squared_error'` which is equivalent.
1532
1533    min_samples_split : int or float, default=2
1534        The minimum number of samples required to split an internal node:
1535
1536        - If int, then consider `min_samples_split` as the minimum number.
1537        - If float, then `min_samples_split` is a fraction and
1538          `ceil(min_samples_split * n_samples)` are the minimum
1539          number of samples for each split.
1540
1541        .. versionchanged:: 0.18
1542           Added float values for fractions.
1543
1544    min_samples_leaf : int or float, default=1
1545        The minimum number of samples required to be at a leaf node.
1546        A split point at any depth will only be considered if it leaves at
1547        least ``min_samples_leaf`` training samples in each of the left and
1548        right branches.  This may have the effect of smoothing the model,
1549        especially in regression.
1550
1551        - If int, then consider `min_samples_leaf` as the minimum number.
1552        - If float, then `min_samples_leaf` is a fraction and
1553          `ceil(min_samples_leaf * n_samples)` are the minimum
1554          number of samples for each node.
1555
1556        .. versionchanged:: 0.18
1557           Added float values for fractions.
1558
1559    min_weight_fraction_leaf : float, default=0.0
1560        The minimum weighted fraction of the sum total of weights (of all
1561        the input samples) required to be at a leaf node. Samples have
1562        equal weight when sample_weight is not provided.
1563
1564    max_depth : int, default=3
1565        Maximum depth of the individual regression estimators. The maximum
1566        depth limits the number of nodes in the tree. Tune this parameter
1567        for best performance; the best value depends on the interaction
1568        of the input variables.
1569
1570    min_impurity_decrease : float, default=0.0
1571        A node will be split if this split induces a decrease of the impurity
1572        greater than or equal to this value.
1573
1574        The weighted impurity decrease equation is the following::
1575
1576            N_t / N * (impurity - N_t_R / N_t * right_impurity
1577                                - N_t_L / N_t * left_impurity)
1578
1579        where ``N`` is the total number of samples, ``N_t`` is the number of
1580        samples at the current node, ``N_t_L`` is the number of samples in the
1581        left child, and ``N_t_R`` is the number of samples in the right child.
1582
1583        ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
1584        if ``sample_weight`` is passed.
1585
1586        .. versionadded:: 0.19
1587
1588    init : estimator or 'zero', default=None
1589        An estimator object that is used to compute the initial predictions.
1590        ``init`` has to provide :term:`fit` and :term:`predict`. If 'zero', the
1591        initial raw predictions are set to zero. By default a
1592        ``DummyEstimator`` is used, predicting either the average target value
1593        (for loss='squared_error'), or a quantile for the other losses.
1594
1595    random_state : int, RandomState instance or None, default=None
1596        Controls the random seed given to each Tree estimator at each
1597        boosting iteration.
1598        In addition, it controls the random permutation of the features at
1599        each split (see Notes for more details).
1600        It also controls the random splitting of the training data to obtain a
1601        validation set if `n_iter_no_change` is not None.
1602        Pass an int for reproducible output across multiple function calls.
1603        See :term:`Glossary <random_state>`.
1604
1605    max_features : {'auto', 'sqrt', 'log2'}, int or float, default=None
1606        The number of features to consider when looking for the best split:
1607
1608        - If int, then consider `max_features` features at each split.
1609        - If float, then `max_features` is a fraction and
1610          `int(max_features * n_features)` features are considered at each
1611          split.
1612        - If "auto", then `max_features=n_features`.
1613        - If "sqrt", then `max_features=sqrt(n_features)`.
1614        - If "log2", then `max_features=log2(n_features)`.
1615        - If None, then `max_features=n_features`.
1616
1617        Choosing `max_features < n_features` leads to a reduction of variance
1618        and an increase in bias.
1619
1620        Note: the search for a split does not stop until at least one
1621        valid partition of the node samples is found, even if it requires to
1622        effectively inspect more than ``max_features`` features.
1623
1624    alpha : float, default=0.9
1625        The alpha-quantile of the huber loss function and the quantile
1626        loss function. Only if ``loss='huber'`` or ``loss='quantile'``.
1627
1628    verbose : int, default=0
1629        Enable verbose output. If 1 then it prints progress and performance
1630        once in a while (the more trees the lower the frequency). If greater
1631        than 1 then it prints progress and performance for every tree.
1632
1633    max_leaf_nodes : int, default=None
1634        Grow trees with ``max_leaf_nodes`` in best-first fashion.
1635        Best nodes are defined as relative reduction in impurity.
1636        If None then unlimited number of leaf nodes.
1637
1638    warm_start : bool, default=False
1639        When set to ``True``, reuse the solution of the previous call to fit
1640        and add more estimators to the ensemble, otherwise, just erase the
1641        previous solution. See :term:`the Glossary <warm_start>`.
1642
1643    validation_fraction : float, default=0.1
1644        The proportion of training data to set aside as validation set for
1645        early stopping. Must be between 0 and 1.
1646        Only used if ``n_iter_no_change`` is set to an integer.
1647
1648        .. versionadded:: 0.20
1649
1650    n_iter_no_change : int, default=None
1651        ``n_iter_no_change`` is used to decide if early stopping will be used
1652        to terminate training when validation score is not improving. By
1653        default it is set to None to disable early stopping. If set to a
1654        number, it will set aside ``validation_fraction`` size of the training
1655        data as validation and terminate training when validation score is not
1656        improving in all of the previous ``n_iter_no_change`` numbers of
1657        iterations.
1658
1659        .. versionadded:: 0.20
1660
1661    tol : float, default=1e-4
1662        Tolerance for the early stopping. When the loss is not improving
1663        by at least tol for ``n_iter_no_change`` iterations (if set to a
1664        number), the training stops.
1665
1666        .. versionadded:: 0.20
1667
1668    ccp_alpha : non-negative float, default=0.0
1669        Complexity parameter used for Minimal Cost-Complexity Pruning. The
1670        subtree with the largest cost complexity that is smaller than
1671        ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
1672        :ref:`minimal_cost_complexity_pruning` for details.
1673
1674        .. versionadded:: 0.22
1675
1676    Attributes
1677    ----------
1678    feature_importances_ : ndarray of shape (n_features,)
1679        The impurity-based feature importances.
1680        The higher, the more important the feature.
1681        The importance of a feature is computed as the (normalized)
1682        total reduction of the criterion brought by that feature.  It is also
1683        known as the Gini importance.
1684
1685        Warning: impurity-based feature importances can be misleading for
1686        high cardinality features (many unique values). See
1687        :func:`sklearn.inspection.permutation_importance` as an alternative.
1688
1689    oob_improvement_ : ndarray of shape (n_estimators,)
1690        The improvement in loss (= deviance) on the out-of-bag samples
1691        relative to the previous iteration.
1692        ``oob_improvement_[0]`` is the improvement in
1693        loss of the first stage over the ``init`` estimator.
1694        Only available if ``subsample < 1.0``
1695
1696    train_score_ : ndarray of shape (n_estimators,)
1697        The i-th score ``train_score_[i]`` is the deviance (= loss) of the
1698        model at iteration ``i`` on the in-bag sample.
1699        If ``subsample == 1`` this is the deviance on the training data.
1700
1701    loss_ : LossFunction
1702        The concrete ``LossFunction`` object.
1703
1704    init_ : estimator
1705        The estimator that provides the initial predictions.
1706        Set via the ``init`` argument or ``loss.init_estimator``.
1707
1708    estimators_ : ndarray of DecisionTreeRegressor of shape (n_estimators, 1)
1709        The collection of fitted sub-estimators.
1710
1711    n_classes_ : int
1712        The number of classes, set to 1 for regressors.
1713
1714        .. deprecated:: 0.24
1715            Attribute ``n_classes_`` was deprecated in version 0.24 and
1716            will be removed in 1.1 (renaming of 0.26).
1717
1718    n_estimators_ : int
1719        The number of estimators as selected by early stopping (if
1720        ``n_iter_no_change`` is specified). Otherwise it is set to
1721        ``n_estimators``.
1722
1723    n_features_ : int
1724        The number of data features.
1725
1726        .. deprecated:: 1.0
1727            Attribute `n_features_` was deprecated in version 1.0 and will be
1728            removed in 1.2. Use `n_features_in_` instead.
1729
1730    n_features_in_ : int
1731        Number of features seen during :term:`fit`.
1732
1733        .. versionadded:: 0.24
1734
1735    feature_names_in_ : ndarray of shape (`n_features_in_`,)
1736        Names of features seen during :term:`fit`. Defined only when `X`
1737        has feature names that are all strings.
1738
1739        .. versionadded:: 1.0
1740
1741    max_features_ : int
1742        The inferred value of max_features.
1743
1744    See Also
1745    --------
1746    HistGradientBoostingRegressor : Histogram-based Gradient Boosting
1747        Classification Tree.
1748    sklearn.tree.DecisionTreeRegressor : A decision tree regressor.
1749    sklearn.ensemble.RandomForestRegressor : A random forest regressor.
1750
1751    Notes
1752    -----
1753    The features are always randomly permuted at each split. Therefore,
1754    the best found split may vary, even with the same training data and
1755    ``max_features=n_features``, if the improvement of the criterion is
1756    identical for several splits enumerated during the search of the best
1757    split. To obtain a deterministic behaviour during fitting,
1758    ``random_state`` has to be fixed.
1759
1760    References
1761    ----------
1762    J. Friedman, Greedy Function Approximation: A Gradient Boosting
1763    Machine, The Annals of Statistics, Vol. 29, No. 5, 2001.
1764
1765    J. Friedman, Stochastic Gradient Boosting, 1999
1766
1767    T. Hastie, R. Tibshirani and J. Friedman.
1768    Elements of Statistical Learning Ed. 2, Springer, 2009.
1769
1770    Examples
1771    --------
1772    >>> from sklearn.datasets import make_regression
1773    >>> from sklearn.ensemble import GradientBoostingRegressor
1774    >>> from sklearn.model_selection import train_test_split
1775    >>> X, y = make_regression(random_state=0)
1776    >>> X_train, X_test, y_train, y_test = train_test_split(
1777    ...     X, y, random_state=0)
1778    >>> reg = GradientBoostingRegressor(random_state=0)
1779    >>> reg.fit(X_train, y_train)
1780    GradientBoostingRegressor(random_state=0)
1781    >>> reg.predict(X_test[1:2])
1782    array([-61...])
1783    >>> reg.score(X_test, y_test)
1784    0.4...
1785    """
1786
1787    # TODO: remove "ls" in version 1.2
1788    _SUPPORTED_LOSS = (
1789        "squared_error",
1790        "ls",
1791        "absolute_error",
1792        "lad",
1793        "huber",
1794        "quantile",
1795    )
1796
1797    def __init__(
1798        self,
1799        *,
1800        loss="squared_error",
1801        learning_rate=0.1,
1802        n_estimators=100,
1803        subsample=1.0,
1804        criterion="friedman_mse",
1805        min_samples_split=2,
1806        min_samples_leaf=1,
1807        min_weight_fraction_leaf=0.0,
1808        max_depth=3,
1809        min_impurity_decrease=0.0,
1810        init=None,
1811        random_state=None,
1812        max_features=None,
1813        alpha=0.9,
1814        verbose=0,
1815        max_leaf_nodes=None,
1816        warm_start=False,
1817        validation_fraction=0.1,
1818        n_iter_no_change=None,
1819        tol=1e-4,
1820        ccp_alpha=0.0,
1821    ):
1822
1823        super().__init__(
1824            loss=loss,
1825            learning_rate=learning_rate,
1826            n_estimators=n_estimators,
1827            criterion=criterion,
1828            min_samples_split=min_samples_split,
1829            min_samples_leaf=min_samples_leaf,
1830            min_weight_fraction_leaf=min_weight_fraction_leaf,
1831            max_depth=max_depth,
1832            init=init,
1833            subsample=subsample,
1834            max_features=max_features,
1835            min_impurity_decrease=min_impurity_decrease,
1836            random_state=random_state,
1837            alpha=alpha,
1838            verbose=verbose,
1839            max_leaf_nodes=max_leaf_nodes,
1840            warm_start=warm_start,
1841            validation_fraction=validation_fraction,
1842            n_iter_no_change=n_iter_no_change,
1843            tol=tol,
1844            ccp_alpha=ccp_alpha,
1845        )
1846
1847    def _validate_y(self, y, sample_weight=None):
1848        if y.dtype.kind == "O":
1849            y = y.astype(DOUBLE)
1850        return y
1851
1852    def _warn_mae_for_criterion(self):
1853        # TODO: This should raise an error from 1.1
1854        warnings.warn(
1855            "criterion='mae' was deprecated in version 0.24 and "
1856            "will be removed in version 1.1 (renaming of 0.26). The "
1857            "correct way of minimizing the absolute error is to use "
1858            " loss='absolute_error' instead.",
1859            FutureWarning,
1860        )
1861
1862    def predict(self, X):
1863        """Predict regression target for X.
1864
1865        Parameters
1866        ----------
1867        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1868            The input samples. Internally, it will be converted to
1869            ``dtype=np.float32`` and if a sparse matrix is provided
1870            to a sparse ``csr_matrix``.
1871
1872        Returns
1873        -------
1874        y : ndarray of shape (n_samples,)
1875            The predicted values.
1876        """
1877        X = self._validate_data(
1878            X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
1879        )
1880        # In regression we can directly return the raw value from the trees.
1881        return self._raw_predict(X).ravel()
1882
1883    def staged_predict(self, X):
1884        """Predict regression target at each stage for X.
1885
1886        This method allows monitoring (i.e. determine error on testing set)
1887        after each stage.
1888
1889        Parameters
1890        ----------
1891        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1892            The input samples. Internally, it will be converted to
1893            ``dtype=np.float32`` and if a sparse matrix is provided
1894            to a sparse ``csr_matrix``.
1895
1896        Yields
1897        ------
1898        y : generator of ndarray of shape (n_samples,)
1899            The predicted value of the input samples.
1900        """
1901        for raw_predictions in self._staged_raw_predict(X):
1902            yield raw_predictions.ravel()
1903
1904    def apply(self, X):
1905        """Apply trees in the ensemble to X, return leaf indices.
1906
1907        .. versionadded:: 0.17
1908
1909        Parameters
1910        ----------
1911        X : {array-like, sparse matrix} of shape (n_samples, n_features)
1912            The input samples. Internally, its dtype will be converted to
1913            ``dtype=np.float32``. If a sparse matrix is provided, it will
1914            be converted to a sparse ``csr_matrix``.
1915
1916        Returns
1917        -------
1918        X_leaves : array-like of shape (n_samples, n_estimators)
1919            For each datapoint x in X and for each tree in the ensemble,
1920            return the index of the leaf x ends up in each estimator.
1921        """
1922
1923        leaves = super().apply(X)
1924        leaves = leaves.reshape(X.shape[0], self.estimators_.shape[0])
1925        return leaves
1926
1927    # FIXME: to be removed in 1.1
1928    # mypy error: Decorated property not supported
1929    @deprecated(  # type: ignore
1930        "Attribute `n_classes_` was deprecated "
1931        "in version 0.24 and will be removed in 1.1 (renaming of 0.26)."
1932    )
1933    @property
1934    def n_classes_(self):
1935        try:
1936            check_is_fitted(self)
1937        except NotFittedError as nfe:
1938            raise AttributeError(
1939                "{} object has no n_classes_ attribute.".format(self.__class__.__name__)
1940            ) from nfe
1941        return 1
1942