1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "lstm.h"
16
17 #include <math.h>
18
19 namespace ncnn {
20
LSTM()21 LSTM::LSTM()
22 {
23 one_blob_only = false;
24 support_inplace = false;
25 }
26
load_param(const ParamDict & pd)27 int LSTM::load_param(const ParamDict& pd)
28 {
29 num_output = pd.get(0, 0);
30 weight_data_size = pd.get(1, 0);
31 direction = pd.get(2, 0);
32 return 0;
33 }
34
load_model(const ModelBin & mb)35 int LSTM::load_model(const ModelBin& mb)
36 {
37 int num_directions = direction == 2 ? 2 : 1;
38
39 int size = weight_data_size / num_directions / num_output / 4;
40
41 // raw weight data
42 weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
43 if (weight_xc_data.empty())
44 return -100;
45
46 bias_c_data = mb.load(num_output, 4, num_directions, 0);
47 if (bias_c_data.empty())
48 return -100;
49
50 weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
51 if (weight_hc_data.empty())
52 return -100;
53
54 return 0;
55 }
56
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)57 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
58 {
59 int size = bottom_blob.w;
60 int T = bottom_blob.h;
61
62 int num_output = top_blob.w;
63
64 // 4 x num_output
65 Mat gates(4, num_output, 4u, opt.workspace_allocator);
66 if (gates.empty())
67 return -100;
68
69 // unroll
70 for (int t = 0; t < T; t++)
71 {
72 // clip hidden by continuation indicator
73 // h_cont_{t-1} = cont_t * h_{t-1}
74 // h_cont_{t-1} = h_{t-1} if cont_t == 1
75 // 0 otherwise
76 // calculate hidden
77 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
78
79 int ti = reverse ? T - 1 - t : t;
80
81 const float* x = bottom_blob.row(ti);
82 for (int q = 0; q < num_output; q++)
83 {
84 const float* bias_c_I = bias_c.row(0);
85 const float* bias_c_F = bias_c.row(1);
86 const float* bias_c_O = bias_c.row(2);
87 const float* bias_c_G = bias_c.row(3);
88
89 float* gates_data = gates.row(q);
90
91 // gate I F O G
92 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
93 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
94 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
95 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
96
97 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
98 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
99 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
100 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
101
102 float I = bias_c_I[q];
103 float F = bias_c_F[q];
104 float O = bias_c_O[q];
105 float G = bias_c_G[q];
106
107 for (int i = 0; i < size; i++)
108 {
109 float xi = x[i];
110
111 I += weight_xc_I[i] * xi;
112 F += weight_xc_F[i] * xi;
113 O += weight_xc_O[i] * xi;
114 G += weight_xc_G[i] * xi;
115 }
116
117 for (int i = 0; i < num_output; i++)
118 {
119 float h_cont = hidden_state[i];
120
121 I += weight_hc_I[i] * h_cont;
122 F += weight_hc_F[i] * h_cont;
123 O += weight_hc_O[i] * h_cont;
124 G += weight_hc_G[i] * h_cont;
125 }
126
127 gates_data[0] = I;
128 gates_data[1] = F;
129 gates_data[2] = O;
130 gates_data[3] = G;
131 }
132
133 // lstm unit
134 // sigmoid(I)
135 // sigmoid(F)
136 // sigmoid(O)
137 // tanh(G)
138 // c_t := f_t .* c_{t-1} + i_t .* g_t
139 // h_t := o_t .* tanh[c_t]
140 float* output_data = top_blob.row(ti);
141 for (int q = 0; q < num_output; q++)
142 {
143 const float* gates_data = gates.row(q);
144
145 float I = gates_data[0];
146 float F = gates_data[1];
147 float O = gates_data[2];
148 float G = gates_data[3];
149
150 I = 1.f / (1.f + exp(-I));
151 F = 1.f / (1.f + exp(-F));
152 O = 1.f / (1.f + exp(-O));
153 G = tanh(G);
154
155 float cell2 = F * cell_state[q] + I * G;
156 float H = O * tanh(cell2);
157 cell_state[q] = cell2;
158 hidden_state[q] = H;
159 output_data[q] = H;
160 }
161 }
162
163 return 0;
164 }
165
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const166 int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
167 {
168 int T = bottom_blob.h;
169
170 int num_directions = direction == 2 ? 2 : 1;
171
172 // initial hidden state
173 Mat hidden(num_output, 4u, opt.workspace_allocator);
174 if (hidden.empty())
175 return -100;
176 hidden.fill(0.f);
177
178 Mat cell(num_output, 4u, opt.workspace_allocator);
179 if (cell.empty())
180 return -100;
181 cell.fill(0.f);
182
183 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
184 if (top_blob.empty())
185 return -100;
186
187 // Uni directional
188 if (direction == 0 || direction == 1)
189 {
190 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
191 if (ret != 0)
192 return ret;
193 }
194
195 if (direction == 2)
196 {
197 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
198 if (top_blob_forward.empty())
199 return -100;
200
201 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
202 if (top_blob_reverse.empty())
203 return -100;
204
205 int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
206 if (ret0 != 0)
207 return ret0;
208
209 hidden.fill(0.0f);
210 cell.fill(0.0f);
211
212 int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
213 if (ret1 != 0)
214 return ret1;
215
216 // concat w
217 for (int i = 0; i < T; i++)
218 {
219 const float* pf = top_blob_forward.row(i);
220 const float* pr = top_blob_reverse.row(i);
221 float* ptr = top_blob.row(i);
222
223 memcpy(ptr, pf, num_output * sizeof(float));
224 memcpy(ptr + num_output, pr, num_output * sizeof(float));
225 }
226 }
227
228 return 0;
229 }
230
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const231 int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
232 {
233 const Mat& bottom_blob = bottom_blobs[0];
234 int T = bottom_blob.h;
235 int num_directions = direction == 2 ? 2 : 1;
236
237 Mat hidden;
238 Mat cell;
239 Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator;
240 if (bottom_blobs.size() == 3)
241 {
242 hidden = bottom_blobs[1].clone(hidden_cell_allocator);
243 cell = bottom_blobs[2].clone(hidden_cell_allocator);
244 }
245 else
246 {
247 hidden.create(num_output, num_directions, 4u, hidden_cell_allocator);
248 if (hidden.empty())
249 return -100;
250 hidden.fill(0.f);
251
252 cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
253 if (cell.empty())
254 return -100;
255 cell.fill(0.f);
256 }
257
258 Mat& top_blob = top_blobs[0];
259 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
260 if (top_blob.empty())
261 return -100;
262
263 // Uni directional
264 if (direction == 0 || direction == 1)
265 {
266 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
267 if (ret != 0)
268 return ret;
269 }
270
271 if (direction == 2)
272 {
273 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
274 if (top_blob_forward.empty())
275 return -100;
276
277 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
278 if (top_blob_reverse.empty())
279 return -100;
280
281 Mat hidden0 = hidden.row_range(0, 1);
282 Mat cell0 = cell.row_range(0, 1);
283 int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt);
284 if (ret0 != 0)
285 return ret0;
286
287 Mat hidden1 = hidden.row_range(1, 1);
288 Mat cell1 = cell.row_range(1, 1);
289 int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt);
290 if (ret1 != 0)
291 return ret1;
292
293 // concat w
294 for (int i = 0; i < T; i++)
295 {
296 const float* pf = top_blob_forward.row(i);
297 const float* pr = top_blob_reverse.row(i);
298 float* ptr = top_blob.row(i);
299
300 memcpy(ptr, pf, num_output * sizeof(float));
301 memcpy(ptr + num_output, pr, num_output * sizeof(float));
302 }
303 }
304
305 if (top_blobs.size() == 3)
306 {
307 top_blobs[1] = hidden;
308 top_blobs[2] = cell;
309 }
310
311 return 0;
312 }
313
314 } // namespace ncnn
315