1<?php
2
3declare(strict_types=1);
4
5namespace Phpml\Classification\Ensemble;
6
7use Phpml\Classification\Classifier;
8use Phpml\Classification\DecisionTree;
9use Phpml\Exception\InvalidArgumentException;
10
11class RandomForest extends Bagging
12{
13    /**
14     * @var float|string
15     */
16    protected $featureSubsetRatio = 'log';
17
18    /**
19     * @var array|null
20     */
21    protected $columnNames;
22
23    /**
24     * Initializes RandomForest with the given number of trees. More trees
25     * may increase the prediction performance while it will also substantially
26     * increase the processing time and the required memory
27     */
28    public function __construct(int $numClassifier = 50)
29    {
30        parent::__construct($numClassifier);
31
32        $this->setSubsetRatio(1.0);
33    }
34
35    /**
36     * This method is used to determine how many of the original columns (features)
37     * will be used to construct subsets to train base classifiers.<br>
38     *
39     * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
40     *
41     * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
42     * features to be taken into consideration while selecting subspace of features
43     *
44     * @param string|float $ratio
45     */
46    public function setFeatureSubsetRatio($ratio): self
47    {
48        if (!is_string($ratio) && !is_float($ratio)) {
49            throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
50        }
51
52        if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
53            throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
54        }
55
56        if (is_string($ratio) && $ratio !== 'sqrt' && $ratio !== 'log') {
57            throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
58        }
59
60        $this->featureSubsetRatio = $ratio;
61
62        return $this;
63    }
64
65    /**
66     * RandomForest algorithm is usable *only* with DecisionTree
67     *
68     * @return $this
69     */
70    public function setClassifer(string $classifier, array $classifierOptions = [])
71    {
72        if ($classifier !== DecisionTree::class) {
73            throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
74        }
75
76        return parent::setClassifer($classifier, $classifierOptions);
77    }
78
79    /**
80     * This will return an array including an importance value for
81     * each column in the given dataset. Importance values for a column
82     * is the average importance of that column in all trees in the forest
83     */
84    public function getFeatureImportances(): array
85    {
86        // Traverse each tree and sum importance of the columns
87        $sum = [];
88        foreach ($this->classifiers as $tree) {
89            /** @var DecisionTree $tree */
90            $importances = $tree->getFeatureImportances();
91
92            foreach ($importances as $column => $importance) {
93                if (array_key_exists($column, $sum)) {
94                    $sum[$column] += $importance;
95                } else {
96                    $sum[$column] = $importance;
97                }
98            }
99        }
100
101        // Normalize & sort the importance values
102        $total = array_sum($sum);
103        array_walk($sum, function (&$importance) use ($total): void {
104            $importance /= $total;
105        });
106        arsort($sum);
107
108        return $sum;
109    }
110
111    /**
112     * A string array to represent the columns is given. They are useful
113     * when trying to print some information about the trees such as feature importances
114     *
115     * @return $this
116     */
117    public function setColumnNames(array $names)
118    {
119        $this->columnNames = $names;
120
121        return $this;
122    }
123
124    /**
125     * @param DecisionTree $classifier
126     *
127     * @return DecisionTree
128     */
129    protected function initSingleClassifier(Classifier $classifier): Classifier
130    {
131        if (is_float($this->featureSubsetRatio)) {
132            $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
133        } elseif ($this->featureSubsetRatio === 'sqrt') {
134            $featureCount = (int) ($this->featureCount ** .5) + 1;
135        } else {
136            $featureCount = (int) log($this->featureCount, 2) + 1;
137        }
138
139        if ($featureCount >= $this->featureCount) {
140            $featureCount = $this->featureCount;
141        }
142
143        if ($this->columnNames === null) {
144            $this->columnNames = range(0, $this->featureCount - 1);
145        }
146
147        return $classifier
148            ->setColumnNames($this->columnNames)
149            ->setNumFeatures($featureCount);
150    }
151}
152