1<?php
2
3namespace Rubix\ML\AnomalyDetectors;
4
5use Rubix\ML\Learner;
6use Rubix\ML\DataType;
7use Rubix\ML\Estimator;
8use Rubix\ML\EstimatorType;
9use Rubix\ML\Kernels\SVM\RBF;
10use Rubix\ML\Datasets\Dataset;
11use Rubix\ML\Kernels\SVM\Kernel;
12use Rubix\ML\Other\Helpers\Params;
13use Rubix\ML\Specifications\ExtensionIsLoaded;
14use Rubix\ML\Specifications\DatasetIsNotEmpty;
15use Rubix\ML\Specifications\SpecificationChain;
16use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
17use Rubix\ML\Exceptions\InvalidArgumentException;
18use Rubix\ML\Exceptions\RuntimeException;
19use svmmodel;
20use svm;
21
22/**
23 * One Class SVM
24 *
25 * An unsupervised Support Vector Machine (SVM) used for anomaly detection. The One
26 * Class SVM aims to find a maximum margin between a set of data points and the
27 * *origin*, rather than between classes such as with SVC.
28 *
29 * > **Note:** This estimator requires the SVM extension which uses the libsvm engine
30 * under the hood.
31 *
32 * References:
33 * [1] C. Chang et al. (2011). LIBSVM: A library for support vector machines.
34 *
35 * @category    Machine Learning
36 * @package     Rubix/ML
37 * @author      Andrew DalPino
38 */
39class OneClassSVM implements Estimator, Learner
40{
41    /**
42     * The support vector machine instance.
43     *
44     * @var \svm
45     */
46    protected $svm;
47
48    /**
49     * The hyper-parameters of the model.
50     *
51     * @var mixed[]
52     */
53    protected $params;
54
55    /**
56     * The trained model instance.
57     *
58     * @var \svmmodel|null
59     */
60    protected $model;
61
62    /**
63     * @param float $nu
64     * @param \Rubix\ML\Kernels\SVM\Kernel|null $kernel
65     * @param bool $shrinking
66     * @param float $tolerance
67     * @param float $cacheSize
68     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
69     */
70    public function __construct(
71        float $nu = 0.5,
72        ?Kernel $kernel = null,
73        bool $shrinking = true,
74        float $tolerance = 1e-3,
75        float $cacheSize = 100.0
76    ) {
77        ExtensionIsLoaded::with('svm')->check();
78
79        if ($nu < 0.0 or $nu > 1.0) {
80            throw new InvalidArgumentException('Nu must be between'
81                . "0 and 1, $nu given.");
82        }
83
84        $kernel = $kernel ?? new RBF();
85
86        if ($tolerance < 0.0) {
87            throw new InvalidArgumentException('Tolerance must be,'
88                . " greater than 0, $tolerance given.");
89        }
90
91        if ($cacheSize <= 0.0) {
92            throw new InvalidArgumentException('Cache size must be'
93                . " greater than 0M, {$cacheSize}M given.");
94        }
95
96        $options = [
97            svm::OPT_TYPE => svm::ONE_CLASS,
98            svm::OPT_NU => $nu,
99            svm::OPT_SHRINKING => $shrinking,
100            svm::OPT_EPS => $tolerance,
101            svm::OPT_CACHE_SIZE => $cacheSize,
102        ];
103
104        $options += $kernel->options();
105
106        $svm = new svm();
107
108        $svm->setOptions($options);
109
110        $this->svm = $svm;
111
112        $this->params = [
113            'nu' => $nu,
114            'kernel' => $kernel,
115            'shrinking' => $shrinking,
116            'tolerance' => $tolerance,
117            'cache_size' => $cacheSize,
118        ];
119    }
120
121    /**
122     * Return the estimator type.
123     *
124     * @internal
125     *
126     * @return \Rubix\ML\EstimatorType
127     */
128    public function type() : EstimatorType
129    {
130        return EstimatorType::anomalyDetector();
131    }
132
133    /**
134     * Return the data types that the estimator is compatible with.
135     *
136     * @internal
137     *
138     * @return list<\Rubix\ML\DataType>
139     */
140    public function compatibility() : array
141    {
142        return [
143            DataType::continuous(),
144        ];
145    }
146
147    /**
148     * Return the settings of the hyper-parameters in an associative array.
149     *
150     * @internal
151     *
152     * @return mixed[]
153     */
154    public function params() : array
155    {
156        return $this->params;
157    }
158
159    /**
160     * Has the learner been trained?
161     *
162     * @return bool
163     */
164    public function trained() : bool
165    {
166        return isset($this->model);
167    }
168
169    /**
170     * Train the learner with a dataset.
171     *
172     * @param \Rubix\ML\Datasets\Dataset $dataset
173     */
174    public function train(Dataset $dataset) : void
175    {
176        SpecificationChain::with([
177            new DatasetIsNotEmpty($dataset),
178            new SamplesAreCompatibleWithEstimator($dataset, $this),
179        ])->check();
180
181        $this->model = $this->svm->train($dataset->samples());
182    }
183
184    /**
185     * Make predictions from a dataset.
186     *
187     * @param \Rubix\ML\Datasets\Dataset $dataset
188     * @return list<int>
189     */
190    public function predict(Dataset $dataset) : array
191    {
192        return array_map([$this, 'predictSample'], $dataset->samples());
193    }
194
195    /**
196     * Predict a single sample and return the result.
197     *
198     * @internal
199     *
200     * @param list<int|float> $sample
201     * @throws \Rubix\ML\Exceptions\RuntimeException
202     * @return int
203     */
204    public function predictSample(array $sample) : int
205    {
206        if (!$this->model) {
207            throw new RuntimeException('Estimator has not been trained.');
208        }
209
210        return $this->model->predict($sample) !== 1.0 ? 0 : 1;
211    }
212
213    /**
214     * Save the model data to the filesystem.
215     *
216     * @param string $path
217     * @throws \Rubix\ML\Exceptions\RuntimeException
218     */
219    public function save(string $path) : void
220    {
221        if (!$this->model) {
222            throw new RuntimeException('Learner must be trained before saving.');
223        }
224
225        $this->model->save($path);
226    }
227
228    /**
229     * Load model data from the filesystem.
230     *
231     * @param string $path
232     */
233    public function load(string $path) : void
234    {
235        $this->model = new svmmodel($path);
236    }
237
238    /**
239     * Return the string representation of the object.
240     *
241     * @return string
242     */
243    public function __toString() : string
244    {
245        return 'One Class SVM (' . Params::stringify($this->params()) . ')';
246    }
247}
248