1<?php 2 3namespace Rubix\ML\Transformers; 4 5use Rubix\ML\DataType; 6use Rubix\ML\Persistable; 7use Rubix\ML\Datasets\Dataset; 8use Rubix\ML\Datasets\Labeled; 9use Rubix\ML\Graph\Trees\Spatial; 10use Rubix\ML\Other\Helpers\Stats; 11use Rubix\ML\Graph\Trees\BallTree; 12use Rubix\ML\Kernels\Distance\NaNSafe; 13use Rubix\ML\Kernels\Distance\Distance; 14use Rubix\ML\Other\Traits\AutotrackRevisions; 15use Rubix\ML\Kernels\Distance\SafeEuclidean; 16use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer; 17use Rubix\ML\Exceptions\InvalidArgumentException; 18use Rubix\ML\Exceptions\RuntimeException; 19 20use function Rubix\ML\argmax; 21use function in_array; 22use function is_null; 23 24/** 25 * KNN Imputer 26 * 27 * An unsupervised imputer that replaces missing values in datasets with the distance-weighted 28 * average of the samples' *k* nearest neighbors' values. The average for a continuous feature 29 * column is defined as the mean of the values of each donor sample while average is defined as 30 * the most frequent for categorical features. 31 * 32 * **Note:** Requires NaN-safe distance kernels, such as Safe Euclidean, for continuous features. 33 * 34 * References: 35 * [1] O. Troyanskaya et al. (2001). Missing value estimation methods for DNA microarrays. 36 * 37 * @category Machine Learning 38 * @package Rubix/ML 39 * @author Andrew DalPino 40 */ 41class KNNImputer implements Transformer, Stateful, Persistable 42{ 43 use AutotrackRevisions; 44 45 /** 46 * The number of neighbors to consider when imputing a value. 47 * 48 * @var int 49 */ 50 protected $k; 51 52 /** 53 * Should we use the inverse distances as confidence scores when imputing values. 54 * 55 * @var bool 56 */ 57 protected $weighted; 58 59 /** 60 * The placeholder category that denotes missing values. 61 * 62 * @var string 63 */ 64 protected $categoricalPlaceholder; 65 66 /** 67 * The spatial tree used to run nearest neighbor searches. 68 * 69 * @var \Rubix\ML\Graph\Trees\Spatial 70 */ 71 protected $tree; 72 73 /** 74 * The data types of the fitted feature columns. 75 * 76 * @var \Rubix\ML\DataType[]|null 77 */ 78 protected $types; 79 80 /** 81 * @param int $k 82 * @param bool $weighted 83 * @param string $categoricalPlaceholder 84 * @param \Rubix\ML\Graph\Trees\Spatial|null $tree 85 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 86 */ 87 public function __construct( 88 int $k = 5, 89 bool $weighted = true, 90 string $categoricalPlaceholder = '?', 91 ?Spatial $tree = null 92 ) { 93 if ($k < 1) { 94 throw new InvalidArgumentException('At least 1 neighbor is required' 95 . " to impute a value, $k given."); 96 } 97 98 if ($tree and in_array(DataType::continuous(), $tree->kernel()->compatibility())) { 99 $kernel = $tree->kernel(); 100 101 if (!$kernel instanceof NaNSafe) { 102 throw new InvalidArgumentException('Continuous distance kernels' 103 . ' must implement the NaNSafe interface.'); 104 } 105 } 106 107 $this->k = $k; 108 $this->weighted = $weighted; 109 $this->categoricalPlaceholder = $categoricalPlaceholder; 110 $this->tree = $tree ?? new BallTree(30, new SafeEuclidean()); 111 } 112 113 /** 114 * Return the data types that this transformer is compatible with. 115 * 116 * @internal 117 * 118 * @return list<\Rubix\ML\DataType> 119 */ 120 public function compatibility() : array 121 { 122 return $this->tree->kernel()->compatibility(); 123 } 124 125 /** 126 * Is the transformer fitted? 127 * 128 * @return bool 129 */ 130 public function fitted() : bool 131 { 132 return !$this->tree->bare(); 133 } 134 135 /** 136 * Fit the transformer to a dataset. 137 * 138 * @param \Rubix\ML\Datasets\Dataset $dataset 139 * @throws \Rubix\ML\Exceptions\RuntimeException 140 */ 141 public function fit(Dataset $dataset) : void 142 { 143 SamplesAreCompatibleWithTransformer::with($dataset, $this)->check(); 144 145 $donors = []; 146 147 foreach ($dataset->samples() as $sample) { 148 foreach ($sample as $value) { 149 if (is_float($value)) { 150 if (is_nan($value)) { 151 continue 2; 152 } 153 } else { 154 if ($value === $this->categoricalPlaceholder) { 155 continue 2; 156 } 157 } 158 } 159 160 $donors[] = $sample; 161 } 162 163 if (empty($donors)) { 164 throw new RuntimeException('No complete donors found in dataset.'); 165 } 166 167 $labels = array_fill(0, count($donors), ''); 168 169 $this->tree->grow(Labeled::quick($donors, $labels)); 170 171 $this->types = $dataset->columnTypes(); 172 } 173 174 /** 175 * Transform the dataset in place. 176 * 177 * @param list<list<mixed>> $samples 178 * @throws \Rubix\ML\Exceptions\RuntimeException 179 */ 180 public function transform(array &$samples) : void 181 { 182 if ($this->tree->bare() or is_null($this->types)) { 183 throw new RuntimeException('Transformer has not been fitted.'); 184 } 185 186 foreach ($samples as &$sample) { 187 $neighbors = $distances = []; 188 189 foreach ($sample as $column => &$value) { 190 if (is_float($value) && is_nan($value) or $value === $this->categoricalPlaceholder) { 191 if (empty($neighbors)) { 192 [$neighbors, $labels, $distances] = $this->tree->nearest($sample, $this->k); 193 } 194 195 $values = array_column($neighbors, $column); 196 197 $type = $this->types[$column]; 198 199 $value = $this->impute($values, $distances, $type); 200 } 201 } 202 } 203 } 204 205 /** 206 * Choose a value to impute from a given set of values. 207 * 208 * @param (string|int|float)[] $values 209 * @param float[] $distances 210 * @param \Rubix\ML\DataType $type 211 * @return string|int|float 212 */ 213 protected function impute(array $values, array $distances, DataType $type) 214 { 215 switch ($type) { 216 case DataType::continuous(): 217 return $this->imputeContinuous($values, $distances); 218 219 case DataType::categorical(): 220 default: 221 return $this->imputeCategorical($values, $distances); 222 } 223 } 224 225 /** 226 * Return an imputed continuous value. 227 * 228 * @param (string|int|float)[] $values 229 * @param float[] $distances 230 * @return int|float 231 */ 232 protected function imputeContinuous(array $values, array $distances) 233 { 234 if ($this->weighted) { 235 $weights = []; 236 237 foreach ($distances as $distance) { 238 $weights[] = 1.0 / (1.0 + $distance); 239 } 240 241 return Stats::weightedMean($values, $weights); 242 } 243 244 return Stats::mean($values); 245 } 246 247 /** 248 * Return an imputed categorical value. 249 * 250 * @param (string|int|float)[] $values 251 * @param float[] $distances 252 * @return string 253 */ 254 protected function imputeCategorical(array $values, array $distances) : string 255 { 256 if ($this->weighted) { 257 $weights = array_fill_keys($values, 0.0); 258 259 foreach ($distances as $i => $distance) { 260 $weights[$values[$i]] += 1.0 / (1.0 + $distance); 261 } 262 } else { 263 $weights = array_count_values($values); 264 } 265 266 return argmax($weights); 267 } 268 269 /** 270 * Return the string representation of the object. 271 * 272 * @return string 273 */ 274 public function __toString() : string 275 { 276 return "KNN Imputer (k: {$this->k}, weighted: {$this->weighted}," 277 . " categorical_placeholder: {$this->categoricalPlaceholder}," 278 . " tree: {$this->tree})"; 279 } 280} 281