1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "lstm_x86.h"
16 
17 #include "x86_activation.h"
18 #include "x86_usability.h"
19 
20 #include <math.h>
21 #include "layer_type.h"
22 
23 namespace ncnn {
24 
LSTM_x86()25 LSTM_x86::LSTM_x86()
26 {
27 #ifdef __AVX__
28     support_weight_fp16_storage = true;
29 #endif
30     one_blob_only = false;
31     support_inplace = false;
32 }
33 
create_pipeline(const Option & opt)34 int LSTM_x86::create_pipeline(const Option& opt)
35 {
36 #if __AVX__
37     if (opt.use_weight_fp16_storage)
38     {
39         ncnn::cast_float32_to_float16(weight_xc_data, weight_xc_data_fp16, opt);
40         ncnn::cast_float32_to_float16(weight_hc_data, weight_hc_data_fp16, opt);
41     }
42 #else
43     (void)(opt);
44 #endif // __AVX__
45 
46     return 0;
47 }
48 #ifdef __AVX__
49 
lstm_fp16(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)50 static int lstm_fp16(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
51 {
52     int size = bottom_blob.w;
53     int T = bottom_blob.h;
54 
55     int num_output = top_blob.w;
56     // fprintf(stderr, "bottom_blob = %d x %d x %d num_output = %d \n", bottom_blob.w,bottom_blob.h,bottom_blob.c,num_output);
57     // 4 x num_output
58     Mat gates(num_output, 4, 4u, opt.workspace_allocator);
59     if (gates.empty())
60         return -100;
61     // unroll
62     for (int t = 0; t < T; t++)
63     {
64         // clip hidden by continuation indicator
65         // h_cont_{t-1} = cont_t * h_{t-1}
66         // h_cont_{t-1} = h_{t-1} if cont_t == 1
67         //                0       otherwise
68         // calculate hidden
69         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
70         int ti = reverse ? T - 1 - t : t;
71         int remain_output = (num_output >> 1) << 1;
72         for (int q = 0; q + 1 < num_output; q += 2)
73         {
74             const float* x = bottom_blob.row(ti);
75             const float* hidden_ptr_r = hidden_state;
76             const float* bias_c_I = bias_c.row(0);
77             const float* bias_c_F = bias_c.row(1);
78             const float* bias_c_O = bias_c.row(2);
79             const float* bias_c_G = bias_c.row(3);
80 
81             float* gates_data_I = gates.row(0);
82             float* gates_data_F = gates.row(1);
83             float* gates_data_O = gates.row(2);
84             float* gates_data_G = gates.row(3);
85             // gate I F O G
86             const unsigned short* weight_xc_I_0 = (const unsigned short*)weight_xc.row(num_output * 0 + q);
87             const unsigned short* weight_xc_F_0 = (const unsigned short*)weight_xc.row(num_output * 1 + q);
88             const unsigned short* weight_xc_O_0 = (const unsigned short*)weight_xc.row(num_output * 2 + q);
89             const unsigned short* weight_xc_G_0 = (const unsigned short*)weight_xc.row(num_output * 3 + q);
90             const unsigned short* weight_xc_I_1 = (const unsigned short*)weight_xc.row(num_output * 0 + (q + 1));
91             const unsigned short* weight_xc_F_1 = (const unsigned short*)weight_xc.row(num_output * 1 + (q + 1));
92             const unsigned short* weight_xc_O_1 = (const unsigned short*)weight_xc.row(num_output * 2 + (q + 1));
93             const unsigned short* weight_xc_G_1 = (const unsigned short*)weight_xc.row(num_output * 3 + (q + 1));
94 
95             const unsigned short* weight_hc_I_0 = (const unsigned short*)weight_hc.row(num_output * 0 + q);
96             const unsigned short* weight_hc_F_0 = (const unsigned short*)weight_hc.row(num_output * 1 + q);
97             const unsigned short* weight_hc_O_0 = (const unsigned short*)weight_hc.row(num_output * 2 + q);
98             const unsigned short* weight_hc_G_0 = (const unsigned short*)weight_hc.row(num_output * 3 + q);
99             const unsigned short* weight_hc_I_1 = (const unsigned short*)weight_hc.row(num_output * 0 + (q + 1));
100             const unsigned short* weight_hc_F_1 = (const unsigned short*)weight_hc.row(num_output * 1 + (q + 1));
101             const unsigned short* weight_hc_O_1 = (const unsigned short*)weight_hc.row(num_output * 2 + (q + 1));
102             const unsigned short* weight_hc_G_1 = (const unsigned short*)weight_hc.row(num_output * 3 + (q + 1));
103 
104             // float I = bias_c_I[q];
105             // float F = bias_c_F[q];
106             // float O = bias_c_O[q];
107             // float G = bias_c_G[q];
108             __m256 _sumI_0 = _mm256_setzero_ps();
109             __m256 _sumF_0 = _mm256_setzero_ps();
110             __m256 _sumO_0 = _mm256_setzero_ps();
111             __m256 _sumG_0 = _mm256_setzero_ps();
112             __m256 _sumI_1 = _mm256_setzero_ps();
113             __m256 _sumF_1 = _mm256_setzero_ps();
114             __m256 _sumO_1 = _mm256_setzero_ps();
115             __m256 _sumG_1 = _mm256_setzero_ps();
116             int nn_num_size = size >> 3;
117             int remain_size = size & 7;
118             for (; nn_num_size > 0; nn_num_size--)
119             {
120                 __m256 xi = _mm256_loadu_ps(x);
121                 _sumI_0 = _mm256_fmadd_ps(loadfp16(weight_xc_I_0), xi, _sumI_0);
122                 _sumF_0 = _mm256_fmadd_ps(loadfp16(weight_xc_F_0), xi, _sumF_0);
123                 _sumO_0 = _mm256_fmadd_ps(loadfp16(weight_xc_O_0), xi, _sumO_0);
124                 _sumG_0 = _mm256_fmadd_ps(loadfp16(weight_xc_G_0), xi, _sumG_0);
125                 _sumI_1 = _mm256_fmadd_ps(loadfp16(weight_xc_I_1), xi, _sumI_1);
126                 _sumF_1 = _mm256_fmadd_ps(loadfp16(weight_xc_F_1), xi, _sumF_1);
127                 _sumO_1 = _mm256_fmadd_ps(loadfp16(weight_xc_O_1), xi, _sumO_1);
128                 _sumG_1 = _mm256_fmadd_ps(loadfp16(weight_xc_G_1), xi, _sumG_1);
129                 x += 8;
130                 weight_xc_I_0 += 8;
131                 weight_xc_F_0 += 8;
132                 weight_xc_O_0 += 8;
133                 weight_xc_G_0 += 8;
134                 weight_xc_I_1 += 8;
135                 weight_xc_F_1 += 8;
136                 weight_xc_O_1 += 8;
137                 weight_xc_G_1 += 8;
138             }
139             int nn_num_output = num_output >> 3;
140             int remain_num_output = num_output & 7;
141             for (; nn_num_output > 0; nn_num_output--)
142             {
143                 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
144 
145                 _sumI_0 = _mm256_fmadd_ps(loadfp16(weight_hc_I_0), h_cont, _sumI_0);
146                 _sumF_0 = _mm256_fmadd_ps(loadfp16(weight_hc_F_0), h_cont, _sumF_0);
147                 _sumO_0 = _mm256_fmadd_ps(loadfp16(weight_hc_O_0), h_cont, _sumO_0);
148                 _sumG_0 = _mm256_fmadd_ps(loadfp16(weight_hc_G_0), h_cont, _sumG_0);
149                 _sumI_1 = _mm256_fmadd_ps(loadfp16(weight_hc_I_1), h_cont, _sumI_1);
150                 _sumF_1 = _mm256_fmadd_ps(loadfp16(weight_hc_F_1), h_cont, _sumF_1);
151                 _sumO_1 = _mm256_fmadd_ps(loadfp16(weight_hc_O_1), h_cont, _sumO_1);
152                 _sumG_1 = _mm256_fmadd_ps(loadfp16(weight_hc_G_1), h_cont, _sumG_1);
153                 hidden_ptr_r += 8;
154                 weight_hc_I_0 += 8;
155                 weight_hc_F_0 += 8;
156                 weight_hc_O_0 += 8;
157                 weight_hc_G_0 += 8;
158                 weight_hc_I_1 += 8;
159                 weight_hc_F_1 += 8;
160                 weight_hc_O_1 += 8;
161                 weight_hc_G_1 += 8;
162             }
163             if (remain_size != 0)
164             {
165                 unsigned short fp16_weights[8][8] = {{0}};
166                 float _xi_f[8] = {0};
167                 // No fast way to convert to fp32 one element at the time
168                 // so batch an 8 lane vector.
169                 for (int i = 0; i < remain_size; i++)
170                 {
171                     _xi_f[i] = *x;
172                     fp16_weights[0][i] = *weight_xc_I_0;
173                     fp16_weights[1][i] = *weight_xc_F_0;
174                     fp16_weights[2][i] = *weight_xc_O_0;
175                     fp16_weights[3][i] = *weight_xc_G_0;
176                     fp16_weights[4][i] = *weight_xc_I_1;
177                     fp16_weights[5][i] = *weight_xc_F_1;
178                     fp16_weights[6][i] = *weight_xc_O_1;
179                     fp16_weights[7][i] = *weight_xc_G_1;
180                     x++;
181                     weight_xc_I_0++;
182                     weight_xc_F_0++;
183                     weight_xc_O_0++;
184                     weight_xc_G_0++;
185                     weight_xc_I_1++;
186                     weight_xc_F_1++;
187                     weight_xc_O_1++;
188                     weight_xc_G_1++;
189                 }
190                 __m256 xi = _mm256_loadu_ps(_xi_f);
191                 _sumI_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), xi, _sumI_0);
192                 _sumF_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), xi, _sumF_0);
193                 _sumO_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), xi, _sumO_0);
194                 _sumG_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), xi, _sumG_0);
195                 _sumI_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[4]), xi, _sumI_1);
196                 _sumF_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[5]), xi, _sumF_1);
197                 _sumO_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[6]), xi, _sumO_1);
198                 _sumG_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[7]), xi, _sumG_1);
199             }
200             if (remain_num_output != 0)
201             {
202                 unsigned short fp16_weights[8][8] = {{0}};
203                 float _hcont_f[8] = {0};
204                 // No fast way to convert to fp32 one element at the time
205                 // so batch an 8 lane vector.
206                 for (int i = 0; i < remain_num_output; i++)
207                 {
208                     _hcont_f[i] = *hidden_ptr_r;
209                     fp16_weights[0][i] = *weight_hc_I_0;
210                     fp16_weights[1][i] = *weight_hc_F_0;
211                     fp16_weights[2][i] = *weight_hc_O_0;
212                     fp16_weights[3][i] = *weight_hc_G_0;
213                     fp16_weights[4][i] = *weight_hc_I_1;
214                     fp16_weights[5][i] = *weight_hc_F_1;
215                     fp16_weights[6][i] = *weight_hc_O_1;
216                     fp16_weights[7][i] = *weight_hc_G_1;
217                     hidden_ptr_r++;
218                     weight_hc_I_0++;
219                     weight_hc_F_0++;
220                     weight_hc_O_0++;
221                     weight_hc_G_0++;
222                     weight_hc_I_1++;
223                     weight_hc_F_1++;
224                     weight_hc_O_1++;
225                     weight_hc_G_1++;
226                 }
227                 __m256 h_cont = _mm256_loadu_ps(_hcont_f);
228                 _sumI_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), h_cont, _sumI_0);
229                 _sumF_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), h_cont, _sumF_0);
230                 _sumO_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), h_cont, _sumO_0);
231                 _sumG_0 = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), h_cont, _sumG_0);
232                 _sumI_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[4]), h_cont, _sumI_1);
233                 _sumF_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[5]), h_cont, _sumF_1);
234                 _sumO_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[6]), h_cont, _sumO_1);
235                 _sumG_1 = _mm256_fmadd_ps(loadfp16(fp16_weights[7]), h_cont, _sumG_1);
236             }
237             float sums[8];
238             _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1));
239             sums[0] += bias_c_I[q];
240             sums[1] += bias_c_F[q];
241             sums[2] += bias_c_O[q];
242             sums[3] += bias_c_G[q];
243             sums[4] += bias_c_I[q + 1];
244             sums[5] += bias_c_F[q + 1];
245             sums[6] += bias_c_O[q + 1];
246             sums[7] += bias_c_G[q + 1];
247             gates_data_I[q] = sums[0];
248             gates_data_F[q] = sums[1];
249             gates_data_O[q] = sums[2];
250             gates_data_G[q] = sums[3];
251             gates_data_I[q + 1] = sums[4];
252             gates_data_F[q + 1] = sums[5];
253             gates_data_O[q + 1] = sums[6];
254             gates_data_G[q + 1] = sums[7];
255         }
256 
257         for (int q = remain_output; q < num_output; q++)
258         {
259             const float* x = bottom_blob.row(ti);
260             const float* hidden_ptr_r = hidden_state;
261             const float* bias_c_I = bias_c.row(0);
262             const float* bias_c_F = bias_c.row(1);
263             const float* bias_c_O = bias_c.row(2);
264             const float* bias_c_G = bias_c.row(3);
265 
266             float* gates_data_I = gates.row(0);
267             float* gates_data_F = gates.row(1);
268             float* gates_data_O = gates.row(2);
269             float* gates_data_G = gates.row(3);
270             // gate I F O G
271             const unsigned short* weight_xc_I = (const unsigned short*)weight_xc.row(num_output * 0 + q);
272             const unsigned short* weight_xc_F = (const unsigned short*)weight_xc.row(num_output * 1 + q);
273             const unsigned short* weight_xc_O = (const unsigned short*)weight_xc.row(num_output * 2 + q);
274             const unsigned short* weight_xc_G = (const unsigned short*)weight_xc.row(num_output * 3 + q);
275 
276             const unsigned short* weight_hc_I = (const unsigned short*)weight_hc.row(num_output * 0 + q);
277             const unsigned short* weight_hc_F = (const unsigned short*)weight_hc.row(num_output * 1 + q);
278             const unsigned short* weight_hc_O = (const unsigned short*)weight_hc.row(num_output * 2 + q);
279             const unsigned short* weight_hc_G = (const unsigned short*)weight_hc.row(num_output * 3 + q);
280 
281             // float I = bias_c_I[q];
282             // float F = bias_c_F[q];
283             // float O = bias_c_O[q];
284             // float G = bias_c_G[q];
285             __m256 _sumI = _mm256_setzero_ps();
286             __m256 _sumF = _mm256_setzero_ps();
287             __m256 _sumO = _mm256_setzero_ps();
288             __m256 _sumG = _mm256_setzero_ps();
289             int nn_num_size = size >> 3;
290             int remain_size = size & 7;
291             for (; nn_num_size > 0; nn_num_size--)
292             {
293                 __m256 xi = _mm256_loadu_ps(x);
294                 _sumI = _mm256_fmadd_ps(loadfp16(weight_xc_I), xi, _sumI);
295                 _sumF = _mm256_fmadd_ps(loadfp16(weight_xc_F), xi, _sumF);
296                 _sumO = _mm256_fmadd_ps(loadfp16(weight_xc_O), xi, _sumO);
297                 _sumG = _mm256_fmadd_ps(loadfp16(weight_xc_G), xi, _sumG);
298                 x += 8;
299                 weight_xc_I += 8;
300                 weight_xc_F += 8;
301                 weight_xc_O += 8;
302                 weight_xc_G += 8;
303             }
304             int nn_num_output = num_output >> 3;
305             int remain_num_output = num_output & 7;
306             for (; nn_num_output > 0; nn_num_output--)
307             {
308                 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
309 
310                 _sumI = _mm256_fmadd_ps(loadfp16(weight_hc_I), h_cont, _sumI);
311                 _sumF = _mm256_fmadd_ps(loadfp16(weight_hc_F), h_cont, _sumF);
312                 _sumO = _mm256_fmadd_ps(loadfp16(weight_hc_O), h_cont, _sumO);
313                 _sumG = _mm256_fmadd_ps(loadfp16(weight_hc_G), h_cont, _sumG);
314                 hidden_ptr_r += 8;
315                 weight_hc_I += 8;
316                 weight_hc_F += 8;
317                 weight_hc_O += 8;
318                 weight_hc_G += 8;
319             }
320             if (remain_size != 0)
321             {
322                 unsigned short fp16_weights[4][8] = {{0}};
323                 float _xi_f[8] = {0};
324                 // No fast way to convert to fp32 one element at the time
325                 // so batch an 8 lane vector.
326                 for (int i = 0; i < remain_size; i++)
327                 {
328                     _xi_f[i] = *x;
329                     fp16_weights[0][i] = *weight_xc_I;
330                     fp16_weights[1][i] = *weight_xc_F;
331                     fp16_weights[2][i] = *weight_xc_O;
332                     fp16_weights[3][i] = *weight_xc_G;
333                     x++;
334                     weight_xc_I++;
335                     weight_xc_F++;
336                     weight_xc_O++;
337                     weight_xc_G++;
338                 }
339                 __m256 xi = _mm256_loadu_ps(_xi_f);
340                 _sumI = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), xi, _sumI);
341                 _sumF = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), xi, _sumF);
342                 _sumO = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), xi, _sumO);
343                 _sumG = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), xi, _sumG);
344             }
345             if (remain_num_output != 0)
346             {
347                 unsigned short fp16_weights[4][8] = {{0}};
348                 float _hcont_f[8] = {0};
349                 // No fast way to convert to fp32 one element at the time
350                 // so batch an 8 lane vector.
351                 for (int i = 0; i < remain_num_output; i++)
352                 {
353                     _hcont_f[i] = *hidden_ptr_r;
354                     fp16_weights[0][i] = *weight_hc_I;
355                     fp16_weights[1][i] = *weight_hc_F;
356                     fp16_weights[2][i] = *weight_hc_O;
357                     fp16_weights[3][i] = *weight_hc_G;
358                     hidden_ptr_r++;
359                     weight_hc_I++;
360                     weight_hc_F++;
361                     weight_hc_O++;
362                     weight_hc_G++;
363                 }
364                 __m256 h_cont = _mm256_loadu_ps(_hcont_f);
365                 _sumI = _mm256_fmadd_ps(loadfp16(fp16_weights[0]), h_cont, _sumI);
366                 _sumF = _mm256_fmadd_ps(loadfp16(fp16_weights[1]), h_cont, _sumF);
367                 _sumO = _mm256_fmadd_ps(loadfp16(fp16_weights[2]), h_cont, _sumO);
368                 _sumG = _mm256_fmadd_ps(loadfp16(fp16_weights[3]), h_cont, _sumG);
369             }
370 
371             float sums[4];
372             _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG));
373             sums[0] += bias_c_I[q];
374             sums[1] += bias_c_F[q];
375             sums[2] += bias_c_O[q];
376             sums[3] += bias_c_G[q];
377             gates_data_I[q] = sums[0];
378             gates_data_F[q] = sums[1];
379             gates_data_O[q] = sums[2];
380             gates_data_G[q] = sums[3];
381         }
382 
383         // lstm unit
384         // sigmoid(I)
385         // sigmoid(F)
386         // sigmoid(O)
387         // tanh(G)
388         // c_t := f_t .* c_{t-1} + i_t .* g_t
389         // h_t := o_t .* tanh[c_t]
390         float* output_data = top_blob.row(ti);
391         float* cell_ptr = cell_state;
392         float* hidden_ptr = hidden_state;
393         const float* gates_data_I = gates.row(0);
394         const float* gates_data_F = gates.row(1);
395         const float* gates_data_O = gates.row(2);
396         const float* gates_data_G = gates.row(3);
397         int nn_activation = num_output >> 3;
398         int remain_activations = num_output & 7;
399         for (; nn_activation > 0; nn_activation--)
400         {
401             __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I));
402             __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F));
403             __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O));
404             __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G));
405             __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G));
406             __m256 H = _mm256_mul_ps(O, tanh_avx(cell2));
407             _mm256_storeu_ps(cell_ptr, cell2);
408             _mm256_storeu_ps(hidden_ptr, H);
409             _mm256_storeu_ps(output_data, H);
410             cell_ptr += 8;
411             output_data += 8;
412             hidden_ptr += 8;
413             gates_data_I += 8;
414             gates_data_F += 8;
415             gates_data_O += 8;
416             gates_data_G += 8;
417         }
418         for (; remain_activations > 0; remain_activations--)
419         {
420             float I = *gates_data_I;
421             float F = *gates_data_F;
422             float O = *gates_data_O;
423             float G = *gates_data_G;
424 
425             I = 1.f / (1.f + exp(-I));
426             F = 1.f / (1.f + exp(-F));
427             O = 1.f / (1.f + exp(-O));
428             G = tanh(G);
429             float cell2 = F * *cell_ptr + I * G;
430             float H = O * tanh(cell2);
431             *cell_ptr = cell2;
432             *hidden_ptr = H;
433             *output_data = H;
434             cell_ptr++;
435             output_data++;
436             hidden_ptr++;
437             gates_data_I++;
438             gates_data_F++;
439             gates_data_O++;
440             gates_data_G++;
441         }
442 
443         // no cell output here
444     }
445 
446     return 0;
447 }
448 
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)449 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
450 {
451     int size = bottom_blob.w;
452     int T = bottom_blob.h;
453 
454     int num_output = top_blob.w;
455 
456     // 4 x num_output
457     Mat gates(num_output, 4, 4u, opt.workspace_allocator);
458     if (gates.empty())
459         return -100;
460 
461     // unroll
462     for (int t = 0; t < T; t++)
463     {
464         // clip hidden by continuation indicator
465         // h_cont_{t-1} = cont_t * h_{t-1}
466         // h_cont_{t-1} = h_{t-1} if cont_t == 1
467         //                0       otherwise
468         // calculate hidden
469         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
470 
471         int ti = reverse ? T - 1 - t : t;
472         int remain_output = (num_output >> 1) << 1;
473         for (int q = 0; q + 1 < num_output; q += 2)
474         {
475             const float* x = bottom_blob.row(ti);
476             const float* hidden_ptr_r = hidden_state;
477             const float* bias_c_I = bias_c.row(0);
478             const float* bias_c_F = bias_c.row(1);
479             const float* bias_c_O = bias_c.row(2);
480             const float* bias_c_G = bias_c.row(3);
481 
482             float* gates_data_I = gates.row(0);
483             float* gates_data_F = gates.row(1);
484             float* gates_data_O = gates.row(2);
485             float* gates_data_G = gates.row(3);
486             // gate I F O G
487             const float* weight_xc_I_0 = weight_xc.row(num_output * 0 + q);
488             const float* weight_xc_F_0 = weight_xc.row(num_output * 1 + q);
489             const float* weight_xc_O_0 = weight_xc.row(num_output * 2 + q);
490             const float* weight_xc_G_0 = weight_xc.row(num_output * 3 + q);
491             const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + (q + 1));
492             const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + (q + 1));
493             const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + (q + 1));
494             const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + (q + 1));
495 
496             const float* weight_hc_I_0 = weight_hc.row(num_output * 0 + q);
497             const float* weight_hc_F_0 = weight_hc.row(num_output * 1 + q);
498             const float* weight_hc_O_0 = weight_hc.row(num_output * 2 + q);
499             const float* weight_hc_G_0 = weight_hc.row(num_output * 3 + q);
500             const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + (q + 1));
501             const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + (q + 1));
502             const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + (q + 1));
503             const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + (q + 1));
504 
505             // float I = bias_c_I[q];
506             // float F = bias_c_F[q];
507             // float O = bias_c_O[q];
508             // float G = bias_c_G[q];
509             __m256 _sumI_0 = _mm256_setzero_ps();
510             __m256 _sumF_0 = _mm256_setzero_ps();
511             __m256 _sumO_0 = _mm256_setzero_ps();
512             __m256 _sumG_0 = _mm256_setzero_ps();
513             __m256 _sumI_1 = _mm256_setzero_ps();
514             __m256 _sumF_1 = _mm256_setzero_ps();
515             __m256 _sumO_1 = _mm256_setzero_ps();
516             __m256 _sumG_1 = _mm256_setzero_ps();
517             int nn_num_size = size >> 3;
518             int remain_size = size & 7;
519             for (; nn_num_size > 0; nn_num_size--)
520             {
521                 __m256 xi = _mm256_loadu_ps(x);
522                 _sumI_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_I_0), xi, _sumI_0);
523                 _sumF_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_F_0), xi, _sumF_0);
524                 _sumO_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_O_0), xi, _sumO_0);
525                 _sumG_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_G_0), xi, _sumG_0);
526                 _sumI_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_I_1), xi, _sumI_1);
527                 _sumF_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_F_1), xi, _sumF_1);
528                 _sumO_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_O_1), xi, _sumO_1);
529                 _sumG_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_G_1), xi, _sumG_1);
530                 x += 8;
531                 weight_xc_I_0 += 8;
532                 weight_xc_F_0 += 8;
533                 weight_xc_O_0 += 8;
534                 weight_xc_G_0 += 8;
535                 weight_xc_I_1 += 8;
536                 weight_xc_F_1 += 8;
537                 weight_xc_O_1 += 8;
538                 weight_xc_G_1 += 8;
539             }
540             int nn_num_output = num_output >> 3;
541             int remain_num_output = num_output & 7;
542             for (; nn_num_output > 0; nn_num_output--)
543             {
544                 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
545 
546                 _sumI_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_I_0), h_cont, _sumI_0);
547                 _sumF_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_F_0), h_cont, _sumF_0);
548                 _sumO_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_O_0), h_cont, _sumO_0);
549                 _sumG_0 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_G_0), h_cont, _sumG_0);
550                 _sumI_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_I_1), h_cont, _sumI_1);
551                 _sumF_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_F_1), h_cont, _sumF_1);
552                 _sumO_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_O_1), h_cont, _sumO_1);
553                 _sumG_1 = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_G_1), h_cont, _sumG_1);
554                 hidden_ptr_r += 8;
555                 weight_hc_I_0 += 8;
556                 weight_hc_F_0 += 8;
557                 weight_hc_O_0 += 8;
558                 weight_hc_G_0 += 8;
559                 weight_hc_I_1 += 8;
560                 weight_hc_F_1 += 8;
561                 weight_hc_O_1 += 8;
562                 weight_hc_G_1 += 8;
563             }
564             float sums[8];
565             _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1));
566             sums[0] += bias_c_I[q];
567             sums[1] += bias_c_F[q];
568             sums[2] += bias_c_O[q];
569             sums[3] += bias_c_G[q];
570             sums[4] += bias_c_I[q + 1];
571             sums[5] += bias_c_F[q + 1];
572             sums[6] += bias_c_O[q + 1];
573             sums[7] += bias_c_G[q + 1];
574 
575             for (; remain_size > 0; remain_size--)
576             {
577                 float xi = *x;
578                 sums[0] += *weight_xc_I_0 * xi;
579                 sums[1] += *weight_xc_F_0 * xi;
580                 sums[2] += *weight_xc_O_0 * xi;
581                 sums[3] += *weight_xc_G_0 * xi;
582                 sums[4] += *weight_xc_I_1 * xi;
583                 sums[5] += *weight_xc_F_1 * xi;
584                 sums[6] += *weight_xc_O_1 * xi;
585                 sums[7] += *weight_xc_G_1 * xi;
586                 x++;
587                 weight_xc_I_0++;
588                 weight_xc_F_0++;
589                 weight_xc_O_0++;
590                 weight_xc_G_0++;
591                 weight_xc_I_1++;
592                 weight_xc_F_1++;
593                 weight_xc_O_1++;
594                 weight_xc_G_1++;
595             }
596 
597             for (; remain_num_output > 0; remain_num_output--)
598             {
599                 float h_cont = *hidden_ptr_r;
600                 sums[0] += *weight_hc_I_0 * h_cont;
601                 sums[1] += *weight_hc_F_0 * h_cont;
602                 sums[2] += *weight_hc_O_0 * h_cont;
603                 sums[3] += *weight_hc_G_0 * h_cont;
604                 sums[4] += *weight_hc_I_1 * h_cont;
605                 sums[5] += *weight_hc_F_1 * h_cont;
606                 sums[6] += *weight_hc_O_1 * h_cont;
607                 sums[7] += *weight_hc_G_1 * h_cont;
608                 hidden_ptr_r++;
609                 weight_hc_I_0++;
610                 weight_hc_F_0++;
611                 weight_hc_O_0++;
612                 weight_hc_G_0++;
613                 weight_hc_I_1++;
614                 weight_hc_F_1++;
615                 weight_hc_O_1++;
616                 weight_hc_G_1++;
617             }
618             gates_data_I[q] = sums[0];
619             gates_data_F[q] = sums[1];
620             gates_data_O[q] = sums[2];
621             gates_data_G[q] = sums[3];
622             gates_data_I[q + 1] = sums[4];
623             gates_data_F[q + 1] = sums[5];
624             gates_data_O[q + 1] = sums[6];
625             gates_data_G[q + 1] = sums[7];
626         }
627 
628         for (int q = remain_output; q < num_output; q++)
629         {
630             const float* x = bottom_blob.row(ti);
631             const float* hidden_ptr_r = hidden_state;
632             const float* bias_c_I = bias_c.row(0);
633             const float* bias_c_F = bias_c.row(1);
634             const float* bias_c_O = bias_c.row(2);
635             const float* bias_c_G = bias_c.row(3);
636 
637             float* gates_data_I = gates.row(0);
638             float* gates_data_F = gates.row(1);
639             float* gates_data_O = gates.row(2);
640             float* gates_data_G = gates.row(3);
641             // gate I F O G
642             const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
643             const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
644             const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
645             const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
646 
647             const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
648             const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
649             const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
650             const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
651 
652             // float I = bias_c_I[q];
653             // float F = bias_c_F[q];
654             // float O = bias_c_O[q];
655             // float G = bias_c_G[q];
656             __m256 _sumI = _mm256_setzero_ps();
657             __m256 _sumF = _mm256_setzero_ps();
658             __m256 _sumO = _mm256_setzero_ps();
659             __m256 _sumG = _mm256_setzero_ps();
660             int nn_num_size = size >> 3;
661             int remain_size = size & 7;
662             for (; nn_num_size > 0; nn_num_size--)
663             {
664                 __m256 xi = _mm256_loadu_ps(x);
665                 _sumI = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_I), xi, _sumI);
666                 _sumF = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_F), xi, _sumF);
667                 _sumO = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_O), xi, _sumO);
668                 _sumG = _mm256_fmadd_ps(_mm256_loadu_ps(weight_xc_G), xi, _sumG);
669                 x += 8;
670                 weight_xc_I += 8;
671                 weight_xc_F += 8;
672                 weight_xc_O += 8;
673                 weight_xc_G += 8;
674             }
675             int nn_num_output = num_output >> 3;
676             int remain_num_output = num_output & 7;
677             for (; nn_num_output > 0; nn_num_output--)
678             {
679                 __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);
680 
681                 _sumI = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_I), h_cont, _sumI);
682                 _sumF = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_F), h_cont, _sumF);
683                 _sumO = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_O), h_cont, _sumO);
684                 _sumG = _mm256_fmadd_ps(_mm256_loadu_ps(weight_hc_G), h_cont, _sumG);
685                 hidden_ptr_r += 8;
686                 weight_hc_I += 8;
687                 weight_hc_F += 8;
688                 weight_hc_O += 8;
689                 weight_hc_G += 8;
690             }
691             float sums[4];
692             _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG));
693             sums[0] += bias_c_I[q];
694             sums[1] += bias_c_F[q];
695             sums[2] += bias_c_O[q];
696             sums[3] += bias_c_G[q];
697 
698             for (; remain_size > 0; remain_size--)
699             {
700                 float xi = *x;
701                 sums[0] += *weight_xc_I * xi;
702                 sums[1] += *weight_xc_F * xi;
703                 sums[2] += *weight_xc_O * xi;
704                 sums[3] += *weight_xc_G * xi;
705                 x++;
706                 weight_xc_I++;
707                 weight_xc_F++;
708                 weight_xc_O++;
709                 weight_xc_G++;
710             }
711 
712             for (; remain_num_output > 0; remain_num_output--)
713             {
714                 float h_cont = *hidden_ptr_r;
715                 sums[0] += *weight_hc_I * h_cont;
716                 sums[1] += *weight_hc_F * h_cont;
717                 sums[2] += *weight_hc_O * h_cont;
718                 sums[3] += *weight_hc_G * h_cont;
719                 hidden_ptr_r++;
720                 weight_hc_I++;
721                 weight_hc_F++;
722                 weight_hc_O++;
723                 weight_hc_G++;
724             }
725             gates_data_I[q] = sums[0];
726             gates_data_F[q] = sums[1];
727             gates_data_O[q] = sums[2];
728             gates_data_G[q] = sums[3];
729         }
730 
731         // lstm unit
732         // sigmoid(I)
733         // sigmoid(F)
734         // sigmoid(O)
735         // tanh(G)
736         // c_t := f_t .* c_{t-1} + i_t .* g_t
737         // h_t := o_t .* tanh[c_t]
738         float* output_data = top_blob.row(ti);
739         float* cell_ptr = cell_state;
740         float* hidden_ptr = hidden_state;
741         const float* gates_data_I = gates.row(0);
742         const float* gates_data_F = gates.row(1);
743         const float* gates_data_O = gates.row(2);
744         const float* gates_data_G = gates.row(3);
745         int nn_activation = num_output >> 3;
746         int remain_activations = num_output & 7;
747         for (; nn_activation > 0; nn_activation--)
748         {
749             __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I));
750             __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F));
751             __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O));
752             __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G));
753             __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G));
754             __m256 H = _mm256_mul_ps(O, tanh_avx(cell2));
755             _mm256_storeu_ps(cell_ptr, cell2);
756             _mm256_storeu_ps(hidden_ptr, H);
757             _mm256_storeu_ps(output_data, H);
758             cell_ptr += 8;
759             output_data += 8;
760             hidden_ptr += 8;
761             gates_data_I += 8;
762             gates_data_F += 8;
763             gates_data_O += 8;
764             gates_data_G += 8;
765         }
766         for (; remain_activations > 0; remain_activations--)
767         {
768             float I = *gates_data_I;
769             float F = *gates_data_F;
770             float O = *gates_data_O;
771             float G = *gates_data_G;
772 
773             I = 1.f / (1.f + exp(-I));
774             F = 1.f / (1.f + exp(-F));
775             O = 1.f / (1.f + exp(-O));
776             G = tanh(G);
777             float cell2 = F * *cell_ptr + I * G;
778             float H = O * tanh(cell2);
779             *cell_ptr = cell2;
780             *hidden_ptr = H;
781             *output_data = H;
782             cell_ptr++;
783             output_data++;
784             hidden_ptr++;
785             gates_data_I++;
786             gates_data_F++;
787             gates_data_O++;
788             gates_data_G++;
789         }
790 
791         // no cell output here
792     }
793 
794     return 0;
795 }
796 #endif
797 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const798 int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
799 {
800 #if __AVX__
801     int T = bottom_blob.h;
802     int num_directions = direction == 2 ? 2 : 1;
803 
804     // initial hidden state
805     Mat hidden(num_output, 4u, opt.workspace_allocator);
806     if (hidden.empty())
807         return -100;
808     hidden.fill(0.f);
809     // internal cell state
810     Mat cell(num_output, 4u, opt.workspace_allocator);
811     if (cell.empty())
812         return -100;
813     cell.fill(0.f);
814 
815     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
816     if (top_blob.empty())
817         return -100;
818 
819     // Uni directional
820     if (direction == 0 || direction == 1)
821     {
822         if (opt.use_weight_fp16_storage)
823         {
824             // Uni directional
825             int ret = lstm_fp16(bottom_blob, top_blob, direction, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden, cell, opt);
826             if (ret != 0)
827                 return ret;
828         }
829         else
830         {
831             // Uni directional
832             int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
833             if (ret != 0)
834                 return ret;
835         }
836     }
837 
838     if (direction == 2)
839     {
840         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
841         if (top_blob_forward.empty())
842             return -100;
843 
844         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
845         if (top_blob_reverse.empty())
846             return -100;
847 
848         if (opt.use_weight_fp16_storage)
849         {
850             // Uni directional
851             int ret0 = lstm_fp16(bottom_blob, top_blob_forward, 0, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden, cell, opt);
852             if (ret0 != 0)
853                 return ret0;
854         }
855         else
856         {
857             // Uni directional
858             int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
859             if (ret0 != 0)
860                 return ret0;
861         }
862 
863         hidden.fill(0.0f);
864         cell.fill(0.0f);
865         if (opt.use_weight_fp16_storage)
866         {
867             // Uni directional
868             int ret1 = lstm_fp16(bottom_blob, top_blob_reverse, 1, weight_xc_data_fp16.channel(1), bias_c_data.channel(1), weight_hc_data_fp16.channel(1), hidden, cell, opt);
869             if (ret1 != 0)
870                 return ret1;
871         }
872         else
873         {
874             // Uni directional
875             int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
876             if (ret1 != 0)
877                 return ret1;
878         }
879 
880         // concat w
881         for (int i = 0; i < T; i++)
882         {
883             const float* pf = top_blob_forward.row(i);
884             const float* pr = top_blob_reverse.row(i);
885             float* ptr = top_blob.row(i);
886 
887             memcpy(ptr, pf, num_output * sizeof(float));
888             memcpy(ptr + num_output, pr, num_output * sizeof(float));
889         }
890     }
891 
892     return 0;
893 #else
894     return LSTM::forward(bottom_blob, top_blob, opt);
895 #endif
896 }
897 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const898 int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
899 {
900 #if __AVX__
901     if (bottom_blobs.size() != 3 || top_blobs.size() != 3)
902     {
903         return forward(bottom_blobs[0], top_blobs[0], opt);
904     }
905     const Mat& bottom_blob = bottom_blobs[0];
906 
907     int T = bottom_blob.h;
908     Mat& top_blob = top_blobs[0];
909     Mat& hidden_state = top_blobs[1];
910     Mat& cell_state = top_blobs[2];
911 
912     //Copy previous states
913     hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
914     cell_state = bottom_blobs[2].clone(opt.blob_allocator);
915 
916     top_blob.create(num_output, T, 4u, opt.blob_allocator);
917     if (top_blob.empty())
918         return -100;
919 
920     if (opt.use_weight_fp16_storage)
921     {
922         // Uni directional
923         int ret = lstm_fp16(bottom_blob, top_blob, direction, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden_state, cell_state, opt);
924         if (ret != 0)
925             return ret;
926     }
927     else
928     {
929         // Uni directional
930         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, cell_state, opt);
931         if (ret != 0)
932             return ret;
933     }
934     return 0;
935 #else
936     return LSTM::forward(bottom_blobs, top_blobs, opt);
937 #endif
938 }
939 
940 } // namespace ncnn
941