1<?php 2 3namespace Rubix\ML\Regressors; 4 5use Tensor\Matrix; 6use Tensor\Vector; 7use Rubix\ML\Learner; 8use Rubix\ML\DataType; 9use Rubix\ML\Estimator; 10use Rubix\ML\Persistable; 11use Rubix\ML\RanksFeatures; 12use Rubix\ML\EstimatorType; 13use Rubix\ML\Datasets\Dataset; 14use Rubix\ML\Other\Helpers\Params; 15use Rubix\ML\Other\Traits\PredictsSingle; 16use Rubix\ML\Other\Traits\AutotrackRevisions; 17use Rubix\ML\Specifications\DatasetIsLabeled; 18use Rubix\ML\Specifications\DatasetIsNotEmpty; 19use Rubix\ML\Specifications\SpecificationChain; 20use Rubix\ML\Specifications\DatasetHasDimensionality; 21use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner; 22use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; 23use Rubix\ML\Exceptions\InvalidArgumentException; 24use Rubix\ML\Exceptions\RuntimeException; 25 26use function is_null; 27 28/** 29 * Ridge 30 * 31 * L2 regularized least squares linear model solved using a closed-form solution. The addition 32 * of regularization, controlled by the *alpha* parameter, makes Ridge less prone to overfitting 33 * than ordinary linear regression. 34 * 35 * @category Machine Learning 36 * @package Rubix/ML 37 * @author Andrew DalPino 38 */ 39class Ridge implements Estimator, Learner, RanksFeatures, Persistable 40{ 41 use AutotrackRevisions, PredictsSingle; 42 43 /** 44 * The strength of the L2 regularization penalty. 45 * 46 * @var float 47 */ 48 protected $alpha; 49 50 /** 51 * The y intercept i.e. the bias added to the decision function. 52 * 53 * @var float|null 54 */ 55 protected $bias; 56 57 /** 58 * The computed coefficients of the regression line. 59 * 60 * @var \Tensor\Vector|null 61 */ 62 protected $coefficients; 63 64 /** 65 * @param float $alpha 66 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 67 */ 68 public function __construct(float $alpha = 1.0) 69 { 70 if ($alpha < 0.0) { 71 throw new InvalidArgumentException('Alpha must be' 72 . " greater than 0, $alpha given."); 73 } 74 75 $this->alpha = $alpha; 76 } 77 78 /** 79 * Return the estimator type. 80 * 81 * @internal 82 * 83 * @return \Rubix\ML\EstimatorType 84 */ 85 public function type() : EstimatorType 86 { 87 return EstimatorType::regressor(); 88 } 89 90 /** 91 * Return the data types that the estimator is compatible with. 92 * 93 * @internal 94 * 95 * @return list<\Rubix\ML\DataType> 96 */ 97 public function compatibility() : array 98 { 99 return [ 100 DataType::continuous(), 101 ]; 102 } 103 104 /** 105 * Return the settings of the hyper-parameters in an associative array. 106 * 107 * @internal 108 * 109 * @return mixed[] 110 */ 111 public function params() : array 112 { 113 return [ 114 'alpha' => $this->alpha, 115 ]; 116 } 117 118 /** 119 * Has the learner been trained? 120 * 121 * @return bool 122 */ 123 public function trained() : bool 124 { 125 return $this->coefficients and isset($this->bias); 126 } 127 128 /** 129 * Return the weights of features in the decision function. 130 * 131 * @return (int|float)[]|null 132 */ 133 public function coefficients() : ?array 134 { 135 return $this->coefficients ? $this->coefficients->asArray() : null; 136 } 137 138 /** 139 * Return the bias added to the decision function. 140 * 141 * @return float|null 142 */ 143 public function bias() : ?float 144 { 145 return $this->bias; 146 } 147 148 /** 149 * Train the learner with a dataset. 150 * 151 * @param \Rubix\ML\Datasets\Labeled $dataset 152 */ 153 public function train(Dataset $dataset) : void 154 { 155 SpecificationChain::with([ 156 new DatasetIsLabeled($dataset), 157 new DatasetIsNotEmpty($dataset), 158 new SamplesAreCompatibleWithEstimator($dataset, $this), 159 new LabelsAreCompatibleWithLearner($dataset, $this), 160 ])->check(); 161 162 $biases = Matrix::ones($dataset->numRows(), 1); 163 164 $x = Matrix::build($dataset->samples())->augmentLeft($biases); 165 $y = Vector::build($dataset->labels()); 166 167 $alphas = array_fill(0, $x->n() - 1, $this->alpha); 168 169 array_unshift($alphas, 0.0); 170 171 $penalties = Matrix::diagonal($alphas); 172 173 $xT = $x->transpose(); 174 175 $coefficients = $xT->matmul($x) 176 ->add($penalties) 177 ->inverse() 178 ->dot($xT->dot($y)) 179 ->asArray(); 180 181 $this->bias = (float) array_shift($coefficients); 182 $this->coefficients = Vector::quick($coefficients); 183 } 184 185 /** 186 * Make a prediction based on the line calculated from the training data. 187 * 188 * @param \Rubix\ML\Datasets\Dataset $dataset 189 * @throws \Rubix\ML\Exceptions\RuntimeException 190 * @return list<int|float> 191 */ 192 public function predict(Dataset $dataset) : array 193 { 194 if (!$this->coefficients or is_null($this->bias)) { 195 throw new RuntimeException('Estimator has not been trained.'); 196 } 197 198 DatasetHasDimensionality::with($dataset, count($this->coefficients))->check(); 199 200 return Matrix::build($dataset->samples()) 201 ->dot($this->coefficients) 202 ->add($this->bias) 203 ->asArray(); 204 } 205 206 /** 207 * Return the normalized importance scores of each feature column of the training set. 208 * 209 * @throws \Rubix\ML\Exceptions\RuntimeException 210 * @return float[] 211 */ 212 public function featureImportances() : array 213 { 214 if (is_null($this->coefficients)) { 215 throw new RuntimeException('Learner has not been trained.'); 216 } 217 218 $importances = $this->coefficients->abs(); 219 220 return $importances->divide($importances->sum())->asArray(); 221 } 222 223 /** 224 * Return the string representation of the object. 225 * 226 * @return string 227 */ 228 public function __toString() : string 229 { 230 return 'Ridge (' . Params::stringify($this->params()) . ')'; 231 } 232} 233