1<?php
2
3namespace Rubix\ML\CrossValidation\Reports;
4
5use Rubix\ML\Report;
6use Rubix\ML\Estimator;
7use Rubix\ML\EstimatorType;
8use Rubix\ML\Specifications\PredictionAndLabelCountsAreEqual;
9
10use function count;
11
12use const Rubix\ML\EPSILON;
13
14/**
15 * Multiclass Breakdown
16 *
17 * A multiclass classification report that computes a number of metrics (Accuracy, Precision,
18 * Recall, etc.) derived from their confusion matrix on an overall and individual class basis.
19 *
20 * @category    Machine Learning
21 * @package     Rubix/ML
22 * @author      Andrew DalPino
23 */
24class MulticlassBreakdown implements ReportGenerator
25{
26    /**
27     * The estimator types that this report is compatible with.
28     *
29     * @internal
30     *
31     * @return list<\Rubix\ML\EstimatorType>
32     */
33    public function compatibility() : array
34    {
35        return [
36            EstimatorType::classifier(),
37            EstimatorType::anomalyDetector(),
38        ];
39    }
40
41    /**
42     * Generate the report.
43     *
44     * @param list<string|int> $predictions
45     * @param list<string|int> $labels
46     * @return \Rubix\ML\Report
47     */
48    public function generate(array $predictions, array $labels) : Report
49    {
50        PredictionAndLabelCountsAreEqual::with($predictions, $labels)->check();
51
52        $classes = array_unique(array_merge($predictions, $labels));
53
54        $n = count($predictions);
55        $k = count($classes);
56
57        $truePos = $trueNeg = $falsePos = $falseNeg = array_fill_keys($classes, 0);
58
59        foreach ($predictions as $i => $prediction) {
60            $label = $labels[$i];
61
62            if ($prediction == $label) {
63                ++$truePos[$prediction];
64
65                foreach ($classes as $class) {
66                    if ($class != $prediction) {
67                        ++$trueNeg[$class];
68                    }
69                }
70            } else {
71                ++$falsePos[$prediction];
72                ++$falseNeg[$label];
73            }
74        }
75
76        $averages = array_fill_keys([
77            'accuracy', 'accuracy_balanced', 'f1_score', 'precision', 'recall', 'specificity',
78            'negative_predictive_value', 'false_discovery_rate', 'miss_rate', 'fall_out',
79            'false_omission_rate', 'threat_score', 'mcc', 'informedness', 'markedness',
80        ], 0.0);
81
82        $counts = array_fill_keys([
83            'true_positives', 'true_negatives', 'false_positives', 'false_negatives',
84        ], 0);
85
86        $overall = $averages + $counts;
87
88        $table = [];
89
90        foreach ($truePos as $label => $tp) {
91            $tn = $trueNeg[$label];
92            $fp = $falsePos[$label];
93            $fn = $falseNeg[$label];
94
95            $accuracy = ($tp + $tn) / (($tp + $tn + $fp + $fn) ?: EPSILON);
96            $precision = $tp / (($tp + $fp) ?: EPSILON);
97            $recall = $tp / (($tp + $fn) ?: EPSILON);
98            $specificity = $tn / (($tn + $fp) ?: EPSILON);
99            $npv = $tn / (($tn + $fn) ?: EPSILON);
100            $threatScore = $tp / (($tp + $fn + $fp) ?: EPSILON);
101
102            $f1score = 2.0 * (($precision * $recall)
103                / (($precision + $recall) ?: EPSILON));
104
105            $mcc = ($tp * $tn - $fp * $fn)
106                / (sqrt(($tp + $fp) * ($tp + $fn)
107                * ($tn + $fp) * ($tn + $fn)) ?: EPSILON);
108
109            $cardinality = $tp + $fn;
110
111            $table[$label] = [
112                'accuracy' => $accuracy,
113                'accuracy_balanced' => ($recall + $specificity) / 2.0,
114                'f1_score' => $f1score,
115                'precision' => $precision,
116                'recall' => $recall,
117                'specificity' => $specificity,
118                'negative_predictive_value' => $npv,
119                'false_discovery_rate' => 1.0 - $precision,
120                'miss_rate' => 1.0 - $recall,
121                'fall_out' => 1.0 - $specificity,
122                'false_omission_rate' => 1.0 - $npv,
123                'threat_score' => $threatScore,
124                'informedness' => $recall + $specificity - 1.0,
125                'markedness' => $precision + $npv - 1.0,
126                'mcc' => $mcc,
127                'true_positives' => $tp,
128                'true_negatives' => $tn,
129                'false_positives' => $fp,
130                'false_negatives' => $fn,
131                'cardinality' => $cardinality,
132                'proportion' => $cardinality / $n,
133            ];
134
135            $overall['accuracy'] += $accuracy;
136            $overall['accuracy_balanced'] += ($recall + $specificity) / 2.0;
137            $overall['f1_score'] += $f1score;
138            $overall['precision'] += $precision;
139            $overall['recall'] += $recall;
140            $overall['specificity'] += $specificity;
141            $overall['negative_predictive_value'] += $npv;
142            $overall['false_discovery_rate'] += 1.0 - $precision;
143            $overall['miss_rate'] += 1.0 - $recall;
144            $overall['fall_out'] += 1.0 - $specificity;
145            $overall['false_omission_rate'] += 1.0 - $npv;
146            $overall['threat_score'] += $threatScore;
147            $overall['informedness'] += $recall + $specificity - 1.0;
148            $overall['markedness'] += $precision + $npv - 1.0;
149            $overall['mcc'] += $mcc;
150            $overall['true_positives'] += $tp;
151            $overall['true_negatives'] += $tn;
152            $overall['false_positives'] += $fp;
153            $overall['false_negatives'] += $fn;
154        }
155
156        foreach (array_keys($averages) as $metric) {
157            $overall[$metric] /= $k;
158        }
159
160        $overall += [
161            'cardinality' => $n,
162        ];
163
164        return new Report([
165            'overall' => $overall,
166            'classes' => $table,
167        ]);
168    }
169}
170