1<?php 2 3namespace Rubix\ML; 4 5use Rubix\ML\Backends\Serial; 6use Rubix\ML\Datasets\Labeled; 7use Rubix\ML\Datasets\Dataset; 8use Rubix\ML\Backends\Tasks\Task; 9use Rubix\ML\Other\Helpers\Params; 10use Rubix\ML\CrossValidation\KFold; 11use Rubix\ML\Other\Traits\LoggerAware; 12use Rubix\ML\CrossValidation\Validator; 13use Rubix\ML\Other\Traits\PredictsSingle; 14use Rubix\ML\Other\Traits\Multiprocessing; 15use Rubix\ML\CrossValidation\Metrics\RMSE; 16use Rubix\ML\CrossValidation\Metrics\FBeta; 17use Rubix\ML\CrossValidation\Metrics\Metric; 18use Rubix\ML\Other\Traits\AutotrackRevisions; 19use Rubix\ML\Specifications\DatasetIsLabeled; 20use Rubix\ML\CrossValidation\Metrics\Accuracy; 21use Rubix\ML\CrossValidation\Metrics\VMeasure; 22use Rubix\ML\Specifications\DatasetIsNotEmpty; 23use Rubix\ML\Specifications\SpecificationChain; 24use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner; 25use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric; 26use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; 27use Rubix\ML\Exceptions\InvalidArgumentException; 28 29use function count; 30 31/** 32 * Grid Search 33 * 34 * Grid Search is an algorithm that optimizes hyper-parameter selection. From 35 * the user's perspective, the process of training and predicting is the same, 36 * however, under the hood, Grid Search trains one estimator per combination 37 * of parameters and the best model is selected as the base estimator. 38 * 39 * > **Note:** You can choose the hyper-parameters manually or you can generate 40 * them randomly or in a grid using the Params helper. 41 * 42 * @category Machine Learning 43 * @package Rubix/ML 44 * @author Andrew DalPino 45 */ 46class GridSearch implements Estimator, Learner, Parallel, Verbose, Wrapper, Persistable 47{ 48 use AutotrackRevisions, Multiprocessing, PredictsSingle, LoggerAware; 49 50 /** 51 * The class name of the base estimator. 52 * 53 * @var string 54 */ 55 protected $base; 56 57 /** 58 * An array of tuples containing the possible values for each of the base learner's constructor parameters. 59 * 60 * @var array[] 61 */ 62 protected $params; 63 64 /** 65 * The validation metric used to score the estimator. 66 * 67 * @var \Rubix\ML\CrossValidation\Metrics\Metric 68 */ 69 protected $metric; 70 71 /** 72 * The validator used to test the estimator. 73 * 74 * @var \Rubix\ML\CrossValidation\Validator 75 */ 76 protected $validator; 77 78 /** 79 * The argument names for the base estimator's constructor. 80 * 81 * @var string[] 82 */ 83 protected $args = [ 84 // 85 ]; 86 87 /** 88 * The results of the last hyper-parameter search. 89 * 90 * @var array[]|null 91 */ 92 protected $results; 93 94 /** 95 * The instance of the estimator with the best parameters. 96 * 97 * @var \Rubix\ML\Learner 98 */ 99 protected $estimator; 100 101 /** 102 * Cross validate a learner with a given dataset and return the score. 103 * 104 * @internal 105 * 106 * @param \Rubix\ML\Learner $estimator 107 * @param \Rubix\ML\Datasets\Labeled $dataset 108 * @param \Rubix\ML\CrossValidation\Validator $validator 109 * @param \Rubix\ML\CrossValidation\Metrics\Metric $metric 110 * @return mixed[] 111 */ 112 public static function score(Learner $estimator, Labeled $dataset, Validator $validator, Metric $metric) : array 113 { 114 $score = $validator->test($estimator, $dataset, $metric); 115 116 return [$score, $estimator->params()]; 117 } 118 119 /** 120 * @param class-string $base 121 * @param array[] $params 122 * @param \Rubix\ML\CrossValidation\Metrics\Metric|null $metric 123 * @param \Rubix\ML\CrossValidation\Validator|null $validator 124 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 125 */ 126 public function __construct( 127 string $base, 128 array $params, 129 ?Metric $metric = null, 130 ?Validator $validator = null 131 ) { 132 if (!class_exists($base)) { 133 throw new InvalidArgumentException("Class $base does not exist."); 134 } 135 136 $proxy = new $base(...array_map('current', $params)); 137 138 if (!$proxy instanceof Learner) { 139 throw new InvalidArgumentException('Base class must' 140 . ' implement the Learner Interface.'); 141 } 142 143 foreach ($params as &$tuple) { 144 $tuple = empty($tuple) ? [null] : array_unique($tuple, SORT_REGULAR); 145 } 146 147 if ($metric) { 148 EstimatorIsCompatibleWithMetric::with($proxy, $metric)->check(); 149 } else { 150 switch ($proxy->type()) { 151 case EstimatorType::classifier(): 152 $metric = new FBeta(); 153 154 break; 155 156 case EstimatorType::regressor(): 157 $metric = new RMSE(); 158 159 break; 160 161 case EstimatorType::clusterer(): 162 $metric = new VMeasure(); 163 164 break; 165 166 case EstimatorType::anomalyDetector(): 167 $metric = new FBeta(); 168 169 break; 170 171 default: 172 $metric = new Accuracy(); 173 } 174 } 175 176 $this->base = $base; 177 $this->params = $params; 178 $this->metric = $metric; 179 $this->validator = $validator ?? new KFold(3); 180 $this->estimator = $proxy; 181 $this->backend = new Serial(); 182 } 183 184 /** 185 * Return the estimator type. 186 * 187 * @internal 188 * 189 * @return \Rubix\ML\EstimatorType 190 */ 191 public function type() : EstimatorType 192 { 193 return $this->estimator->type(); 194 } 195 196 /** 197 * Return the data types that the estimator is compatible with. 198 * 199 * @internal 200 * 201 * @return list<\Rubix\ML\DataType> 202 */ 203 public function compatibility() : array 204 { 205 return $this->trained() 206 ? $this->estimator->compatibility() 207 : DataType::all(); 208 } 209 210 /** 211 * Return the settings of the hyper-parameters in an associative array. 212 * 213 * @internal 214 * 215 * @return mixed[] 216 */ 217 public function params() : array 218 { 219 return [ 220 'base' => $this->base, 221 'params' => $this->params, 222 'metric' => $this->metric, 223 'validator' => $this->validator, 224 ]; 225 } 226 227 /** 228 * Has the learner been trained? 229 * 230 * @return bool 231 */ 232 public function trained() : bool 233 { 234 return $this->estimator->trained(); 235 } 236 237 /** 238 * Return an array containing the validation scores and hyper-parameters under test 239 * for each combination in a 2-tuple. 240 * 241 * @return array[]|null 242 */ 243 public function results() : ?array 244 { 245 return $this->results; 246 } 247 248 /** 249 * Return an array containing the hyper-parameters with the highest validation score 250 * from the last search. 251 * 252 * @return mixed[]|null 253 */ 254 public function best() : ?array 255 { 256 return $this->results ? $this->results[0][1] : null; 257 } 258 259 /** 260 * Return the base estimator instance. 261 * 262 * @return \Rubix\ML\Estimator 263 */ 264 public function base() : Estimator 265 { 266 return $this->estimator; 267 } 268 269 /** 270 * Train one estimator per combination of parameters given by the grid and 271 * assign the best one as the base estimator of this instance. 272 * 273 * @param \Rubix\ML\Datasets\Labeled $dataset 274 */ 275 public function train(Dataset $dataset) : void 276 { 277 SpecificationChain::with([ 278 new DatasetIsLabeled($dataset), 279 new DatasetIsNotEmpty($dataset), 280 new SamplesAreCompatibleWithEstimator($dataset, $this), 281 new LabelsAreCompatibleWithLearner($dataset, $this), 282 ])->check(); 283 284 $combinations = $this->combinations(); 285 286 if ($this->logger) { 287 $this->logger->info("$this initialized"); 288 289 $this->logger->info('Searching ' . count($combinations) 290 . ' combinations of hyper-parameters'); 291 } 292 293 $this->backend->flush(); 294 295 foreach ($combinations as $params) { 296 $estimator = new $this->base(...$params); 297 298 $this->backend->enqueue( 299 new Task( 300 [self::class, 'score'], 301 [ 302 $estimator, 303 $dataset, 304 $this->validator, 305 $this->metric, 306 ] 307 ), 308 [$this, 'afterScore'] 309 ); 310 } 311 312 [$scores, $combinations] = array_transpose($this->backend->process()); 313 314 array_multisort($scores, $combinations, SORT_DESC); 315 316 $this->results = array_transpose([$scores, $combinations]); 317 318 $best = reset($combinations); 319 320 $estimator = new $this->base(...array_values($best)); 321 322 if ($this->logger) { 323 $this->logger->info('Training with best hyper-parameters'); 324 } 325 326 $estimator->train($dataset); 327 328 $this->estimator = $estimator; 329 330 if ($this->logger) { 331 $this->logger->info('Training complete'); 332 } 333 } 334 335 /** 336 * Make a prediction on a given sample dataset. 337 * 338 * @param \Rubix\ML\Datasets\Dataset $dataset 339 * @throws \Rubix\ML\Exceptions\RuntimeException 340 * @return mixed[] 341 */ 342 public function predict(Dataset $dataset) : array 343 { 344 return $this->estimator->predict($dataset); 345 } 346 347 /** 348 * The callback that executes after the scoring task. 349 * 350 * @internal 351 * 352 * @param mixed[] $result 353 */ 354 public function afterScore(array $result) : void 355 { 356 if ($this->logger) { 357 [$score, $params] = $result; 358 359 $this->logger->info( 360 "{$this->metric}: $score, params: [" . Params::stringify($params) . ']' 361 ); 362 } 363 } 364 365 /** 366 * Return an array of all possible combinations of parameters. i.e the Cartesian product of 367 * the user-supplied parameter array. 368 * 369 * @return array[] 370 */ 371 protected function combinations() : array 372 { 373 $combinations = [[]]; 374 375 foreach ($this->params as $i => $params) { 376 $append = []; 377 378 foreach ($combinations as $product) { 379 foreach ($params as $param) { 380 $product[$i] = $param; 381 $append[] = $product; 382 } 383 } 384 385 $combinations = $append; 386 } 387 388 return $combinations; 389 } 390 391 /** 392 * Allow methods to be called on the estimator from the wrapper. 393 * 394 * @param string $name 395 * @param mixed[] $arguments 396 * @return mixed 397 */ 398 public function __call(string $name, array $arguments) 399 { 400 return $this->estimator->$name(...$arguments); 401 } 402 403 /** 404 * Return the string representation of the object. 405 * 406 * @return string 407 */ 408 public function __toString() : string 409 { 410 return 'Grid Search (' . Params::stringify($this->params()) . ')'; 411 } 412} 413