1<?php
2
3namespace Rubix\ML\Tests\Classifiers;
4
5use Rubix\ML\Learner;
6use Rubix\ML\DataType;
7use Rubix\ML\Estimator;
8use Rubix\ML\EstimatorType;
9use Rubix\ML\Classifiers\SVC;
10use Rubix\ML\Kernels\SVM\RBF;
11use Rubix\ML\Datasets\Labeled;
12use Rubix\ML\Datasets\Unlabeled;
13use Rubix\ML\Datasets\Generators\Blob;
14use Rubix\ML\Transformers\ZScaleStandardizer;
15use Rubix\ML\Datasets\Generators\Agglomerate;
16use Rubix\ML\CrossValidation\Metrics\Accuracy;
17use Rubix\ML\Exceptions\InvalidArgumentException;
18use Rubix\ML\Exceptions\RuntimeException;
19use PHPUnit\Framework\TestCase;
20
21/**
22 * @group Classifiers
23 * @requires extension svm
24 * @covers \Rubix\ML\Classifiers\SVC
25 */
26class SVCTest extends TestCase
27{
28    /**
29     * The number of samples in the training set.
30     *
31     * @var int
32     */
33    protected const TRAIN_SIZE = 200;
34
35    /**
36     * The number of samples in the validation set.
37     *
38     * @var int
39     */
40    protected const TEST_SIZE = 20;
41
42    /**
43     * The minimum validation score required to pass the test.
44     *
45     * @var float
46     */
47    protected const MIN_SCORE = 0.9;
48
49    /**
50     * Constant used to see the random number generator.
51     *
52     * @var int
53     */
54    protected const RANDOM_SEED = 0;
55
56    /**
57     * @var \Rubix\ML\Datasets\Generators\Agglomerate
58     */
59    protected $generator;
60
61    /**
62     * @var \Rubix\ML\Classifiers\SVC
63     */
64    protected $estimator;
65
66    /**
67     * @var \Rubix\ML\CrossValidation\Metrics\Accuracy
68     */
69    protected $metric;
70
71    /**
72     * @before
73     */
74    protected function setUp() : void
75    {
76        $this->generator = new Agglomerate([
77            'male' => new Blob([69.2, 195.7, 40.0], [1.0, 3.0, 0.3]),
78            'female' => new Blob([63.7, 168.5, 38.1], [0.8, 2.5, 0.4]),
79        ], [0.45, 0.55]);
80
81        $this->estimator = new SVC(1.0, new RBF(), true, 1e-3);
82
83        $this->metric = new Accuracy();
84
85        srand(self::RANDOM_SEED);
86    }
87
88    /**
89     * @test
90     */
91    public function build() : void
92    {
93        $this->assertInstanceOf(SVC::class, $this->estimator);
94        $this->assertInstanceOf(Learner::class, $this->estimator);
95        $this->assertInstanceOf(Estimator::class, $this->estimator);
96    }
97
98    /**
99     * @test
100     */
101    public function type() : void
102    {
103        $this->assertEquals(EstimatorType::classifier(), $this->estimator->type());
104    }
105
106    /**
107     * @test
108     */
109    public function compatibility() : void
110    {
111        $expected = [
112            DataType::continuous(),
113        ];
114
115        $this->assertEquals($expected, $this->estimator->compatibility());
116    }
117
118    /**
119     * @test
120     */
121    public function params() : void
122    {
123        $expected = [
124            'c' => 1.0,
125            'kernel' => new RBF(),
126            'shrinking' => true,
127            'tolerance' => 1e-3,
128            'cache_size' => 100.0,
129        ];
130
131        $this->assertEquals($expected, $this->estimator->params());
132    }
133
134    /**
135     * @test
136     */
137    public function trainPredict() : void
138    {
139        $dataset = $this->generator->generate(self::TRAIN_SIZE + self::TEST_SIZE);
140
141        $dataset->apply(new ZScaleStandardizer());
142
143        $testing = $dataset->randomize()->take(self::TEST_SIZE);
144
145        $this->estimator->train($dataset);
146
147        $this->assertTrue($this->estimator->trained());
148
149        $predictions = $this->estimator->predict($testing);
150
151        $score = $this->metric->score($predictions, $testing->labels());
152
153        $this->assertGreaterThanOrEqual(self::MIN_SCORE, $score);
154    }
155
156    /**
157     * @test
158     */
159    public function trainIncompatible() : void
160    {
161        $this->expectException(InvalidArgumentException::class);
162
163        $this->estimator->train(Labeled::quick([['bad']]));
164    }
165
166    /**
167     * @test
168     */
169    public function predictUntrained() : void
170    {
171        $this->expectException(RuntimeException::class);
172
173        $this->estimator->predict(Unlabeled::quick([[1.5]]));
174    }
175
176    protected function assertPreConditions() : void
177    {
178        $this->assertFalse($this->estimator->trained());
179    }
180}
181