1<?php 2 3namespace Rubix\ML\Classifiers; 4 5use Rubix\ML\Learner; 6use Rubix\ML\Parallel; 7use Rubix\ML\Estimator; 8use Rubix\ML\Persistable; 9use Rubix\ML\Probabilistic; 10use Rubix\ML\RanksFeatures; 11use Rubix\ML\EstimatorType; 12use Rubix\ML\Backends\Serial; 13use Rubix\ML\Datasets\Dataset; 14use Rubix\ML\Other\Helpers\Params; 15use Rubix\ML\Backends\Tasks\Proba; 16use Rubix\ML\Backends\Tasks\Predict; 17use Rubix\ML\Other\Traits\ProbaSingle; 18use Rubix\ML\Other\Traits\PredictsSingle; 19use Rubix\ML\Backends\Tasks\TrainLearner; 20use Rubix\ML\Other\Traits\Multiprocessing; 21use Rubix\ML\Other\Traits\AutotrackRevisions; 22use Rubix\ML\Specifications\DatasetIsLabeled; 23use Rubix\ML\Specifications\DatasetIsNotEmpty; 24use Rubix\ML\Specifications\SpecificationChain; 25use Rubix\ML\Specifications\DatasetHasDimensionality; 26use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner; 27use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; 28use Rubix\ML\Exceptions\InvalidArgumentException; 29use Rubix\ML\Exceptions\RuntimeException; 30 31use function Rubix\ML\argmax; 32use function Rubix\ML\array_transpose; 33use function get_class; 34use function in_array; 35 36/** 37 * Random Forest 38 * 39 * An ensemble classifier that trains an ensemble of Decision Trees (Classification or Extra Trees) 40 * on random subsets (*bootstrap* set) of the training data. Predictions are based on the 41 * probability scores returned from each tree in the forest, averaged and weighted equally. 42 * 43 * References: 44 * [1] L. Breiman. (2001). Random Forests. 45 * [2] L. Breiman et al. (2005). Extremely Randomized Trees. 46 * 47 * @category Machine Learning 48 * @package Rubix/ML 49 * @author Andrew DalPino 50 */ 51class RandomForest implements Estimator, Learner, Probabilistic, Parallel, RanksFeatures, Persistable 52{ 53 use AutotrackRevisions, Multiprocessing, PredictsSingle, ProbaSingle; 54 55 /** 56 * The class names of the learners that are compatible with the ensemble. 57 * 58 * @var class-string[] 59 */ 60 public const COMPATIBLE_LEARNERS = [ 61 ClassificationTree::class, 62 ExtraTreeClassifier::class, 63 ]; 64 65 /** 66 * The minimum size of each training subset. 67 * 68 * @var int 69 */ 70 protected const MIN_SUBSAMPLE = 1; 71 72 /** 73 * The base learner. 74 * 75 * @var \Rubix\ML\Learner 76 */ 77 protected $base; 78 79 /** 80 * The number of learners to train in the ensemble. 81 * 82 * @var int 83 */ 84 protected $estimators; 85 86 /** 87 * The ratio of samples from the training set to randomly subsample to train each base learner. 88 * 89 * @var float 90 */ 91 protected $ratio; 92 93 /** 94 * Should we sample the bootstrap set to compensate for imbalanced class labels? 95 * 96 * @var bool 97 */ 98 protected $balanced; 99 100 /** 101 * The decision trees that make up the forest. 102 * 103 * @var list<ClassificationTree|ExtraTreeClassifier>|null 104 */ 105 protected $trees; 106 107 /** 108 * The zero vector for the possible class outcomes. 109 * 110 * @var float[]|null 111 */ 112 protected $classes; 113 114 /** 115 * The dimensionality of the training set. 116 * 117 * @var int|null 118 */ 119 protected $featureCount; 120 121 /** 122 * @param \Rubix\ML\Learner|null $base 123 * @param int $estimators 124 * @param float $ratio 125 * @param bool $balanced 126 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 127 */ 128 public function __construct( 129 ?Learner $base = null, 130 int $estimators = 100, 131 float $ratio = 0.2, 132 bool $balanced = false 133 ) { 134 if ($base and !in_array(get_class($base), self::COMPATIBLE_LEARNERS)) { 135 throw new InvalidArgumentException('Base Learner must be' 136 . ' compatible with ensemble.'); 137 } 138 139 if ($estimators < 1) { 140 throw new InvalidArgumentException('Number of estimators' 141 . " must be greater than 0, $estimators given."); 142 } 143 144 if ($ratio <= 0.0 or $ratio > 1.5) { 145 throw new InvalidArgumentException('Ratio must be between' 146 . " 0 and 1.5, $ratio given."); 147 } 148 149 $this->base = $base ?? new ClassificationTree(); 150 $this->estimators = $estimators; 151 $this->ratio = $ratio; 152 $this->balanced = $balanced; 153 $this->backend = new Serial(); 154 } 155 156 /** 157 * Return the estimator type. 158 * 159 * @internal 160 * 161 * @return \Rubix\ML\EstimatorType 162 */ 163 public function type() : EstimatorType 164 { 165 return EstimatorType::classifier(); 166 } 167 168 /** 169 * Return the data types that the estimator is compatible with. 170 * 171 * @internal 172 * 173 * @return list<\Rubix\ML\DataType> 174 */ 175 public function compatibility() : array 176 { 177 return $this->base->compatibility(); 178 } 179 180 /** 181 * Return the settings of the hyper-parameters in an associative array. 182 * 183 * @internal 184 * 185 * @return mixed[] 186 */ 187 public function params() : array 188 { 189 return [ 190 'base' => $this->base, 191 'estimators' => $this->estimators, 192 'ratio' => $this->ratio, 193 'balanced' => $this->balanced, 194 ]; 195 } 196 197 /** 198 * Has the learner been trained? 199 * 200 * @return bool 201 */ 202 public function trained() : bool 203 { 204 return !empty($this->trees); 205 } 206 207 /** 208 * Train the learner with a dataset. 209 * 210 * @param \Rubix\ML\Datasets\Labeled $dataset 211 */ 212 public function train(Dataset $dataset) : void 213 { 214 SpecificationChain::with([ 215 new DatasetIsLabeled($dataset), 216 new DatasetIsNotEmpty($dataset), 217 new SamplesAreCompatibleWithEstimator($dataset, $this), 218 new LabelsAreCompatibleWithLearner($dataset, $this), 219 ])->check(); 220 221 $p = max(self::MIN_SUBSAMPLE, (int) ceil($this->ratio * $dataset->numRows())); 222 223 if ($this->balanced) { 224 $counts = array_count_values($dataset->labels()); 225 226 $min = min($counts); 227 228 $weights = []; 229 230 foreach ($dataset->labels() as $label) { 231 $weights[] = $min / $counts[$label]; 232 } 233 } 234 235 $this->backend->flush(); 236 237 for ($i = 0; $i < $this->estimators; ++$i) { 238 $estimator = clone $this->base; 239 240 if (isset($weights)) { 241 $subset = $dataset->randomWeightedSubsetWithReplacement($p, $weights); 242 } else { 243 $subset = $dataset->randomSubsetWithReplacement($p); 244 } 245 246 $this->backend->enqueue(new TrainLearner($estimator, $subset)); 247 } 248 249 $this->trees = $this->backend->process(); 250 251 $this->classes = array_fill_keys($dataset->possibleOutcomes(), 0.0); 252 253 $this->featureCount = $dataset->numColumns(); 254 } 255 256 /** 257 * Make predictions from a dataset. 258 * 259 * @param \Rubix\ML\Datasets\Dataset $dataset 260 * @throws \Rubix\ML\Exceptions\RuntimeException 261 * @return list<string> 262 */ 263 public function predict(Dataset $dataset) : array 264 { 265 if (!$this->trees or !$this->featureCount) { 266 throw new RuntimeException('Estimator has not been trained.'); 267 } 268 269 DatasetHasDimensionality::with($dataset, $this->featureCount)->check(); 270 271 $this->backend->flush(); 272 273 foreach ($this->trees as $estimator) { 274 $this->backend->enqueue(new Predict($estimator, $dataset)); 275 } 276 277 $aggregate = array_transpose($this->backend->process()); 278 279 $predictions = []; 280 281 foreach ($aggregate as $votes) { 282 $predictions[] = argmax(array_count_values($votes)); 283 } 284 285 return $predictions; 286 } 287 288 /** 289 * Estimate the joint probabilities for each possible outcome. 290 * 291 * @param \Rubix\ML\Datasets\Dataset $dataset 292 * @throws \Rubix\ML\Exceptions\RuntimeException 293 * @return list<float[]> 294 */ 295 public function proba(Dataset $dataset) : array 296 { 297 if (!$this->trees or !$this->classes or !$this->featureCount) { 298 throw new RuntimeException('Estimator has not been trained.'); 299 } 300 301 DatasetHasDimensionality::with($dataset, $this->featureCount)->check(); 302 303 $probabilities = array_fill(0, $dataset->numRows(), $this->classes); 304 305 $this->backend->flush(); 306 307 foreach ($this->trees as $estimator) { 308 $this->backend->enqueue(new Proba($estimator, $dataset)); 309 } 310 311 $aggregate = $this->backend->process(); 312 313 foreach ($aggregate as $proba) { 314 /** @var int $i */ 315 foreach ($proba as $i => $joint) { 316 foreach ($joint as $class => $probability) { 317 $probabilities[$i][$class] += $probability; 318 } 319 } 320 } 321 322 foreach ($probabilities as &$joint) { 323 foreach ($joint as &$probability) { 324 $probability /= $this->estimators; 325 } 326 } 327 328 return $probabilities; 329 } 330 331 /** 332 * Return the normalized importance scores of each feature column of the training set. 333 * 334 * @throws \Rubix\ML\Exceptions\RuntimeException 335 * @return float[] 336 */ 337 public function featureImportances() : array 338 { 339 if (!$this->trees or !$this->featureCount) { 340 throw new RuntimeException('Estimator has not been trained.'); 341 } 342 343 $importances = array_fill(0, $this->featureCount, 0.0); 344 345 foreach ($this->trees as $tree) { 346 foreach ($tree->featureImportances() as $column => $importance) { 347 $importances[$column] += $importance; 348 } 349 } 350 351 foreach ($importances as &$importance) { 352 $importance /= $this->estimators; 353 } 354 355 return $importances; 356 } 357 358 /** 359 * Return the string representation of the object. 360 * 361 * @return string 362 */ 363 public function __toString() : string 364 { 365 return 'Random Forest (' . Params::stringify($this->params()) . ')'; 366 } 367} 368