1<?php
2
3namespace Rubix\ML\CrossValidation;
4
5use Rubix\ML\Learner;
6use Rubix\ML\Estimator;
7use Rubix\ML\Datasets\Labeled;
8use Rubix\ML\CrossValidation\Metrics\Metric;
9use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
10use Rubix\ML\Exceptions\InvalidArgumentException;
11use Rubix\ML\Exceptions\RuntimeException;
12
13/**
14 * Hold Out
15 *
16 * Hold Out is a quick and simple cross validation technique that uses a validation set
17 * that is *held out* from the training data. The advantages of Hold Out is that the
18 * validation score is quick to compute, however it does not allow the learner to *both*
19 * train and test on all the data in the training set.
20 *
21 * @category    Machine Learning
22 * @package     Rubix/ML
23 * @author      Andrew DalPino
24 */
25class HoldOut implements Validator
26{
27    /**
28     * The hold out ratio. i.e. the ratio of samples to use for testing.
29     *
30     * @var float
31     */
32    protected $ratio;
33
34    /**
35     * @param float $ratio
36     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
37     */
38    public function __construct(float $ratio = 0.2)
39    {
40        if ($ratio <= 0.0 or $ratio >= 1.0) {
41            throw new InvalidArgumentException('Ratio must be'
42                . " between 0 and 1, $ratio given.");
43        }
44
45        $this->ratio = $ratio;
46    }
47
48    /**
49     * Test the estimator with the supplied dataset and return a validation score.
50     *
51     * @param \Rubix\ML\Learner $estimator
52     * @param \Rubix\ML\Datasets\Labeled $dataset
53     * @param \Rubix\ML\CrossValidation\Metrics\Metric $metric
54     * @throws \Rubix\ML\Exceptions\RuntimeException
55     * @return float
56     */
57    public function test(Learner $estimator, Labeled $dataset, Metric $metric) : float
58    {
59        EstimatorIsCompatibleWithMetric::with($estimator, $metric)->check();
60
61        $dataset->randomize();
62
63        [$testing, $training] = $dataset->labelType()->isCategorical()
64            ? $dataset->stratifiedSplit($this->ratio)
65            : $dataset->split($this->ratio);
66
67        if ($testing->empty()) {
68            throw new RuntimeException('Dataset does not contain'
69                . ' enough records to create a validation set with a'
70                . " hold out ratio of {$this->ratio}.");
71        }
72
73        $estimator->train($training);
74
75        $predictions = $estimator->predict($testing);
76
77        return $metric->score($predictions, $testing->labels());
78    }
79
80    /**
81     * Return the string representation of the object.
82     *
83     * @return string
84     */
85    public function __toString() : string
86    {
87        return "Hold Out (ratio: {$this->ratio})";
88    }
89}
90