1<?php
2
3namespace Rubix\ML\NeuralNet\Layers;
4
5use Tensor\Matrix;
6use Rubix\ML\Deferred;
7use Rubix\ML\NeuralNet\Optimizers\Optimizer;
8use Rubix\ML\NeuralNet\CostFunctions\CrossEntropy;
9use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid;
10use Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss;
11use Rubix\ML\Exceptions\InvalidArgumentException;
12use Rubix\ML\Exceptions\RuntimeException;
13
14use function count;
15
16/**
17 * Binary
18 *
19 * This Binary layer consists of a single sigmoid neuron capable of distinguishing between
20 * two discrete classes.
21 *
22 * @internal
23 *
24 * @category    Machine Learning
25 * @package     Rubix/ML
26 * @author      Andrew DalPino
27 */
28class Binary implements Output
29{
30    /**
31     * The labels of either of the possible outcomes.
32     *
33     * @var string[]
34     */
35    protected $classes = [
36        //
37    ];
38
39    /**
40     * The function that computes the loss of erroneous activations.
41     *
42     * @var \Rubix\ML\NeuralNet\CostFunctions\CostFunction
43     */
44    protected $costFn;
45
46    /**
47     * The sigmoid activation function.
48     *
49     * @var \Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid
50     */
51    protected $activationFn;
52
53    /**
54     * The memorized input matrix.
55     *
56     * @var \Tensor\Matrix|null
57     */
58    protected $input;
59
60    /**
61     * The memorized activation matrix.
62     *
63     * @var \Tensor\Matrix|null
64     */
65    protected $computed;
66
67    /**
68     * @param string[] $classes
69     * @param \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss|null $costFn
70     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
71     */
72    public function __construct(array $classes, ?ClassificationLoss $costFn = null)
73    {
74        $classes = array_unique($classes);
75
76        if (count($classes) !== 2) {
77            throw new InvalidArgumentException('Number of classes'
78                . ' must be 2, ' . count($classes) . ' given.');
79        }
80
81        $this->classes = array_flip(array_values($classes));
82        $this->costFn = $costFn ?? new CrossEntropy();
83        $this->activationFn = new Sigmoid();
84    }
85
86    /**
87     * Return the width of the layer.
88     *
89     * @return int
90     */
91    public function width() : int
92    {
93        return 1;
94    }
95
96    /**
97     * Initialize the layer with the fan in from the previous layer and return
98     * the fan out for this layer.
99     *
100     * @param int $fanIn
101     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
102     * @return int
103     */
104    public function initialize(int $fanIn) : int
105    {
106        if ($fanIn !== 1) {
107            throw new InvalidArgumentException('Fan in must be'
108                . " equal to 1, $fanIn given.");
109        }
110
111        return 1;
112    }
113
114    /**
115     * Compute a forward pass through the layer.
116     *
117     * @param \Tensor\Matrix $input
118     * @return \Tensor\Matrix
119     */
120    public function forward(Matrix $input) : Matrix
121    {
122        $this->input = $input;
123
124        $this->computed = $this->activationFn->compute($input);
125
126        return $this->computed;
127    }
128
129    /**
130     * Compute an inferential pass through the layer.
131     *
132     * @param \Tensor\Matrix $input
133     * @return \Tensor\Matrix
134     */
135    public function infer(Matrix $input) : Matrix
136    {
137        return $this->activationFn->compute($input);
138    }
139
140    /**
141     * Compute the gradient and loss at the output.
142     *
143     * @param string[] $labels
144     * @param \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer
145     * @throws \Rubix\ML\Exceptions\RuntimeException
146     * @return (\Rubix\ML\Deferred|float)[]
147     */
148    public function back(array $labels, Optimizer $optimizer) : array
149    {
150        if (!$this->input or !$this->computed) {
151            throw new RuntimeException('Must perform forward pass'
152                . ' before backpropagating.');
153        }
154
155        $expected = [];
156
157        foreach ($labels as $label) {
158            $expected[] = $this->classes[$label];
159        }
160
161        $expected = Matrix::quick([$expected]);
162
163        $input = $this->input;
164        $computed = $this->computed;
165
166        $gradient = new Deferred([$this, 'gradient'], [$input, $computed, $expected]);
167
168        $loss = $this->costFn->compute($computed, $expected);
169
170        $this->input = $this->computed = null;
171
172        return [$gradient, $loss];
173    }
174
175    /**
176     * Calculate the gradient for the previous layer.
177     *
178     * @param \Tensor\Matrix $input
179     * @param \Tensor\Matrix $computed
180     * @param \Tensor\Matrix $expected
181     * @return \Tensor\Matrix
182     */
183    public function gradient(Matrix $input, Matrix $computed, Matrix $expected) : Matrix
184    {
185        if ($this->costFn instanceof CrossEntropy) {
186            return $computed->subtract($expected)
187                ->divide($computed->n());
188        }
189
190        $dL = $this->costFn->differentiate($computed, $expected)
191            ->divide($computed->n());
192
193        return $this->activationFn->differentiate($input, $computed)
194            ->multiply($dL);
195    }
196}
197