1<?php 2 3namespace Rubix\ML\Tests\Classifiers; 4 5use Rubix\ML\Learner; 6use Rubix\ML\DataType; 7use Rubix\ML\Estimator; 8use Rubix\ML\EstimatorType; 9use Rubix\ML\Classifiers\SVC; 10use Rubix\ML\Kernels\SVM\RBF; 11use Rubix\ML\Datasets\Labeled; 12use Rubix\ML\Datasets\Unlabeled; 13use Rubix\ML\Datasets\Generators\Blob; 14use Rubix\ML\Transformers\ZScaleStandardizer; 15use Rubix\ML\Datasets\Generators\Agglomerate; 16use Rubix\ML\CrossValidation\Metrics\Accuracy; 17use Rubix\ML\Exceptions\InvalidArgumentException; 18use Rubix\ML\Exceptions\RuntimeException; 19use PHPUnit\Framework\TestCase; 20 21/** 22 * @group Classifiers 23 * @requires extension svm 24 * @covers \Rubix\ML\Classifiers\SVC 25 */ 26class SVCTest extends TestCase 27{ 28 /** 29 * The number of samples in the training set. 30 * 31 * @var int 32 */ 33 protected const TRAIN_SIZE = 200; 34 35 /** 36 * The number of samples in the validation set. 37 * 38 * @var int 39 */ 40 protected const TEST_SIZE = 20; 41 42 /** 43 * The minimum validation score required to pass the test. 44 * 45 * @var float 46 */ 47 protected const MIN_SCORE = 0.9; 48 49 /** 50 * Constant used to see the random number generator. 51 * 52 * @var int 53 */ 54 protected const RANDOM_SEED = 0; 55 56 /** 57 * @var \Rubix\ML\Datasets\Generators\Agglomerate 58 */ 59 protected $generator; 60 61 /** 62 * @var \Rubix\ML\Classifiers\SVC 63 */ 64 protected $estimator; 65 66 /** 67 * @var \Rubix\ML\CrossValidation\Metrics\Accuracy 68 */ 69 protected $metric; 70 71 /** 72 * @before 73 */ 74 protected function setUp() : void 75 { 76 $this->generator = new Agglomerate([ 77 'male' => new Blob([69.2, 195.7, 40.0], [1.0, 3.0, 0.3]), 78 'female' => new Blob([63.7, 168.5, 38.1], [0.8, 2.5, 0.4]), 79 ], [0.45, 0.55]); 80 81 $this->estimator = new SVC(1.0, new RBF(), true, 1e-3); 82 83 $this->metric = new Accuracy(); 84 85 srand(self::RANDOM_SEED); 86 } 87 88 /** 89 * @test 90 */ 91 public function build() : void 92 { 93 $this->assertInstanceOf(SVC::class, $this->estimator); 94 $this->assertInstanceOf(Learner::class, $this->estimator); 95 $this->assertInstanceOf(Estimator::class, $this->estimator); 96 } 97 98 /** 99 * @test 100 */ 101 public function type() : void 102 { 103 $this->assertEquals(EstimatorType::classifier(), $this->estimator->type()); 104 } 105 106 /** 107 * @test 108 */ 109 public function compatibility() : void 110 { 111 $expected = [ 112 DataType::continuous(), 113 ]; 114 115 $this->assertEquals($expected, $this->estimator->compatibility()); 116 } 117 118 /** 119 * @test 120 */ 121 public function params() : void 122 { 123 $expected = [ 124 'c' => 1.0, 125 'kernel' => new RBF(), 126 'shrinking' => true, 127 'tolerance' => 1e-3, 128 'cache_size' => 100.0, 129 ]; 130 131 $this->assertEquals($expected, $this->estimator->params()); 132 } 133 134 /** 135 * @test 136 */ 137 public function trainPredict() : void 138 { 139 $dataset = $this->generator->generate(self::TRAIN_SIZE + self::TEST_SIZE); 140 141 $dataset->apply(new ZScaleStandardizer()); 142 143 $testing = $dataset->randomize()->take(self::TEST_SIZE); 144 145 $this->estimator->train($dataset); 146 147 $this->assertTrue($this->estimator->trained()); 148 149 $predictions = $this->estimator->predict($testing); 150 151 $score = $this->metric->score($predictions, $testing->labels()); 152 153 $this->assertGreaterThanOrEqual(self::MIN_SCORE, $score); 154 } 155 156 /** 157 * @test 158 */ 159 public function trainIncompatible() : void 160 { 161 $this->expectException(InvalidArgumentException::class); 162 163 $this->estimator->train(Labeled::quick([['bad']])); 164 } 165 166 /** 167 * @test 168 */ 169 public function predictUntrained() : void 170 { 171 $this->expectException(RuntimeException::class); 172 173 $this->estimator->predict(Unlabeled::quick([[1.5]])); 174 } 175 176 protected function assertPreConditions() : void 177 { 178 $this->assertFalse($this->estimator->trained()); 179 } 180} 181