1from __future__ import absolute_import
2from __future__ import division
3from __future__ import print_function
4from __future__ import unicode_literals
5
6import numpy as np  # type: ignore
7from typing import Any, Tuple
8
9import onnx
10from ..base import Base
11from . import expect
12
13
14class LSTM_Helper():
15    def __init__(self, **params):  # type: (*Any) -> None
16        # LSTM Input Names
17        X = str('X')
18        W = str('W')
19        R = str('R')
20        B = str('B')
21        H_0 = str('initial_h')
22        C_0 = str('initial_c')
23        P = str('P')
24        number_of_gates = 4
25        number_of_peepholes = 3
26
27        required_inputs = [X, W, R]
28        for i in required_inputs:
29            assert i in params, "Missing Required Input: {0}".format(i)
30
31        self.num_directions = params[W].shape[0]
32
33        if self.num_directions == 1:
34            for k in params.keys():
35                if k != X:
36                    params[k] = np.squeeze(params[k], axis=0)
37
38            hidden_size = params[R].shape[-1]
39            batch_size = params[X].shape[1]
40
41            b = params[B] if B in params else np.zeros(2 * number_of_gates * hidden_size, dtype=np.float32)
42            p = params[P] if P in params else np.zeros(number_of_peepholes * hidden_size, dtype=np.float32)
43            h_0 = params[H_0] if H_0 in params else np.zeros((batch_size, hidden_size), dtype=np.float32)
44            c_0 = params[C_0] if C_0 in params else np.zeros((batch_size, hidden_size), dtype=np.float32)
45
46            self.X = params[X]
47            self.W = params[W]
48            self.R = params[R]
49            self.B = b
50            self.P = p
51            self.H_0 = h_0
52            self.C_0 = c_0
53        else:
54            raise NotImplementedError()
55
56    def f(self, x):  # type: (np.ndarray) -> np.ndarray
57        return 1 / (1 + np.exp(-x))
58
59    def g(self, x):  # type: (np.ndarray) -> np.ndarray
60        return np.tanh(x)
61
62    def h(self, x):  # type: (np.ndarray) -> np.ndarray
63        return np.tanh(x)
64
65    def step(self):  # type: () -> Tuple[np.ndarray, np.ndarray]
66        [p_i, p_o, p_f] = np.split(self.P, 3)
67        h_list = []
68        H_t = self.H_0
69        C_t = self.C_0
70        for x in np.split(self.X, self.X.shape[0], axis=0):
71            gates = np.dot(x, np.transpose(self.W)) + np.dot(H_t, np.transpose(self.R)) + np.add(
72                *np.split(self.B, 2))
73            i, o, f, c = np.split(gates, 4, -1)
74            i = self.f(i + p_i * C_t)
75            f = self.f(f + p_f * C_t)
76            c = self.g(c)
77            C = f * C_t + i * c
78            o = self.f(o + p_o * C)
79            H = o * self.h(C)
80            h_list.append(H)
81            H_t = H
82            C_t = C
83        concatenated = np.concatenate(h_list)
84        if self.num_directions == 1:
85            output = np.expand_dims(concatenated, 1)
86        return output, h_list[-1]
87
88
89class LSTM(Base):
90
91    @staticmethod
92    def export_defaults():  # type: () -> None
93        input = np.array([[[1., 2.], [3., 4.], [5., 6.]]]).astype(np.float32)
94
95        input_size = 2
96        hidden_size = 3
97        weight_scale = 0.1
98        number_of_gates = 4
99
100        node = onnx.helper.make_node(
101            'LSTM',
102            inputs=['X', 'W', 'R'],
103            outputs=['', 'Y'],
104            hidden_size=hidden_size
105        )
106
107        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
108        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
109
110        lstm = LSTM_Helper(X=input, W=W, R=R)
111        _, Y_h = lstm.step()
112        expect(node, inputs=[input, W, R], outputs=[Y_h.astype(np.float32)], name='test_lstm_defaults')
113
114    @staticmethod
115    def export_initial_bias():  # type: () -> None
116        input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]).astype(np.float32)
117
118        input_size = 3
119        hidden_size = 4
120        weight_scale = 0.1
121        custom_bias = 0.1
122        number_of_gates = 4
123
124        node = onnx.helper.make_node(
125            'LSTM',
126            inputs=['X', 'W', 'R', 'B'],
127            outputs=['', 'Y'],
128            hidden_size=hidden_size
129        )
130
131        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
132        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
133
134        # Adding custom bias
135        W_B = custom_bias * np.ones((1, number_of_gates * hidden_size)).astype(np.float32)
136        R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
137        B = np.concatenate((W_B, R_B), 1)
138
139        lstm = LSTM_Helper(X=input, W=W, R=R, B=B)
140        _, Y_h = lstm.step()
141        expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_lstm_with_initial_bias')
142
143    @staticmethod
144    def export_peepholes():  # type: () -> None
145        input = np.array([[[1., 2., 3., 4.], [5., 6., 7., 8.]]]).astype(np.float32)
146
147        input_size = 4
148        hidden_size = 3
149        weight_scale = 0.1
150        number_of_gates = 4
151        number_of_peepholes = 3
152
153        node = onnx.helper.make_node(
154            'LSTM',
155            inputs=['X', 'W', 'R', 'B', 'sequence_lens', 'initial_h', 'initial_c', 'P'],
156            outputs=['', 'Y'],
157            hidden_size=hidden_size
158        )
159
160        # Initializing Inputs
161        W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
162        R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
163        B = np.zeros((1, 2 * number_of_gates * hidden_size)).astype(np.float32)
164        seq_lens = np.repeat(input.shape[0], input.shape[1]).astype(np.int32)
165        init_h = np.zeros((1, input.shape[1], hidden_size)).astype(np.float32)
166        init_c = np.zeros((1, input.shape[1], hidden_size)).astype(np.float32)
167        P = weight_scale * np.ones((1, number_of_peepholes * hidden_size)).astype(np.float32)
168
169        lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, initial_c=init_c, initial_h=init_h)
170        _, Y_h = lstm.step()
171        expect(node, inputs=[input, W, R, B, seq_lens, init_h, init_c, P], outputs=[Y_h.astype(np.float32)],
172               name='test_lstm_with_peepholes')
173