1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "lstm.h"
16 
17 #include <math.h>
18 
19 namespace ncnn {
20 
LSTM()21 LSTM::LSTM()
22 {
23     one_blob_only = false;
24     support_inplace = false;
25 }
26 
load_param(const ParamDict & pd)27 int LSTM::load_param(const ParamDict& pd)
28 {
29     num_output = pd.get(0, 0);
30     weight_data_size = pd.get(1, 0);
31     direction = pd.get(2, 0);
32     return 0;
33 }
34 
load_model(const ModelBin & mb)35 int LSTM::load_model(const ModelBin& mb)
36 {
37     int num_directions = direction == 2 ? 2 : 1;
38 
39     int size = weight_data_size / num_directions / num_output / 4;
40 
41     // raw weight data
42     weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
43     if (weight_xc_data.empty())
44         return -100;
45 
46     bias_c_data = mb.load(num_output, 4, num_directions, 0);
47     if (bias_c_data.empty())
48         return -100;
49 
50     weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
51     if (weight_hc_data.empty())
52         return -100;
53 
54     return 0;
55 }
56 
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)57 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
58 {
59     int size = bottom_blob.w;
60     int T = bottom_blob.h;
61 
62     int num_output = top_blob.w;
63 
64     // 4 x num_output
65     Mat gates(4, num_output, 4u, opt.workspace_allocator);
66     if (gates.empty())
67         return -100;
68 
69     // unroll
70     for (int t = 0; t < T; t++)
71     {
72         // clip hidden by continuation indicator
73         // h_cont_{t-1} = cont_t * h_{t-1}
74         // h_cont_{t-1} = h_{t-1} if cont_t == 1
75         //                0       otherwise
76         // calculate hidden
77         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
78 
79         int ti = reverse ? T - 1 - t : t;
80 
81         const float* x = bottom_blob.row(ti);
82         for (int q = 0; q < num_output; q++)
83         {
84             const float* bias_c_I = bias_c.row(0);
85             const float* bias_c_F = bias_c.row(1);
86             const float* bias_c_O = bias_c.row(2);
87             const float* bias_c_G = bias_c.row(3);
88 
89             float* gates_data = gates.row(q);
90 
91             // gate I F O G
92             const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
93             const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
94             const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
95             const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
96 
97             const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
98             const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
99             const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
100             const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
101 
102             float I = bias_c_I[q];
103             float F = bias_c_F[q];
104             float O = bias_c_O[q];
105             float G = bias_c_G[q];
106 
107             for (int i = 0; i < size; i++)
108             {
109                 float xi = x[i];
110 
111                 I += weight_xc_I[i] * xi;
112                 F += weight_xc_F[i] * xi;
113                 O += weight_xc_O[i] * xi;
114                 G += weight_xc_G[i] * xi;
115             }
116 
117             for (int i = 0; i < num_output; i++)
118             {
119                 float h_cont = hidden_state[i];
120 
121                 I += weight_hc_I[i] * h_cont;
122                 F += weight_hc_F[i] * h_cont;
123                 O += weight_hc_O[i] * h_cont;
124                 G += weight_hc_G[i] * h_cont;
125             }
126 
127             gates_data[0] = I;
128             gates_data[1] = F;
129             gates_data[2] = O;
130             gates_data[3] = G;
131         }
132 
133         // lstm unit
134         // sigmoid(I)
135         // sigmoid(F)
136         // sigmoid(O)
137         // tanh(G)
138         // c_t := f_t .* c_{t-1} + i_t .* g_t
139         // h_t := o_t .* tanh[c_t]
140         float* output_data = top_blob.row(ti);
141         for (int q = 0; q < num_output; q++)
142         {
143             const float* gates_data = gates.row(q);
144 
145             float I = gates_data[0];
146             float F = gates_data[1];
147             float O = gates_data[2];
148             float G = gates_data[3];
149 
150             I = 1.f / (1.f + exp(-I));
151             F = 1.f / (1.f + exp(-F));
152             O = 1.f / (1.f + exp(-O));
153             G = tanh(G);
154 
155             float cell2 = F * cell_state[q] + I * G;
156             float H = O * tanh(cell2);
157             cell_state[q] = cell2;
158             hidden_state[q] = H;
159             output_data[q] = H;
160         }
161     }
162 
163     return 0;
164 }
165 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const166 int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
167 {
168     int T = bottom_blob.h;
169 
170     int num_directions = direction == 2 ? 2 : 1;
171 
172     // initial hidden state
173     Mat hidden(num_output, 4u, opt.workspace_allocator);
174     if (hidden.empty())
175         return -100;
176     hidden.fill(0.f);
177 
178     Mat cell(num_output, 4u, opt.workspace_allocator);
179     if (cell.empty())
180         return -100;
181     cell.fill(0.f);
182 
183     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
184     if (top_blob.empty())
185         return -100;
186 
187     // Uni directional
188     if (direction == 0 || direction == 1)
189     {
190         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
191         if (ret != 0)
192             return ret;
193     }
194 
195     if (direction == 2)
196     {
197         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
198         if (top_blob_forward.empty())
199             return -100;
200 
201         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
202         if (top_blob_reverse.empty())
203             return -100;
204 
205         int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
206         if (ret0 != 0)
207             return ret0;
208 
209         hidden.fill(0.0f);
210         cell.fill(0.0f);
211 
212         int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
213         if (ret1 != 0)
214             return ret1;
215 
216         // concat w
217         for (int i = 0; i < T; i++)
218         {
219             const float* pf = top_blob_forward.row(i);
220             const float* pr = top_blob_reverse.row(i);
221             float* ptr = top_blob.row(i);
222 
223             memcpy(ptr, pf, num_output * sizeof(float));
224             memcpy(ptr + num_output, pr, num_output * sizeof(float));
225         }
226     }
227 
228     return 0;
229 }
230 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const231 int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
232 {
233     const Mat& bottom_blob = bottom_blobs[0];
234     int T = bottom_blob.h;
235     int num_directions = direction == 2 ? 2 : 1;
236 
237     Mat hidden;
238     Mat cell;
239     Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator;
240     if (bottom_blobs.size() == 3)
241     {
242         hidden = bottom_blobs[1].clone(hidden_cell_allocator);
243         cell = bottom_blobs[2].clone(hidden_cell_allocator);
244     }
245     else
246     {
247         hidden.create(num_output, num_directions, 4u, hidden_cell_allocator);
248         if (hidden.empty())
249             return -100;
250         hidden.fill(0.f);
251 
252         cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
253         if (cell.empty())
254             return -100;
255         cell.fill(0.f);
256     }
257 
258     Mat& top_blob = top_blobs[0];
259     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
260     if (top_blob.empty())
261         return -100;
262 
263     // Uni directional
264     if (direction == 0 || direction == 1)
265     {
266         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
267         if (ret != 0)
268             return ret;
269     }
270 
271     if (direction == 2)
272     {
273         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
274         if (top_blob_forward.empty())
275             return -100;
276 
277         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
278         if (top_blob_reverse.empty())
279             return -100;
280 
281         Mat hidden0 = hidden.row_range(0, 1);
282         Mat cell0 = cell.row_range(0, 1);
283         int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt);
284         if (ret0 != 0)
285             return ret0;
286 
287         Mat hidden1 = hidden.row_range(1, 1);
288         Mat cell1 = cell.row_range(1, 1);
289         int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt);
290         if (ret1 != 0)
291             return ret1;
292 
293         // concat w
294         for (int i = 0; i < T; i++)
295         {
296             const float* pf = top_blob_forward.row(i);
297             const float* pr = top_blob_reverse.row(i);
298             float* ptr = top_blob.row(i);
299 
300             memcpy(ptr, pf, num_output * sizeof(float));
301             memcpy(ptr + num_output, pr, num_output * sizeof(float));
302         }
303     }
304 
305     if (top_blobs.size() == 3)
306     {
307         top_blobs[1] = hidden;
308         top_blobs[2] = cell;
309     }
310 
311     return 0;
312 }
313 
314 } // namespace ncnn
315