1<?php 2 3namespace Rubix\ML\CrossValidation; 4 5use Rubix\ML\Learner; 6use Rubix\ML\Estimator; 7use Rubix\ML\Datasets\Labeled; 8use Rubix\ML\CrossValidation\Metrics\Metric; 9use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric; 10use Rubix\ML\Exceptions\InvalidArgumentException; 11use Rubix\ML\Exceptions\RuntimeException; 12 13/** 14 * Hold Out 15 * 16 * Hold Out is a quick and simple cross validation technique that uses a validation set 17 * that is *held out* from the training data. The advantages of Hold Out is that the 18 * validation score is quick to compute, however it does not allow the learner to *both* 19 * train and test on all the data in the training set. 20 * 21 * @category Machine Learning 22 * @package Rubix/ML 23 * @author Andrew DalPino 24 */ 25class HoldOut implements Validator 26{ 27 /** 28 * The hold out ratio. i.e. the ratio of samples to use for testing. 29 * 30 * @var float 31 */ 32 protected $ratio; 33 34 /** 35 * @param float $ratio 36 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 37 */ 38 public function __construct(float $ratio = 0.2) 39 { 40 if ($ratio <= 0.0 or $ratio >= 1.0) { 41 throw new InvalidArgumentException('Ratio must be' 42 . " between 0 and 1, $ratio given."); 43 } 44 45 $this->ratio = $ratio; 46 } 47 48 /** 49 * Test the estimator with the supplied dataset and return a validation score. 50 * 51 * @param \Rubix\ML\Learner $estimator 52 * @param \Rubix\ML\Datasets\Labeled $dataset 53 * @param \Rubix\ML\CrossValidation\Metrics\Metric $metric 54 * @throws \Rubix\ML\Exceptions\RuntimeException 55 * @return float 56 */ 57 public function test(Learner $estimator, Labeled $dataset, Metric $metric) : float 58 { 59 EstimatorIsCompatibleWithMetric::with($estimator, $metric)->check(); 60 61 $dataset->randomize(); 62 63 [$testing, $training] = $dataset->labelType()->isCategorical() 64 ? $dataset->stratifiedSplit($this->ratio) 65 : $dataset->split($this->ratio); 66 67 if ($testing->empty()) { 68 throw new RuntimeException('Dataset does not contain' 69 . ' enough records to create a validation set with a' 70 . " hold out ratio of {$this->ratio}."); 71 } 72 73 $estimator->train($training); 74 75 $predictions = $estimator->predict($testing); 76 77 return $metric->score($predictions, $testing->labels()); 78 } 79 80 /** 81 * Return the string representation of the object. 82 * 83 * @return string 84 */ 85 public function __toString() : string 86 { 87 return "Hold Out (ratio: {$this->ratio})"; 88 } 89} 90