1<?php
2
3namespace Rubix\ML\Transformers;
4
5use Rubix\ML\DataType;
6use Rubix\ML\Persistable;
7use Rubix\ML\Datasets\Dataset;
8use Rubix\ML\Datasets\Labeled;
9use Rubix\ML\Graph\Trees\Spatial;
10use Rubix\ML\Other\Helpers\Stats;
11use Rubix\ML\Graph\Trees\BallTree;
12use Rubix\ML\Kernels\Distance\NaNSafe;
13use Rubix\ML\Kernels\Distance\Distance;
14use Rubix\ML\Other\Traits\AutotrackRevisions;
15use Rubix\ML\Kernels\Distance\SafeEuclidean;
16use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer;
17use Rubix\ML\Exceptions\InvalidArgumentException;
18use Rubix\ML\Exceptions\RuntimeException;
19
20use function Rubix\ML\argmax;
21use function in_array;
22use function is_null;
23
24/**
25 * KNN Imputer
26 *
27 * An unsupervised imputer that replaces missing values in datasets with the distance-weighted
28 * average of the samples' *k* nearest neighbors' values. The average for a continuous feature
29 * column is defined as the mean of the values of each donor sample while average is defined as
30 * the most frequent for categorical features.
31 *
32 * **Note:** Requires NaN-safe distance kernels, such as Safe Euclidean, for continuous features.
33 *
34 * References:
35 * [1] O. Troyanskaya et al. (2001). Missing value estimation methods for DNA microarrays.
36 *
37 * @category    Machine Learning
38 * @package     Rubix/ML
39 * @author      Andrew DalPino
40 */
41class KNNImputer implements Transformer, Stateful, Persistable
42{
43    use AutotrackRevisions;
44
45    /**
46     * The number of neighbors to consider when imputing a value.
47     *
48     * @var int
49     */
50    protected $k;
51
52    /**
53     * Should we use the inverse distances as confidence scores when imputing values.
54     *
55     * @var bool
56     */
57    protected $weighted;
58
59    /**
60     * The placeholder category that denotes missing values.
61     *
62     * @var string
63     */
64    protected $categoricalPlaceholder;
65
66    /**
67     * The spatial tree used to run nearest neighbor searches.
68     *
69     * @var \Rubix\ML\Graph\Trees\Spatial
70     */
71    protected $tree;
72
73    /**
74     * The data types of the fitted feature columns.
75     *
76     * @var \Rubix\ML\DataType[]|null
77     */
78    protected $types;
79
80    /**
81     * @param int $k
82     * @param bool $weighted
83     * @param string $categoricalPlaceholder
84     * @param \Rubix\ML\Graph\Trees\Spatial|null $tree
85     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
86     */
87    public function __construct(
88        int $k = 5,
89        bool $weighted = true,
90        string $categoricalPlaceholder = '?',
91        ?Spatial $tree = null
92    ) {
93        if ($k < 1) {
94            throw new InvalidArgumentException('At least 1 neighbor is required'
95                . " to impute a value, $k given.");
96        }
97
98        if ($tree and in_array(DataType::continuous(), $tree->kernel()->compatibility())) {
99            $kernel = $tree->kernel();
100
101            if (!$kernel instanceof NaNSafe) {
102                throw new InvalidArgumentException('Continuous distance kernels'
103                    . ' must implement the NaNSafe interface.');
104            }
105        }
106
107        $this->k = $k;
108        $this->weighted = $weighted;
109        $this->categoricalPlaceholder = $categoricalPlaceholder;
110        $this->tree = $tree ?? new BallTree(30, new SafeEuclidean());
111    }
112
113    /**
114     * Return the data types that this transformer is compatible with.
115     *
116     * @internal
117     *
118     * @return list<\Rubix\ML\DataType>
119     */
120    public function compatibility() : array
121    {
122        return $this->tree->kernel()->compatibility();
123    }
124
125    /**
126     * Is the transformer fitted?
127     *
128     * @return bool
129     */
130    public function fitted() : bool
131    {
132        return !$this->tree->bare();
133    }
134
135    /**
136     * Fit the transformer to a dataset.
137     *
138     * @param \Rubix\ML\Datasets\Dataset $dataset
139     * @throws \Rubix\ML\Exceptions\RuntimeException
140     */
141    public function fit(Dataset $dataset) : void
142    {
143        SamplesAreCompatibleWithTransformer::with($dataset, $this)->check();
144
145        $donors = [];
146
147        foreach ($dataset->samples() as $sample) {
148            foreach ($sample as $value) {
149                if (is_float($value)) {
150                    if (is_nan($value)) {
151                        continue 2;
152                    }
153                } else {
154                    if ($value === $this->categoricalPlaceholder) {
155                        continue 2;
156                    }
157                }
158            }
159
160            $donors[] = $sample;
161        }
162
163        if (empty($donors)) {
164            throw new RuntimeException('No complete donors found in dataset.');
165        }
166
167        $labels = array_fill(0, count($donors), '');
168
169        $this->tree->grow(Labeled::quick($donors, $labels));
170
171        $this->types = $dataset->columnTypes();
172    }
173
174    /**
175     * Transform the dataset in place.
176     *
177     * @param list<list<mixed>> $samples
178     * @throws \Rubix\ML\Exceptions\RuntimeException
179     */
180    public function transform(array &$samples) : void
181    {
182        if ($this->tree->bare() or is_null($this->types)) {
183            throw new RuntimeException('Transformer has not been fitted.');
184        }
185
186        foreach ($samples as &$sample) {
187            $neighbors = $distances = [];
188
189            foreach ($sample as $column => &$value) {
190                if (is_float($value) && is_nan($value) or $value === $this->categoricalPlaceholder) {
191                    if (empty($neighbors)) {
192                        [$neighbors, $labels, $distances] = $this->tree->nearest($sample, $this->k);
193                    }
194
195                    $values = array_column($neighbors, $column);
196
197                    $type = $this->types[$column];
198
199                    $value = $this->impute($values, $distances, $type);
200                }
201            }
202        }
203    }
204
205    /**
206     * Choose a value to impute from a given set of values.
207     *
208     * @param (string|int|float)[] $values
209     * @param float[] $distances
210     * @param \Rubix\ML\DataType $type
211     * @return string|int|float
212     */
213    protected function impute(array $values, array $distances, DataType $type)
214    {
215        switch ($type) {
216            case DataType::continuous():
217                return $this->imputeContinuous($values, $distances);
218
219            case DataType::categorical():
220            default:
221                return $this->imputeCategorical($values, $distances);
222        }
223    }
224
225    /**
226     * Return an imputed continuous value.
227     *
228     * @param (string|int|float)[] $values
229     * @param float[] $distances
230     * @return int|float
231     */
232    protected function imputeContinuous(array $values, array $distances)
233    {
234        if ($this->weighted) {
235            $weights = [];
236
237            foreach ($distances as $distance) {
238                $weights[] = 1.0 / (1.0 + $distance);
239            }
240
241            return Stats::weightedMean($values, $weights);
242        }
243
244        return Stats::mean($values);
245    }
246
247    /**
248     * Return an imputed categorical value.
249     *
250     * @param (string|int|float)[] $values
251     * @param float[] $distances
252     * @return string
253     */
254    protected function imputeCategorical(array $values, array $distances) : string
255    {
256        if ($this->weighted) {
257            $weights = array_fill_keys($values, 0.0);
258
259            foreach ($distances as $i => $distance) {
260                $weights[$values[$i]] += 1.0 / (1.0 + $distance);
261            }
262        } else {
263            $weights = array_count_values($values);
264        }
265
266        return argmax($weights);
267    }
268
269    /**
270     * Return the string representation of the object.
271     *
272     * @return string
273     */
274    public function __toString() : string
275    {
276        return "KNN Imputer (k: {$this->k}, weighted: {$this->weighted},"
277            . " categorical_placeholder: {$this->categoricalPlaceholder},"
278            . " tree: {$this->tree})";
279    }
280}
281