1<?php
2
3namespace Rubix\ML;
4
5use Rubix\ML\Backends\Serial;
6use Rubix\ML\Datasets\Labeled;
7use Rubix\ML\Datasets\Dataset;
8use Rubix\ML\Backends\Tasks\Task;
9use Rubix\ML\Other\Helpers\Params;
10use Rubix\ML\CrossValidation\KFold;
11use Rubix\ML\Other\Traits\LoggerAware;
12use Rubix\ML\CrossValidation\Validator;
13use Rubix\ML\Other\Traits\PredictsSingle;
14use Rubix\ML\Other\Traits\Multiprocessing;
15use Rubix\ML\CrossValidation\Metrics\RMSE;
16use Rubix\ML\CrossValidation\Metrics\FBeta;
17use Rubix\ML\CrossValidation\Metrics\Metric;
18use Rubix\ML\Other\Traits\AutotrackRevisions;
19use Rubix\ML\Specifications\DatasetIsLabeled;
20use Rubix\ML\CrossValidation\Metrics\Accuracy;
21use Rubix\ML\CrossValidation\Metrics\VMeasure;
22use Rubix\ML\Specifications\DatasetIsNotEmpty;
23use Rubix\ML\Specifications\SpecificationChain;
24use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner;
25use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
26use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
27use Rubix\ML\Exceptions\InvalidArgumentException;
28
29use function count;
30
31/**
32 * Grid Search
33 *
34 * Grid Search is an algorithm that optimizes hyper-parameter selection. From
35 * the user's perspective, the process of training and predicting is the same,
36 * however, under the hood, Grid Search trains one estimator per combination
37 * of parameters and the best model is selected as the base estimator.
38 *
39 * > **Note:** You can choose the hyper-parameters manually or you can generate
40 * them randomly or in a grid using the Params helper.
41 *
42 * @category    Machine Learning
43 * @package     Rubix/ML
44 * @author      Andrew DalPino
45 */
46class GridSearch implements Estimator, Learner, Parallel, Verbose, Wrapper, Persistable
47{
48    use AutotrackRevisions, Multiprocessing, PredictsSingle, LoggerAware;
49
50    /**
51     * The class name of the base estimator.
52     *
53     * @var string
54     */
55    protected $base;
56
57    /**
58     * An array of tuples containing the possible values for each of the base learner's constructor parameters.
59     *
60     * @var array[]
61     */
62    protected $params;
63
64    /**
65     * The validation metric used to score the estimator.
66     *
67     * @var \Rubix\ML\CrossValidation\Metrics\Metric
68     */
69    protected $metric;
70
71    /**
72     * The validator used to test the estimator.
73     *
74     * @var \Rubix\ML\CrossValidation\Validator
75     */
76    protected $validator;
77
78    /**
79     * The argument names for the base estimator's constructor.
80     *
81     * @var string[]
82     */
83    protected $args = [
84        //
85    ];
86
87    /**
88     * The results of the last hyper-parameter search.
89     *
90     * @var array[]|null
91     */
92    protected $results;
93
94    /**
95     * The instance of the estimator with the best parameters.
96     *
97     * @var \Rubix\ML\Learner
98     */
99    protected $estimator;
100
101    /**
102     * Cross validate a learner with a given dataset and return the score.
103     *
104     * @internal
105     *
106     * @param \Rubix\ML\Learner $estimator
107     * @param \Rubix\ML\Datasets\Labeled $dataset
108     * @param \Rubix\ML\CrossValidation\Validator $validator
109     * @param \Rubix\ML\CrossValidation\Metrics\Metric $metric
110     * @return mixed[]
111     */
112    public static function score(Learner $estimator, Labeled $dataset, Validator $validator, Metric $metric) : array
113    {
114        $score = $validator->test($estimator, $dataset, $metric);
115
116        return [$score, $estimator->params()];
117    }
118
119    /**
120     * @param class-string $base
121     * @param array[] $params
122     * @param \Rubix\ML\CrossValidation\Metrics\Metric|null $metric
123     * @param \Rubix\ML\CrossValidation\Validator|null $validator
124     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
125     */
126    public function __construct(
127        string $base,
128        array $params,
129        ?Metric $metric = null,
130        ?Validator $validator = null
131    ) {
132        if (!class_exists($base)) {
133            throw new InvalidArgumentException("Class $base does not exist.");
134        }
135
136        $proxy = new $base(...array_map('current', $params));
137
138        if (!$proxy instanceof Learner) {
139            throw new InvalidArgumentException('Base class must'
140                . ' implement the Learner Interface.');
141        }
142
143        foreach ($params as &$tuple) {
144            $tuple = empty($tuple) ? [null] : array_unique($tuple, SORT_REGULAR);
145        }
146
147        if ($metric) {
148            EstimatorIsCompatibleWithMetric::with($proxy, $metric)->check();
149        } else {
150            switch ($proxy->type()) {
151                case EstimatorType::classifier():
152                    $metric = new FBeta();
153
154                    break;
155
156                case EstimatorType::regressor():
157                    $metric = new RMSE();
158
159                    break;
160
161                case EstimatorType::clusterer():
162                    $metric = new VMeasure();
163
164                    break;
165
166                case EstimatorType::anomalyDetector():
167                    $metric = new FBeta();
168
169                    break;
170
171                default:
172                    $metric = new Accuracy();
173            }
174        }
175
176        $this->base = $base;
177        $this->params = $params;
178        $this->metric = $metric;
179        $this->validator = $validator ?? new KFold(3);
180        $this->estimator = $proxy;
181        $this->backend = new Serial();
182    }
183
184    /**
185     * Return the estimator type.
186     *
187     * @internal
188     *
189     * @return \Rubix\ML\EstimatorType
190     */
191    public function type() : EstimatorType
192    {
193        return $this->estimator->type();
194    }
195
196    /**
197     * Return the data types that the estimator is compatible with.
198     *
199     * @internal
200     *
201     * @return list<\Rubix\ML\DataType>
202     */
203    public function compatibility() : array
204    {
205        return $this->trained()
206            ? $this->estimator->compatibility()
207            : DataType::all();
208    }
209
210    /**
211     * Return the settings of the hyper-parameters in an associative array.
212     *
213     * @internal
214     *
215     * @return mixed[]
216     */
217    public function params() : array
218    {
219        return [
220            'base' => $this->base,
221            'params' => $this->params,
222            'metric' => $this->metric,
223            'validator' => $this->validator,
224        ];
225    }
226
227    /**
228     * Has the learner been trained?
229     *
230     * @return bool
231     */
232    public function trained() : bool
233    {
234        return $this->estimator->trained();
235    }
236
237    /**
238     * Return an array containing the validation scores and hyper-parameters under test
239     * for each combination in a 2-tuple.
240     *
241     * @return array[]|null
242     */
243    public function results() : ?array
244    {
245        return $this->results;
246    }
247
248    /**
249     * Return an array containing the hyper-parameters with the highest validation score
250     * from the last search.
251     *
252     * @return mixed[]|null
253     */
254    public function best() : ?array
255    {
256        return $this->results ? $this->results[0][1] : null;
257    }
258
259    /**
260     * Return the base estimator instance.
261     *
262     * @return \Rubix\ML\Estimator
263     */
264    public function base() : Estimator
265    {
266        return $this->estimator;
267    }
268
269    /**
270     * Train one estimator per combination of parameters given by the grid and
271     * assign the best one as the base estimator of this instance.
272     *
273     * @param \Rubix\ML\Datasets\Labeled $dataset
274     */
275    public function train(Dataset $dataset) : void
276    {
277        SpecificationChain::with([
278            new DatasetIsLabeled($dataset),
279            new DatasetIsNotEmpty($dataset),
280            new SamplesAreCompatibleWithEstimator($dataset, $this),
281            new LabelsAreCompatibleWithLearner($dataset, $this),
282        ])->check();
283
284        $combinations = $this->combinations();
285
286        if ($this->logger) {
287            $this->logger->info("$this initialized");
288
289            $this->logger->info('Searching ' . count($combinations)
290                . ' combinations of hyper-parameters');
291        }
292
293        $this->backend->flush();
294
295        foreach ($combinations as $params) {
296            $estimator = new $this->base(...$params);
297
298            $this->backend->enqueue(
299                new Task(
300                    [self::class, 'score'],
301                    [
302                        $estimator,
303                        $dataset,
304                        $this->validator,
305                        $this->metric,
306                    ]
307                ),
308                [$this, 'afterScore']
309            );
310        }
311
312        [$scores, $combinations] = array_transpose($this->backend->process());
313
314        array_multisort($scores, $combinations, SORT_DESC);
315
316        $this->results = array_transpose([$scores, $combinations]);
317
318        $best = reset($combinations);
319
320        $estimator = new $this->base(...array_values($best));
321
322        if ($this->logger) {
323            $this->logger->info('Training with best hyper-parameters');
324        }
325
326        $estimator->train($dataset);
327
328        $this->estimator = $estimator;
329
330        if ($this->logger) {
331            $this->logger->info('Training complete');
332        }
333    }
334
335    /**
336     * Make a prediction on a given sample dataset.
337     *
338     * @param \Rubix\ML\Datasets\Dataset $dataset
339     * @throws \Rubix\ML\Exceptions\RuntimeException
340     * @return mixed[]
341     */
342    public function predict(Dataset $dataset) : array
343    {
344        return $this->estimator->predict($dataset);
345    }
346
347    /**
348     * The callback that executes after the scoring task.
349     *
350     * @internal
351     *
352     * @param mixed[] $result
353     */
354    public function afterScore(array $result) : void
355    {
356        if ($this->logger) {
357            [$score, $params] = $result;
358
359            $this->logger->info(
360                "{$this->metric}: $score, params: [" . Params::stringify($params) . ']'
361            );
362        }
363    }
364
365    /**
366     * Return an array of all possible combinations of parameters. i.e the Cartesian product of
367     * the user-supplied parameter array.
368     *
369     * @return array[]
370     */
371    protected function combinations() : array
372    {
373        $combinations = [[]];
374
375        foreach ($this->params as $i => $params) {
376            $append = [];
377
378            foreach ($combinations as $product) {
379                foreach ($params as $param) {
380                    $product[$i] = $param;
381                    $append[] = $product;
382                }
383            }
384
385            $combinations = $append;
386        }
387
388        return $combinations;
389    }
390
391    /**
392     * Allow methods to be called on the estimator from the wrapper.
393     *
394     * @param string $name
395     * @param mixed[] $arguments
396     * @return mixed
397     */
398    public function __call(string $name, array $arguments)
399    {
400        return $this->estimator->$name(...$arguments);
401    }
402
403    /**
404     * Return the string representation of the object.
405     *
406     * @return string
407     */
408    public function __toString() : string
409    {
410        return 'Grid Search (' . Params::stringify($this->params()) . ')';
411    }
412}
413