1<?php
2
3declare(strict_types=1);
4
5namespace Phpml\Clustering\KMeans;
6
7use InvalidArgumentException;
8use LogicException;
9use Phpml\Clustering\KMeans;
10use SplObjectStorage;
11
12class Space extends SplObjectStorage
13{
14    /**
15     * @var int
16     */
17    protected $dimension;
18
19    public function __construct(int $dimension)
20    {
21        if ($dimension < 1) {
22            throw new LogicException('a space dimension cannot be null or negative');
23        }
24
25        $this->dimension = $dimension;
26    }
27
28    public function toArray(): array
29    {
30        $points = [];
31
32        /** @var Point $point */
33        foreach ($this as $point) {
34            $points[] = $point->toArray();
35        }
36
37        return ['points' => $points];
38    }
39
40    /**
41     * @param mixed $label
42     */
43    public function newPoint(array $coordinates, $label = null): Point
44    {
45        if (count($coordinates) !== $this->dimension) {
46            throw new LogicException('('.implode(',', $coordinates).') is not a point of this space');
47        }
48
49        return new Point($coordinates, $label);
50    }
51
52    /**
53     * @param mixed $label
54     * @param mixed $data
55     */
56    public function addPoint(array $coordinates, $label = null, $data = null): void
57    {
58        $this->attach($this->newPoint($coordinates, $label), $data);
59    }
60
61    /**
62     * @param object $point
63     * @param mixed  $data
64     */
65    public function attach($point, $data = null): void
66    {
67        if (!$point instanceof Point) {
68            throw new InvalidArgumentException('can only attach points to spaces');
69        }
70
71        parent::attach($point, $data);
72    }
73
74    public function getDimension(): int
75    {
76        return $this->dimension;
77    }
78
79    /**
80     * @return array|bool
81     */
82    public function getBoundaries()
83    {
84        if (count($this) === 0) {
85            return false;
86        }
87
88        $min = $this->newPoint(array_fill(0, $this->dimension, null));
89        $max = $this->newPoint(array_fill(0, $this->dimension, null));
90
91        /** @var self $point */
92        foreach ($this as $point) {
93            for ($n = 0; $n < $this->dimension; ++$n) {
94                if ($min[$n] === null || $min[$n] > $point[$n]) {
95                    $min[$n] = $point[$n];
96                }
97
98                if ($max[$n] === null || $max[$n] < $point[$n]) {
99                    $max[$n] = $point[$n];
100                }
101            }
102        }
103
104        return [$min, $max];
105    }
106
107    public function getRandomPoint(Point $min, Point $max): Point
108    {
109        $point = $this->newPoint(array_fill(0, $this->dimension, null));
110
111        for ($n = 0; $n < $this->dimension; ++$n) {
112            $point[$n] = random_int($min[$n], $max[$n]);
113        }
114
115        return $point;
116    }
117
118    /**
119     * @return Cluster[]
120     */
121    public function cluster(int $clustersNumber, int $initMethod = KMeans::INIT_RANDOM): array
122    {
123        $clusters = $this->initializeClusters($clustersNumber, $initMethod);
124
125        do {
126        } while (!$this->iterate($clusters));
127
128        return $clusters;
129    }
130
131    /**
132     * @return Cluster[]
133     */
134    protected function initializeClusters(int $clustersNumber, int $initMethod): array
135    {
136        switch ($initMethod) {
137            case KMeans::INIT_RANDOM:
138                $clusters = $this->initializeRandomClusters($clustersNumber);
139
140                break;
141
142            case KMeans::INIT_KMEANS_PLUS_PLUS:
143                $clusters = $this->initializeKMPPClusters($clustersNumber);
144
145                break;
146
147            default:
148                return [];
149        }
150
151        $clusters[0]->attachAll($this);
152
153        return $clusters;
154    }
155
156    /**
157     * @param Cluster[] $clusters
158     */
159    protected function iterate(array $clusters): bool
160    {
161        $convergence = true;
162
163        $attach = new SplObjectStorage();
164        $detach = new SplObjectStorage();
165
166        foreach ($clusters as $cluster) {
167            foreach ($cluster as $point) {
168                $closest = $point->getClosest($clusters);
169
170                if ($closest !== $cluster) {
171                    $attach[$closest] ?? $attach[$closest] = new SplObjectStorage();
172                    $detach[$cluster] ?? $detach[$cluster] = new SplObjectStorage();
173
174                    $attach[$closest]->attach($point);
175                    $detach[$cluster]->attach($point);
176
177                    $convergence = false;
178                }
179            }
180        }
181
182        /** @var Cluster $cluster */
183        foreach ($attach as $cluster) {
184            $cluster->attachAll($attach[$cluster]);
185        }
186
187        /** @var Cluster $cluster */
188        foreach ($detach as $cluster) {
189            $cluster->detachAll($detach[$cluster]);
190        }
191
192        foreach ($clusters as $cluster) {
193            $cluster->updateCentroid();
194        }
195
196        return $convergence;
197    }
198
199    /**
200     * @return Cluster[]
201     */
202    protected function initializeKMPPClusters(int $clustersNumber): array
203    {
204        $clusters = [];
205        $this->rewind();
206
207        /** @var Point $current */
208        $current = $this->current();
209
210        $clusters[] = new Cluster($this, $current->getCoordinates());
211
212        $distances = new SplObjectStorage();
213
214        for ($i = 1; $i < $clustersNumber; ++$i) {
215            $sum = 0;
216            /** @var Point $point */
217            foreach ($this as $point) {
218                $closest = $point->getClosest($clusters);
219                if ($closest === null) {
220                    continue;
221                }
222
223                $distance = $point->getDistanceWith($closest);
224                $sum += $distances[$point] = $distance;
225            }
226
227            $sum = random_int(0, (int) $sum);
228            /** @var Point $point */
229            foreach ($this as $point) {
230                $sum -= $distances[$point];
231
232                if ($sum > 0) {
233                    continue;
234                }
235
236                $clusters[] = new Cluster($this, $point->getCoordinates());
237
238                break;
239            }
240        }
241
242        return $clusters;
243    }
244
245    /**
246     * @return Cluster[]
247     */
248    private function initializeRandomClusters(int $clustersNumber): array
249    {
250        $clusters = [];
251        [$min, $max] = $this->getBoundaries();
252
253        for ($n = 0; $n < $clustersNumber; ++$n) {
254            $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates());
255        }
256
257        return $clusters;
258    }
259}
260