1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "lstm_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #include "neon_mathfun.h"
20 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
21 #include "neon_mathfun_fp16s.h"
22 #endif
23 #include "neon_activation.h"
24 #endif // __ARM_NEON
25 
26 #include <math.h>
27 
28 namespace ncnn {
29 
LSTM_arm()30 LSTM_arm::LSTM_arm()
31 {
32 #if __ARM_NEON
33 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
34     support_fp16_storage = true;
35 #endif
36 #endif // __ARM_NEON
37 
38     support_bf16_storage = true;
39 }
40 
create_pipeline(const Option & opt)41 int LSTM_arm::create_pipeline(const Option& opt)
42 {
43 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
44     if (opt.use_fp16_storage)
45     {
46         return create_pipeline_fp16s(opt);
47     }
48 #endif
49 
50     if (opt.use_bf16_storage)
51     {
52         return create_pipeline_bf16s(opt);
53     }
54 
55     // pack IFOG
56     int num_directions = direction == 2 ? 2 : 1;
57     int size = weight_data_size / num_directions / num_output / 4;
58 
59     weight_xc_data_packed.create(size, num_output, num_directions, 16u, 4);
60     bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4);
61     weight_hc_data_packed.create(num_output, num_output, num_directions, 16u, 4);
62 
63     #pragma omp parallel for num_threads(opt.num_threads)
64     for (int dr = 0; dr < num_directions; dr++)
65     {
66         const Mat weight_xc = weight_xc_data.channel(dr);
67         const Mat bias_c = bias_c_data.channel(dr);
68         const Mat weight_hc = weight_hc_data.channel(dr);
69 
70         Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
71         Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
72         Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
73 
74         const float* bias_c_I = bias_c.row(0);
75         const float* bias_c_F = bias_c.row(1);
76         const float* bias_c_O = bias_c.row(2);
77         const float* bias_c_G = bias_c.row(3);
78 
79         float* bias_c_IFOG = bias_c_data_packed_dr.row(0);
80 
81         for (int q = 0; q < num_output; q++)
82         {
83             bias_c_IFOG[0] = bias_c_I[q];
84             bias_c_IFOG[1] = bias_c_F[q];
85             bias_c_IFOG[2] = bias_c_O[q];
86             bias_c_IFOG[3] = bias_c_G[q];
87 
88             bias_c_IFOG += 4;
89 
90             const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
91             const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
92             const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
93             const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
94 
95             const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
96             const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
97             const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
98             const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
99 
100             float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q);
101             float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q);
102 
103             for (int i = 0; i < size; i++)
104             {
105                 weight_xc_IFOG[0] = weight_xc_I[i];
106                 weight_xc_IFOG[1] = weight_xc_F[i];
107                 weight_xc_IFOG[2] = weight_xc_O[i];
108                 weight_xc_IFOG[3] = weight_xc_G[i];
109 
110                 weight_xc_IFOG += 4;
111             }
112 
113             for (int i = 0; i < num_output; i++)
114             {
115                 weight_hc_IFOG[0] = weight_hc_I[i];
116                 weight_hc_IFOG[1] = weight_hc_F[i];
117                 weight_hc_IFOG[2] = weight_hc_O[i];
118                 weight_hc_IFOG[3] = weight_hc_G[i];
119 
120                 weight_hc_IFOG += 4;
121             }
122         }
123     }
124 
125     return 0;
126 }
127 
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)128 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
129 {
130     int size = bottom_blob.w;
131     int T = bottom_blob.h;
132 
133     int num_output = top_blob.w;
134 
135     // 4 x num_output
136     Mat gates(4, num_output, 4u, opt.workspace_allocator);
137     if (gates.empty())
138         return -100;
139 
140     // unroll
141     for (int t = 0; t < T; t++)
142     {
143         // clip hidden by continuation indicator
144         // h_cont_{t-1} = cont_t * h_{t-1}
145         // h_cont_{t-1} = h_{t-1} if cont_t == 1
146         //                0       otherwise
147         // calculate hidden
148         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
149 
150         int ti = reverse ? T - 1 - t : t;
151 
152         const float* x = bottom_blob.row(ti);
153         for (int q = 0; q < num_output; q++)
154         {
155             const float* bias_c_IFOG = (const float*)bias_c + q * 4;
156 
157             // gate I F O G
158             const float* weight_xc_IFOG = weight_xc.row(q);
159 
160             const float* weight_hc_IFOG = weight_hc.row(q);
161 
162 #if __ARM_NEON
163             float32x4_t _IFOG = vld1q_f32(bias_c_IFOG);
164             float32x4_t _sum1 = vdupq_n_f32(0.f);
165             float32x4_t _sum2 = vdupq_n_f32(0.f);
166             float32x4_t _sum3 = vdupq_n_f32(0.f);
167 #else
168             float I = bias_c_IFOG[0];
169             float F = bias_c_IFOG[1];
170             float O = bias_c_IFOG[2];
171             float G = bias_c_IFOG[3];
172 #endif // __ARM_NEON
173 
174             int i = 0;
175 #if __ARM_NEON
176             for (; i + 3 < size; i += 4)
177             {
178                 float32x4_t _xi = vld1q_f32(x + i);
179 
180                 float32x4_t _weight_xc_IFOG_0 = vld1q_f32(weight_xc_IFOG);
181                 float32x4_t _weight_xc_IFOG_1 = vld1q_f32(weight_xc_IFOG + 4);
182                 float32x4_t _weight_xc_IFOG_2 = vld1q_f32(weight_xc_IFOG + 8);
183                 float32x4_t _weight_xc_IFOG_3 = vld1q_f32(weight_xc_IFOG + 12);
184 
185 #if __aarch64__
186                 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
187                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1);
188                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2);
189                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3);
190 #else
191                 _IFOG = vmlaq_lane_f32(_IFOG, _weight_xc_IFOG_0, vget_low_f32(_xi), 0);
192                 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_IFOG_1, vget_low_f32(_xi), 1);
193                 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_IFOG_2, vget_high_f32(_xi), 0);
194                 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1);
195 #endif
196 
197                 weight_xc_IFOG += 16;
198             }
199 #endif // __ARM_NEON
200             for (; i < size; i++)
201             {
202                 float xi = x[i];
203 
204 #if __ARM_NEON
205                 float32x4_t _xi = vdupq_n_f32(xi);
206                 float32x4_t _weight_xc_IFOG = vld1q_f32(weight_xc_IFOG);
207                 _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
208 #else
209                 I += weight_xc_IFOG[0] * xi;
210                 F += weight_xc_IFOG[1] * xi;
211                 O += weight_xc_IFOG[2] * xi;
212                 G += weight_xc_IFOG[3] * xi;
213 #endif // __ARM_NEON
214 
215                 weight_xc_IFOG += 4;
216             }
217 
218             i = 0;
219 #if __ARM_NEON
220             for (; i + 3 < num_output; i += 4)
221             {
222                 float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i);
223 
224                 float32x4_t _weight_hc_IFOG_0 = vld1q_f32(weight_hc_IFOG);
225                 float32x4_t _weight_hc_IFOG_1 = vld1q_f32(weight_hc_IFOG + 4);
226                 float32x4_t _weight_hc_IFOG_2 = vld1q_f32(weight_hc_IFOG + 8);
227                 float32x4_t _weight_hc_IFOG_3 = vld1q_f32(weight_hc_IFOG + 12);
228 
229 #if __aarch64__
230                 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
231                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1);
232                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2);
233                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3);
234 #else
235                 _IFOG = vmlaq_lane_f32(_IFOG, _weight_hc_IFOG_0, vget_low_f32(_h_cont), 0);
236                 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_IFOG_1, vget_low_f32(_h_cont), 1);
237                 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_IFOG_2, vget_high_f32(_h_cont), 0);
238                 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1);
239 #endif
240 
241                 weight_hc_IFOG += 16;
242             }
243 #endif // __ARM_NEON
244             for (; i < num_output; i++)
245             {
246                 float h_cont = hidden_state[i];
247 
248 #if __ARM_NEON
249                 float32x4_t _h_cont = vdupq_n_f32(h_cont);
250                 float32x4_t _weight_hc_IFOG = vld1q_f32(weight_hc_IFOG);
251                 _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
252 #else
253                 I += weight_hc_IFOG[0] * h_cont;
254                 F += weight_hc_IFOG[1] * h_cont;
255                 O += weight_hc_IFOG[2] * h_cont;
256                 G += weight_hc_IFOG[3] * h_cont;
257 #endif // __ARM_NEON
258 
259                 weight_hc_IFOG += 4;
260             }
261 
262             float* gates_data = gates.row(q);
263 
264 #if __ARM_NEON
265             _IFOG = vaddq_f32(_IFOG, _sum1);
266             _sum2 = vaddq_f32(_sum2, _sum3);
267             _IFOG = vaddq_f32(_IFOG, _sum2);
268 
269             vst1q_f32(gates_data, _IFOG);
270 #else
271             gates_data[0] = I;
272             gates_data[1] = F;
273             gates_data[2] = O;
274             gates_data[3] = G;
275 #endif // __ARM_NEON
276         }
277 
278         // lstm unit
279         // sigmoid(I)
280         // sigmoid(F)
281         // sigmoid(O)
282         // tanh(G)
283         // c_t := f_t .* c_{t-1} + i_t .* g_t
284         // h_t := o_t .* tanh[c_t]
285         float* output_data = top_blob.row(ti);
286 
287         float* cell_ptr = cell_state;
288         float* hidden_ptr = hidden_state;
289 
290         int q = 0;
291 #if __ARM_NEON
292         for (; q + 3 < num_output; q += 4)
293         {
294             const float* gates_data = gates.row(q);
295 
296             float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
297 
298             float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]);
299             float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]);
300             float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]);
301             float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]);
302 
303             float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
304             float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
305 
306             vst1q_f32(cell_ptr, _cell2);
307             vst1q_f32(hidden_ptr, _H);
308             vst1q_f32(output_data, _H);
309 
310             cell_ptr += 4;
311             hidden_ptr += 4;
312             output_data += 4;
313         }
314 #endif // __ARM_NEON
315         for (; q < num_output; q++)
316         {
317             const float* gates_data = gates.row(q);
318 
319             float I = gates_data[0];
320             float F = gates_data[1];
321             float O = gates_data[2];
322             float G = gates_data[3];
323 
324             I = 1.f / (1.f + exp(-I));
325             F = 1.f / (1.f + exp(-F));
326             O = 1.f / (1.f + exp(-O));
327             G = tanh(G);
328 
329             float cell2 = F * *cell_ptr + I * G;
330             float H = O * tanh(cell2);
331 
332             *cell_ptr++ = cell2;
333             *hidden_ptr++ = H;
334             *output_data++ = H;
335         }
336     }
337 
338     return 0;
339 }
340 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const341 int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
342 {
343     int elembits = bottom_blob.elembits();
344 
345 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
346     if (opt.use_fp16_storage && elembits == 16)
347     {
348         if (opt.use_fp16_arithmetic)
349             return forward_fp16sa(bottom_blob, top_blob, opt);
350         else
351             return forward_fp16s(bottom_blob, top_blob, opt);
352     }
353 #endif
354 
355     if (opt.use_bf16_storage && elembits == 16)
356         return forward_bf16s(bottom_blob, top_blob, opt);
357 
358     int T = bottom_blob.h;
359 
360     int num_directions = direction == 2 ? 2 : 1;
361 
362     // initial hidden state
363     Mat hidden(num_output, 4u, opt.workspace_allocator);
364     if (hidden.empty())
365         return -100;
366     hidden.fill(0.f);
367 
368     Mat cell(num_output, 4u, opt.workspace_allocator);
369     if (cell.empty())
370         return -100;
371     cell.fill(0.f);
372 
373     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
374     if (top_blob.empty())
375         return -100;
376 
377     // Uni directional
378     if (direction == 0 || direction == 1)
379     {
380         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
381         if (ret != 0)
382             return ret;
383     }
384 
385     if (direction == 2)
386     {
387         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
388         if (top_blob_forward.empty())
389             return -100;
390 
391         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
392         if (top_blob_reverse.empty())
393             return -100;
394 
395         int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
396         if (ret0 != 0)
397             return ret0;
398 
399         hidden.fill(0.0f);
400         cell.fill(0.0f);
401 
402         int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
403         if (ret1 != 0)
404             return ret1;
405 
406         // concat w
407         for (int i = 0; i < T; i++)
408         {
409             const float* pf = top_blob_forward.row(i);
410             const float* pr = top_blob_reverse.row(i);
411             float* ptr = top_blob.row(i);
412 
413             memcpy(ptr, pf, num_output * sizeof(float));
414             memcpy(ptr + num_output, pr, num_output * sizeof(float));
415         }
416     }
417 
418     return 0;
419 }
420 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const421 int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
422 {
423     if (bottom_blobs.size() != 3 || top_blobs.size() != 3)
424     {
425         return forward(bottom_blobs[0], top_blobs[0], opt);
426     }
427 
428     const Mat& bottom_blob = bottom_blobs[0];
429 
430     int elembits = bottom_blob.elembits();
431 
432 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
433     if (opt.use_fp16_storage && elembits == 16)
434     {
435         if (opt.use_fp16_arithmetic)
436             return forward_fp16sa(bottom_blobs, top_blobs, opt);
437         else
438             return forward_fp16s(bottom_blobs, top_blobs, opt);
439     }
440 #endif
441 
442     if (opt.use_bf16_storage && elembits == 16)
443         return forward_bf16s(bottom_blobs, top_blobs, opt);
444 
445     int T = bottom_blob.h;
446     Mat& top_blob = top_blobs[0];
447     Mat& hidden_state = top_blobs[1];
448     Mat& cell_state = top_blobs[2];
449 
450     //Copy previous states
451     hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
452     cell_state = bottom_blobs[2].clone(opt.blob_allocator);
453 
454     top_blob.create(num_output, T, 4u, opt.blob_allocator);
455     if (top_blob.empty())
456         return -100;
457 
458     // Uni directional
459     if (direction == 0 || direction == 1)
460     {
461         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, cell_state, opt);
462         if (ret != 0)
463             return ret;
464     }
465 
466     return 0;
467 }
468 
469 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
lstm_fp16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)470 static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
471 {
472     int size = bottom_blob.w;
473     int T = bottom_blob.h;
474 
475     int num_output = top_blob.w;
476 
477     // 4 x num_output
478     Mat gates(4, num_output, 4u, opt.workspace_allocator);
479     if (gates.empty())
480         return -100;
481 
482     // unroll
483     for (int t = 0; t < T; t++)
484     {
485         // clip hidden by continuation indicator
486         // h_cont_{t-1} = cont_t * h_{t-1}
487         // h_cont_{t-1} = h_{t-1} if cont_t == 1
488         //                0       otherwise
489         // calculate hidden
490         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
491 
492         int ti = reverse ? T - 1 - t : t;
493 
494         const __fp16* x = bottom_blob.row<const __fp16>(ti);
495         for (int q = 0; q < num_output; q++)
496         {
497             const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;
498 
499             // gate I F O G
500             const __fp16* weight_xc_IFOG = weight_xc.row<const __fp16>(q);
501 
502             const __fp16* weight_hc_IFOG = weight_hc.row<const __fp16>(q);
503 
504             float32x4_t _IFOG = vcvt_f32_f16(vld1_f16(bias_c_IFOG));
505             float32x4_t _sum1 = vdupq_n_f32(0.f);
506             float32x4_t _sum2 = vdupq_n_f32(0.f);
507             float32x4_t _sum3 = vdupq_n_f32(0.f);
508 
509             int i = 0;
510             for (; i + 3 < size; i += 4)
511             {
512                 float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i));
513 
514                 float32x4_t _weight_xc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG));
515                 float32x4_t _weight_xc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 4));
516                 float32x4_t _weight_xc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 8));
517                 float32x4_t _weight_xc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 12));
518 
519                 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
520                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1);
521                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2);
522                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3);
523 
524                 weight_xc_IFOG += 16;
525             }
526             for (; i < size; i++)
527             {
528                 __fp16 xi = x[i];
529 
530                 float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi));
531                 float32x4_t _weight_xc_IFOG = vcvt_f32_f16(vld1_f16(weight_xc_IFOG));
532                 _IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi);
533 
534                 weight_xc_IFOG += 4;
535             }
536 
537             i = 0;
538             for (; i + 3 < num_output; i += 4)
539             {
540                 float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i);
541 
542                 float32x4_t _weight_hc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG));
543                 float32x4_t _weight_hc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 4));
544                 float32x4_t _weight_hc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 8));
545                 float32x4_t _weight_hc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 12));
546 
547                 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
548                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1);
549                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2);
550                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3);
551 
552                 weight_hc_IFOG += 16;
553             }
554             for (; i < num_output; i++)
555             {
556                 float h_cont = hidden_state[i];
557 
558                 float32x4_t _h_cont = vdupq_n_f32(h_cont);
559                 float32x4_t _weight_hc_IFOG = vcvt_f32_f16(vld1_f16(weight_hc_IFOG));
560                 _IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
561 
562                 weight_hc_IFOG += 4;
563             }
564 
565             float* gates_data = gates.row(q);
566 
567             _IFOG = vaddq_f32(_IFOG, _sum1);
568             _sum2 = vaddq_f32(_sum2, _sum3);
569             _IFOG = vaddq_f32(_IFOG, _sum2);
570 
571             vst1q_f32(gates_data, _IFOG);
572         }
573 
574         // lstm unit
575         // sigmoid(I)
576         // sigmoid(F)
577         // sigmoid(O)
578         // tanh(G)
579         // c_t := f_t .* c_{t-1} + i_t .* g_t
580         // h_t := o_t .* tanh[c_t]
581         __fp16* output_data = top_blob.row<__fp16>(ti);
582 
583         float* cell_ptr = cell_state;
584         float* hidden_ptr = hidden_state;
585 
586         int q = 0;
587         for (; q + 3 < num_output; q += 4)
588         {
589             const float* gates_data = gates.row(q);
590 
591             float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
592 
593             float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]);
594             float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]);
595             float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]);
596             float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]);
597 
598             float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
599             float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
600 
601             vst1q_f32(cell_ptr, _cell2);
602             vst1q_f32(hidden_ptr, _H);
603             vst1_f16(output_data, vcvt_f16_f32(_H));
604 
605             cell_ptr += 4;
606             hidden_ptr += 4;
607             output_data += 4;
608         }
609         for (; q < num_output; q++)
610         {
611             const float* gates_data = gates.row(q);
612 
613             float I = gates_data[0];
614             float F = gates_data[1];
615             float O = gates_data[2];
616             float G = gates_data[3];
617 
618             I = 1.f / (1.f + exp(-I));
619             F = 1.f / (1.f + exp(-F));
620             O = 1.f / (1.f + exp(-O));
621             G = tanh(G);
622 
623             float cell2 = F * *cell_ptr + I * G;
624             float H = O * tanh(cell2);
625 
626             *cell_ptr++ = cell2;
627             *hidden_ptr++ = H;
628             *output_data++ = (__fp16)(H);
629         }
630     }
631 
632     return 0;
633 }
634 
lstm_fp16sa(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)635 static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
636 {
637     int size = bottom_blob.w;
638     int T = bottom_blob.h;
639 
640     int num_output = top_blob.w;
641 
642     // 4 x num_output
643     Mat gates(4, num_output, 2u, opt.workspace_allocator);
644     if (gates.empty())
645         return -100;
646 
647     // unroll
648     for (int t = 0; t < T; t++)
649     {
650         // clip hidden by continuation indicator
651         // h_cont_{t-1} = cont_t * h_{t-1}
652         // h_cont_{t-1} = h_{t-1} if cont_t == 1
653         //                0       otherwise
654         // calculate hidden
655         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
656 
657         int ti = reverse ? T - 1 - t : t;
658 
659         int q = 0;
660         for (; q + 1 < num_output; q += 2)
661         {
662             const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;
663 
664             // gate I F O G
665             const __fp16* weight_xc_IFOG = weight_xc.row<const __fp16>(q / 2);
666 
667             const __fp16* weight_hc_IFOG = weight_hc.row<const __fp16>(q / 2);
668 
669             float16x8_t _IFOG = vld1q_f16(bias_c_IFOG);
670             float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f);
671             float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f);
672             float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f);
673 
674             const __fp16* x = bottom_blob.row<const __fp16>(ti);
675 
676             int i = 0;
677             for (; i + 3 < size; i += 4)
678             {
679                 asm volatile(
680                     "ld1    {v4.4h}, [%0], #8       \n"
681                     "ld1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
682                     "fmla   %2.8h, v0.8h, v4.h[0]   \n"
683                     "fmla   %3.8h, v1.8h, v4.h[1]   \n"
684                     "fmla   %4.8h, v2.8h, v4.h[2]   \n"
685                     "fmla   %5.8h, v3.8h, v4.h[3]   \n"
686                     : "=r"(x),
687                     "=r"(weight_xc_IFOG),
688                     "=w"(_IFOG),
689                     "=w"(_sum1),
690                     "=w"(_sum2),
691                     "=w"(_sum3)
692                     : "0"(x),
693                     "1"(weight_xc_IFOG),
694                     "2"(_IFOG),
695                     "3"(_sum1),
696                     "4"(_sum2),
697                     "5"(_sum3)
698                     : "memory", "v0", "v1", "v2", "v3", "v4");
699             }
700             for (; i < size; i++)
701             {
702                 __fp16 xi = *x++;
703 
704                 float16x8_t _xi = vdupq_n_f16(xi);
705                 float16x8_t _weight_xc_IFOG = vld1q_f16(weight_xc_IFOG);
706                 _IFOG = vfmaq_f16(_IFOG, _weight_xc_IFOG, _xi);
707 
708                 weight_xc_IFOG += 8;
709             }
710 
711             const float* hidden_ptr = hidden_state;
712 
713             i = 0;
714             for (; i + 3 < num_output; i += 4)
715             {
716                 asm volatile(
717                     "ld1    {v4.4s}, [%0], #16      \n"
718                     "ld1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
719                     "fcvtn  v4.4h, v4.4s            \n"
720                     "fmla   %2.8h, v0.8h, v4.h[0]   \n"
721                     "fmla   %3.8h, v1.8h, v4.h[1]   \n"
722                     "fmla   %4.8h, v2.8h, v4.h[2]   \n"
723                     "fmla   %5.8h, v3.8h, v4.h[3]   \n"
724                     : "=r"(hidden_ptr),
725                     "=r"(weight_hc_IFOG),
726                     "=w"(_IFOG),
727                     "=w"(_sum1),
728                     "=w"(_sum2),
729                     "=w"(_sum3)
730                     : "0"(hidden_ptr),
731                     "1"(weight_hc_IFOG),
732                     "2"(_IFOG),
733                     "3"(_sum1),
734                     "4"(_sum2),
735                     "5"(_sum3)
736                     : "memory", "v0", "v1", "v2", "v3", "v4");
737             }
738             for (; i < num_output; i++)
739             {
740                 float h_cont = *hidden_ptr++;
741 
742                 float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont);
743                 float16x8_t _weight_hc_IFOG = vld1q_f16(weight_hc_IFOG);
744                 _IFOG = vfmaq_f16(_IFOG, _weight_hc_IFOG, _h_cont);
745 
746                 weight_hc_IFOG += 8;
747             }
748 
749             __fp16* gates_data = gates.row<__fp16>(q);
750 
751             _IFOG = vaddq_f16(_IFOG, _sum1);
752             _sum2 = vaddq_f16(_sum2, _sum3);
753             _IFOG = vaddq_f16(_IFOG, _sum2);
754 
755             vst1q_f16(gates_data, _IFOG);
756         }
757         for (; q < num_output; q++)
758         {
759             const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;
760 
761             // gate I F O G
762             const __fp16* weight_xc_IFOG = weight_xc.row<const __fp16>(q / 2 + q % 2);
763 
764             const __fp16* weight_hc_IFOG = weight_hc.row<const __fp16>(q / 2 + q % 2);
765 
766             float16x4_t _IFOG = vld1_f16(bias_c_IFOG);
767             float16x4_t _sum1 = vdup_n_f16((__fp16)0.f);
768             float16x4_t _sum2 = vdup_n_f16((__fp16)0.f);
769             float16x4_t _sum3 = vdup_n_f16((__fp16)0.f);
770 
771             const __fp16* x = bottom_blob.row<const __fp16>(ti);
772 
773             int i = 0;
774             for (; i + 3 < size; i += 4)
775             {
776                 asm volatile(
777                     "ld1    {v4.4h}, [%0], #8       \n"
778                     "ld1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
779                     "fmla   %2.4h, v0.4h, v4.h[0]   \n"
780                     "fmla   %3.4h, v1.4h, v4.h[1]   \n"
781                     "fmla   %4.4h, v2.4h, v4.h[2]   \n"
782                     "fmla   %5.4h, v3.4h, v4.h[3]   \n"
783                     : "=r"(x),
784                     "=r"(weight_xc_IFOG),
785                     "=w"(_IFOG),
786                     "=w"(_sum1),
787                     "=w"(_sum2),
788                     "=w"(_sum3)
789                     : "0"(x),
790                     "1"(weight_xc_IFOG),
791                     "2"(_IFOG),
792                     "3"(_sum1),
793                     "4"(_sum2),
794                     "5"(_sum3)
795                     : "memory", "v0", "v1", "v2", "v3", "v4");
796             }
797             for (; i < size; i++)
798             {
799                 __fp16 xi = *x++;
800 
801                 float16x4_t _xi = vdup_n_f16(xi);
802                 float16x4_t _weight_xc_IFOG = vld1_f16(weight_xc_IFOG);
803                 _IFOG = vfma_f16(_IFOG, _weight_xc_IFOG, _xi);
804 
805                 weight_xc_IFOG += 4;
806             }
807 
808             const float* hidden_ptr = hidden_state;
809 
810             i = 0;
811             for (; i + 3 < num_output; i += 4)
812             {
813                 asm volatile(
814                     "ld1    {v4.4s}, [%0], #16      \n"
815                     "ld1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
816                     "fcvtn  v4.4h, v4.4s            \n"
817                     "fmla   %2.4h, v0.4h, v4.h[0]   \n"
818                     "fmla   %3.4h, v1.4h, v4.h[1]   \n"
819                     "fmla   %4.4h, v2.4h, v4.h[2]   \n"
820                     "fmla   %5.4h, v3.4h, v4.h[3]   \n"
821                     : "=r"(hidden_ptr),
822                     "=r"(weight_hc_IFOG),
823                     "=w"(_IFOG),
824                     "=w"(_sum1),
825                     "=w"(_sum2),
826                     "=w"(_sum3)
827                     : "0"(hidden_ptr),
828                     "1"(weight_hc_IFOG),
829                     "2"(_IFOG),
830                     "3"(_sum1),
831                     "4"(_sum2),
832                     "5"(_sum3)
833                     : "memory", "v0", "v1", "v2", "v3", "v4");
834             }
835             for (; i < num_output; i++)
836             {
837                 float h_cont = *hidden_ptr++;
838 
839                 float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont);
840                 float16x4_t _weight_hc_IFOG = vld1_f16(weight_hc_IFOG);
841                 _IFOG = vfma_f16(_IFOG, _weight_hc_IFOG, _h_cont);
842 
843                 weight_hc_IFOG += 4;
844             }
845 
846             __fp16* gates_data = gates.row<__fp16>(q);
847 
848             _IFOG = vadd_f16(_IFOG, _sum1);
849             _sum2 = vadd_f16(_sum2, _sum3);
850             _IFOG = vadd_f16(_IFOG, _sum2);
851 
852             vst1_f16(gates_data, _IFOG);
853         }
854 
855         // lstm unit
856         // sigmoid(I)
857         // sigmoid(F)
858         // sigmoid(O)
859         // tanh(G)
860         // c_t := f_t .* c_{t-1} + i_t .* g_t
861         // h_t := o_t .* tanh[c_t]
862         __fp16* output_data = top_blob.row<__fp16>(ti);
863 
864         float* cell_ptr = cell_state;
865         float* hidden_ptr = hidden_state;
866 
867         q = 0;
868         for (; q + 3 < num_output; q += 4)
869         {
870             const __fp16* gates_data = gates.row<const __fp16>(q);
871 
872             float16x4x4_t _IFOG_4x4 = vld4_f16(gates_data);
873 
874             float32x4_t _I = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[0]));
875             float32x4_t _F = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[1]));
876             float32x4_t _O = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[2]));
877             float32x4_t _G = tanh_ps(vcvt_f32_f16(_IFOG_4x4.val[3]));
878 
879             float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
880             float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
881 
882             vst1q_f32(cell_ptr, _cell2);
883             vst1q_f32(hidden_ptr, _H);
884             vst1_f16(output_data, vcvt_f16_f32(_H));
885 
886             cell_ptr += 4;
887             hidden_ptr += 4;
888             output_data += 4;
889         }
890         for (; q < num_output; q++)
891         {
892             const __fp16* gates_data = gates.row<const __fp16>(q);
893 
894             float I = (float)gates_data[0];
895             float F = (float)gates_data[1];
896             float O = (float)gates_data[2];
897             float G = (float)gates_data[3];
898 
899             I = 1.f / (1.f + exp(-I));
900             F = 1.f / (1.f + exp(-F));
901             O = 1.f / (1.f + exp(-O));
902             G = tanh(G);
903 
904             float cell2 = F * *cell_ptr + I * G;
905             float H = O * tanh(cell2);
906 
907             *cell_ptr++ = cell2;
908             *hidden_ptr++ = H;
909             *output_data++ = (__fp16)H;
910         }
911     }
912 
913     return 0;
914 }
915 
create_pipeline_fp16s(const Option & opt)916 int LSTM_arm::create_pipeline_fp16s(const Option& opt)
917 {
918     // pack IFOG
919     int num_directions = direction == 2 ? 2 : 1;
920     int size = weight_data_size / num_directions / num_output / 4;
921 
922     if (opt.use_fp16_arithmetic)
923     {
924         weight_xc_data_packed.create(size, num_output / 2 + num_output % 2, num_directions, 16u, 8);
925         bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
926         weight_hc_data_packed.create(num_output, num_output / 2 + num_output % 2, num_directions, 16u, 8);
927     }
928     else
929     {
930         weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4);
931         bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
932         weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4);
933     }
934 
935     #pragma omp parallel for num_threads(opt.num_threads)
936     for (int dr = 0; dr < num_directions; dr++)
937     {
938         const Mat weight_xc = weight_xc_data.channel(dr);
939         const Mat bias_c = bias_c_data.channel(dr);
940         const Mat weight_hc = weight_hc_data.channel(dr);
941 
942         Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
943         Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
944         Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
945 
946         const float* bias_c_I = bias_c.row(0);
947         const float* bias_c_F = bias_c.row(1);
948         const float* bias_c_O = bias_c.row(2);
949         const float* bias_c_G = bias_c.row(3);
950 
951         __fp16* bias_c_IFOG = bias_c_data_packed_dr.row<__fp16>(0);
952 
953         if (opt.use_fp16_arithmetic)
954         {
955             int q = 0;
956             for (; q + 1 < num_output; q += 2)
957             {
958                 bias_c_IFOG[0] = (__fp16)bias_c_I[q];
959                 bias_c_IFOG[1] = (__fp16)bias_c_F[q];
960                 bias_c_IFOG[2] = (__fp16)bias_c_O[q];
961                 bias_c_IFOG[3] = (__fp16)bias_c_G[q];
962                 bias_c_IFOG[4] = (__fp16)bias_c_I[q + 1];
963                 bias_c_IFOG[5] = (__fp16)bias_c_F[q + 1];
964                 bias_c_IFOG[6] = (__fp16)bias_c_O[q + 1];
965                 bias_c_IFOG[7] = (__fp16)bias_c_G[q + 1];
966 
967                 bias_c_IFOG += 8;
968 
969                 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
970                 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
971                 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
972                 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
973                 const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + q + 1);
974                 const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + q + 1);
975                 const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + q + 1);
976                 const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + q + 1);
977 
978                 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
979                 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
980                 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
981                 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
982                 const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + q + 1);
983                 const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + q + 1);
984                 const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + q + 1);
985                 const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + q + 1);
986 
987                 __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2);
988                 __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2);
989 
990                 for (int i = 0; i < size; i++)
991                 {
992                     weight_xc_IFOG[0] = (__fp16)weight_xc_I[i];
993                     weight_xc_IFOG[1] = (__fp16)weight_xc_F[i];
994                     weight_xc_IFOG[2] = (__fp16)weight_xc_O[i];
995                     weight_xc_IFOG[3] = (__fp16)weight_xc_G[i];
996                     weight_xc_IFOG[4] = (__fp16)weight_xc_I_1[i];
997                     weight_xc_IFOG[5] = (__fp16)weight_xc_F_1[i];
998                     weight_xc_IFOG[6] = (__fp16)weight_xc_O_1[i];
999                     weight_xc_IFOG[7] = (__fp16)weight_xc_G_1[i];
1000 
1001                     weight_xc_IFOG += 8;
1002                 }
1003 
1004                 for (int i = 0; i < num_output; i++)
1005                 {
1006                     weight_hc_IFOG[0] = (__fp16)weight_hc_I[i];
1007                     weight_hc_IFOG[1] = (__fp16)weight_hc_F[i];
1008                     weight_hc_IFOG[2] = (__fp16)weight_hc_O[i];
1009                     weight_hc_IFOG[3] = (__fp16)weight_hc_G[i];
1010                     weight_hc_IFOG[4] = (__fp16)weight_hc_I_1[i];
1011                     weight_hc_IFOG[5] = (__fp16)weight_hc_F_1[i];
1012                     weight_hc_IFOG[6] = (__fp16)weight_hc_O_1[i];
1013                     weight_hc_IFOG[7] = (__fp16)weight_hc_G_1[i];
1014 
1015                     weight_hc_IFOG += 8;
1016                 }
1017             }
1018             for (; q < num_output; q++)
1019             {
1020                 bias_c_IFOG[0] = (__fp16)bias_c_I[q];
1021                 bias_c_IFOG[1] = (__fp16)bias_c_F[q];
1022                 bias_c_IFOG[2] = (__fp16)bias_c_O[q];
1023                 bias_c_IFOG[3] = (__fp16)bias_c_G[q];
1024 
1025                 bias_c_IFOG += 4;
1026 
1027                 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
1028                 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
1029                 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
1030                 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
1031 
1032                 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
1033                 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
1034                 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
1035                 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
1036 
1037                 __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2 + q % 2);
1038                 __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2 + q % 2);
1039 
1040                 for (int i = 0; i < size; i++)
1041                 {
1042                     weight_xc_IFOG[0] = (__fp16)weight_xc_I[i];
1043                     weight_xc_IFOG[1] = (__fp16)weight_xc_F[i];
1044                     weight_xc_IFOG[2] = (__fp16)weight_xc_O[i];
1045                     weight_xc_IFOG[3] = (__fp16)weight_xc_G[i];
1046 
1047                     weight_xc_IFOG += 4;
1048                 }
1049 
1050                 for (int i = 0; i < num_output; i++)
1051                 {
1052                     weight_hc_IFOG[0] = (__fp16)weight_hc_I[i];
1053                     weight_hc_IFOG[1] = (__fp16)weight_hc_F[i];
1054                     weight_hc_IFOG[2] = (__fp16)weight_hc_O[i];
1055                     weight_hc_IFOG[3] = (__fp16)weight_hc_G[i];
1056 
1057                     weight_hc_IFOG += 4;
1058                 }
1059             }
1060         }
1061         else
1062         {
1063             for (int q = 0; q < num_output; q++)
1064             {
1065                 bias_c_IFOG[0] = (__fp16)bias_c_I[q];
1066                 bias_c_IFOG[1] = (__fp16)bias_c_F[q];
1067                 bias_c_IFOG[2] = (__fp16)bias_c_O[q];
1068                 bias_c_IFOG[3] = (__fp16)bias_c_G[q];
1069 
1070                 bias_c_IFOG += 4;
1071 
1072                 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
1073                 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
1074                 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
1075                 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
1076 
1077                 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
1078                 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
1079                 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
1080                 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
1081 
1082                 __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q);
1083                 __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q);
1084 
1085                 for (int i = 0; i < size; i++)
1086                 {
1087                     weight_xc_IFOG[0] = (__fp16)weight_xc_I[i];
1088                     weight_xc_IFOG[1] = (__fp16)weight_xc_F[i];
1089                     weight_xc_IFOG[2] = (__fp16)weight_xc_O[i];
1090                     weight_xc_IFOG[3] = (__fp16)weight_xc_G[i];
1091 
1092                     weight_xc_IFOG += 4;
1093                 }
1094 
1095                 for (int i = 0; i < num_output; i++)
1096                 {
1097                     weight_hc_IFOG[0] = (__fp16)weight_hc_I[i];
1098                     weight_hc_IFOG[1] = (__fp16)weight_hc_F[i];
1099                     weight_hc_IFOG[2] = (__fp16)weight_hc_O[i];
1100                     weight_hc_IFOG[3] = (__fp16)weight_hc_G[i];
1101 
1102                     weight_hc_IFOG += 4;
1103                 }
1104             }
1105         }
1106     }
1107 
1108     return 0;
1109 }
1110 
forward_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1111 int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1112 {
1113     int T = bottom_blob.h;
1114 
1115     int num_directions = direction == 2 ? 2 : 1;
1116 
1117     // initial hidden state
1118     Mat hidden(num_output, 4u, opt.workspace_allocator);
1119     if (hidden.empty())
1120         return -100;
1121     hidden.fill(0.f);
1122 
1123     Mat cell(num_output, 4u, opt.workspace_allocator);
1124     if (cell.empty())
1125         return -100;
1126     cell.fill(0.f);
1127 
1128     top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1129     if (top_blob.empty())
1130         return -100;
1131 
1132     // Uni directional
1133     if (direction == 0 || direction == 1)
1134     {
1135         int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1136         if (ret != 0)
1137             return ret;
1138     }
1139 
1140     if (direction == 2)
1141     {
1142         Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1143         if (top_blob_forward.empty())
1144             return -100;
1145 
1146         Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1147         if (top_blob_reverse.empty())
1148             return -100;
1149 
1150         int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1151         if (ret0 != 0)
1152             return ret0;
1153 
1154         hidden.fill(0.f);
1155         cell.fill(0.f);
1156 
1157         int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
1158         if (ret1 != 0)
1159             return ret1;
1160 
1161         // concat w
1162         for (int i = 0; i < T; i++)
1163         {
1164             const __fp16* pf = top_blob_forward.row<const __fp16>(i);
1165             const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
1166             __fp16* ptr = top_blob.row<__fp16>(i);
1167 
1168             memcpy(ptr, pf, num_output * sizeof(__fp16));
1169             memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
1170         }
1171     }
1172 
1173     return 0;
1174 }
1175 
forward_fp16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1176 int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1177 {
1178     const Mat& bottom_blob = bottom_blobs[0];
1179     int T = bottom_blob.h;
1180     Mat& top_blob = top_blobs[0];
1181 
1182     top_blob.create(num_output, T, 2u, opt.blob_allocator);
1183     if (top_blob.empty())
1184         return -100;
1185 
1186     // copy previous states
1187     Mat hidden;
1188     Mat cell;
1189     cast_float16_to_float32(bottom_blobs[1], hidden, opt);
1190     cast_float16_to_float32(bottom_blobs[2], cell, opt);
1191 
1192     // Uni directional
1193     if (direction == 0 || direction == 1)
1194     {
1195         int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1196         if (ret != 0)
1197             return ret;
1198     }
1199 
1200     cast_float32_to_float16(hidden, top_blobs[1], opt);
1201     cast_float32_to_float16(cell, top_blobs[2], opt);
1202 
1203     return 0;
1204 }
1205 
forward_fp16sa(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1206 int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1207 {
1208     int T = bottom_blob.h;
1209 
1210     int num_directions = direction == 2 ? 2 : 1;
1211 
1212     // initial hidden state
1213     Mat hidden(num_output, 4u, opt.workspace_allocator);
1214     if (hidden.empty())
1215         return -100;
1216     hidden.fill(0.f);
1217 
1218     Mat cell(num_output, 4u, opt.workspace_allocator);
1219     if (cell.empty())
1220         return -100;
1221     cell.fill(0.f);
1222 
1223     top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1224     if (top_blob.empty())
1225         return -100;
1226 
1227     // Uni directional
1228     if (direction == 0 || direction == 1)
1229     {
1230         int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1231         if (ret != 0)
1232             return ret;
1233     }
1234 
1235     if (direction == 2)
1236     {
1237         Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1238         if (top_blob_forward.empty())
1239             return -100;
1240 
1241         Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1242         if (top_blob_reverse.empty())
1243             return -100;
1244 
1245         int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1246         if (ret0 != 0)
1247             return ret0;
1248 
1249         hidden.fill(0.f);
1250         cell.fill(0.f);
1251 
1252         int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
1253         if (ret1 != 0)
1254             return ret1;
1255 
1256         // concat w
1257         for (int i = 0; i < T; i++)
1258         {
1259             const __fp16* pf = top_blob_forward.row<const __fp16>(i);
1260             const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
1261             __fp16* ptr = top_blob.row<__fp16>(i);
1262 
1263             memcpy(ptr, pf, num_output * sizeof(__fp16));
1264             memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
1265         }
1266     }
1267 
1268     return 0;
1269 }
1270 
forward_fp16sa(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1271 int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1272 {
1273     const Mat& bottom_blob = bottom_blobs[0];
1274     int T = bottom_blob.h;
1275     Mat& top_blob = top_blobs[0];
1276 
1277     top_blob.create(num_output, T, 2u, opt.blob_allocator);
1278     if (top_blob.empty())
1279         return -100;
1280 
1281     // copy previous states
1282     Mat hidden;
1283     Mat cell;
1284     cast_float16_to_float32(bottom_blobs[1], hidden, opt);
1285     cast_float16_to_float32(bottom_blobs[2], cell, opt);
1286 
1287     // Uni directional
1288     if (direction == 0 || direction == 1)
1289     {
1290         int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1291         if (ret != 0)
1292             return ret;
1293     }
1294 
1295     cast_float32_to_float16(hidden, top_blobs[1], opt);
1296     cast_float32_to_float16(cell, top_blobs[2], opt);
1297 
1298     return 0;
1299 }
1300 #endif
1301 
lstm_bf16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)1302 static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
1303 {
1304     int size = bottom_blob.w;
1305     int T = bottom_blob.h;
1306 
1307     int num_output = top_blob.w;
1308 
1309     // 4 x num_output
1310     Mat gates(4, num_output, 4u, opt.workspace_allocator);
1311     if (gates.empty())
1312         return -100;
1313 
1314     // unroll
1315     for (int t = 0; t < T; t++)
1316     {
1317         // clip hidden by continuation indicator
1318         // h_cont_{t-1} = cont_t * h_{t-1}
1319         // h_cont_{t-1} = h_{t-1} if cont_t == 1
1320         //                0       otherwise
1321         // calculate hidden
1322         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
1323 
1324         int ti = reverse ? T - 1 - t : t;
1325 
1326         const unsigned short* x = bottom_blob.row<const unsigned short>(ti);
1327         for (int q = 0; q < num_output; q++)
1328         {
1329             const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4;
1330 
1331             // gate I F O G
1332             const unsigned short* weight_xc_IFOG = weight_xc.row<const unsigned short>(q);
1333 
1334             const unsigned short* weight_hc_IFOG = weight_hc.row<const unsigned short>(q);
1335 
1336 #if __ARM_NEON
1337             float32x4_t _IFOG = vcvt_f32_bf16(vld1_u16(bias_c_IFOG));
1338             float32x4_t _sum1 = vdupq_n_f32(0.f);
1339             float32x4_t _sum2 = vdupq_n_f32(0.f);
1340             float32x4_t _sum3 = vdupq_n_f32(0.f);
1341 #else
1342             float I = bfloat16_to_float32(bias_c_IFOG[0]);
1343             float F = bfloat16_to_float32(bias_c_IFOG[1]);
1344             float O = bfloat16_to_float32(bias_c_IFOG[2]);
1345             float G = bfloat16_to_float32(bias_c_IFOG[3]);
1346 #endif // __ARM_NEON
1347 
1348             int i = 0;
1349 #if __ARM_NEON
1350             for (; i + 3 < size; i += 4)
1351             {
1352                 float32x4_t _xi = vcvt_f32_bf16(vld1_u16(x + i));
1353 
1354                 float32x4_t _weight_xc_IFOG_0 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG));
1355                 float32x4_t _weight_xc_IFOG_1 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG + 4));
1356                 float32x4_t _weight_xc_IFOG_2 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG + 8));
1357                 float32x4_t _weight_xc_IFOG_3 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG + 12));
1358 
1359 #if __aarch64__
1360                 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
1361                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1);
1362                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2);
1363                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3);
1364 #else
1365                 _IFOG = vmlaq_lane_f32(_IFOG, _weight_xc_IFOG_0, vget_low_f32(_xi), 0);
1366                 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_IFOG_1, vget_low_f32(_xi), 1);
1367                 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_IFOG_2, vget_high_f32(_xi), 0);
1368                 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1);
1369 #endif
1370 
1371                 weight_xc_IFOG += 16;
1372             }
1373 #endif // __ARM_NEON
1374             for (; i < size; i++)
1375             {
1376 #if __ARM_NEON
1377                 unsigned short xi = x[i];
1378 
1379                 float32x4_t _xi = vcvt_f32_bf16(vdup_n_u16(xi));
1380                 float32x4_t _weight_xc_IFOG = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG));
1381                 _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
1382 #else
1383                 float xi = bfloat16_to_float32(x[i]);
1384 
1385                 I += bfloat16_to_float32(weight_xc_IFOG[0]) * xi;
1386                 F += bfloat16_to_float32(weight_xc_IFOG[1]) * xi;
1387                 O += bfloat16_to_float32(weight_xc_IFOG[2]) * xi;
1388                 G += bfloat16_to_float32(weight_xc_IFOG[3]) * xi;
1389 #endif // __ARM_NEON
1390 
1391                 weight_xc_IFOG += 4;
1392             }
1393 
1394             i = 0;
1395 #if __ARM_NEON
1396             for (; i + 3 < num_output; i += 4)
1397             {
1398                 float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i);
1399 
1400                 float32x4_t _weight_hc_IFOG_0 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG));
1401                 float32x4_t _weight_hc_IFOG_1 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG + 4));
1402                 float32x4_t _weight_hc_IFOG_2 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG + 8));
1403                 float32x4_t _weight_hc_IFOG_3 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG + 12));
1404 
1405 #if __aarch64__
1406                 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
1407                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1);
1408                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2);
1409                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3);
1410 #else
1411                 _IFOG = vmlaq_lane_f32(_IFOG, _weight_hc_IFOG_0, vget_low_f32(_h_cont), 0);
1412                 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_IFOG_1, vget_low_f32(_h_cont), 1);
1413                 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_IFOG_2, vget_high_f32(_h_cont), 0);
1414                 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1);
1415 #endif
1416 
1417                 weight_hc_IFOG += 16;
1418             }
1419 #endif // __ARM_NEON
1420             for (; i < num_output; i++)
1421             {
1422                 float h_cont = hidden_state[i];
1423 
1424 #if __ARM_NEON
1425                 float32x4_t _h_cont = vdupq_n_f32(h_cont);
1426                 float32x4_t _weight_hc_IFOG = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG));
1427                 _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
1428 #else
1429                 I += bfloat16_to_float32(weight_hc_IFOG[0]) * h_cont;
1430                 F += bfloat16_to_float32(weight_hc_IFOG[1]) * h_cont;
1431                 O += bfloat16_to_float32(weight_hc_IFOG[2]) * h_cont;
1432                 G += bfloat16_to_float32(weight_hc_IFOG[3]) * h_cont;
1433 #endif // __ARM_NEON
1434 
1435                 weight_hc_IFOG += 4;
1436             }
1437 
1438             float* gates_data = gates.row(q);
1439 
1440 #if __ARM_NEON
1441             _IFOG = vaddq_f32(_IFOG, _sum1);
1442             _sum2 = vaddq_f32(_sum2, _sum3);
1443             _IFOG = vaddq_f32(_IFOG, _sum2);
1444 
1445             vst1q_f32(gates_data, _IFOG);
1446 #else
1447             gates_data[0] = I;
1448             gates_data[1] = F;
1449             gates_data[2] = O;
1450             gates_data[3] = G;
1451 #endif // __ARM_NEON
1452         }
1453 
1454         // lstm unit
1455         // sigmoid(I)
1456         // sigmoid(F)
1457         // sigmoid(O)
1458         // tanh(G)
1459         // c_t := f_t .* c_{t-1} + i_t .* g_t
1460         // h_t := o_t .* tanh[c_t]
1461         unsigned short* output_data = top_blob.row<unsigned short>(ti);
1462 
1463         float* cell_ptr = cell_state;
1464         float* hidden_ptr = hidden_state;
1465 
1466         int q = 0;
1467 #if __ARM_NEON
1468         for (; q + 3 < num_output; q += 4)
1469         {
1470             const float* gates_data = gates.row(q);
1471 
1472             float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
1473 
1474             float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]);
1475             float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]);
1476             float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]);
1477             float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]);
1478 
1479             float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
1480             float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
1481 
1482             vst1q_f32(cell_ptr, _cell2);
1483             vst1q_f32(hidden_ptr, _H);
1484             vst1_u16(output_data, vcvt_bf16_f32(_H));
1485 
1486             cell_ptr += 4;
1487             hidden_ptr += 4;
1488             output_data += 4;
1489         }
1490 #endif // __ARM_NEON
1491         for (; q < num_output; q++)
1492         {
1493             const float* gates_data = gates.row(q);
1494 
1495             float I = gates_data[0];
1496             float F = gates_data[1];
1497             float O = gates_data[2];
1498             float G = gates_data[3];
1499 
1500             I = 1.f / (1.f + exp(-I));
1501             F = 1.f / (1.f + exp(-F));
1502             O = 1.f / (1.f + exp(-O));
1503             G = tanh(G);
1504 
1505             float cell2 = F * *cell_ptr + I * G;
1506             float H = O * tanh(cell2);
1507 
1508             *cell_ptr++ = cell2;
1509             *hidden_ptr++ = H;
1510             *output_data++ = float32_to_bfloat16(H);
1511         }
1512     }
1513 
1514     return 0;
1515 }
1516 
create_pipeline_bf16s(const Option & opt)1517 int LSTM_arm::create_pipeline_bf16s(const Option& opt)
1518 {
1519     // pack IFOG
1520     int num_directions = direction == 2 ? 2 : 1;
1521     int size = weight_data_size / num_directions / num_output / 4;
1522 
1523     weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4);
1524     bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
1525     weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4);
1526 
1527     #pragma omp parallel for num_threads(opt.num_threads)
1528     for (int dr = 0; dr < num_directions; dr++)
1529     {
1530         const Mat weight_xc = weight_xc_data.channel(dr);
1531         const Mat bias_c = bias_c_data.channel(dr);
1532         const Mat weight_hc = weight_hc_data.channel(dr);
1533 
1534         Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
1535         Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
1536         Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
1537 
1538         const float* bias_c_I = bias_c.row(0);
1539         const float* bias_c_F = bias_c.row(1);
1540         const float* bias_c_O = bias_c.row(2);
1541         const float* bias_c_G = bias_c.row(3);
1542 
1543         unsigned short* bias_c_IFOG = bias_c_data_packed_dr.row<unsigned short>(0);
1544 
1545         for (int q = 0; q < num_output; q++)
1546         {
1547             bias_c_IFOG[0] = float32_to_bfloat16(bias_c_I[q]);
1548             bias_c_IFOG[1] = float32_to_bfloat16(bias_c_F[q]);
1549             bias_c_IFOG[2] = float32_to_bfloat16(bias_c_O[q]);
1550             bias_c_IFOG[3] = float32_to_bfloat16(bias_c_G[q]);
1551 
1552             bias_c_IFOG += 4;
1553 
1554             const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
1555             const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
1556             const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
1557             const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
1558 
1559             const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
1560             const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
1561             const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
1562             const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
1563 
1564             unsigned short* weight_xc_IFOG = weight_xc_data_packed_dr.row<unsigned short>(q);
1565             unsigned short* weight_hc_IFOG = weight_hc_data_packed_dr.row<unsigned short>(q);
1566 
1567             for (int i = 0; i < size; i++)
1568             {
1569                 weight_xc_IFOG[0] = float32_to_bfloat16(weight_xc_I[i]);
1570                 weight_xc_IFOG[1] = float32_to_bfloat16(weight_xc_F[i]);
1571                 weight_xc_IFOG[2] = float32_to_bfloat16(weight_xc_O[i]);
1572                 weight_xc_IFOG[3] = float32_to_bfloat16(weight_xc_G[i]);
1573 
1574                 weight_xc_IFOG += 4;
1575             }
1576 
1577             for (int i = 0; i < num_output; i++)
1578             {
1579                 weight_hc_IFOG[0] = float32_to_bfloat16(weight_hc_I[i]);
1580                 weight_hc_IFOG[1] = float32_to_bfloat16(weight_hc_F[i]);
1581                 weight_hc_IFOG[2] = float32_to_bfloat16(weight_hc_O[i]);
1582                 weight_hc_IFOG[3] = float32_to_bfloat16(weight_hc_G[i]);
1583 
1584                 weight_hc_IFOG += 4;
1585             }
1586         }
1587     }
1588 
1589     return 0;
1590 }
1591 
forward_bf16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1592 int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1593 {
1594     int T = bottom_blob.h;
1595 
1596     int num_directions = direction == 2 ? 2 : 1;
1597 
1598     // initial hidden state
1599     Mat hidden(num_output, 4u, opt.workspace_allocator);
1600     if (hidden.empty())
1601         return -100;
1602     hidden.fill(0.f);
1603 
1604     Mat cell(num_output, 4u, opt.workspace_allocator);
1605     if (cell.empty())
1606         return -100;
1607     cell.fill(0.f);
1608 
1609     top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1610     if (top_blob.empty())
1611         return -100;
1612 
1613     // Uni directional
1614     if (direction == 0 || direction == 1)
1615     {
1616         int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1617         if (ret != 0)
1618             return ret;
1619     }
1620 
1621     if (direction == 2)
1622     {
1623         Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1624         if (top_blob_forward.empty())
1625             return -100;
1626 
1627         Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1628         if (top_blob_reverse.empty())
1629             return -100;
1630 
1631         int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1632         if (ret0 != 0)
1633             return ret0;
1634 
1635         hidden.fill(0.f);
1636         cell.fill(0.f);
1637 
1638         int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
1639         if (ret1 != 0)
1640             return ret1;
1641 
1642         // concat w
1643         for (int i = 0; i < T; i++)
1644         {
1645             const unsigned short* pf = top_blob_forward.row<const unsigned short>(i);
1646             const unsigned short* pr = top_blob_reverse.row<const unsigned short>(i);
1647             unsigned short* ptr = top_blob.row<unsigned short>(i);
1648 
1649             memcpy(ptr, pf, num_output * sizeof(unsigned short));
1650             memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short));
1651         }
1652     }
1653 
1654     return 0;
1655 }
1656 
forward_bf16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1657 int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1658 {
1659     const Mat& bottom_blob = bottom_blobs[0];
1660     int T = bottom_blob.h;
1661     Mat& top_blob = top_blobs[0];
1662 
1663     top_blob.create(num_output, T, 2u, opt.blob_allocator);
1664     if (top_blob.empty())
1665         return -100;
1666 
1667     // copy previous states
1668     Mat hidden;
1669     Mat cell;
1670     cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt);
1671     cast_bfloat16_to_float32(bottom_blobs[2], cell, opt);
1672 
1673     // Uni directional
1674     if (direction == 0 || direction == 1)
1675     {
1676         int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1677         if (ret != 0)
1678             return ret;
1679     }
1680 
1681     cast_float32_to_bfloat16(hidden, top_blobs[1], opt);
1682     cast_float32_to_bfloat16(cell, top_blobs[2], opt);
1683 
1684     return 0;
1685 }
1686 
1687 } // namespace ncnn
1688