1<?php
2
3declare(strict_types=1);
4
5namespace Phpml\Helper\Optimizer;
6
7use Closure;
8use Phpml\Exception\InvalidOperationException;
9
10/**
11 * Batch version of Gradient Descent to optimize the weights
12 * of a classifier given samples, targets and the objective function to minimize
13 */
14class GD extends StochasticGD
15{
16    /**
17     * Number of samples given
18     *
19     * @var int|null
20     */
21    protected $sampleCount;
22
23    public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
24    {
25        $this->samples = $samples;
26        $this->targets = $targets;
27        $this->gradientCb = $gradientCb;
28        $this->sampleCount = count($this->samples);
29
30        // Batch learning is executed:
31        $currIter = 0;
32        $this->costValues = [];
33        while ($this->maxIterations > $currIter++) {
34            $theta = $this->theta;
35
36            // Calculate update terms for each sample
37            [$errors, $updates, $totalPenalty] = $this->gradient($theta);
38
39            $this->updateWeightsWithUpdates($updates, $totalPenalty);
40
41            $this->costValues[] = array_sum($errors) / $this->sampleCount;
42
43            if ($this->earlyStop($theta)) {
44                break;
45            }
46        }
47
48        $this->clear();
49
50        return $this->theta;
51    }
52
53    /**
54     * Calculates gradient, cost function and penalty term for each sample
55     * then returns them as an array of values
56     */
57    protected function gradient(array $theta): array
58    {
59        $costs = [];
60        $gradient = [];
61        $totalPenalty = 0;
62
63        if ($this->gradientCb === null) {
64            throw new InvalidOperationException('Gradient callback is not defined');
65        }
66
67        foreach ($this->samples as $index => $sample) {
68            $target = $this->targets[$index];
69
70            $result = ($this->gradientCb)($theta, $sample, $target);
71            [$cost, $grad, $penalty] = array_pad($result, 3, 0);
72
73            $costs[] = $cost;
74            $gradient[] = $grad;
75            $totalPenalty += $penalty;
76        }
77
78        $totalPenalty /= $this->sampleCount;
79
80        return [$costs, $gradient, $totalPenalty];
81    }
82
83    protected function updateWeightsWithUpdates(array $updates, float $penalty): void
84    {
85        // Updates all weights at once
86        for ($i = 0; $i <= $this->dimensions; ++$i) {
87            if ($i === 0) {
88                $this->theta[0] -= $this->learningRate * array_sum($updates);
89            } else {
90                $col = array_column($this->samples, $i - 1);
91
92                $error = 0;
93                foreach ($col as $index => $val) {
94                    $error += $val * $updates[$index];
95                }
96
97                $this->theta[$i] -= $this->learningRate *
98                    ($error + $penalty * $this->theta[$i]);
99            }
100        }
101    }
102
103    /**
104     * Clears the optimizer internal vars after the optimization process.
105     */
106    protected function clear(): void
107    {
108        $this->sampleCount = null;
109        parent::clear();
110    }
111}
112