1<?php
2
3namespace Rubix\ML\Regressors;
4
5use Tensor\Vector;
6use Rubix\ML\Learner;
7use Rubix\ML\Verbose;
8use Rubix\ML\Estimator;
9use Rubix\ML\Persistable;
10use Rubix\ML\RanksFeatures;
11use Rubix\ML\EstimatorType;
12use Rubix\ML\Datasets\Dataset;
13use Rubix\ML\Datasets\Labeled;
14use Rubix\ML\Other\Helpers\Params;
15use Rubix\ML\Other\Strategies\Mean;
16use Rubix\ML\Other\Traits\LoggerAware;
17use Rubix\ML\Other\Traits\PredictsSingle;
18use Rubix\ML\CrossValidation\Metrics\RMSE;
19use Rubix\ML\CrossValidation\Metrics\Metric;
20use Rubix\ML\Other\Traits\AutotrackRevisions;
21use Rubix\ML\Specifications\DatasetIsLabeled;
22use Rubix\ML\Specifications\DatasetIsNotEmpty;
23use Rubix\ML\Specifications\SpecificationChain;
24use Rubix\ML\Specifications\DatasetHasDimensionality;
25use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner;
26use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
27use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
28use Rubix\ML\Exceptions\InvalidArgumentException;
29use Rubix\ML\Exceptions\RuntimeException;
30
31use function count;
32use function is_nan;
33use function array_slice;
34use function get_class;
35use function in_array;
36
37/**
38 * Gradient Boost
39 *
40 * Gradient Boost is a stage-wise additive ensemble that uses a Gradient Descent boosting
41 * scheme for training  boosters (Decision Trees) to correct the error residuals of a
42 * series of *weak* base learners. Stochastic gradient boosting is achieved by varying
43 * the ratio of samples to subsample uniformly at random from the training set.
44 *
45 * > **Note**: The default base classifier is a Dummy Classifier using the Mean strategy
46 * and the default booster is a Regression Tree with a max height of 3.
47 *
48 * References:
49 * [1] J. H. Friedman. (2001). Greedy Function Approximation: A Gradient
50 * Boosting Machine.
51 * [2] J. H. Friedman. (1999). Stochastic Gradient Boosting.
52 * [3] Y. Wei. et al. (2017). Early stopping for kernel boosting algorithms:
53 * A general analysis with localized complexities.
54 *
55 * @category    Machine Learning
56 * @package     Rubix/ML
57 * @author      Andrew DalPino
58 */
59class GradientBoost implements Estimator, Learner, RanksFeatures, Verbose, Persistable
60{
61    use AutotrackRevisions, PredictsSingle, LoggerAware;
62
63    /**
64     * The class names of the compatible learners to used as boosters.
65     *
66     * @var class-string[]
67     */
68    public const COMPATIBLE_BOOSTERS = [
69        RegressionTree::class,
70        ExtraTreeRegressor::class,
71    ];
72
73    /**
74     * The minimum size of each training subset.
75     *
76     * @var int
77     */
78    protected const MIN_SUBSAMPLE = 1;
79
80    /**
81     * The regressor that will fix up the error residuals of the *weak* base learner.
82     *
83     * @var \Rubix\ML\Learner
84     */
85    protected $booster;
86
87    /**
88     * The learning rate of the ensemble i.e. the *shrinkage* applied to each step.
89     *
90     * @var float
91     */
92    protected $rate;
93
94    /**
95     * The ratio of samples to subsample from the training set for each booster.
96     *
97     * @var float
98     */
99    protected $ratio;
100
101    /**
102     *  The max number of estimators to train in the ensemble.
103     *
104     * @var int
105     */
106    protected $estimators;
107
108    /**
109     * The minimum change in the training loss necessary to continue training.
110     *
111     * @var float
112     */
113    protected $minChange;
114
115    /**
116     * The number of epochs without improvement in the validation score to wait
117     * before considering an early stop.
118     *
119     * @var int
120     */
121    protected $window;
122
123    /**
124     * The proportion of training samples to use for validation and progress monitoring.
125     *
126     * @var float
127     */
128    protected $holdOut;
129
130    /**
131     * The metric used to score the generalization performance of the model
132     * during training.
133     *
134     * @var \Rubix\ML\CrossValidation\Metrics\Metric
135     */
136    protected $metric;
137
138    /**
139     * The *weak* base regressor to be boosted.
140     *
141     * @var \Rubix\ML\Learner
142     */
143    protected $base;
144
145    /**
146     * An ensemble of weak regressors.
147     *
148     * @var mixed[]
149     */
150    protected $ensemble = [
151        //
152    ];
153
154    /**
155     * The dimensionality of the training set.
156     *
157     * @var int|null
158     */
159    protected $featureCount;
160
161    /**
162     * The validation scores at each epoch.
163     *
164     * @var float[]|null
165     */
166    protected $scores;
167
168    /**
169     * The average training loss at each epoch.
170     *
171     * @var float[]|null
172     */
173    protected $steps;
174
175    /**
176     * @param \Rubix\ML\Learner|null $booster
177     * @param float $rate
178     * @param float $ratio
179     * @param int $estimators
180     * @param float $minChange
181     * @param int $window
182     * @param float $holdOut
183     * @param \Rubix\ML\CrossValidation\Metrics\Metric|null $metric
184     * @param \Rubix\ML\Learner|null $base
185     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
186     */
187    public function __construct(
188        ?Learner $booster = null,
189        float $rate = 0.1,
190        float $ratio = 0.5,
191        int $estimators = 1000,
192        float $minChange = 1e-4,
193        int $window = 10,
194        float $holdOut = 0.1,
195        ?Metric $metric = null,
196        ?Learner $base = null
197    ) {
198        if ($booster and !in_array(get_class($booster), self::COMPATIBLE_BOOSTERS)) {
199            throw new InvalidArgumentException('Booster is not compatible'
200                . ' with the ensemble.');
201        }
202
203        if ($rate <= 0.0 or $rate > 1.0) {
204            throw new InvalidArgumentException('Learning rate must be'
205                . " greater than 0, $rate given.");
206        }
207
208        if ($ratio <= 0.0 or $ratio > 1.0) {
209            throw new InvalidArgumentException('Ratio must be'
210                . " between 0 and 1, $ratio given.");
211        }
212
213        if ($estimators < 1) {
214            throw new InvalidArgumentException('Number of estimators'
215                . " must be greater than 0, $estimators given.");
216        }
217
218        if ($minChange < 0.0) {
219            throw new InvalidArgumentException('Minimum change must be'
220                . " greater than 0, $minChange given.");
221        }
222
223        if ($window < 1) {
224            throw new InvalidArgumentException('Window must be'
225                . " greater than 0, $window given.");
226        }
227
228        if ($holdOut < 0.0 or $holdOut > 0.5) {
229            throw new InvalidArgumentException('Hold out ratio must be'
230                . " between 0 and 0.5, $holdOut given.");
231        }
232
233        if ($metric) {
234            EstimatorIsCompatibleWithMetric::with($this, $metric)->check();
235        }
236
237        if ($base and $base->type() != EstimatorType::regressor()) {
238            throw new InvalidArgumentException('Base Estimator must be a'
239                . " regressor, {$base->type()} given.");
240        }
241
242        $this->booster = $booster ?? new RegressionTree(3);
243        $this->rate = $rate;
244        $this->ratio = $ratio;
245        $this->estimators = $estimators;
246        $this->minChange = $minChange;
247        $this->window = $window;
248        $this->holdOut = $holdOut;
249        $this->metric = $metric ?? new RMSE();
250        $this->base = $base ?? new DummyRegressor(new Mean());
251    }
252
253    /**
254     * Return the estimator type.
255     *
256     * @internal
257     *
258     * @return \Rubix\ML\EstimatorType
259     */
260    public function type() : EstimatorType
261    {
262        return EstimatorType::regressor();
263    }
264
265    /**
266     * Return the data types that the estimator is compatible with.
267     *
268     * @internal
269     *
270     * @return list<\Rubix\ML\DataType>
271     */
272    public function compatibility() : array
273    {
274        $compatibility = array_intersect(
275            $this->booster->compatibility(),
276            $this->base->compatibility()
277        );
278
279        return array_values($compatibility);
280    }
281
282    /**
283     * Return the settings of the hyper-parameters in an associative array.
284     *
285     * @internal
286     *
287     * @return mixed[]
288     */
289    public function params() : array
290    {
291        return [
292            'booster' => $this->booster,
293            'rate' => $this->rate,
294            'ratio' => $this->ratio,
295            'estimators' => $this->estimators,
296            'min_change' => $this->minChange,
297            'window' => $this->window,
298            'hold_out' => $this->holdOut,
299            'metric' => $this->metric,
300            'base' => $this->base,
301        ];
302    }
303
304    /**
305     * Has the learner been trained?
306     *
307     * @return bool
308     */
309    public function trained() : bool
310    {
311        return $this->base->trained() and $this->ensemble;
312    }
313
314    /**
315     * Return the validation scores at each epoch from the last training session.
316     *
317     * @return float[]|null
318     */
319    public function scores() : ?array
320    {
321        return $this->scores;
322    }
323
324    /**
325     * Return the loss at each epoch from the last training session.
326     *
327     * @return float[]|null
328     */
329    public function steps() : ?array
330    {
331        return $this->steps;
332    }
333
334    /**
335     * Train the estimator with a dataset.
336     *
337     * @param \Rubix\ML\Datasets\Labeled $dataset
338     */
339    public function train(Dataset $dataset) : void
340    {
341        SpecificationChain::with([
342            new DatasetIsLabeled($dataset),
343            new DatasetIsNotEmpty($dataset),
344            new SamplesAreCompatibleWithEstimator($dataset, $this),
345            new LabelsAreCompatibleWithLearner($dataset, $this),
346        ])->check();
347
348        if ($this->logger) {
349            $this->logger->info("$this initialized");
350        }
351
352        $this->featureCount = $dataset->numColumns();
353
354        [$testing, $training] = $dataset->randomize()->split($this->holdOut);
355
356        [$min, $max] = $this->metric->range();
357
358        if ($this->logger) {
359            $this->logger->info("Training {$this->base}");
360        }
361
362        $this->base->train($training);
363
364        $this->ensemble = $this->scores = $this->steps = [];
365
366        /** @var list<int|float> $predictions */
367        $predictions = $this->base->predict($training);
368
369        $out = $prevOut = Vector::quick($predictions);
370        $target = Vector::quick($training->labels());
371
372        if (!$testing->empty()) {
373            /** @var list<int|float> $predictions */
374            $predictions = $this->base->predict($testing);
375
376            $prevPred = Vector::quick($predictions);
377        }
378
379        $p = max(self::MIN_SUBSAMPLE, (int) round($this->ratio * $training->numRows()));
380
381        $bestScore = $min;
382        $bestEpoch = $delta = 0;
383        $score = null;
384        $prevLoss = INF;
385
386        for ($epoch = 1; $epoch <= $this->estimators; ++$epoch) {
387            $gradient = $target->subtract($out);
388
389            $training = Labeled::quick($training->samples(), $gradient->asArray());
390
391            $booster = clone $this->booster;
392
393            $subset = $training->randomSubset($p);
394
395            $booster->train($subset);
396
397            $this->ensemble[] = $booster;
398
399            /** @var list<int|float> $predictions */
400            $predictions = $booster->predict($training);
401
402            $out = Vector::quick($predictions)
403                ->multiply($this->rate)
404                ->add($prevOut);
405
406            $loss = $gradient->square()->mean();
407
408            if (is_nan($loss)) {
409                if ($this->logger) {
410                    $this->logger->info('Numerical instability detected');
411                }
412
413                break;
414            }
415
416            $this->steps[] = $loss;
417
418            if (isset($prevPred)) {
419                /** @var list<int|float> $predictions */
420                $predictions = $booster->predict($testing);
421
422                $pred = Vector::quick($predictions)
423                    ->multiply($this->rate)
424                    ->add($prevPred);
425
426                $score = $this->metric->score($pred->asArray(), $testing->labels());
427
428                $this->scores[] = $score;
429            }
430
431            if ($this->logger) {
432                $this->logger->info("Epoch $epoch - {$this->metric}: "
433                    . ($score ?? 'n/a') . ", L2 Loss: $loss");
434            }
435
436            if (isset($pred)) {
437                if ($score >= $max) {
438                    break;
439                }
440
441                if ($score > $bestScore) {
442                    $bestScore = $score;
443                    $bestEpoch = $epoch;
444
445                    $delta = 0;
446                } else {
447                    ++$delta;
448                }
449
450                if ($delta >= $this->window) {
451                    break;
452                }
453
454                $prevPred = $pred;
455            }
456
457            if ($loss <= 0.0) {
458                break;
459            }
460
461            if (abs($prevLoss - $loss) < $this->minChange) {
462                break;
463            }
464
465            $prevOut = $out;
466            $prevLoss = $loss;
467        }
468
469        if ($this->scores and end($this->scores) < $bestScore) {
470            if ($this->logger) {
471                $this->logger->info("Restoring ensemble state to epoch $bestEpoch");
472            }
473
474            $this->ensemble = array_slice($this->ensemble, 0, $bestEpoch);
475        }
476
477        if ($this->logger) {
478            $this->logger->info('Training complete');
479        }
480    }
481
482    /**
483     * Make a prediction from a dataset.
484     *
485     * @param \Rubix\ML\Datasets\Dataset $dataset
486     * @throws \Rubix\ML\Exceptions\RuntimeException
487     * @return list<int|float>
488     */
489    public function predict(Dataset $dataset) : array
490    {
491        if (!$this->ensemble or !$this->featureCount) {
492            throw new RuntimeException('Estimator has not been trained.');
493        }
494
495        DatasetHasDimensionality::with($dataset, $this->featureCount)->check();
496
497        /** @var list<int|float> $predictions */
498        $predictions = $this->base->predict($dataset);
499
500        foreach ($this->ensemble as $estimator) {
501            /** @var int $j */
502            foreach ($estimator->predict($dataset) as $j => $prediction) {
503                $predictions[$j] += $this->rate * $prediction;
504            }
505        }
506
507        return $predictions;
508    }
509
510    /**
511     * Return the normalized importance scores of each feature column of the training set.
512     *
513     * @throws \Rubix\ML\Exceptions\RuntimeException
514     * @return float[]
515     */
516    public function featureImportances() : array
517    {
518        if (!$this->ensemble or !$this->featureCount) {
519            throw new RuntimeException('Estimator has not been trained.');
520        }
521
522        $importances = array_fill(0, $this->featureCount, 0.0);
523
524        foreach ($this->ensemble as $tree) {
525            foreach ($tree->featureImportances() as $column => $importance) {
526                $importances[$column] += $importance;
527            }
528        }
529
530        $n = count($this->ensemble);
531
532        foreach ($importances as &$importance) {
533            $importance /= $n;
534        }
535
536        return $importances;
537    }
538
539    /**
540     * Return the string representation of the object.
541     *
542     * @return string
543     */
544    public function __toString() : string
545    {
546        return 'Gradient Boost (' . Params::stringify($this->params()) . ')';
547    }
548}
549