1<?php 2 3namespace Rubix\ML\CrossValidation\Reports; 4 5use Rubix\ML\Report; 6use Rubix\ML\Estimator; 7use Rubix\ML\EstimatorType; 8use Rubix\ML\Specifications\PredictionAndLabelCountsAreEqual; 9 10use function count; 11 12use const Rubix\ML\EPSILON; 13 14/** 15 * Multiclass Breakdown 16 * 17 * A multiclass classification report that computes a number of metrics (Accuracy, Precision, 18 * Recall, etc.) derived from their confusion matrix on an overall and individual class basis. 19 * 20 * @category Machine Learning 21 * @package Rubix/ML 22 * @author Andrew DalPino 23 */ 24class MulticlassBreakdown implements ReportGenerator 25{ 26 /** 27 * The estimator types that this report is compatible with. 28 * 29 * @internal 30 * 31 * @return list<\Rubix\ML\EstimatorType> 32 */ 33 public function compatibility() : array 34 { 35 return [ 36 EstimatorType::classifier(), 37 EstimatorType::anomalyDetector(), 38 ]; 39 } 40 41 /** 42 * Generate the report. 43 * 44 * @param list<string|int> $predictions 45 * @param list<string|int> $labels 46 * @return \Rubix\ML\Report 47 */ 48 public function generate(array $predictions, array $labels) : Report 49 { 50 PredictionAndLabelCountsAreEqual::with($predictions, $labels)->check(); 51 52 $classes = array_unique(array_merge($predictions, $labels)); 53 54 $n = count($predictions); 55 $k = count($classes); 56 57 $truePos = $trueNeg = $falsePos = $falseNeg = array_fill_keys($classes, 0); 58 59 foreach ($predictions as $i => $prediction) { 60 $label = $labels[$i]; 61 62 if ($prediction == $label) { 63 ++$truePos[$prediction]; 64 65 foreach ($classes as $class) { 66 if ($class != $prediction) { 67 ++$trueNeg[$class]; 68 } 69 } 70 } else { 71 ++$falsePos[$prediction]; 72 ++$falseNeg[$label]; 73 } 74 } 75 76 $averages = array_fill_keys([ 77 'accuracy', 'accuracy_balanced', 'f1_score', 'precision', 'recall', 'specificity', 78 'negative_predictive_value', 'false_discovery_rate', 'miss_rate', 'fall_out', 79 'false_omission_rate', 'threat_score', 'mcc', 'informedness', 'markedness', 80 ], 0.0); 81 82 $counts = array_fill_keys([ 83 'true_positives', 'true_negatives', 'false_positives', 'false_negatives', 84 ], 0); 85 86 $overall = $averages + $counts; 87 88 $table = []; 89 90 foreach ($truePos as $label => $tp) { 91 $tn = $trueNeg[$label]; 92 $fp = $falsePos[$label]; 93 $fn = $falseNeg[$label]; 94 95 $accuracy = ($tp + $tn) / (($tp + $tn + $fp + $fn) ?: EPSILON); 96 $precision = $tp / (($tp + $fp) ?: EPSILON); 97 $recall = $tp / (($tp + $fn) ?: EPSILON); 98 $specificity = $tn / (($tn + $fp) ?: EPSILON); 99 $npv = $tn / (($tn + $fn) ?: EPSILON); 100 $threatScore = $tp / (($tp + $fn + $fp) ?: EPSILON); 101 102 $f1score = 2.0 * (($precision * $recall) 103 / (($precision + $recall) ?: EPSILON)); 104 105 $mcc = ($tp * $tn - $fp * $fn) 106 / (sqrt(($tp + $fp) * ($tp + $fn) 107 * ($tn + $fp) * ($tn + $fn)) ?: EPSILON); 108 109 $cardinality = $tp + $fn; 110 111 $table[$label] = [ 112 'accuracy' => $accuracy, 113 'accuracy_balanced' => ($recall + $specificity) / 2.0, 114 'f1_score' => $f1score, 115 'precision' => $precision, 116 'recall' => $recall, 117 'specificity' => $specificity, 118 'negative_predictive_value' => $npv, 119 'false_discovery_rate' => 1.0 - $precision, 120 'miss_rate' => 1.0 - $recall, 121 'fall_out' => 1.0 - $specificity, 122 'false_omission_rate' => 1.0 - $npv, 123 'threat_score' => $threatScore, 124 'informedness' => $recall + $specificity - 1.0, 125 'markedness' => $precision + $npv - 1.0, 126 'mcc' => $mcc, 127 'true_positives' => $tp, 128 'true_negatives' => $tn, 129 'false_positives' => $fp, 130 'false_negatives' => $fn, 131 'cardinality' => $cardinality, 132 'proportion' => $cardinality / $n, 133 ]; 134 135 $overall['accuracy'] += $accuracy; 136 $overall['accuracy_balanced'] += ($recall + $specificity) / 2.0; 137 $overall['f1_score'] += $f1score; 138 $overall['precision'] += $precision; 139 $overall['recall'] += $recall; 140 $overall['specificity'] += $specificity; 141 $overall['negative_predictive_value'] += $npv; 142 $overall['false_discovery_rate'] += 1.0 - $precision; 143 $overall['miss_rate'] += 1.0 - $recall; 144 $overall['fall_out'] += 1.0 - $specificity; 145 $overall['false_omission_rate'] += 1.0 - $npv; 146 $overall['threat_score'] += $threatScore; 147 $overall['informedness'] += $recall + $specificity - 1.0; 148 $overall['markedness'] += $precision + $npv - 1.0; 149 $overall['mcc'] += $mcc; 150 $overall['true_positives'] += $tp; 151 $overall['true_negatives'] += $tn; 152 $overall['false_positives'] += $fp; 153 $overall['false_negatives'] += $fn; 154 } 155 156 foreach (array_keys($averages) as $metric) { 157 $overall[$metric] /= $k; 158 } 159 160 $overall += [ 161 'cardinality' => $n, 162 ]; 163 164 return new Report([ 165 'overall' => $overall, 166 'classes' => $table, 167 ]); 168 } 169} 170