1<?php
2
3declare(strict_types=1);
4
5namespace Phpml\Preprocessing;
6
7use Phpml\Exception\NormalizerException;
8use Phpml\Math\Statistic\Mean;
9use Phpml\Math\Statistic\StandardDeviation;
10
11class Normalizer implements Preprocessor
12{
13    public const NORM_L1 = 1;
14
15    public const NORM_L2 = 2;
16
17    public const NORM_STD = 3;
18
19    /**
20     * @var int
21     */
22    private $norm;
23
24    /**
25     * @var bool
26     */
27    private $fitted = false;
28
29    /**
30     * @var array
31     */
32    private $std = [];
33
34    /**
35     * @var array
36     */
37    private $mean = [];
38
39    /**
40     * @throws NormalizerException
41     */
42    public function __construct(int $norm = self::NORM_L2)
43    {
44        if (!in_array($norm, [self::NORM_L1, self::NORM_L2, self::NORM_STD], true)) {
45            throw new NormalizerException('Unknown norm supplied.');
46        }
47
48        $this->norm = $norm;
49    }
50
51    public function fit(array $samples, ?array $targets = null): void
52    {
53        if ($this->fitted) {
54            return;
55        }
56
57        if ($this->norm === self::NORM_STD) {
58            $features = range(0, count($samples[0]) - 1);
59            foreach ($features as $i) {
60                $values = array_column($samples, $i);
61                $this->std[$i] = StandardDeviation::population($values);
62                $this->mean[$i] = Mean::arithmetic($values);
63            }
64        }
65
66        $this->fitted = true;
67    }
68
69    public function transform(array &$samples): void
70    {
71        $methods = [
72            self::NORM_L1 => 'normalizeL1',
73            self::NORM_L2 => 'normalizeL2',
74            self::NORM_STD => 'normalizeSTD',
75        ];
76        $method = $methods[$this->norm];
77
78        $this->fit($samples);
79
80        foreach ($samples as &$sample) {
81            $this->{$method}($sample);
82        }
83    }
84
85    private function normalizeL1(array &$sample): void
86    {
87        $norm1 = 0;
88        foreach ($sample as $feature) {
89            $norm1 += abs($feature);
90        }
91
92        if ($norm1 == 0) {
93            $count = count($sample);
94            $sample = array_fill(0, $count, 1.0 / $count);
95        } else {
96            array_walk($sample, function (&$feature) use ($norm1): void {
97                $feature /= $norm1;
98            });
99        }
100    }
101
102    private function normalizeL2(array &$sample): void
103    {
104        $norm2 = 0;
105        foreach ($sample as $feature) {
106            $norm2 += $feature * $feature;
107        }
108
109        $norm2 **= .5;
110
111        if ($norm2 == 0) {
112            $sample = array_fill(0, count($sample), 1);
113        } else {
114            array_walk($sample, function (&$feature) use ($norm2): void {
115                $feature /= $norm2;
116            });
117        }
118    }
119
120    private function normalizeSTD(array &$sample): void
121    {
122        foreach (array_keys($sample) as $i) {
123            if ($this->std[$i] != 0) {
124                $sample[$i] = ($sample[$i] - $this->mean[$i]) / $this->std[$i];
125            } else {
126                // Same value for all samples.
127                $sample[$i] = 0;
128            }
129        }
130    }
131}
132