1<?php
2
3declare(strict_types=1);
4
5namespace Phpml\Metric;
6
7use Phpml\Exception\InvalidArgumentException;
8
9class ClassificationReport
10{
11    public const MICRO_AVERAGE = 1;
12
13    public const MACRO_AVERAGE = 2;
14
15    public const WEIGHTED_AVERAGE = 3;
16
17    /**
18     * @var array
19     */
20    private $truePositive = [];
21
22    /**
23     * @var array
24     */
25    private $falsePositive = [];
26
27    /**
28     * @var array
29     */
30    private $falseNegative = [];
31
32    /**
33     * @var array
34     */
35    private $support = [];
36
37    /**
38     * @var array
39     */
40    private $precision = [];
41
42    /**
43     * @var array
44     */
45    private $recall = [];
46
47    /**
48     * @var array
49     */
50    private $f1score = [];
51
52    /**
53     * @var array
54     */
55    private $average = [];
56
57    public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
58    {
59        $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
60        if (!in_array($average, $averagingMethods, true)) {
61            throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
62        }
63
64        $this->aggregateClassificationResults($actualLabels, $predictedLabels);
65        $this->computeMetrics();
66        $this->computeAverage($average);
67    }
68
69    public function getPrecision(): array
70    {
71        return $this->precision;
72    }
73
74    public function getRecall(): array
75    {
76        return $this->recall;
77    }
78
79    public function getF1score(): array
80    {
81        return $this->f1score;
82    }
83
84    public function getSupport(): array
85    {
86        return $this->support;
87    }
88
89    public function getAverage(): array
90    {
91        return $this->average;
92    }
93
94    private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
95    {
96        $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
97
98        foreach ($actualLabels as $index => $actual) {
99            $predicted = $predictedLabels[$index];
100            ++$support[$actual];
101
102            if ($actual === $predicted) {
103                ++$truePositive[$actual];
104            } else {
105                ++$falsePositive[$predicted];
106                ++$falseNegative[$actual];
107            }
108        }
109
110        $this->truePositive = $truePositive;
111        $this->falsePositive = $falsePositive;
112        $this->falseNegative = $falseNegative;
113        $this->support = $support;
114    }
115
116    private function computeMetrics(): void
117    {
118        foreach ($this->truePositive as $label => $tp) {
119            $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
120            $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
121            $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
122        }
123    }
124
125    private function computeAverage(int $average): void
126    {
127        switch ($average) {
128            case self::MICRO_AVERAGE:
129                $this->computeMicroAverage();
130
131                return;
132            case self::MACRO_AVERAGE:
133                $this->computeMacroAverage();
134
135                return;
136            case self::WEIGHTED_AVERAGE:
137                $this->computeWeightedAverage();
138
139                return;
140        }
141    }
142
143    private function computeMicroAverage(): void
144    {
145        $truePositive = (int) array_sum($this->truePositive);
146        $falsePositive = (int) array_sum($this->falsePositive);
147        $falseNegative = (int) array_sum($this->falseNegative);
148
149        $precision = $this->computePrecision($truePositive, $falsePositive);
150        $recall = $this->computeRecall($truePositive, $falseNegative);
151        $f1score = $this->computeF1Score((float) $precision, (float) $recall);
152
153        $this->average = compact('precision', 'recall', 'f1score');
154    }
155
156    private function computeMacroAverage(): void
157    {
158        foreach (['precision', 'recall', 'f1score'] as $metric) {
159            $values = $this->{$metric};
160            if (count($values) == 0) {
161                $this->average[$metric] = 0.0;
162
163                continue;
164            }
165
166            $this->average[$metric] = array_sum($values) / count($values);
167        }
168    }
169
170    private function computeWeightedAverage(): void
171    {
172        foreach (['precision', 'recall', 'f1score'] as $metric) {
173            $values = $this->{$metric};
174            if (count($values) == 0) {
175                $this->average[$metric] = 0.0;
176
177                continue;
178            }
179
180            $sum = 0;
181            foreach ($values as $i => $value) {
182                $sum += $value * $this->support[$i];
183            }
184
185            $this->average[$metric] = $sum / array_sum($this->support);
186        }
187    }
188
189    /**
190     * @return float|string
191     */
192    private function computePrecision(int $truePositive, int $falsePositive)
193    {
194        $divider = $truePositive + $falsePositive;
195        if ($divider == 0) {
196            return 0.0;
197        }
198
199        return $truePositive / $divider;
200    }
201
202    /**
203     * @return float|string
204     */
205    private function computeRecall(int $truePositive, int $falseNegative)
206    {
207        $divider = $truePositive + $falseNegative;
208        if ($divider == 0) {
209            return 0.0;
210        }
211
212        return $truePositive / $divider;
213    }
214
215    private function computeF1Score(float $precision, float $recall): float
216    {
217        $divider = $precision + $recall;
218        if ($divider == 0) {
219            return 0.0;
220        }
221
222        return 2.0 * (($precision * $recall) / $divider);
223    }
224
225    private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
226    {
227        $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
228        sort($labels);
229
230        return (array) array_combine($labels, array_fill(0, count($labels), 0));
231    }
232}
233