1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "lstm_x86.h"
16
17 #include "x86_activation.h"
18 #include "x86_usability.h"
19
20 #include <math.h>
21 #include "layer_type.h"
22
23 namespace ncnn {
24
LSTM_x86()25 LSTM_x86::LSTM_x86()
26 {
27 #ifdef __AVX__
28 support_weight_fp16_storage = true;
29 #endif
30 one_blob_only = false;
31 support_inplace = false;
32 }
33
create_pipeline(const Option & opt)34 int LSTM_x86::create_pipeline(const Option& opt)
35 {
36 #if __AVX__
37 if (opt.use_weight_fp16_storage)
38 {
39 ncnn::cast_float32_to_float16(weight_xc_data, weight_xc_data_fp16, opt);
40 ncnn::cast_float32_to_float16(weight_hc_data, weight_hc_data_fp16, opt);
41 }
42 #else
43 (void)(opt);
44 #endif // __AVX__
45
46 return 0;
47 }
48 #ifdef __AVX__
49
lstm_fp16(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)50 static int lstm_fp16(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
51 {
52 int size = bottom_blob.w;
53 int T = bottom_blob.h;
54
55 int num_output = top_blob.w;
56 // fprintf(stderr, "bottom_blob = %d x %d x %d num_output = %d \n", bottom_blob.w,bottom_blob.h,bottom_blob.c,num_output);
57 // 4 x num_output
58 Mat gates(num_output, 4, 4u, opt.workspace_allocator);
59 if (gates.empty())
60 return -100;
61 // unroll
62 for (int t = 0; t < T; t++)
63 {
64 // clip hidden by continuation indicator
65 // h_cont_{t-1} = cont_t * h_{t-1}
66 // h_cont_{t-1} = h_{t-1} if cont_t == 1
67 // 0 otherwise
68 // calculate hidden
69 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
70 int ti = reverse ? T - 1 - t : t;
71 int remain_output = (num_output >> 1) << 1;
72 for (int q = 0; q + 1 < num_output; q += 2)
73 {
74 const float* x = bottom_blob.row(ti);
75 const float* hidden_ptr_r = hidden_state;
76 const float* bias_c_I = bias_c.row(0);
77 const float* bias_c_F = bias_c.row(1);
78 const float* bias_c_O = bias_c.row(2);
79 const float* bias_c_G = bias_c.row(3);
80
81 float* gates_data_I = gates.row(0);
82 float* gates_data_F = gates.row(1);
83 float* gates_data_O = gates.row(2);
84 float* gates_data_G = gates.row(3);
85 // gate I F O G
86 const unsigned short* weight_xc_I_0 = (const unsigned short*)weight_xc.row(num_output * 0 + q);
87 const unsigned short* weight_xc_F_0 = (const unsigned short*)weight_xc.row(num_output * 1 + q);
88 const unsigned short* weight_xc_O_0 = (const unsigned short*)weight_xc.row(num_output * 2 + q);
89 const unsigned short* weight_xc_G_0 = (const unsigned short*)weight_xc.row(num_output * 3 + q);
90 const unsigned short* weight_xc_I_1 = (const unsigned short*)weight_xc.row(num_output * 0 + (q + 1));
91 const unsigned short* weight_xc_F_1 = (const unsigned short*)weight_xc.row(num_output * 1 + (q + 1));
92 const unsigned short* weight_xc_O_1 = (const unsigned short*)weight_xc.row(num_output * 2 + (q + 1));
93 const unsigned short* weight_xc_G_1 = (const unsigned short*)weight_xc.row(num_output * 3 + (q + 1));
94
95 const unsigned short* weight_hc_I_0 = (const unsigned short*)weight_hc.row(num_output * 0 + q);
96 const unsigned short* weight_hc_F_0 = (const unsigned short*)weight_hc.row(num_output * 1 + q);
97 const unsigned short* weight_hc_O_0 = (const unsigned short*)weight_hc.row(num_output * 2 + q);
98 const unsigned short* weight_hc_G_0 = (const unsigned short*)weight_hc.row(num_output * 3 + q);
99 const unsigned short* weight_hc_I_1 = (const unsigned short*)weight_hc.row(num_output * 0 + (q + 1));
100 const unsigned short* weight_hc_F_1 = (const unsigned short*)weight_hc.row(num_output * 1 + (q + 1));
101 const unsigned short* weight_hc_O_1 = (const unsigned short*)weight_hc.row(num_output * 2 + (q + 1));
102 const unsigned short* weight_hc_G_1 = (const unsigned short*)weight_hc.row(num_output * 3 + (q + 1));
103
104 // float I = bias_c_I[q];
105 // float F = bias_c_F[q];
106 // float O = bias_c_O[q];
107 // float G = bias_c_G[q];
108 __m256 _sumI_0 = _mm256_setzero_ps();
109 __m256 _sumF_0 = _mm256_setzero_ps();
110 __m256 _sumO_0 = _mm256_setzero_ps();
111 __m256 _sumG_0 = _mm256_setzero_ps();
112 __m256 _sumI_1 = _mm256_setzero_ps();
113 __m256 _sumF_1 = _mm256_setzero_ps();
114 __m256 _sumO_1 = _mm256_setzero_ps();
115 __m256 _sumG_1 = _mm256_setzero_ps();
116 int nn_num_size = size >> 3;
117 int remain_size = size & 7;
118 for (; nn_num_size > 0; nn_num_size--)
119 {
120 __m256 xi = _mm256_loadu_ps(x);
121 _sumI_0 = _mm256_fmadd_ps(loadfp16(weight_xc_I_0), xi, _sumI_0);
122 _sumF_0 = _mm256_fmadd_ps(loadfp16(weight_xc_F_0), xi, _sumF_0);
123 _sumO_0 = _mm256_fmadd_ps(loadfp16(weight_xc_O_0), xi, _sumO_0);
124 _sumG_0 = _mm256_fmadd_ps(loadfp16(weight_xc_G_0), xi, _sumG_0);
125 _sumI_1 = _mm256_fmadd_ps(loadfp16(weight_xc_I_1), xi, _sumI_1);
126 _sumF_1 = _mm256_fmadd_ps(loadfp16(weight_xc_F_1), xi, _sumF_1);
127 _sumO_1 = _mm256_fmadd_ps(loadfp16(weight_xc_O_1), xi, _sumO_1);
128 _sumG_1 = _mm256_fmadd_ps(loadfp16(weight_xc_G_1), xi, _sumG_1);
129 x += 8;
130 weight_xc_I_0 += 8;
131 weight_xc_F_0 += 8;
132 weight_xc_O_0 += 8;
133 weight_xc_G_0 += 8;
134 weight_xc_I_1 += 8;
135 weight_xc_F_1 += 8;
136 weight_xc_O_1 += 8;
137 weight_xc_G_1 += 8;
138 }
139 int nn_num_output = num_output >> 3;
140 int remain_num_output = num_output & 7;
141 for (; nn_num_output > 0; nn_num_output--)
142 {
143 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
144
145 _sumI_0 = _mm256_fmadd_ps(loadfp16(weight_hc_I_0), h_cont, _sumI_0);
146 _sumF_0 = _mm256_fmadd_ps(loadfp16(weight_hc_F_0), h_cont, _sumF_0);
147 _sumO_0 = _mm256_fmadd_ps(loadfp16(weight_hc_O_0), h_cont, _sumO_0);
148 _sumG_0 = _mm256_fmadd_ps(loadfp16(weight_hc_G_0), h_cont, _sumG_0);
149 _sumI_1 = _mm256_fmadd_ps(loadfp16(weight_hc_I_1), h_cont, _sumI_1);
150 _sumF_1 = _mm256_fmadd_ps(loadfp16(weight_hc_F_1), h_cont, _sumF_1);
151 _sumO_1 = _mm256_fmadd_ps(loadfp16(weight_hc_O_1), h_cont, _sumO_1);
152 _sumG_1 = _mm256_fmadd_ps(loadfp16(weight_hc_G_1), h_cont, _sumG_1);
153 hidden_ptr_r += 8;
154 weight_hc_I_0 += 8;
155 weight_hc_F_0 += 8;
156 weight_hc_O_0 += 8;
157 weight_hc_G_0 += 8;
158 weight_hc_I_1 += 8;
159 weight_hc_F_1 += 8;
160 weight_hc_O_1 += 8;
161 weight_hc_G_1 += 8;
162 }
163 if (remain_size != 0)
164 {
165 unsigned short fp16_weights[8][8] = {{0}};
166 float _xi_f[8] = {0};
167 // No fast way to convert to fp32 one element at the time
168 // so batch an 8 lane vector.
169 for (int i = 0; i < remain_size; i++)
170 {
171 _xi_f[i] = *x;
172 fp16_weights[0][i] = *weight_xc_I_0;
173 fp16_weights[1][i] = *weight_xc_F_0;
174 fp16_weights[2][i] = *weight_xc_O_0;
175 fp16_weights[3][i] = *weight_xc_G_0;
176 fp16_weights[4][i] = *weight_xc_I_1;
177 fp16_weights[5][i] = *weight_xc_F_1;
178 fp16_weights[6][i] = *weight_xc_O_1;
179 fp16_weights[7][i] = *weight_xc_G_1;
180 x++;
181 weight_xc_I_0++;
182 weight_xc_F_0++;
183 weight_xc_O_0++;
184 weight_xc_G_0++;
185 weight_xc_I_1++;
186 weight_xc_F_1++;
187 weight_xc_O_1++;
188 weight_xc_G_1++;
189 }
190 __m256 xi = _mm256_loadu_ps(_xi_f);
191 _sumI_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), xi, _sumI_0);
192 _sumF_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), xi, _sumF_0);
193 _sumO_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), xi, _sumO_0);
194 _sumG_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), xi, _sumG_0);
195 _sumI_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[4]), xi, _sumI_1);
196 _sumF_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[5]), xi, _sumF_1);
197 _sumO_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[6]), xi, _sumO_1);
198 _sumG_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[7]), xi, _sumG_1);
199 }
200 if (remain_num_output != 0)
201 {
202 unsigned short fp16_weights[8][8] = {{0}};
203 float _hcont_f[8] = {0};
204 // No fast way to convert to fp32 one element at the time
205 // so batch an 8 lane vector.
206 for (int i = 0; i < remain_num_output; i++)
207 {
208 _hcont_f[i] = *hidden_ptr_r;
209 fp16_weights[0][i] = *weight_hc_I_0;
210 fp16_weights[1][i] = *weight_hc_F_0;
211 fp16_weights[2][i] = *weight_hc_O_0;
212 fp16_weights[3][i] = *weight_hc_G_0;
213 fp16_weights[4][i] = *weight_hc_I_1;
214 fp16_weights[5][i] = *weight_hc_F_1;
215 fp16_weights[6][i] = *weight_hc_O_1;
216 fp16_weights[7][i] = *weight_hc_G_1;
217 hidden_ptr_r++;
218 weight_hc_I_0++;
219 weight_hc_F_0++;
220 weight_hc_O_0++;
221 weight_hc_G_0++;
222 weight_hc_I_1++;
223 weight_hc_F_1++;
224 weight_hc_O_1++;
225 weight_hc_G_1++;
226 }
227 __m256 h_cont = _mm256_loadu_ps(_hcont_f);
228 _sumI_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), h_cont, _sumI_0);
229 _sumF_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), h_cont, _sumF_0);
230 _sumO_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), h_cont, _sumO_0);
231 _sumG_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), h_cont, _sumG_0);
232 _sumI_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[4]), h_cont, _sumI_1);
233 _sumF_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[5]), h_cont, _sumF_1);
234 _sumO_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[6]), h_cont, _sumO_1);
235 _sumG_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[7]), h_cont, _sumG_1);
236 }
237 float sums[8];
238 _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1));
239 sums[0] += bias_c_I[q];
240 sums[1] += bias_c_F[q];
241 sums[2] += bias_c_O[q];
242 sums[3] += bias_c_G[q];
243 sums[4] += bias_c_I[q + 1];
244 sums[5] += bias_c_F[q + 1];
245 sums[6] += bias_c_O[q + 1];
246 sums[7] += bias_c_G[q + 1];
247 gates_data_I[q] = sums[0];
248 gates_data_F[q] = sums[1];
249 gates_data_O[q] = sums[2];
250 gates_data_G[q] = sums[3];
251 gates_data_I[q + 1] = sums[4];
252 gates_data_F[q + 1] = sums[5];
253 gates_data_O[q + 1] = sums[6];
254 gates_data_G[q + 1] = sums[7];
255 }
256
257 for (int q = remain_output; q < num_output; q++)
258 {
259 const float* x = bottom_blob.row(ti);
260 const float* hidden_ptr_r = hidden_state;
261 const float* bias_c_I = bias_c.row(0);
262 const float* bias_c_F = bias_c.row(1);
263 const float* bias_c_O = bias_c.row(2);
264 const float* bias_c_G = bias_c.row(3);
265
266 float* gates_data_I = gates.row(0);
267 float* gates_data_F = gates.row(1);
268 float* gates_data_O = gates.row(2);
269 float* gates_data_G = gates.row(3);
270 // gate I F O G
271 const unsigned short* weight_xc_I = (const unsigned short*)weight_xc.row(num_output * 0 + q);
272 const unsigned short* weight_xc_F = (const unsigned short*)weight_xc.row(num_output * 1 + q);
273 const unsigned short* weight_xc_O = (const unsigned short*)weight_xc.row(num_output * 2 + q);
274 const unsigned short* weight_xc_G = (const unsigned short*)weight_xc.row(num_output * 3 + q);
275
276 const unsigned short* weight_hc_I = (const unsigned short*)weight_hc.row(num_output * 0 + q);
277 const unsigned short* weight_hc_F = (const unsigned short*)weight_hc.row(num_output * 1 + q);
278 const unsigned short* weight_hc_O = (const unsigned short*)weight_hc.row(num_output * 2 + q);
279 const unsigned short* weight_hc_G = (const unsigned short*)weight_hc.row(num_output * 3 + q);
280
281 // float I = bias_c_I[q];
282 // float F = bias_c_F[q];
283 // float O = bias_c_O[q];
284 // float G = bias_c_G[q];
285 __m256 _sumI = _mm256_setzero_ps();
286 __m256 _sumF = _mm256_setzero_ps();
287 __m256 _sumO = _mm256_setzero_ps();
288 __m256 _sumG = _mm256_setzero_ps();
289 int nn_num_size = size >> 3;
290 int remain_size = size & 7;
291 for (; nn_num_size > 0; nn_num_size--)
292 {
293 __m256 xi = _mm256_loadu_ps(x);
294 _sumI = _mm256_fmadd_ps(loadfp16(weight_xc_I), xi, _sumI);
295 _sumF = _mm256_fmadd_ps(loadfp16(weight_xc_F), xi, _sumF);
296 _sumO = _mm256_fmadd_ps(loadfp16(weight_xc_O), xi, _sumO);
297 _sumG = _mm256_fmadd_ps(loadfp16(weight_xc_G), xi, _sumG);
298 x += 8;
299 weight_xc_I += 8;
300 weight_xc_F += 8;
301 weight_xc_O += 8;
302 weight_xc_G += 8;
303 }
304 int nn_num_output = num_output >> 3;
305 int remain_num_output = num_output & 7;
306 for (; nn_num_output > 0; nn_num_output--)
307 {
308 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
309
310 _sumI = _mm256_fmadd_ps(loadfp16(weight_hc_I), h_cont, _sumI);
311 _sumF = _mm256_fmadd_ps(loadfp16(weight_hc_F), h_cont, _sumF);
312 _sumO = _mm256_fmadd_ps(loadfp16(weight_hc_O), h_cont, _sumO);
313 _sumG = _mm256_fmadd_ps(loadfp16(weight_hc_G), h_cont, _sumG);
314 hidden_ptr_r += 8;
315 weight_hc_I += 8;
316 weight_hc_F += 8;
317 weight_hc_O += 8;
318 weight_hc_G += 8;
319 }
320 if (remain_size != 0)
321 {
322 unsigned short fp16_weights[4][8] = {{0}};
323 float _xi_f[8] = {0};
324 // No fast way to convert to fp32 one element at the time
325 // so batch an 8 lane vector.
326 for (int i = 0; i < remain_size; i++)
327 {
328 _xi_f[i] = *x;
329 fp16_weights[0][i] = *weight_xc_I;
330 fp16_weights[1][i] = *weight_xc_F;
331 fp16_weights[2][i] = *weight_xc_O;
332 fp16_weights[3][i] = *weight_xc_G;
333 x++;
334 weight_xc_I++;
335 weight_xc_F++;
336 weight_xc_O++;
337 weight_xc_G++;
338 }
339 __m256 xi = _mm256_loadu_ps(_xi_f);
340 _sumI = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), xi, _sumI);
341 _sumF = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), xi, _sumF);
342 _sumO = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), xi, _sumO);
343 _sumG = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), xi, _sumG);
344 }
345 if (remain_num_output != 0)
346 {
347 unsigned short fp16_weights[4][8] = {{0}};
348 float _hcont_f[8] = {0};
349 // No fast way to convert to fp32 one element at the time
350 // so batch an 8 lane vector.
351 for (int i = 0; i < remain_num_output; i++)
352 {
353 _hcont_f[i] = *hidden_ptr_r;
354 fp16_weights[0][i] = *weight_hc_I;
355 fp16_weights[1][i] = *weight_hc_F;
356 fp16_weights[2][i] = *weight_hc_O;
357 fp16_weights[3][i] = *weight_hc_G;
358 hidden_ptr_r++;
359 weight_hc_I++;
360 weight_hc_F++;
361 weight_hc_O++;
362 weight_hc_G++;
363 }
364 __m256 h_cont = _mm256_loadu_ps(_hcont_f);
365 _sumI = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), h_cont, _sumI);
366 _sumF = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), h_cont, _sumF);
367 _sumO = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), h_cont, _sumO);
368 _sumG = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), h_cont, _sumG);
369 }
370
371 float sums[4];
372 _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG));
373 sums[0] += bias_c_I[q];
374 sums[1] += bias_c_F[q];
375 sums[2] += bias_c_O[q];
376 sums[3] += bias_c_G[q];
377 gates_data_I[q] = sums[0];
378 gates_data_F[q] = sums[1];
379 gates_data_O[q] = sums[2];
380 gates_data_G[q] = sums[3];
381 }
382
383 // lstm unit
384 // sigmoid(I)
385 // sigmoid(F)
386 // sigmoid(O)
387 // tanh(G)
388 // c_t := f_t .* c_{t-1} + i_t .* g_t
389 // h_t := o_t .* tanh[c_t]
390 float* output_data = top_blob.row(ti);
391 float* cell_ptr = cell_state;
392 float* hidden_ptr = hidden_state;
393 const float* gates_data_I = gates.row(0);
394 const float* gates_data_F = gates.row(1);
395 const float* gates_data_O = gates.row(2);
396 const float* gates_data_G = gates.row(3);
397 int nn_activation = num_output >> 3;
398 int remain_activations = num_output & 7;
399 for (; nn_activation > 0; nn_activation--)
400 {
401 __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I));
402 __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F));
403 __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O));
404 __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G));
405 __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G));
406 __m256 H = _mm256_mul_ps(O, tanh_avx(cell2));
407 _mm256_storeu_ps(cell_ptr, cell2);
408 _mm256_storeu_ps(hidden_ptr, H);
409 _mm256_storeu_ps(output_data, H);
410 cell_ptr += 8;
411 output_data += 8;
412 hidden_ptr += 8;
413 gates_data_I += 8;
414 gates_data_F += 8;
415 gates_data_O += 8;
416 gates_data_G += 8;
417 }
418 for (; remain_activations > 0; remain_activations--)
419 {
420 float I = *gates_data_I;
421 float F = *gates_data_F;
422 float O = *gates_data_O;
423 float G = *gates_data_G;
424
425 I = 1.f / (1.f + exp(-I));
426 F = 1.f / (1.f + exp(-F));
427 O = 1.f / (1.f + exp(-O));
428 G = tanh(G);
429 float cell2 = F * *cell_ptr + I * G;
430 float H = O * tanh(cell2);
431 *cell_ptr = cell2;
432 *hidden_ptr = H;
433 *output_data = H;
434 cell_ptr++;
435 output_data++;
436 hidden_ptr++;
437 gates_data_I++;
438 gates_data_F++;
439 gates_data_O++;
440 gates_data_G++;
441 }
442
443 // no cell output here
444 }
445
446 return 0;
447 }
448
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)449 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
450 {
451 int size = bottom_blob.w;
452 int T = bottom_blob.h;
453
454 int num_output = top_blob.w;
455
456 // 4 x num_output
457 Mat gates(num_output, 4, 4u, opt.workspace_allocator);
458 if (gates.empty())
459 return -100;
460
461 // unroll
462 for (int t = 0; t < T; t++)
463 {
464 // clip hidden by continuation indicator
465 // h_cont_{t-1} = cont_t * h_{t-1}
466 // h_cont_{t-1} = h_{t-1} if cont_t == 1
467 // 0 otherwise
468 // calculate hidden
469 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
470
471 int ti = reverse ? T - 1 - t : t;
472 int remain_output = (num_output >> 1) << 1;
473 for (int q = 0; q + 1 < num_output; q += 2)
474 {
475 const float* x = bottom_blob.row(ti);
476 const float* hidden_ptr_r = hidden_state;
477 const float* bias_c_I = bias_c.row(0);
478 const float* bias_c_F = bias_c.row(1);
479 const float* bias_c_O = bias_c.row(2);
480 const float* bias_c_G = bias_c.row(3);
481
482 float* gates_data_I = gates.row(0);
483 float* gates_data_F = gates.row(1);
484 float* gates_data_O = gates.row(2);
485 float* gates_data_G = gates.row(3);
486 // gate I F O G
487 const float* weight_xc_I_0 = weight_xc.row(num_output * 0 + q);
488 const float* weight_xc_F_0 = weight_xc.row(num_output * 1 + q);
489 const float* weight_xc_O_0 = weight_xc.row(num_output * 2 + q);
490 const float* weight_xc_G_0 = weight_xc.row(num_output * 3 + q);
491 const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + (q + 1));
492 const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + (q + 1));
493 const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + (q + 1));
494 const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + (q + 1));
495
496 const float* weight_hc_I_0 = weight_hc.row(num_output * 0 + q);
497 const float* weight_hc_F_0 = weight_hc.row(num_output * 1 + q);
498 const float* weight_hc_O_0 = weight_hc.row(num_output * 2 + q);
499 const float* weight_hc_G_0 = weight_hc.row(num_output * 3 + q);
500 const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + (q + 1));
501 const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + (q + 1));
502 const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + (q + 1));
503 const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + (q + 1));
504
505 // float I = bias_c_I[q];
506 // float F = bias_c_F[q];
507 // float O = bias_c_O[q];
508 // float G = bias_c_G[q];
509 __m256 _sumI_0 = _mm256_setzero_ps();
510 __m256 _sumF_0 = _mm256_setzero_ps();
511 __m256 _sumO_0 = _mm256_setzero_ps();
512 __m256 _sumG_0 = _mm256_setzero_ps();
513 __m256 _sumI_1 = _mm256_setzero_ps();
514 __m256 _sumF_1 = _mm256_setzero_ps();
515 __m256 _sumO_1 = _mm256_setzero_ps();
516 __m256 _sumG_1 = _mm256_setzero_ps();
517 int nn_num_size = size >> 3;
518 int remain_size = size & 7;
519 for (; nn_num_size > 0; nn_num_size--)
520 {
521 __m256 xi = _mm256_loadu_ps(x);
522 _sumI_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_I_0), xi, _sumI_0);
523 _sumF_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_F_0), xi, _sumF_0);
524 _sumO_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_O_0), xi, _sumO_0);
525 _sumG_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_G_0), xi, _sumG_0);
526 _sumI_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_I_1), xi, _sumI_1);
527 _sumF_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_F_1), xi, _sumF_1);
528 _sumO_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_O_1), xi, _sumO_1);
529 _sumG_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_G_1), xi, _sumG_1);
530 x += 8;
531 weight_xc_I_0 += 8;
532 weight_xc_F_0 += 8;
533 weight_xc_O_0 += 8;
534 weight_xc_G_0 += 8;
535 weight_xc_I_1 += 8;
536 weight_xc_F_1 += 8;
537 weight_xc_O_1 += 8;
538 weight_xc_G_1 += 8;
539 }
540 int nn_num_output = num_output >> 3;
541 int remain_num_output = num_output & 7;
542 for (; nn_num_output > 0; nn_num_output--)
543 {
544 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
545
546 _sumI_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_I_0), h_cont, _sumI_0);
547 _sumF_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_F_0), h_cont, _sumF_0);
548 _sumO_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_O_0), h_cont, _sumO_0);
549 _sumG_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_G_0), h_cont, _sumG_0);
550 _sumI_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_I_1), h_cont, _sumI_1);
551 _sumF_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_F_1), h_cont, _sumF_1);
552 _sumO_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_O_1), h_cont, _sumO_1);
553 _sumG_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_G_1), h_cont, _sumG_1);
554 hidden_ptr_r += 8;
555 weight_hc_I_0 += 8;
556 weight_hc_F_0 += 8;
557 weight_hc_O_0 += 8;
558 weight_hc_G_0 += 8;
559 weight_hc_I_1 += 8;
560 weight_hc_F_1 += 8;
561 weight_hc_O_1 += 8;
562 weight_hc_G_1 += 8;
563 }
564 float sums[8];
565 _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1));
566 sums[0] += bias_c_I[q];
567 sums[1] += bias_c_F[q];
568 sums[2] += bias_c_O[q];
569 sums[3] += bias_c_G[q];
570 sums[4] += bias_c_I[q + 1];
571 sums[5] += bias_c_F[q + 1];
572 sums[6] += bias_c_O[q + 1];
573 sums[7] += bias_c_G[q + 1];
574
575 for (; remain_size > 0; remain_size--)
576 {
577 float xi = *x;
578 sums[0] += *weight_xc_I_0 * xi;
579 sums[1] += *weight_xc_F_0 * xi;
580 sums[2] += *weight_xc_O_0 * xi;
581 sums[3] += *weight_xc_G_0 * xi;
582 sums[4] += *weight_xc_I_1 * xi;
583 sums[5] += *weight_xc_F_1 * xi;
584 sums[6] += *weight_xc_O_1 * xi;
585 sums[7] += *weight_xc_G_1 * xi;
586 x++;
587 weight_xc_I_0++;
588 weight_xc_F_0++;
589 weight_xc_O_0++;
590 weight_xc_G_0++;
591 weight_xc_I_1++;
592 weight_xc_F_1++;
593 weight_xc_O_1++;
594 weight_xc_G_1++;
595 }
596
597 for (; remain_num_output > 0; remain_num_output--)
598 {
599 float h_cont = *hidden_ptr_r;
600 sums[0] += *weight_hc_I_0 * h_cont;
601 sums[1] += *weight_hc_F_0 * h_cont;
602 sums[2] += *weight_hc_O_0 * h_cont;
603 sums[3] += *weight_hc_G_0 * h_cont;
604 sums[4] += *weight_hc_I_1 * h_cont;
605 sums[5] += *weight_hc_F_1 * h_cont;
606 sums[6] += *weight_hc_O_1 * h_cont;
607 sums[7] += *weight_hc_G_1 * h_cont;
608 hidden_ptr_r++;
609 weight_hc_I_0++;
610 weight_hc_F_0++;
611 weight_hc_O_0++;
612 weight_hc_G_0++;
613 weight_hc_I_1++;
614 weight_hc_F_1++;
615 weight_hc_O_1++;
616 weight_hc_G_1++;
617 }
618 gates_data_I[q] = sums[0];
619 gates_data_F[q] = sums[1];
620 gates_data_O[q] = sums[2];
621 gates_data_G[q] = sums[3];
622 gates_data_I[q + 1] = sums[4];
623 gates_data_F[q + 1] = sums[5];
624 gates_data_O[q + 1] = sums[6];
625 gates_data_G[q + 1] = sums[7];
626 }
627
628 for (int q = remain_output; q < num_output; q++)
629 {
630 const float* x = bottom_blob.row(ti);
631 const float* hidden_ptr_r = hidden_state;
632 const float* bias_c_I = bias_c.row(0);
633 const float* bias_c_F = bias_c.row(1);
634 const float* bias_c_O = bias_c.row(2);
635 const float* bias_c_G = bias_c.row(3);
636
637 float* gates_data_I = gates.row(0);
638 float* gates_data_F = gates.row(1);
639 float* gates_data_O = gates.row(2);
640 float* gates_data_G = gates.row(3);
641 // gate I F O G
642 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
643 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
644 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
645 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
646
647 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
648 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
649 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
650 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
651
652 // float I = bias_c_I[q];
653 // float F = bias_c_F[q];
654 // float O = bias_c_O[q];
655 // float G = bias_c_G[q];
656 __m256 _sumI = _mm256_setzero_ps();
657 __m256 _sumF = _mm256_setzero_ps();
658 __m256 _sumO = _mm256_setzero_ps();
659 __m256 _sumG = _mm256_setzero_ps();
660 int nn_num_size = size >> 3;
661 int remain_size = size & 7;
662 for (; nn_num_size > 0; nn_num_size--)
663 {
664 __m256 xi = _mm256_loadu_ps(x);
665 _sumI = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_I), xi, _sumI);
666 _sumF = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_F), xi, _sumF);
667 _sumO = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_O), xi, _sumO);
668 _sumG = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_G), xi, _sumG);
669 x += 8;
670 weight_xc_I += 8;
671 weight_xc_F += 8;
672 weight_xc_O += 8;
673 weight_xc_G += 8;
674 }
675 int nn_num_output = num_output >> 3;
676 int remain_num_output = num_output & 7;
677 for (; nn_num_output > 0; nn_num_output--)
678 {
679 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
680
681 _sumI = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_I), h_cont, _sumI);
682 _sumF = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_F), h_cont, _sumF);
683 _sumO = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_O), h_cont, _sumO);
684 _sumG = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_G), h_cont, _sumG);
685 hidden_ptr_r += 8;
686 weight_hc_I += 8;
687 weight_hc_F += 8;
688 weight_hc_O += 8;
689 weight_hc_G += 8;
690 }
691 float sums[4];
692 _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG));
693 sums[0] += bias_c_I[q];
694 sums[1] += bias_c_F[q];
695 sums[2] += bias_c_O[q];
696 sums[3] += bias_c_G[q];
697
698 for (; remain_size > 0; remain_size--)
699 {
700 float xi = *x;
701 sums[0] += *weight_xc_I * xi;
702 sums[1] += *weight_xc_F * xi;
703 sums[2] += *weight_xc_O * xi;
704 sums[3] += *weight_xc_G * xi;
705 x++;
706 weight_xc_I++;
707 weight_xc_F++;
708 weight_xc_O++;
709 weight_xc_G++;
710 }
711
712 for (; remain_num_output > 0; remain_num_output--)
713 {
714 float h_cont = *hidden_ptr_r;
715 sums[0] += *weight_hc_I * h_cont;
716 sums[1] += *weight_hc_F * h_cont;
717 sums[2] += *weight_hc_O * h_cont;
718 sums[3] += *weight_hc_G * h_cont;
719 hidden_ptr_r++;
720 weight_hc_I++;
721 weight_hc_F++;
722 weight_hc_O++;
723 weight_hc_G++;
724 }
725 gates_data_I[q] = sums[0];
726 gates_data_F[q] = sums[1];
727 gates_data_O[q] = sums[2];
728 gates_data_G[q] = sums[3];
729 }
730
731 // lstm unit
732 // sigmoid(I)
733 // sigmoid(F)
734 // sigmoid(O)
735 // tanh(G)
736 // c_t := f_t .* c_{t-1} + i_t .* g_t
737 // h_t := o_t .* tanh[c_t]
738 float* output_data = top_blob.row(ti);
739 float* cell_ptr = cell_state;
740 float* hidden_ptr = hidden_state;
741 const float* gates_data_I = gates.row(0);
742 const float* gates_data_F = gates.row(1);
743 const float* gates_data_O = gates.row(2);
744 const float* gates_data_G = gates.row(3);
745 int nn_activation = num_output >> 3;
746 int remain_activations = num_output & 7;
747 for (; nn_activation > 0; nn_activation--)
748 {
749 __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I));
750 __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F));
751 __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O));
752 __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G));
753 __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G));
754 __m256 H = _mm256_mul_ps(O, tanh_avx(cell2));
755 _mm256_storeu_ps(cell_ptr, cell2);
756 _mm256_storeu_ps(hidden_ptr, H);
757 _mm256_storeu_ps(output_data, H);
758 cell_ptr += 8;
759 output_data += 8;
760 hidden_ptr += 8;
761 gates_data_I += 8;
762 gates_data_F += 8;
763 gates_data_O += 8;
764 gates_data_G += 8;
765 }
766 for (; remain_activations > 0; remain_activations--)
767 {
768 float I = *gates_data_I;
769 float F = *gates_data_F;
770 float O = *gates_data_O;
771 float G = *gates_data_G;
772
773 I = 1.f / (1.f + exp(-I));
774 F = 1.f / (1.f + exp(-F));
775 O = 1.f / (1.f + exp(-O));
776 G = tanh(G);
777 float cell2 = F * *cell_ptr + I * G;
778 float H = O * tanh(cell2);
779 *cell_ptr = cell2;
780 *hidden_ptr = H;
781 *output_data = H;
782 cell_ptr++;
783 output_data++;
784 hidden_ptr++;
785 gates_data_I++;
786 gates_data_F++;
787 gates_data_O++;
788 gates_data_G++;
789 }
790
791 // no cell output here
792 }
793
794 return 0;
795 }
796 #endif
797
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const798 int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
799 {
800 #if __AVX__
801 int T = bottom_blob.h;
802 int num_directions = direction == 2 ? 2 : 1;
803
804 // initial hidden state
805 Mat hidden(num_output, 4u, opt.workspace_allocator);
806 if (hidden.empty())
807 return -100;
808 hidden.fill(0.f);
809 // internal cell state
810 Mat cell(num_output, 4u, opt.workspace_allocator);
811 if (cell.empty())
812 return -100;
813 cell.fill(0.f);
814
815 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
816 if (top_blob.empty())
817 return -100;
818
819 // Uni directional
820 if (direction == 0 || direction == 1)
821 {
822 if (opt.use_weight_fp16_storage)
823 {
824 // Uni directional
825 int ret = lstm_fp16(bottom_blob, top_blob, direction, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden, cell, opt);
826 if (ret != 0)
827 return ret;
828 }
829 else
830 {
831 // Uni directional
832 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
833 if (ret != 0)
834 return ret;
835 }
836 }
837
838 if (direction == 2)
839 {
840 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
841 if (top_blob_forward.empty())
842 return -100;
843
844 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
845 if (top_blob_reverse.empty())
846 return -100;
847
848 if (opt.use_weight_fp16_storage)
849 {
850 // Uni directional
851 int ret0 = lstm_fp16(bottom_blob, top_blob_forward, 0, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden, cell, opt);
852 if (ret0 != 0)
853 return ret0;
854 }
855 else
856 {
857 // Uni directional
858 int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
859 if (ret0 != 0)
860 return ret0;
861 }
862
863 hidden.fill(0.0f);
864 cell.fill(0.0f);
865 if (opt.use_weight_fp16_storage)
866 {
867 // Uni directional
868 int ret1 = lstm_fp16(bottom_blob, top_blob_reverse, 1, weight_xc_data_fp16.channel(1), bias_c_data.channel(1), weight_hc_data_fp16.channel(1), hidden, cell, opt);
869 if (ret1 != 0)
870 return ret1;
871 }
872 else
873 {
874 // Uni directional
875 int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
876 if (ret1 != 0)
877 return ret1;
878 }
879
880 // concat w
881 for (int i = 0; i < T; i++)
882 {
883 const float* pf = top_blob_forward.row(i);
884 const float* pr = top_blob_reverse.row(i);
885 float* ptr = top_blob.row(i);
886
887 memcpy(ptr, pf, num_output * sizeof(float));
888 memcpy(ptr + num_output, pr, num_output * sizeof(float));
889 }
890 }
891
892 return 0;
893 #else
894 return LSTM::forward(bottom_blob, top_blob, opt);
895 #endif
896 }
897
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const898 int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
899 {
900 #if __AVX__
901 if (bottom_blobs.size() != 3 || top_blobs.size() != 3)
902 {
903 return forward(bottom_blobs[0], top_blobs[0], opt);
904 }
905 const Mat& bottom_blob = bottom_blobs[0];
906
907 int T = bottom_blob.h;
908 Mat& top_blob = top_blobs[0];
909 Mat& hidden_state = top_blobs[1];
910 Mat& cell_state = top_blobs[2];
911
912 //Copy previous states
913 hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
914 cell_state = bottom_blobs[2].clone(opt.blob_allocator);
915
916 top_blob.create(num_output, T, 4u, opt.blob_allocator);
917 if (top_blob.empty())
918 return -100;
919
920 if (opt.use_weight_fp16_storage)
921 {
922 // Uni directional
923 int ret = lstm_fp16(bottom_blob, top_blob, direction, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden_state, cell_state, opt);
924 if (ret != 0)
925 return ret;
926 }
927 else
928 {
929 // Uni directional
930 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, cell_state, opt);
931 if (ret != 0)
932 return ret;
933 }
934 return 0;
935 #else
936 return LSTM::forward(bottom_blobs, top_blobs, opt);
937 #endif
938 }
939
940 } // namespace ncnn
941