1<?php
2
3namespace Rubix\ML\Regressors;
4
5use Tensor\Matrix;
6use Tensor\Vector;
7use Rubix\ML\Learner;
8use Rubix\ML\DataType;
9use Rubix\ML\Estimator;
10use Rubix\ML\Persistable;
11use Rubix\ML\RanksFeatures;
12use Rubix\ML\EstimatorType;
13use Rubix\ML\Datasets\Dataset;
14use Rubix\ML\Other\Helpers\Params;
15use Rubix\ML\Other\Traits\PredictsSingle;
16use Rubix\ML\Other\Traits\AutotrackRevisions;
17use Rubix\ML\Specifications\DatasetIsLabeled;
18use Rubix\ML\Specifications\DatasetIsNotEmpty;
19use Rubix\ML\Specifications\SpecificationChain;
20use Rubix\ML\Specifications\DatasetHasDimensionality;
21use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner;
22use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
23use Rubix\ML\Exceptions\InvalidArgumentException;
24use Rubix\ML\Exceptions\RuntimeException;
25
26use function is_null;
27
28/**
29 * Ridge
30 *
31 * L2 regularized least squares linear model solved using a closed-form solution. The addition
32 * of regularization, controlled by the *alpha* parameter, makes Ridge less prone to overfitting
33 * than ordinary linear regression.
34 *
35 * @category    Machine Learning
36 * @package     Rubix/ML
37 * @author      Andrew DalPino
38 */
39class Ridge implements Estimator, Learner, RanksFeatures, Persistable
40{
41    use AutotrackRevisions, PredictsSingle;
42
43    /**
44     * The strength of the L2 regularization penalty.
45     *
46     * @var float
47     */
48    protected $alpha;
49
50    /**
51     * The y intercept i.e. the bias added to the decision function.
52     *
53     * @var float|null
54     */
55    protected $bias;
56
57    /**
58     * The computed coefficients of the regression line.
59     *
60     * @var \Tensor\Vector|null
61     */
62    protected $coefficients;
63
64    /**
65     * @param float $alpha
66     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
67     */
68    public function __construct(float $alpha = 1.0)
69    {
70        if ($alpha < 0.0) {
71            throw new InvalidArgumentException('Alpha must be'
72                . " greater than 0, $alpha given.");
73        }
74
75        $this->alpha = $alpha;
76    }
77
78    /**
79     * Return the estimator type.
80     *
81     * @internal
82     *
83     * @return \Rubix\ML\EstimatorType
84     */
85    public function type() : EstimatorType
86    {
87        return EstimatorType::regressor();
88    }
89
90    /**
91     * Return the data types that the estimator is compatible with.
92     *
93     * @internal
94     *
95     * @return list<\Rubix\ML\DataType>
96     */
97    public function compatibility() : array
98    {
99        return [
100            DataType::continuous(),
101        ];
102    }
103
104    /**
105     * Return the settings of the hyper-parameters in an associative array.
106     *
107     * @internal
108     *
109     * @return mixed[]
110     */
111    public function params() : array
112    {
113        return [
114            'alpha' => $this->alpha,
115        ];
116    }
117
118    /**
119     * Has the learner been trained?
120     *
121     * @return bool
122     */
123    public function trained() : bool
124    {
125        return $this->coefficients and isset($this->bias);
126    }
127
128    /**
129     * Return the weights of features in the decision function.
130     *
131     * @return (int|float)[]|null
132     */
133    public function coefficients() : ?array
134    {
135        return $this->coefficients ? $this->coefficients->asArray() : null;
136    }
137
138    /**
139     * Return the bias added to the decision function.
140     *
141     * @return float|null
142     */
143    public function bias() : ?float
144    {
145        return $this->bias;
146    }
147
148    /**
149     * Train the learner with a dataset.
150     *
151     * @param \Rubix\ML\Datasets\Labeled $dataset
152     */
153    public function train(Dataset $dataset) : void
154    {
155        SpecificationChain::with([
156            new DatasetIsLabeled($dataset),
157            new DatasetIsNotEmpty($dataset),
158            new SamplesAreCompatibleWithEstimator($dataset, $this),
159            new LabelsAreCompatibleWithLearner($dataset, $this),
160        ])->check();
161
162        $biases = Matrix::ones($dataset->numRows(), 1);
163
164        $x = Matrix::build($dataset->samples())->augmentLeft($biases);
165        $y = Vector::build($dataset->labels());
166
167        $alphas = array_fill(0, $x->n() - 1, $this->alpha);
168
169        array_unshift($alphas, 0.0);
170
171        $penalties = Matrix::diagonal($alphas);
172
173        $xT = $x->transpose();
174
175        $coefficients = $xT->matmul($x)
176            ->add($penalties)
177            ->inverse()
178            ->dot($xT->dot($y))
179            ->asArray();
180
181        $this->bias = (float) array_shift($coefficients);
182        $this->coefficients = Vector::quick($coefficients);
183    }
184
185    /**
186     * Make a prediction based on the line calculated from the training data.
187     *
188     * @param \Rubix\ML\Datasets\Dataset $dataset
189     * @throws \Rubix\ML\Exceptions\RuntimeException
190     * @return list<int|float>
191     */
192    public function predict(Dataset $dataset) : array
193    {
194        if (!$this->coefficients or is_null($this->bias)) {
195            throw new RuntimeException('Estimator has not been trained.');
196        }
197
198        DatasetHasDimensionality::with($dataset, count($this->coefficients))->check();
199
200        return Matrix::build($dataset->samples())
201            ->dot($this->coefficients)
202            ->add($this->bias)
203            ->asArray();
204    }
205
206    /**
207     * Return the normalized importance scores of each feature column of the training set.
208     *
209     * @throws \Rubix\ML\Exceptions\RuntimeException
210     * @return float[]
211     */
212    public function featureImportances() : array
213    {
214        if (is_null($this->coefficients)) {
215            throw new RuntimeException('Learner has not been trained.');
216        }
217
218        $importances = $this->coefficients->abs();
219
220        return $importances->divide($importances->sum())->asArray();
221    }
222
223    /**
224     * Return the string representation of the object.
225     *
226     * @return string
227     */
228    public function __toString() : string
229    {
230        return 'Ridge (' . Params::stringify($this->params()) . ')';
231    }
232}
233