1# SPDX-License-Identifier: Apache-2.0
2
3from __future__ import absolute_import
4from __future__ import division
5from __future__ import print_function
6from __future__ import unicode_literals
7
8import numpy as np  # type: ignore
9from typing import Any, Tuple
10
11import onnx
12from ..base import Base
13from . import expect
14
15
16class LSTM_Helper():
17    def __init__(self, **params):  # type: (*Any) -> None
18        # LSTM Input Names
19        X = str('X')
20        W = str('W')
21        R = str('R')
22        B = str('B')
23        H_0 = str('initial_h')
24        C_0 = str('initial_c')
25        P = str('P')
26        LAYOUT = str('layout')
27        number_of_gates = 4
28        number_of_peepholes = 3
29
30        required_inputs = [X, W, R]
31        for i in required_inputs:
32            assert i in params, "Missing Required Input: {0}".format(i)
33
34        self.num_directions = params[W].shape[0]
35
36        if self.num_directions == 1:
37            for k in params.keys():
38                if k != X:
39                    params[k] = np.squeeze(params[k], axis=0)
40
41            hidden_size = params[R].shape[-1]
42            batch_size = params[X].shape[1]
43
44            layout = params[LAYOUT] if LAYOUT in params else 0
45            x = params[X]
46            x = x if layout == 0 else np.swapaxes(x, 0, 1)
47            b = params[B] if B in params else np.zeros(2 * number_of_gates * hidden_size, dtype=np.float32)
48            p = params[P] if P in params else np.zeros(number_of_peepholes * hidden_size, dtype=np.float32)
49            h_0 = params[H_0] if H_0 in params else np.zeros((batch_size, hidden_size), dtype=np.float32)
50            c_0 = params[C_0] if C_0 in params else np.zeros((batch_size, hidden_size), dtype=np.float32)
51
52            self.X = x
53            self.W = params[W]
54            self.R = params[R]
55            self.B = b
56            self.P = p
57            self.H_0 = h_0
58            self.C_0 = c_0
59            self.LAYOUT = layout
60
61        else:
62            raise NotImplementedError()
63
64    def f(self, x):  # type: (np.ndarray) -> np.ndarray
65        return 1 / (1 + np.exp(-x))
66
67    def g(self, x):  # type: (np.ndarray) -> np.ndarray
68        return np.tanh(x)
69
70    def h(self, x):  # type: (np.ndarray) -> np.ndarray
71        return np.tanh(x)
72
73    def step(self):  # type: () -> Tuple[np.ndarray, np.ndarray]
74        seq_length = self.X.shape[0]
75        hidden_size = self.H_0.shape[-1]
76        batch_size = self.X.shape[1]
77
78        Y = np.empty([seq_length, self.num_directions, batch_size, hidden_size])
79        h_list = []
80
81        [p_i, p_o, p_f] = np.split(self.P, 3)
82        H_t = self.H_0
83        C_t = self.C_0
84        for x in np.split(self.X, self.X.shape[0], axis=0):
85            gates = np.dot(x, np.transpose(self.W)) + np.dot(H_t, np.transpose(self.R)) + np.add(
86                *np.split(self.B, 2))
87            i, o, f, c = np.split(gates, 4, -1)
88            i = self.f(i + p_i * C_t)
89            f = self.f(f + p_f * C_t)
90            c = self.g(c)
91            C = f * C_t + i * c
92            o = self.f(o + p_o * C)
93            H = o * self.h(C)
94            h_list.append(H)
95            H_t = H
96            C_t = C
97
98        concatenated = np.concatenate(h_list)
99        if self.num_directions == 1:
100            Y[:, 0, :, :] = concatenated
101
102        if self.LAYOUT == 0:
103            Y_h = Y[-1]
104        else:
105            Y = np.transpose(Y, [2, 0, 1, 3])
106            Y_h = Y[:, :, -1, :]
107
108        return Y, Y_h
109
110
111class LSTM(Base):
112
113    @staticmethod
114    def export_defaults():  # type: () -> None
115        input = np.array([[[1., 2.], [3., 4.], [5., 6.]]]).astype(np.float32)
116
117        input_size = 2
118        hidden_size = 3
119        weight_scale = 0.1
120        number_of_gates = 4
121
122        node = onnx.helper.make_node(
123            'LSTM',
124            inputs=['X', 'W', 'R'],
125            outputs=['', 'Y_h'],
126            hidden_size=hidden_size
127        )
128
129        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
130        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
131
132        lstm = LSTM_Helper(X=input, W=W, R=R)
133        _, Y_h = lstm.step()
134        expect(node, inputs=[input, W, R], outputs=[Y_h.astype(np.float32)], name='test_lstm_defaults')
135
136    @staticmethod
137    def export_initial_bias():  # type: () -> None
138        input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]).astype(np.float32)
139
140        input_size = 3
141        hidden_size = 4
142        weight_scale = 0.1
143        custom_bias = 0.1
144        number_of_gates = 4
145
146        node = onnx.helper.make_node(
147            'LSTM',
148            inputs=['X', 'W', 'R', 'B'],
149            outputs=['', 'Y_h'],
150            hidden_size=hidden_size
151        )
152
153        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
154        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
155
156        # Adding custom bias
157        W_B = custom_bias * np.ones((1, number_of_gates * hidden_size)).astype(np.float32)
158        R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
159        B = np.concatenate((W_B, R_B), 1)
160
161        lstm = LSTM_Helper(X=input, W=W, R=R, B=B)
162        _, Y_h = lstm.step()
163        expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_lstm_with_initial_bias')
164
165    @staticmethod
166    def export_peepholes():  # type: () -> None
167        input = np.array([[[1., 2., 3., 4.], [5., 6., 7., 8.]]]).astype(np.float32)
168
169        input_size = 4
170        hidden_size = 3
171        weight_scale = 0.1
172        number_of_gates = 4
173        number_of_peepholes = 3
174
175        node = onnx.helper.make_node(
176            'LSTM',
177            inputs=['X', 'W', 'R', 'B', 'sequence_lens', 'initial_h', 'initial_c', 'P'],
178            outputs=['', 'Y_h'],
179            hidden_size=hidden_size
180        )
181
182        # Initializing Inputs
183        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
184        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
185        B = np.zeros((1, 2 * number_of_gates * hidden_size)).astype(np.float32)
186        seq_lens = np.repeat(input.shape[0], input.shape[1]).astype(np.int32)
187        init_h = np.zeros((1, input.shape[1], hidden_size)).astype(np.float32)
188        init_c = np.zeros((1, input.shape[1], hidden_size)).astype(np.float32)
189        P = weight_scale * np.ones((1, number_of_peepholes * hidden_size)).astype(np.float32)
190
191        lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, initial_c=init_c, initial_h=init_h)
192        _, Y_h = lstm.step()
193        expect(node, inputs=[input, W, R, B, seq_lens, init_h, init_c, P], outputs=[Y_h.astype(np.float32)],
194               name='test_lstm_with_peepholes')
195
196    @staticmethod
197    def export_batchwise():  # type: () -> None
198        input = np.array([[[1., 2.]], [[3., 4.]], [[5., 6.]]]).astype(np.float32)
199
200        input_size = 2
201        hidden_size = 7
202        weight_scale = 0.3
203        number_of_gates = 4
204        layout = 1
205
206        node = onnx.helper.make_node(
207            'LSTM',
208            inputs=['X', 'W', 'R'],
209            outputs=['Y', 'Y_h'],
210            hidden_size=hidden_size,
211            layout=layout
212        )
213
214        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
215        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
216
217        lstm = LSTM_Helper(X=input, W=W, R=R, layout=layout)
218        Y, Y_h = lstm.step()
219        expect(node, inputs=[input, W, R], outputs=[Y.astype(np.float32), Y_h.astype(np.float32)], name='test_lstm_batchwise')
220