1<?php 2 3namespace Rubix\ML\Regressors; 4 5use Tensor\Vector; 6use Rubix\ML\Learner; 7use Rubix\ML\Verbose; 8use Rubix\ML\Estimator; 9use Rubix\ML\Persistable; 10use Rubix\ML\RanksFeatures; 11use Rubix\ML\EstimatorType; 12use Rubix\ML\Datasets\Dataset; 13use Rubix\ML\Datasets\Labeled; 14use Rubix\ML\Other\Helpers\Params; 15use Rubix\ML\Other\Strategies\Mean; 16use Rubix\ML\Other\Traits\LoggerAware; 17use Rubix\ML\Other\Traits\PredictsSingle; 18use Rubix\ML\CrossValidation\Metrics\RMSE; 19use Rubix\ML\CrossValidation\Metrics\Metric; 20use Rubix\ML\Other\Traits\AutotrackRevisions; 21use Rubix\ML\Specifications\DatasetIsLabeled; 22use Rubix\ML\Specifications\DatasetIsNotEmpty; 23use Rubix\ML\Specifications\SpecificationChain; 24use Rubix\ML\Specifications\DatasetHasDimensionality; 25use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner; 26use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric; 27use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; 28use Rubix\ML\Exceptions\InvalidArgumentException; 29use Rubix\ML\Exceptions\RuntimeException; 30 31use function count; 32use function is_nan; 33use function array_slice; 34use function get_class; 35use function in_array; 36 37/** 38 * Gradient Boost 39 * 40 * Gradient Boost is a stage-wise additive ensemble that uses a Gradient Descent boosting 41 * scheme for training boosters (Decision Trees) to correct the error residuals of a 42 * series of *weak* base learners. Stochastic gradient boosting is achieved by varying 43 * the ratio of samples to subsample uniformly at random from the training set. 44 * 45 * > **Note**: The default base classifier is a Dummy Classifier using the Mean strategy 46 * and the default booster is a Regression Tree with a max height of 3. 47 * 48 * References: 49 * [1] J. H. Friedman. (2001). Greedy Function Approximation: A Gradient 50 * Boosting Machine. 51 * [2] J. H. Friedman. (1999). Stochastic Gradient Boosting. 52 * [3] Y. Wei. et al. (2017). Early stopping for kernel boosting algorithms: 53 * A general analysis with localized complexities. 54 * 55 * @category Machine Learning 56 * @package Rubix/ML 57 * @author Andrew DalPino 58 */ 59class GradientBoost implements Estimator, Learner, RanksFeatures, Verbose, Persistable 60{ 61 use AutotrackRevisions, PredictsSingle, LoggerAware; 62 63 /** 64 * The class names of the compatible learners to used as boosters. 65 * 66 * @var class-string[] 67 */ 68 public const COMPATIBLE_BOOSTERS = [ 69 RegressionTree::class, 70 ExtraTreeRegressor::class, 71 ]; 72 73 /** 74 * The minimum size of each training subset. 75 * 76 * @var int 77 */ 78 protected const MIN_SUBSAMPLE = 1; 79 80 /** 81 * The regressor that will fix up the error residuals of the *weak* base learner. 82 * 83 * @var \Rubix\ML\Learner 84 */ 85 protected $booster; 86 87 /** 88 * The learning rate of the ensemble i.e. the *shrinkage* applied to each step. 89 * 90 * @var float 91 */ 92 protected $rate; 93 94 /** 95 * The ratio of samples to subsample from the training set for each booster. 96 * 97 * @var float 98 */ 99 protected $ratio; 100 101 /** 102 * The max number of estimators to train in the ensemble. 103 * 104 * @var int 105 */ 106 protected $estimators; 107 108 /** 109 * The minimum change in the training loss necessary to continue training. 110 * 111 * @var float 112 */ 113 protected $minChange; 114 115 /** 116 * The number of epochs without improvement in the validation score to wait 117 * before considering an early stop. 118 * 119 * @var int 120 */ 121 protected $window; 122 123 /** 124 * The proportion of training samples to use for validation and progress monitoring. 125 * 126 * @var float 127 */ 128 protected $holdOut; 129 130 /** 131 * The metric used to score the generalization performance of the model 132 * during training. 133 * 134 * @var \Rubix\ML\CrossValidation\Metrics\Metric 135 */ 136 protected $metric; 137 138 /** 139 * The *weak* base regressor to be boosted. 140 * 141 * @var \Rubix\ML\Learner 142 */ 143 protected $base; 144 145 /** 146 * An ensemble of weak regressors. 147 * 148 * @var mixed[] 149 */ 150 protected $ensemble = [ 151 // 152 ]; 153 154 /** 155 * The dimensionality of the training set. 156 * 157 * @var int|null 158 */ 159 protected $featureCount; 160 161 /** 162 * The validation scores at each epoch. 163 * 164 * @var float[]|null 165 */ 166 protected $scores; 167 168 /** 169 * The average training loss at each epoch. 170 * 171 * @var float[]|null 172 */ 173 protected $steps; 174 175 /** 176 * @param \Rubix\ML\Learner|null $booster 177 * @param float $rate 178 * @param float $ratio 179 * @param int $estimators 180 * @param float $minChange 181 * @param int $window 182 * @param float $holdOut 183 * @param \Rubix\ML\CrossValidation\Metrics\Metric|null $metric 184 * @param \Rubix\ML\Learner|null $base 185 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 186 */ 187 public function __construct( 188 ?Learner $booster = null, 189 float $rate = 0.1, 190 float $ratio = 0.5, 191 int $estimators = 1000, 192 float $minChange = 1e-4, 193 int $window = 10, 194 float $holdOut = 0.1, 195 ?Metric $metric = null, 196 ?Learner $base = null 197 ) { 198 if ($booster and !in_array(get_class($booster), self::COMPATIBLE_BOOSTERS)) { 199 throw new InvalidArgumentException('Booster is not compatible' 200 . ' with the ensemble.'); 201 } 202 203 if ($rate <= 0.0 or $rate > 1.0) { 204 throw new InvalidArgumentException('Learning rate must be' 205 . " greater than 0, $rate given."); 206 } 207 208 if ($ratio <= 0.0 or $ratio > 1.0) { 209 throw new InvalidArgumentException('Ratio must be' 210 . " between 0 and 1, $ratio given."); 211 } 212 213 if ($estimators < 1) { 214 throw new InvalidArgumentException('Number of estimators' 215 . " must be greater than 0, $estimators given."); 216 } 217 218 if ($minChange < 0.0) { 219 throw new InvalidArgumentException('Minimum change must be' 220 . " greater than 0, $minChange given."); 221 } 222 223 if ($window < 1) { 224 throw new InvalidArgumentException('Window must be' 225 . " greater than 0, $window given."); 226 } 227 228 if ($holdOut < 0.0 or $holdOut > 0.5) { 229 throw new InvalidArgumentException('Hold out ratio must be' 230 . " between 0 and 0.5, $holdOut given."); 231 } 232 233 if ($metric) { 234 EstimatorIsCompatibleWithMetric::with($this, $metric)->check(); 235 } 236 237 if ($base and $base->type() != EstimatorType::regressor()) { 238 throw new InvalidArgumentException('Base Estimator must be a' 239 . " regressor, {$base->type()} given."); 240 } 241 242 $this->booster = $booster ?? new RegressionTree(3); 243 $this->rate = $rate; 244 $this->ratio = $ratio; 245 $this->estimators = $estimators; 246 $this->minChange = $minChange; 247 $this->window = $window; 248 $this->holdOut = $holdOut; 249 $this->metric = $metric ?? new RMSE(); 250 $this->base = $base ?? new DummyRegressor(new Mean()); 251 } 252 253 /** 254 * Return the estimator type. 255 * 256 * @internal 257 * 258 * @return \Rubix\ML\EstimatorType 259 */ 260 public function type() : EstimatorType 261 { 262 return EstimatorType::regressor(); 263 } 264 265 /** 266 * Return the data types that the estimator is compatible with. 267 * 268 * @internal 269 * 270 * @return list<\Rubix\ML\DataType> 271 */ 272 public function compatibility() : array 273 { 274 $compatibility = array_intersect( 275 $this->booster->compatibility(), 276 $this->base->compatibility() 277 ); 278 279 return array_values($compatibility); 280 } 281 282 /** 283 * Return the settings of the hyper-parameters in an associative array. 284 * 285 * @internal 286 * 287 * @return mixed[] 288 */ 289 public function params() : array 290 { 291 return [ 292 'booster' => $this->booster, 293 'rate' => $this->rate, 294 'ratio' => $this->ratio, 295 'estimators' => $this->estimators, 296 'min_change' => $this->minChange, 297 'window' => $this->window, 298 'hold_out' => $this->holdOut, 299 'metric' => $this->metric, 300 'base' => $this->base, 301 ]; 302 } 303 304 /** 305 * Has the learner been trained? 306 * 307 * @return bool 308 */ 309 public function trained() : bool 310 { 311 return $this->base->trained() and $this->ensemble; 312 } 313 314 /** 315 * Return the validation scores at each epoch from the last training session. 316 * 317 * @return float[]|null 318 */ 319 public function scores() : ?array 320 { 321 return $this->scores; 322 } 323 324 /** 325 * Return the loss at each epoch from the last training session. 326 * 327 * @return float[]|null 328 */ 329 public function steps() : ?array 330 { 331 return $this->steps; 332 } 333 334 /** 335 * Train the estimator with a dataset. 336 * 337 * @param \Rubix\ML\Datasets\Labeled $dataset 338 */ 339 public function train(Dataset $dataset) : void 340 { 341 SpecificationChain::with([ 342 new DatasetIsLabeled($dataset), 343 new DatasetIsNotEmpty($dataset), 344 new SamplesAreCompatibleWithEstimator($dataset, $this), 345 new LabelsAreCompatibleWithLearner($dataset, $this), 346 ])->check(); 347 348 if ($this->logger) { 349 $this->logger->info("$this initialized"); 350 } 351 352 $this->featureCount = $dataset->numColumns(); 353 354 [$testing, $training] = $dataset->randomize()->split($this->holdOut); 355 356 [$min, $max] = $this->metric->range(); 357 358 if ($this->logger) { 359 $this->logger->info("Training {$this->base}"); 360 } 361 362 $this->base->train($training); 363 364 $this->ensemble = $this->scores = $this->steps = []; 365 366 /** @var list<int|float> $predictions */ 367 $predictions = $this->base->predict($training); 368 369 $out = $prevOut = Vector::quick($predictions); 370 $target = Vector::quick($training->labels()); 371 372 if (!$testing->empty()) { 373 /** @var list<int|float> $predictions */ 374 $predictions = $this->base->predict($testing); 375 376 $prevPred = Vector::quick($predictions); 377 } 378 379 $p = max(self::MIN_SUBSAMPLE, (int) round($this->ratio * $training->numRows())); 380 381 $bestScore = $min; 382 $bestEpoch = $delta = 0; 383 $score = null; 384 $prevLoss = INF; 385 386 for ($epoch = 1; $epoch <= $this->estimators; ++$epoch) { 387 $gradient = $target->subtract($out); 388 389 $training = Labeled::quick($training->samples(), $gradient->asArray()); 390 391 $booster = clone $this->booster; 392 393 $subset = $training->randomSubset($p); 394 395 $booster->train($subset); 396 397 $this->ensemble[] = $booster; 398 399 /** @var list<int|float> $predictions */ 400 $predictions = $booster->predict($training); 401 402 $out = Vector::quick($predictions) 403 ->multiply($this->rate) 404 ->add($prevOut); 405 406 $loss = $gradient->square()->mean(); 407 408 if (is_nan($loss)) { 409 if ($this->logger) { 410 $this->logger->info('Numerical instability detected'); 411 } 412 413 break; 414 } 415 416 $this->steps[] = $loss; 417 418 if (isset($prevPred)) { 419 /** @var list<int|float> $predictions */ 420 $predictions = $booster->predict($testing); 421 422 $pred = Vector::quick($predictions) 423 ->multiply($this->rate) 424 ->add($prevPred); 425 426 $score = $this->metric->score($pred->asArray(), $testing->labels()); 427 428 $this->scores[] = $score; 429 } 430 431 if ($this->logger) { 432 $this->logger->info("Epoch $epoch - {$this->metric}: " 433 . ($score ?? 'n/a') . ", L2 Loss: $loss"); 434 } 435 436 if (isset($pred)) { 437 if ($score >= $max) { 438 break; 439 } 440 441 if ($score > $bestScore) { 442 $bestScore = $score; 443 $bestEpoch = $epoch; 444 445 $delta = 0; 446 } else { 447 ++$delta; 448 } 449 450 if ($delta >= $this->window) { 451 break; 452 } 453 454 $prevPred = $pred; 455 } 456 457 if ($loss <= 0.0) { 458 break; 459 } 460 461 if (abs($prevLoss - $loss) < $this->minChange) { 462 break; 463 } 464 465 $prevOut = $out; 466 $prevLoss = $loss; 467 } 468 469 if ($this->scores and end($this->scores) < $bestScore) { 470 if ($this->logger) { 471 $this->logger->info("Restoring ensemble state to epoch $bestEpoch"); 472 } 473 474 $this->ensemble = array_slice($this->ensemble, 0, $bestEpoch); 475 } 476 477 if ($this->logger) { 478 $this->logger->info('Training complete'); 479 } 480 } 481 482 /** 483 * Make a prediction from a dataset. 484 * 485 * @param \Rubix\ML\Datasets\Dataset $dataset 486 * @throws \Rubix\ML\Exceptions\RuntimeException 487 * @return list<int|float> 488 */ 489 public function predict(Dataset $dataset) : array 490 { 491 if (!$this->ensemble or !$this->featureCount) { 492 throw new RuntimeException('Estimator has not been trained.'); 493 } 494 495 DatasetHasDimensionality::with($dataset, $this->featureCount)->check(); 496 497 /** @var list<int|float> $predictions */ 498 $predictions = $this->base->predict($dataset); 499 500 foreach ($this->ensemble as $estimator) { 501 /** @var int $j */ 502 foreach ($estimator->predict($dataset) as $j => $prediction) { 503 $predictions[$j] += $this->rate * $prediction; 504 } 505 } 506 507 return $predictions; 508 } 509 510 /** 511 * Return the normalized importance scores of each feature column of the training set. 512 * 513 * @throws \Rubix\ML\Exceptions\RuntimeException 514 * @return float[] 515 */ 516 public function featureImportances() : array 517 { 518 if (!$this->ensemble or !$this->featureCount) { 519 throw new RuntimeException('Estimator has not been trained.'); 520 } 521 522 $importances = array_fill(0, $this->featureCount, 0.0); 523 524 foreach ($this->ensemble as $tree) { 525 foreach ($tree->featureImportances() as $column => $importance) { 526 $importances[$column] += $importance; 527 } 528 } 529 530 $n = count($this->ensemble); 531 532 foreach ($importances as &$importance) { 533 $importance /= $n; 534 } 535 536 return $importances; 537 } 538 539 /** 540 * Return the string representation of the object. 541 * 542 * @return string 543 */ 544 public function __toString() : string 545 { 546 return 'Gradient Boost (' . Params::stringify($this->params()) . ')'; 547 } 548} 549