1<?php 2 3namespace Rubix\ML; 4 5use Rubix\ML\Datasets\Dataset; 6use Rubix\ML\Persisters\Persister; 7use Rubix\ML\Other\Helpers\Params; 8use Rubix\ML\Other\Traits\ProbaSingle; 9use Rubix\ML\AnomalyDetectors\Scoring; 10use Rubix\ML\Other\Traits\RanksSingle; 11use Rubix\ML\Other\Traits\PredictsSingle; 12use Rubix\ML\Exceptions\InvalidArgumentException; 13use Rubix\ML\Exceptions\RuntimeException; 14 15/** 16 * Persistent Model 17 * 18 * The Persistent Model wrapper gives the estimator two additional methods (`save()` 19 * and `load()`) that allow the estimator to be saved and retrieved from storage. 20 * 21 * @category Machine Learning 22 * @package Rubix/ML 23 * @author Andrew DalPino 24 */ 25class PersistentModel implements Estimator, Learner, Wrapper, Probabilistic, Scoring, Ranking 26{ 27 use PredictsSingle, ProbaSingle, RanksSingle; 28 29 /** 30 * The persistable base learner. 31 * 32 * @var \Rubix\ML\Learner 33 */ 34 protected $base; 35 36 /** 37 * The persister used to interface with the storage medium. 38 * 39 * @var \Rubix\ML\Persisters\Persister 40 */ 41 protected $persister; 42 43 /** 44 * Factory method to restore the model from persistence. 45 * 46 * @param \Rubix\ML\Persisters\Persister $persister 47 * @return self 48 */ 49 public static function load(Persister $persister) : self 50 { 51 $base = $persister->load(); 52 53 if (!$base instanceof Learner) { 54 throw new InvalidArgumentException('Persistable must' 55 . ' implement the Learner interface.'); 56 } 57 58 return new self($base, $persister); 59 } 60 61 /** 62 * @param \Rubix\ML\Learner $base 63 * @param \Rubix\ML\Persisters\Persister $persister 64 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 65 */ 66 public function __construct(Learner $base, Persister $persister) 67 { 68 if (!$base instanceof Persistable) { 69 throw new InvalidArgumentException('Base Learner must' 70 . ' implement the Persistable interface.'); 71 } 72 73 $this->base = $base; 74 $this->persister = $persister; 75 } 76 77 /** 78 * Return the estimator type. 79 * 80 * @internal 81 * 82 * @return \Rubix\ML\EstimatorType 83 */ 84 public function type() : EstimatorType 85 { 86 return $this->base->type(); 87 } 88 89 /** 90 * Return the data types that the estimator is compatible with. 91 * 92 * @internal 93 * 94 * @return list<\Rubix\ML\DataType> 95 */ 96 public function compatibility() : array 97 { 98 return $this->base->compatibility(); 99 } 100 101 /** 102 * Return the settings of the hyper-parameters in an associative array. 103 * 104 * @internal 105 * 106 * @return mixed[] 107 */ 108 public function params() : array 109 { 110 return [ 111 'base' => $this->base, 112 'persister' => $this->persister, 113 ]; 114 } 115 116 /** 117 * Has the learner been trained? 118 * 119 * @return bool 120 */ 121 public function trained() : bool 122 { 123 return $this->base->trained(); 124 } 125 126 /** 127 * Return the base estimator instance. 128 * 129 * @return \Rubix\ML\Estimator 130 */ 131 public function base() : Estimator 132 { 133 return $this->base; 134 } 135 136 /** 137 * Save the model to storage. 138 */ 139 public function save() : void 140 { 141 if ($this->base instanceof Persistable) { 142 $this->persister->save($this->base); 143 } 144 } 145 146 /** 147 * Train the learner with a dataset. 148 * 149 * @param \Rubix\ML\Datasets\Dataset $dataset 150 */ 151 public function train(Dataset $dataset) : void 152 { 153 $this->base->train($dataset); 154 } 155 156 /** 157 * Make a prediction on a given sample dataset. 158 * 159 * @param \Rubix\ML\Datasets\Dataset $dataset 160 * @return mixed[] 161 */ 162 public function predict(Dataset $dataset) : array 163 { 164 return $this->base->predict($dataset); 165 } 166 167 /** 168 * Estimate the joint probabilities for each possible outcome. 169 * 170 * @param \Rubix\ML\Datasets\Dataset $dataset 171 * @throws \Rubix\ML\Exceptions\RuntimeException 172 * @return array[] 173 */ 174 public function proba(Dataset $dataset) : array 175 { 176 if (!$this->base instanceof Probabilistic) { 177 throw new RuntimeException('Base Estimator must' 178 . ' implement the Probabilistic interface.'); 179 } 180 181 return $this->base->proba($dataset); 182 } 183 184 /** 185 * Return the anomaly scores assigned to the samples in a dataset. 186 * 187 * @param \Rubix\ML\Datasets\Dataset $dataset 188 * @throws \Rubix\ML\Exceptions\RuntimeException 189 * @return float[] 190 */ 191 public function score(Dataset $dataset) : array 192 { 193 if (!$this->base instanceof Scoring) { 194 throw new RuntimeException('Base Estimator must' 195 . ' implement the Ranking interface.'); 196 } 197 198 return $this->base->score($dataset); 199 } 200 201 /** 202 * Return the anomaly scores assigned to the samples in a dataset. 203 * 204 * @deprecated 205 * 206 * @param \Rubix\ML\Datasets\Dataset $dataset 207 * @return float[] 208 */ 209 public function rank(Dataset $dataset) : array 210 { 211 warn_deprecated('Rank() is deprecated, use score() instead.'); 212 213 return $this->score($dataset); 214 } 215 216 /** 217 * Allow methods to be called on the model from the wrapper. 218 * 219 * @param string $name 220 * @param mixed[] $arguments 221 * @return mixed 222 */ 223 public function __call(string $name, array $arguments) 224 { 225 return $this->base->$name(...$arguments); 226 } 227 228 /** 229 * Return the string representation of the object. 230 * 231 * @return string 232 */ 233 public function __toString() : string 234 { 235 return 'Persistent Model (' . Params::stringify($this->params()) . ')'; 236 } 237} 238