1<?php
2
3namespace Rubix\ML\Classifiers;
4
5use Rubix\ML\Learner;
6use Rubix\ML\Parallel;
7use Rubix\ML\Estimator;
8use Rubix\ML\Persistable;
9use Rubix\ML\Probabilistic;
10use Rubix\ML\RanksFeatures;
11use Rubix\ML\EstimatorType;
12use Rubix\ML\Backends\Serial;
13use Rubix\ML\Datasets\Dataset;
14use Rubix\ML\Other\Helpers\Params;
15use Rubix\ML\Backends\Tasks\Proba;
16use Rubix\ML\Backends\Tasks\Predict;
17use Rubix\ML\Other\Traits\ProbaSingle;
18use Rubix\ML\Other\Traits\PredictsSingle;
19use Rubix\ML\Backends\Tasks\TrainLearner;
20use Rubix\ML\Other\Traits\Multiprocessing;
21use Rubix\ML\Other\Traits\AutotrackRevisions;
22use Rubix\ML\Specifications\DatasetIsLabeled;
23use Rubix\ML\Specifications\DatasetIsNotEmpty;
24use Rubix\ML\Specifications\SpecificationChain;
25use Rubix\ML\Specifications\DatasetHasDimensionality;
26use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner;
27use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
28use Rubix\ML\Exceptions\InvalidArgumentException;
29use Rubix\ML\Exceptions\RuntimeException;
30
31use function Rubix\ML\argmax;
32use function Rubix\ML\array_transpose;
33use function get_class;
34use function in_array;
35
36/**
37 * Random Forest
38 *
39 * An ensemble classifier that trains an ensemble of Decision Trees (Classification or Extra Trees)
40 * on random subsets (*bootstrap* set) of the training data. Predictions are based on the
41 * probability scores returned from each tree in the forest, averaged and weighted equally.
42 *
43 * References:
44 * [1] L. Breiman. (2001). Random Forests.
45 * [2] L. Breiman et al. (2005). Extremely Randomized Trees.
46 *
47 * @category    Machine Learning
48 * @package     Rubix/ML
49 * @author      Andrew DalPino
50 */
51class RandomForest implements Estimator, Learner, Probabilistic, Parallel, RanksFeatures, Persistable
52{
53    use AutotrackRevisions, Multiprocessing, PredictsSingle, ProbaSingle;
54
55    /**
56     * The class names of the learners that are compatible with the ensemble.
57     *
58     * @var class-string[]
59     */
60    public const COMPATIBLE_LEARNERS = [
61        ClassificationTree::class,
62        ExtraTreeClassifier::class,
63    ];
64
65    /**
66     * The minimum size of each training subset.
67     *
68     * @var int
69     */
70    protected const MIN_SUBSAMPLE = 1;
71
72    /**
73     * The base learner.
74     *
75     * @var \Rubix\ML\Learner
76     */
77    protected $base;
78
79    /**
80     * The number of learners to train in the ensemble.
81     *
82     * @var int
83     */
84    protected $estimators;
85
86    /**
87     * The ratio of samples from the training set to randomly subsample to train each base learner.
88     *
89     * @var float
90     */
91    protected $ratio;
92
93    /**
94     * Should we sample the bootstrap set to compensate for imbalanced class labels?
95     *
96     * @var bool
97     */
98    protected $balanced;
99
100    /**
101     * The decision trees that make up the forest.
102     *
103     * @var list<ClassificationTree|ExtraTreeClassifier>|null
104     */
105    protected $trees;
106
107    /**
108     * The zero vector for the possible class outcomes.
109     *
110     * @var float[]|null
111     */
112    protected $classes;
113
114    /**
115     * The dimensionality of the training set.
116     *
117     * @var int|null
118     */
119    protected $featureCount;
120
121    /**
122     * @param \Rubix\ML\Learner|null $base
123     * @param int $estimators
124     * @param float $ratio
125     * @param bool $balanced
126     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
127     */
128    public function __construct(
129        ?Learner $base = null,
130        int $estimators = 100,
131        float $ratio = 0.2,
132        bool $balanced = false
133    ) {
134        if ($base and !in_array(get_class($base), self::COMPATIBLE_LEARNERS)) {
135            throw new InvalidArgumentException('Base Learner must be'
136                . ' compatible with ensemble.');
137        }
138
139        if ($estimators < 1) {
140            throw new InvalidArgumentException('Number of estimators'
141                . " must be greater than 0, $estimators given.");
142        }
143
144        if ($ratio <= 0.0 or $ratio > 1.5) {
145            throw new InvalidArgumentException('Ratio must be between'
146                . " 0 and 1.5, $ratio given.");
147        }
148
149        $this->base = $base ?? new ClassificationTree();
150        $this->estimators = $estimators;
151        $this->ratio = $ratio;
152        $this->balanced = $balanced;
153        $this->backend = new Serial();
154    }
155
156    /**
157     * Return the estimator type.
158     *
159     * @internal
160     *
161     * @return \Rubix\ML\EstimatorType
162     */
163    public function type() : EstimatorType
164    {
165        return EstimatorType::classifier();
166    }
167
168    /**
169     * Return the data types that the estimator is compatible with.
170     *
171     * @internal
172     *
173     * @return list<\Rubix\ML\DataType>
174     */
175    public function compatibility() : array
176    {
177        return $this->base->compatibility();
178    }
179
180    /**
181     * Return the settings of the hyper-parameters in an associative array.
182     *
183     * @internal
184     *
185     * @return mixed[]
186     */
187    public function params() : array
188    {
189        return [
190            'base' => $this->base,
191            'estimators' => $this->estimators,
192            'ratio' => $this->ratio,
193            'balanced' => $this->balanced,
194        ];
195    }
196
197    /**
198     * Has the learner been trained?
199     *
200     * @return bool
201     */
202    public function trained() : bool
203    {
204        return !empty($this->trees);
205    }
206
207    /**
208     * Train the learner with a dataset.
209     *
210     * @param \Rubix\ML\Datasets\Labeled $dataset
211     */
212    public function train(Dataset $dataset) : void
213    {
214        SpecificationChain::with([
215            new DatasetIsLabeled($dataset),
216            new DatasetIsNotEmpty($dataset),
217            new SamplesAreCompatibleWithEstimator($dataset, $this),
218            new LabelsAreCompatibleWithLearner($dataset, $this),
219        ])->check();
220
221        $p = max(self::MIN_SUBSAMPLE, (int) ceil($this->ratio * $dataset->numRows()));
222
223        if ($this->balanced) {
224            $counts = array_count_values($dataset->labels());
225
226            $min = min($counts);
227
228            $weights = [];
229
230            foreach ($dataset->labels() as $label) {
231                $weights[] = $min / $counts[$label];
232            }
233        }
234
235        $this->backend->flush();
236
237        for ($i = 0; $i < $this->estimators; ++$i) {
238            $estimator = clone $this->base;
239
240            if (isset($weights)) {
241                $subset = $dataset->randomWeightedSubsetWithReplacement($p, $weights);
242            } else {
243                $subset = $dataset->randomSubsetWithReplacement($p);
244            }
245
246            $this->backend->enqueue(new TrainLearner($estimator, $subset));
247        }
248
249        $this->trees = $this->backend->process();
250
251        $this->classes = array_fill_keys($dataset->possibleOutcomes(), 0.0);
252
253        $this->featureCount = $dataset->numColumns();
254    }
255
256    /**
257     * Make predictions from a dataset.
258     *
259     * @param \Rubix\ML\Datasets\Dataset $dataset
260     * @throws \Rubix\ML\Exceptions\RuntimeException
261     * @return list<string>
262     */
263    public function predict(Dataset $dataset) : array
264    {
265        if (!$this->trees or !$this->featureCount) {
266            throw new RuntimeException('Estimator has not been trained.');
267        }
268
269        DatasetHasDimensionality::with($dataset, $this->featureCount)->check();
270
271        $this->backend->flush();
272
273        foreach ($this->trees as $estimator) {
274            $this->backend->enqueue(new Predict($estimator, $dataset));
275        }
276
277        $aggregate = array_transpose($this->backend->process());
278
279        $predictions = [];
280
281        foreach ($aggregate as $votes) {
282            $predictions[] = argmax(array_count_values($votes));
283        }
284
285        return $predictions;
286    }
287
288    /**
289     * Estimate the joint probabilities for each possible outcome.
290     *
291     * @param \Rubix\ML\Datasets\Dataset $dataset
292     * @throws \Rubix\ML\Exceptions\RuntimeException
293     * @return list<float[]>
294     */
295    public function proba(Dataset $dataset) : array
296    {
297        if (!$this->trees or !$this->classes or !$this->featureCount) {
298            throw new RuntimeException('Estimator has not been trained.');
299        }
300
301        DatasetHasDimensionality::with($dataset, $this->featureCount)->check();
302
303        $probabilities = array_fill(0, $dataset->numRows(), $this->classes);
304
305        $this->backend->flush();
306
307        foreach ($this->trees as $estimator) {
308            $this->backend->enqueue(new Proba($estimator, $dataset));
309        }
310
311        $aggregate = $this->backend->process();
312
313        foreach ($aggregate as $proba) {
314            /** @var int $i */
315            foreach ($proba as $i => $joint) {
316                foreach ($joint as $class => $probability) {
317                    $probabilities[$i][$class] += $probability;
318                }
319            }
320        }
321
322        foreach ($probabilities as &$joint) {
323            foreach ($joint as &$probability) {
324                $probability /= $this->estimators;
325            }
326        }
327
328        return $probabilities;
329    }
330
331    /**
332     * Return the normalized importance scores of each feature column of the training set.
333     *
334     * @throws \Rubix\ML\Exceptions\RuntimeException
335     * @return float[]
336     */
337    public function featureImportances() : array
338    {
339        if (!$this->trees or !$this->featureCount) {
340            throw new RuntimeException('Estimator has not been trained.');
341        }
342
343        $importances = array_fill(0, $this->featureCount, 0.0);
344
345        foreach ($this->trees as $tree) {
346            foreach ($tree->featureImportances() as $column => $importance) {
347                $importances[$column] += $importance;
348            }
349        }
350
351        foreach ($importances as &$importance) {
352            $importance /= $this->estimators;
353        }
354
355        return $importances;
356    }
357
358    /**
359     * Return the string representation of the object.
360     *
361     * @return string
362     */
363    public function __toString() : string
364    {
365        return 'Random Forest (' . Params::stringify($this->params()) . ')';
366    }
367}
368