1<?php
2
3namespace Rubix\ML;
4
5use Rubix\ML\Datasets\Dataset;
6use Rubix\ML\Persisters\Persister;
7use Rubix\ML\Other\Helpers\Params;
8use Rubix\ML\Other\Traits\ProbaSingle;
9use Rubix\ML\AnomalyDetectors\Scoring;
10use Rubix\ML\Other\Traits\RanksSingle;
11use Rubix\ML\Other\Traits\PredictsSingle;
12use Rubix\ML\Exceptions\InvalidArgumentException;
13use Rubix\ML\Exceptions\RuntimeException;
14
15/**
16 * Persistent Model
17 *
18 * The Persistent Model wrapper gives the estimator two additional methods (`save()`
19 * and `load()`) that allow the estimator to be saved and retrieved from storage.
20 *
21 * @category    Machine Learning
22 * @package     Rubix/ML
23 * @author      Andrew DalPino
24 */
25class PersistentModel implements Estimator, Learner, Wrapper, Probabilistic, Scoring, Ranking
26{
27    use PredictsSingle, ProbaSingle, RanksSingle;
28
29    /**
30     * The persistable base learner.
31     *
32     * @var \Rubix\ML\Learner
33     */
34    protected $base;
35
36    /**
37     * The persister used to interface with the storage medium.
38     *
39     * @var \Rubix\ML\Persisters\Persister
40     */
41    protected $persister;
42
43    /**
44     * Factory method to restore the model from persistence.
45     *
46     * @param \Rubix\ML\Persisters\Persister $persister
47     * @return self
48     */
49    public static function load(Persister $persister) : self
50    {
51        $base = $persister->load();
52
53        if (!$base instanceof Learner) {
54            throw new InvalidArgumentException('Persistable must'
55                . ' implement the Learner interface.');
56        }
57
58        return new self($base, $persister);
59    }
60
61    /**
62     * @param \Rubix\ML\Learner $base
63     * @param \Rubix\ML\Persisters\Persister $persister
64     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
65     */
66    public function __construct(Learner $base, Persister $persister)
67    {
68        if (!$base instanceof Persistable) {
69            throw new InvalidArgumentException('Base Learner must'
70                . ' implement the Persistable interface.');
71        }
72
73        $this->base = $base;
74        $this->persister = $persister;
75    }
76
77    /**
78     * Return the estimator type.
79     *
80     * @internal
81     *
82     * @return \Rubix\ML\EstimatorType
83     */
84    public function type() : EstimatorType
85    {
86        return $this->base->type();
87    }
88
89    /**
90     * Return the data types that the estimator is compatible with.
91     *
92     * @internal
93     *
94     * @return list<\Rubix\ML\DataType>
95     */
96    public function compatibility() : array
97    {
98        return $this->base->compatibility();
99    }
100
101    /**
102     * Return the settings of the hyper-parameters in an associative array.
103     *
104     * @internal
105     *
106     * @return mixed[]
107     */
108    public function params() : array
109    {
110        return [
111            'base' => $this->base,
112            'persister' => $this->persister,
113        ];
114    }
115
116    /**
117     * Has the learner been trained?
118     *
119     * @return bool
120     */
121    public function trained() : bool
122    {
123        return $this->base->trained();
124    }
125
126    /**
127     * Return the base estimator instance.
128     *
129     * @return \Rubix\ML\Estimator
130     */
131    public function base() : Estimator
132    {
133        return $this->base;
134    }
135
136    /**
137     * Save the model to storage.
138     */
139    public function save() : void
140    {
141        if ($this->base instanceof Persistable) {
142            $this->persister->save($this->base);
143        }
144    }
145
146    /**
147     * Train the learner with a dataset.
148     *
149     * @param \Rubix\ML\Datasets\Dataset $dataset
150     */
151    public function train(Dataset $dataset) : void
152    {
153        $this->base->train($dataset);
154    }
155
156    /**
157     * Make a prediction on a given sample dataset.
158     *
159     * @param \Rubix\ML\Datasets\Dataset $dataset
160     * @return mixed[]
161     */
162    public function predict(Dataset $dataset) : array
163    {
164        return $this->base->predict($dataset);
165    }
166
167    /**
168     * Estimate the joint probabilities for each possible outcome.
169     *
170     * @param \Rubix\ML\Datasets\Dataset $dataset
171     * @throws \Rubix\ML\Exceptions\RuntimeException
172     * @return array[]
173     */
174    public function proba(Dataset $dataset) : array
175    {
176        if (!$this->base instanceof Probabilistic) {
177            throw new RuntimeException('Base Estimator must'
178                . ' implement the Probabilistic interface.');
179        }
180
181        return $this->base->proba($dataset);
182    }
183
184    /**
185     * Return the anomaly scores assigned to the samples in a dataset.
186     *
187     * @param \Rubix\ML\Datasets\Dataset $dataset
188     * @throws \Rubix\ML\Exceptions\RuntimeException
189     * @return float[]
190     */
191    public function score(Dataset $dataset) : array
192    {
193        if (!$this->base instanceof Scoring) {
194            throw new RuntimeException('Base Estimator must'
195                . ' implement the Ranking interface.');
196        }
197
198        return $this->base->score($dataset);
199    }
200
201    /**
202     * Return the anomaly scores assigned to the samples in a dataset.
203     *
204     * @deprecated
205     *
206     * @param \Rubix\ML\Datasets\Dataset $dataset
207     * @return float[]
208     */
209    public function rank(Dataset $dataset) : array
210    {
211        warn_deprecated('Rank() is deprecated, use score() instead.');
212
213        return $this->score($dataset);
214    }
215
216    /**
217     * Allow methods to be called on the model from the wrapper.
218     *
219     * @param string $name
220     * @param mixed[] $arguments
221     * @return mixed
222     */
223    public function __call(string $name, array $arguments)
224    {
225        return $this->base->$name(...$arguments);
226    }
227
228    /**
229     * Return the string representation of the object.
230     *
231     * @return string
232     */
233    public function __toString() : string
234    {
235        return 'Persistent Model (' . Params::stringify($this->params()) . ')';
236    }
237}
238