1<?php 2 3namespace Rubix\ML\AnomalyDetectors; 4 5use Rubix\ML\Learner; 6use Rubix\ML\DataType; 7use Rubix\ML\Estimator; 8use Rubix\ML\EstimatorType; 9use Rubix\ML\Kernels\SVM\RBF; 10use Rubix\ML\Datasets\Dataset; 11use Rubix\ML\Kernels\SVM\Kernel; 12use Rubix\ML\Other\Helpers\Params; 13use Rubix\ML\Specifications\ExtensionIsLoaded; 14use Rubix\ML\Specifications\DatasetIsNotEmpty; 15use Rubix\ML\Specifications\SpecificationChain; 16use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; 17use Rubix\ML\Exceptions\InvalidArgumentException; 18use Rubix\ML\Exceptions\RuntimeException; 19use svmmodel; 20use svm; 21 22/** 23 * One Class SVM 24 * 25 * An unsupervised Support Vector Machine (SVM) used for anomaly detection. The One 26 * Class SVM aims to find a maximum margin between a set of data points and the 27 * *origin*, rather than between classes such as with SVC. 28 * 29 * > **Note:** This estimator requires the SVM extension which uses the libsvm engine 30 * under the hood. 31 * 32 * References: 33 * [1] C. Chang et al. (2011). LIBSVM: A library for support vector machines. 34 * 35 * @category Machine Learning 36 * @package Rubix/ML 37 * @author Andrew DalPino 38 */ 39class OneClassSVM implements Estimator, Learner 40{ 41 /** 42 * The support vector machine instance. 43 * 44 * @var \svm 45 */ 46 protected $svm; 47 48 /** 49 * The hyper-parameters of the model. 50 * 51 * @var mixed[] 52 */ 53 protected $params; 54 55 /** 56 * The trained model instance. 57 * 58 * @var \svmmodel|null 59 */ 60 protected $model; 61 62 /** 63 * @param float $nu 64 * @param \Rubix\ML\Kernels\SVM\Kernel|null $kernel 65 * @param bool $shrinking 66 * @param float $tolerance 67 * @param float $cacheSize 68 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 69 */ 70 public function __construct( 71 float $nu = 0.5, 72 ?Kernel $kernel = null, 73 bool $shrinking = true, 74 float $tolerance = 1e-3, 75 float $cacheSize = 100.0 76 ) { 77 ExtensionIsLoaded::with('svm')->check(); 78 79 if ($nu < 0.0 or $nu > 1.0) { 80 throw new InvalidArgumentException('Nu must be between' 81 . "0 and 1, $nu given."); 82 } 83 84 $kernel = $kernel ?? new RBF(); 85 86 if ($tolerance < 0.0) { 87 throw new InvalidArgumentException('Tolerance must be,' 88 . " greater than 0, $tolerance given."); 89 } 90 91 if ($cacheSize <= 0.0) { 92 throw new InvalidArgumentException('Cache size must be' 93 . " greater than 0M, {$cacheSize}M given."); 94 } 95 96 $options = [ 97 svm::OPT_TYPE => svm::ONE_CLASS, 98 svm::OPT_NU => $nu, 99 svm::OPT_SHRINKING => $shrinking, 100 svm::OPT_EPS => $tolerance, 101 svm::OPT_CACHE_SIZE => $cacheSize, 102 ]; 103 104 $options += $kernel->options(); 105 106 $svm = new svm(); 107 108 $svm->setOptions($options); 109 110 $this->svm = $svm; 111 112 $this->params = [ 113 'nu' => $nu, 114 'kernel' => $kernel, 115 'shrinking' => $shrinking, 116 'tolerance' => $tolerance, 117 'cache_size' => $cacheSize, 118 ]; 119 } 120 121 /** 122 * Return the estimator type. 123 * 124 * @internal 125 * 126 * @return \Rubix\ML\EstimatorType 127 */ 128 public function type() : EstimatorType 129 { 130 return EstimatorType::anomalyDetector(); 131 } 132 133 /** 134 * Return the data types that the estimator is compatible with. 135 * 136 * @internal 137 * 138 * @return list<\Rubix\ML\DataType> 139 */ 140 public function compatibility() : array 141 { 142 return [ 143 DataType::continuous(), 144 ]; 145 } 146 147 /** 148 * Return the settings of the hyper-parameters in an associative array. 149 * 150 * @internal 151 * 152 * @return mixed[] 153 */ 154 public function params() : array 155 { 156 return $this->params; 157 } 158 159 /** 160 * Has the learner been trained? 161 * 162 * @return bool 163 */ 164 public function trained() : bool 165 { 166 return isset($this->model); 167 } 168 169 /** 170 * Train the learner with a dataset. 171 * 172 * @param \Rubix\ML\Datasets\Dataset $dataset 173 */ 174 public function train(Dataset $dataset) : void 175 { 176 SpecificationChain::with([ 177 new DatasetIsNotEmpty($dataset), 178 new SamplesAreCompatibleWithEstimator($dataset, $this), 179 ])->check(); 180 181 $this->model = $this->svm->train($dataset->samples()); 182 } 183 184 /** 185 * Make predictions from a dataset. 186 * 187 * @param \Rubix\ML\Datasets\Dataset $dataset 188 * @return list<int> 189 */ 190 public function predict(Dataset $dataset) : array 191 { 192 return array_map([$this, 'predictSample'], $dataset->samples()); 193 } 194 195 /** 196 * Predict a single sample and return the result. 197 * 198 * @internal 199 * 200 * @param list<int|float> $sample 201 * @throws \Rubix\ML\Exceptions\RuntimeException 202 * @return int 203 */ 204 public function predictSample(array $sample) : int 205 { 206 if (!$this->model) { 207 throw new RuntimeException('Estimator has not been trained.'); 208 } 209 210 return $this->model->predict($sample) !== 1.0 ? 0 : 1; 211 } 212 213 /** 214 * Save the model data to the filesystem. 215 * 216 * @param string $path 217 * @throws \Rubix\ML\Exceptions\RuntimeException 218 */ 219 public function save(string $path) : void 220 { 221 if (!$this->model) { 222 throw new RuntimeException('Learner must be trained before saving.'); 223 } 224 225 $this->model->save($path); 226 } 227 228 /** 229 * Load model data from the filesystem. 230 * 231 * @param string $path 232 */ 233 public function load(string $path) : void 234 { 235 $this->model = new svmmodel($path); 236 } 237 238 /** 239 * Return the string representation of the object. 240 * 241 * @return string 242 */ 243 public function __toString() : string 244 { 245 return 'One Class SVM (' . Params::stringify($this->params()) . ')'; 246 } 247} 248