1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "rnn_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 #include "arm_activation.h"
22 
23 #include <math.h>
24 
25 namespace ncnn {
26 
RNN_arm()27 RNN_arm::RNN_arm()
28 {
29 #if __ARM_NEON
30 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
31     support_fp16_storage = true;
32 #endif
33 #endif // __ARM_NEON
34 
35     support_bf16_storage = true;
36 }
37 
create_pipeline(const Option & opt)38 int RNN_arm::create_pipeline(const Option& opt)
39 {
40 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
41     if (opt.use_fp16_storage)
42     {
43         return create_pipeline_fp16s(opt);
44     }
45 #endif
46 
47     if (opt.use_bf16_storage)
48     {
49         return create_pipeline_bf16s(opt);
50     }
51 
52     int num_directions = direction == 2 ? 2 : 1;
53     int size = weight_data_size / num_directions / num_output;
54 
55 #if __ARM_NEON
56     weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions);
57     weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions);
58 
59     #pragma omp parallel for num_threads(opt.num_threads)
60     for (int dr = 0; dr < num_directions; dr++)
61     {
62         const Mat weight_xc = weight_xc_data.channel(dr);
63         const Mat weight_hc = weight_hc_data.channel(dr);
64 
65         Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
66         Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
67 
68         int q = 0;
69 #if __ARM_NEON
70         for (; q + 3 < num_output; q += 4)
71         {
72             const float* weight_xc_0 = weight_xc.row(q);
73             const float* weight_xc_1 = weight_xc.row(q + 1);
74             const float* weight_xc_2 = weight_xc.row(q + 2);
75             const float* weight_xc_3 = weight_xc.row(q + 3);
76 
77             const float* weight_hc_0 = weight_hc.row(q);
78             const float* weight_hc_1 = weight_hc.row(q + 1);
79             const float* weight_hc_2 = weight_hc.row(q + 2);
80             const float* weight_hc_3 = weight_hc.row(q + 3);
81 
82             float* weight_xc = weight_xc_data_packed_dr.row(q / 4);
83             float* weight_hc = weight_hc_data_packed_dr.row(q / 4);
84 
85             for (int i = 0; i < size; i++)
86             {
87                 weight_xc[0] = weight_xc_0[i];
88                 weight_xc[1] = weight_xc_1[i];
89                 weight_xc[2] = weight_xc_2[i];
90                 weight_xc[3] = weight_xc_3[i];
91 
92                 weight_xc += 4;
93             }
94 
95             for (int i = 0; i < num_output; i++)
96             {
97                 weight_hc[0] = weight_hc_0[i];
98                 weight_hc[1] = weight_hc_1[i];
99                 weight_hc[2] = weight_hc_2[i];
100                 weight_hc[3] = weight_hc_3[i];
101 
102                 weight_hc += 4;
103             }
104         }
105 #endif // __ARM_NEON
106         for (; q < num_output; q++)
107         {
108             const float* weight_xc_0 = weight_xc.row(q);
109             const float* weight_hc_0 = weight_hc.row(q);
110 
111 #if __ARM_NEON
112             float* weight_xc = weight_xc_data_packed_dr.row(q / 4 + q % 4);
113             float* weight_hc = weight_hc_data_packed_dr.row(q / 4 + q % 4);
114 #else
115             float* weight_xc = weight_xc_data_packed_dr.row(q);
116             float* weight_hc = weight_hc_data_packed_dr.row(q);
117 #endif // __ARM_NEON
118 
119             for (int i = 0; i < size; i++)
120             {
121                 weight_xc[i] = weight_xc_0[i];
122             }
123 
124             for (int i = 0; i < num_output; i++)
125             {
126                 weight_hc[i] = weight_hc_0[i];
127             }
128         }
129     }
130 #else
131     weight_xc_data_packed = weight_xc_data;
132     weight_hc_data_packed = weight_hc_data;
133 #endif
134 
135     bias_c_data_packed = bias_c_data;
136 
137     return 0;
138 }
139 
rnn(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)140 static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
141 {
142     int size = bottom_blob.w;
143     int T = bottom_blob.h;
144 
145     int num_output = top_blob.w;
146 
147     // num_output
148     Mat gates(num_output, 4u, opt.workspace_allocator);
149     if (gates.empty())
150         return -100;
151 
152     // unroll
153     for (int t = 0; t < T; t++)
154     {
155         int ti = reverse ? T - 1 - t : t;
156 
157         const float* x = bottom_blob.row(ti);
158 
159         int q = 0;
160 #if __ARM_NEON
161         for (; q + 3 < num_output; q += 4)
162         {
163             const float* weight_xc_ptr = weight_xc.row(q / 4);
164             const float* weight_hc_ptr = weight_hc.row(q / 4);
165 
166             float32x4_t _H = vld1q_f32((const float*)bias_c + q);
167             float32x4_t _sum1 = vdupq_n_f32(0.f);
168             float32x4_t _sum2 = vdupq_n_f32(0.f);
169             float32x4_t _sum3 = vdupq_n_f32(0.f);
170 
171             int i = 0;
172             for (; i + 3 < size; i += 4)
173             {
174                 float32x4_t _x = vld1q_f32(x + i);
175                 float32x4_t _weight_xc = vld1q_f32(weight_xc_ptr);
176                 float32x4_t _weight_xc_1 = vld1q_f32(weight_xc_ptr + 4);
177                 float32x4_t _weight_xc_2 = vld1q_f32(weight_xc_ptr + 8);
178                 float32x4_t _weight_xc_3 = vld1q_f32(weight_xc_ptr + 12);
179 #if __aarch64__
180                 _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0);
181                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1);
182                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2);
183                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3);
184 #else
185                 _H = vmlaq_lane_f32(_H, _weight_xc, vget_low_f32(_x), 0);
186                 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1);
187                 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0);
188                 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1);
189 #endif
190 
191                 weight_xc_ptr += 16;
192             }
193             for (; i < size; i++)
194             {
195                 float32x4_t _x = vdupq_n_f32(x[i]);
196                 float32x4_t _weight_xc = vld1q_f32(weight_xc_ptr);
197                 _H = vmlaq_f32(_H, _weight_xc, _x);
198 
199                 weight_xc_ptr += 4;
200             }
201 
202             i = 0;
203             for (; i + 3 < num_output; i += 4)
204             {
205                 float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i);
206                 float32x4_t _weight_hc = vld1q_f32(weight_hc_ptr);
207                 float32x4_t _weight_hc_1 = vld1q_f32(weight_hc_ptr + 4);
208                 float32x4_t _weight_hc_2 = vld1q_f32(weight_hc_ptr + 8);
209                 float32x4_t _weight_hc_3 = vld1q_f32(weight_hc_ptr + 12);
210 #if __aarch64__
211                 _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0);
212                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1);
213                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2);
214                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3);
215 #else
216                 _H = vmlaq_lane_f32(_H, _weight_hc, vget_low_f32(_hidden_state), 0);
217                 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1);
218                 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0);
219                 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1);
220 #endif
221 
222                 weight_hc_ptr += 16;
223             }
224             for (; i < num_output; i++)
225             {
226                 float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
227                 float32x4_t _weight_hc = vld1q_f32(weight_hc_ptr);
228                 _H = vmlaq_f32(_H, _weight_hc, _hidden_state);
229 
230                 weight_hc_ptr += 4;
231             }
232 
233             _H = vaddq_f32(_H, _sum1);
234             _sum2 = vaddq_f32(_sum2, _sum3);
235             _H = vaddq_f32(_H, _sum2);
236 
237             _H = tanh_ps(_H);
238 
239             vst1q_f32((float*)gates + q, _H);
240         }
241 #endif // __ARM_NEON
242         for (; q < num_output; q++)
243         {
244 #if __ARM_NEON
245             const float* weight_xc_ptr = weight_xc.row(q / 4 + q % 4);
246             const float* weight_hc_ptr = weight_hc.row(q / 4 + q % 4);
247 #else
248             const float* weight_xc_ptr = weight_xc.row(q);
249             const float* weight_hc_ptr = weight_hc.row(q);
250 #endif // __ARM_NEON
251 
252             float H = bias_c[q];
253 
254             for (int i = 0; i < size; i++)
255             {
256                 H += weight_xc_ptr[i] * x[i];
257             }
258 
259             for (int i = 0; i < num_output; i++)
260             {
261                 H += weight_hc_ptr[i] * hidden_state[i];
262             }
263 
264             H = tanh(H);
265 
266             gates[q] = H;
267         }
268 
269         float* output_data = top_blob.row(ti);
270 
271         float* hidden_ptr = hidden_state;
272 
273         q = 0;
274 #if __ARM_NEON
275         for (; q + 3 < num_output; q += 4)
276         {
277             float32x4_t _H = vld1q_f32((float*)gates + q);
278 
279             vst1q_f32(hidden_ptr, _H);
280             vst1q_f32(output_data, _H);
281 
282             hidden_ptr += 4;
283             output_data += 4;
284         }
285 #endif // __ARM_NEON
286         for (; q < num_output; q++)
287         {
288             float H = gates[q];
289 
290             *hidden_ptr++ = H;
291             *output_data++ = H;
292         }
293     }
294 
295     return 0;
296 }
297 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const298 int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
299 {
300     int elembits = bottom_blob.elembits();
301 
302 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
303     if (opt.use_fp16_storage && elembits == 16)
304     {
305         if (opt.use_fp16_arithmetic)
306             return forward_fp16sa(bottom_blob, top_blob, opt);
307         else
308             return forward_fp16s(bottom_blob, top_blob, opt);
309     }
310 #endif
311 
312     if (opt.use_bf16_storage && elembits == 16)
313         return forward_bf16s(bottom_blob, top_blob, opt);
314 
315     int T = bottom_blob.h;
316 
317     int num_directions = direction == 2 ? 2 : 1;
318 
319     // initial hidden state
320     Mat hidden(num_output, 4u, opt.workspace_allocator);
321     if (hidden.empty())
322         return -100;
323     hidden.fill(0.f);
324 
325     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
326     if (top_blob.empty())
327         return -100;
328 
329     // Uni directional
330     if (direction == 0 || direction == 1)
331     {
332         int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
333         if (ret != 0)
334             return ret;
335     }
336 
337     if (direction == 2)
338     {
339         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
340         if (top_blob_forward.empty())
341             return -100;
342 
343         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
344         if (top_blob_reverse.empty())
345             return -100;
346 
347         int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
348         if (ret0 != 0)
349             return ret0;
350 
351         hidden.fill(0.0f);
352 
353         int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
354         if (ret1 != 0)
355             return ret1;
356 
357         // concat w
358         for (int i = 0; i < T; i++)
359         {
360             const float* pf = top_blob_forward.row(i);
361             const float* pr = top_blob_reverse.row(i);
362             float* ptr = top_blob.row(i);
363 
364             memcpy(ptr, pf, num_output * sizeof(float));
365             memcpy(ptr + num_output, pr, num_output * sizeof(float));
366         }
367     }
368 
369     return 0;
370 }
371 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const372 int RNN_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
373 {
374     if (bottom_blobs.size() != 2 || top_blobs.size() != 2)
375     {
376         return forward(bottom_blobs[0], top_blobs[0], opt);
377     }
378 
379     const Mat& bottom_blob = bottom_blobs[0];
380 
381     int elembits = bottom_blob.elembits();
382 
383 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
384     if (opt.use_fp16_storage && elembits == 16)
385     {
386         if (opt.use_fp16_arithmetic)
387             return forward_fp16sa(bottom_blobs, top_blobs, opt);
388         else
389             return forward_fp16s(bottom_blobs, top_blobs, opt);
390     }
391 #endif
392 
393     if (opt.use_bf16_storage && elembits == 16)
394         return forward_bf16s(bottom_blobs, top_blobs, opt);
395 
396     int T = bottom_blob.h;
397     Mat& top_blob = top_blobs[0];
398     Mat& hidden_state = top_blobs[1];
399 
400     //Copy previous states
401     hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
402 
403     top_blob.create(num_output, T, 4u, opt.blob_allocator);
404     if (top_blob.empty())
405         return -100;
406 
407     // Uni directional
408     if (direction == 0 || direction == 1)
409     {
410         int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, opt);
411         if (ret != 0)
412             return ret;
413     }
414 
415     return 0;
416 }
417 
418 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
rnn_fp16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)419 static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
420 {
421     int size = bottom_blob.w;
422     int T = bottom_blob.h;
423 
424     int num_output = top_blob.w;
425 
426     // num_output
427     Mat gates(num_output, 4u, opt.workspace_allocator);
428     if (gates.empty())
429         return -100;
430 
431     // unroll
432     for (int t = 0; t < T; t++)
433     {
434         int ti = reverse ? T - 1 - t : t;
435 
436         const __fp16* x = bottom_blob.row<const __fp16>(ti);
437 
438         int q = 0;
439         for (; q + 3 < num_output; q += 4)
440         {
441             const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 4);
442             const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 4);
443 
444             float32x4_t _H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q));
445             float32x4_t _sum1 = vdupq_n_f32(0.f);
446             float32x4_t _sum2 = vdupq_n_f32(0.f);
447             float32x4_t _sum3 = vdupq_n_f32(0.f);
448 
449             int i = 0;
450             for (; i + 3 < size; i += 4)
451             {
452                 float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i));
453                 float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr));
454                 float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4));
455                 float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8));
456                 float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12));
457                 _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0);
458                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1);
459                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2);
460                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3);
461 
462                 weight_xc_ptr += 16;
463             }
464             for (; i < size; i++)
465             {
466                 float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i]));
467                 float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr));
468                 _H = vfmaq_f32(_H, _weight_xc, _x);
469 
470                 weight_xc_ptr += 4;
471             }
472 
473             i = 0;
474             for (; i + 3 < num_output; i += 4)
475             {
476                 float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i);
477                 float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr));
478                 float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4));
479                 float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8));
480                 float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12));
481                 _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0);
482                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1);
483                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2);
484                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3);
485 
486                 weight_hc_ptr += 16;
487             }
488             for (; i < num_output; i++)
489             {
490                 float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
491                 float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr));
492                 _H = vfmaq_f32(_H, _weight_hc, _hidden_state);
493 
494                 weight_hc_ptr += 4;
495             }
496 
497             _H = vaddq_f32(_H, _sum1);
498             _sum2 = vaddq_f32(_sum2, _sum3);
499             _H = vaddq_f32(_H, _sum2);
500 
501             _H = tanh_ps(_H);
502 
503             vst1q_f32((float*)gates + q, _H);
504         }
505         for (; q < num_output; q++)
506         {
507             const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 4 + q % 4);
508             const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 4 + q % 4);
509 
510             float H = (float)(((const __fp16*)bias_c)[q]);
511 
512             for (int i = 0; i < size; i++)
513             {
514                 H += (float)weight_xc_ptr[i] * (float)x[i];
515             }
516 
517             for (int i = 0; i < num_output; i++)
518             {
519                 H += (float)weight_hc_ptr[i] * hidden_state[i];
520             }
521 
522             H = tanh(H);
523 
524             gates[q] = H;
525         }
526 
527         __fp16* output_data = top_blob.row<__fp16>(ti);
528 
529         float* hidden_ptr = hidden_state;
530 
531         q = 0;
532         for (; q + 3 < num_output; q += 4)
533         {
534             float32x4_t _H = vld1q_f32((float*)gates + q);
535 
536             vst1q_f32(hidden_ptr, _H);
537             vst1_f16(output_data, vcvt_f16_f32(_H));
538 
539             hidden_ptr += 4;
540             output_data += 4;
541         }
542         for (; q < num_output; q++)
543         {
544             float H = gates[q];
545 
546             *hidden_ptr++ = H;
547             *output_data++ = (__fp16)H;
548         }
549     }
550 
551     return 0;
552 }
553 
rnn_fp16sa(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)554 static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
555 {
556     int size = bottom_blob.w;
557     int T = bottom_blob.h;
558 
559     int num_output = top_blob.w;
560 
561     // num_output
562     Mat gates(num_output, 4u, opt.workspace_allocator);
563     if (gates.empty())
564         return -100;
565 
566     // unroll
567     for (int t = 0; t < T; t++)
568     {
569         int ti = reverse ? T - 1 - t : t;
570 
571         const __fp16* x = bottom_blob.row<const __fp16>(ti);
572 
573         int q = 0;
574         for (; q + 7 < num_output; q += 8)
575         {
576             const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8);
577             const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8);
578 
579             float16x8_t _H = vld1q_f16((const __fp16*)bias_c + q);
580             float16x8_t _sum1 = vdupq_n_f16(0.f);
581             float16x8_t _sum2 = vdupq_n_f16(0.f);
582             float16x8_t _sum3 = vdupq_n_f16(0.f);
583 
584             int i = 0;
585             for (; i + 3 < size; i += 4)
586             {
587                 float16x4_t _x = vld1_f16(x + i);
588                 float16x8_t _weight_xc = vld1q_f16(weight_xc_ptr);
589                 float16x8_t _weight_xc_1 = vld1q_f16(weight_xc_ptr + 8);
590                 float16x8_t _weight_xc_2 = vld1q_f16(weight_xc_ptr + 16);
591                 float16x8_t _weight_xc_3 = vld1q_f16(weight_xc_ptr + 24);
592                 _H = vfmaq_lane_f16(_H, _weight_xc, _x, 0);
593                 _sum1 = vfmaq_lane_f16(_sum1, _weight_xc_1, _x, 1);
594                 _sum2 = vfmaq_lane_f16(_sum2, _weight_xc_2, _x, 2);
595                 _sum3 = vfmaq_lane_f16(_sum3, _weight_xc_3, _x, 3);
596 
597                 weight_xc_ptr += 32;
598             }
599             for (; i < size; i++)
600             {
601                 float16x8_t _x = vdupq_n_f16(x[i]);
602                 float16x8_t _weight_xc = vld1q_f16(weight_xc_ptr);
603                 _H = vfmaq_f16(_H, _weight_xc, _x);
604 
605                 weight_xc_ptr += 8;
606             }
607 
608             i = 0;
609             for (; i + 3 < num_output; i += 4)
610             {
611                 float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i));
612                 float16x8_t _weight_hc = vld1q_f16(weight_hc_ptr);
613                 float16x8_t _weight_hc_1 = vld1q_f16(weight_hc_ptr + 8);
614                 float16x8_t _weight_hc_2 = vld1q_f16(weight_hc_ptr + 16);
615                 float16x8_t _weight_hc_3 = vld1q_f16(weight_hc_ptr + 24);
616                 _H = vfmaq_lane_f16(_H, _weight_hc, _hidden_state, 0);
617                 _sum1 = vfmaq_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1);
618                 _sum2 = vfmaq_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2);
619                 _sum3 = vfmaq_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3);
620 
621                 weight_hc_ptr += 32;
622             }
623             for (; i < num_output; i++)
624             {
625                 float16x8_t _hidden_state = vdupq_n_f16((__fp16)hidden_state[i]);
626                 float16x8_t _weight_hc = vld1q_f16(weight_hc_ptr);
627                 _H = vfmaq_f16(_H, _weight_hc, _hidden_state);
628 
629                 weight_hc_ptr += 8;
630             }
631 
632             _H = vaddq_f16(_H, _sum1);
633             _sum2 = vaddq_f16(_sum2, _sum3);
634             _H = vaddq_f16(_H, _sum2);
635 
636             float32x4_t _H32low = tanh_ps(vcvt_f32_f16(vget_low_f16(_H)));
637             float32x4_t _H32high = tanh_ps(vcvt_f32_f16(vget_high_f16(_H)));
638 
639             vst1q_f32((float*)gates + q, _H32low);
640             vst1q_f32((float*)gates + q + 4, _H32high);
641         }
642         for (; q + 3 < num_output; q += 4)
643         {
644             const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8 + (q % 8) / 4);
645             const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8 + (q % 8) / 4);
646 
647             float16x4_t _H = vld1_f16((const __fp16*)bias_c + q);
648             float16x4_t _sum1 = vdup_n_f16(0.f);
649             float16x4_t _sum2 = vdup_n_f16(0.f);
650             float16x4_t _sum3 = vdup_n_f16(0.f);
651 
652             int i = 0;
653             for (; i + 3 < size; i += 4)
654             {
655                 float16x4_t _x = vld1_f16(x + i);
656                 float16x4_t _weight_xc = vld1_f16(weight_xc_ptr);
657                 float16x4_t _weight_xc_1 = vld1_f16(weight_xc_ptr + 4);
658                 float16x4_t _weight_xc_2 = vld1_f16(weight_xc_ptr + 8);
659                 float16x4_t _weight_xc_3 = vld1_f16(weight_xc_ptr + 12);
660                 _H = vfma_lane_f16(_H, _weight_xc, _x, 0);
661                 _sum1 = vfma_lane_f16(_sum1, _weight_xc_1, _x, 1);
662                 _sum2 = vfma_lane_f16(_sum2, _weight_xc_2, _x, 2);
663                 _sum3 = vfma_lane_f16(_sum3, _weight_xc_3, _x, 3);
664 
665                 weight_xc_ptr += 16;
666             }
667             for (; i < size; i++)
668             {
669                 float16x4_t _x = vdup_n_f16(x[i]);
670                 float16x4_t _weight_xc = vld1_f16(weight_xc_ptr);
671                 _H = vfma_f16(_H, _weight_xc, _x);
672 
673                 weight_xc_ptr += 4;
674             }
675 
676             i = 0;
677             for (; i + 3 < num_output; i += 4)
678             {
679                 float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i));
680                 float16x4_t _weight_hc = vld1_f16(weight_hc_ptr);
681                 float16x4_t _weight_hc_1 = vld1_f16(weight_hc_ptr + 4);
682                 float16x4_t _weight_hc_2 = vld1_f16(weight_hc_ptr + 8);
683                 float16x4_t _weight_hc_3 = vld1_f16(weight_hc_ptr + 12);
684                 _H = vfma_lane_f16(_H, _weight_hc, _hidden_state, 0);
685                 _sum1 = vfma_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1);
686                 _sum2 = vfma_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2);
687                 _sum3 = vfma_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3);
688 
689                 weight_hc_ptr += 16;
690             }
691             for (; i < num_output; i++)
692             {
693                 float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]);
694                 float16x4_t _weight_hc = vld1_f16(weight_hc_ptr);
695                 _H = vfma_f16(_H, _weight_hc, _hidden_state);
696 
697                 weight_hc_ptr += 4;
698             }
699 
700             _H = vadd_f16(_H, _sum1);
701             _sum2 = vadd_f16(_sum2, _sum3);
702             _H = vadd_f16(_H, _sum2);
703 
704             float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_H));
705 
706             vst1q_f32((float*)gates + q, _H32);
707         }
708         for (; q < num_output; q++)
709         {
710             const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8 + (q % 8) / 4 + q % 4);
711             const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8 + (q % 8) / 4 + q % 4);
712 
713             __fp16 H = ((const __fp16*)bias_c)[q];
714 
715             for (int i = 0; i < size; i++)
716             {
717                 H += weight_xc_ptr[i] * x[i];
718             }
719 
720             for (int i = 0; i < num_output; i++)
721             {
722                 H += weight_hc_ptr[i] * (__fp16)hidden_state[i];
723             }
724 
725             float H32 = tanh((float)H);
726 
727             gates[q] = H32;
728         }
729 
730         __fp16* output_data = top_blob.row<__fp16>(ti);
731 
732         float* hidden_ptr = hidden_state;
733 
734         q = 0;
735         for (; q + 3 < num_output; q += 4)
736         {
737             float32x4_t _H = vld1q_f32((float*)gates + q);
738 
739             vst1q_f32(hidden_ptr, _H);
740             vst1_f16(output_data, vcvt_f16_f32(_H));
741 
742             hidden_ptr += 4;
743             output_data += 4;
744         }
745         for (; q < num_output; q++)
746         {
747             float H = gates[q];
748 
749             *hidden_ptr++ = H;
750             *output_data++ = (__fp16)H;
751         }
752     }
753 
754     return 0;
755 }
756 
create_pipeline_fp16s(const Option & opt)757 int RNN_arm::create_pipeline_fp16s(const Option& opt)
758 {
759     int num_directions = direction == 2 ? 2 : 1;
760     int size = weight_data_size / num_directions / num_output;
761 
762     if (opt.use_fp16_arithmetic)
763     {
764         weight_xc_data_packed.create(size * 8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 2u, 1);
765         weight_hc_data_packed.create(num_output * 8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 2u, 1);
766     }
767     else
768     {
769         weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
770         weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
771     }
772 
773     #pragma omp parallel for num_threads(opt.num_threads)
774     for (int dr = 0; dr < num_directions; dr++)
775     {
776         const Mat weight_xc = weight_xc_data.channel(dr);
777         const Mat weight_hc = weight_hc_data.channel(dr);
778 
779         Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
780         Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
781 
782         int q = 0;
783         if (opt.use_fp16_arithmetic)
784         {
785             for (; q + 7 < num_output; q += 8)
786             {
787                 const float* weight_xc_0 = weight_xc.row(q);
788                 const float* weight_xc_1 = weight_xc.row(q + 1);
789                 const float* weight_xc_2 = weight_xc.row(q + 2);
790                 const float* weight_xc_3 = weight_xc.row(q + 3);
791                 const float* weight_xc_4 = weight_xc.row(q + 4);
792                 const float* weight_xc_5 = weight_xc.row(q + 5);
793                 const float* weight_xc_6 = weight_xc.row(q + 6);
794                 const float* weight_xc_7 = weight_xc.row(q + 7);
795 
796                 const float* weight_hc_0 = weight_hc.row(q);
797                 const float* weight_hc_1 = weight_hc.row(q + 1);
798                 const float* weight_hc_2 = weight_hc.row(q + 2);
799                 const float* weight_hc_3 = weight_hc.row(q + 3);
800                 const float* weight_hc_4 = weight_hc.row(q + 4);
801                 const float* weight_hc_5 = weight_hc.row(q + 5);
802                 const float* weight_hc_6 = weight_hc.row(q + 6);
803                 const float* weight_hc_7 = weight_hc.row(q + 7);
804 
805                 __fp16* weight_xc = weight_xc_data_packed_dr.row<__fp16>(q / 8);
806                 __fp16* weight_hc = weight_hc_data_packed_dr.row<__fp16>(q / 8);
807 
808                 for (int i = 0; i < size; i++)
809                 {
810                     weight_xc[0] = (__fp16)weight_xc_0[i];
811                     weight_xc[1] = (__fp16)weight_xc_1[i];
812                     weight_xc[2] = (__fp16)weight_xc_2[i];
813                     weight_xc[3] = (__fp16)weight_xc_3[i];
814                     weight_xc[4] = (__fp16)weight_xc_4[i];
815                     weight_xc[5] = (__fp16)weight_xc_5[i];
816                     weight_xc[6] = (__fp16)weight_xc_6[i];
817                     weight_xc[7] = (__fp16)weight_xc_7[i];
818 
819                     weight_xc += 8;
820                 }
821 
822                 for (int i = 0; i < num_output; i++)
823                 {
824                     weight_hc[0] = (__fp16)weight_hc_0[i];
825                     weight_hc[1] = (__fp16)weight_hc_1[i];
826                     weight_hc[2] = (__fp16)weight_hc_2[i];
827                     weight_hc[3] = (__fp16)weight_hc_3[i];
828                     weight_hc[4] = (__fp16)weight_hc_4[i];
829                     weight_hc[5] = (__fp16)weight_hc_5[i];
830                     weight_hc[6] = (__fp16)weight_hc_6[i];
831                     weight_hc[7] = (__fp16)weight_hc_7[i];
832 
833                     weight_hc += 8;
834                 }
835             }
836         }
837         for (; q + 3 < num_output; q += 4)
838         {
839             const float* weight_xc_0 = weight_xc.row(q);
840             const float* weight_xc_1 = weight_xc.row(q + 1);
841             const float* weight_xc_2 = weight_xc.row(q + 2);
842             const float* weight_xc_3 = weight_xc.row(q + 3);
843 
844             const float* weight_hc_0 = weight_hc.row(q);
845             const float* weight_hc_1 = weight_hc.row(q + 1);
846             const float* weight_hc_2 = weight_hc.row(q + 2);
847             const float* weight_hc_3 = weight_hc.row(q + 3);
848 
849             __fp16* weight_xc = opt.use_fp16_arithmetic ? weight_xc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4) : weight_xc_data_packed_dr.row<__fp16>(q / 4);
850             __fp16* weight_hc = opt.use_fp16_arithmetic ? weight_hc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4) : weight_hc_data_packed_dr.row<__fp16>(q / 4);
851 
852             for (int i = 0; i < size; i++)
853             {
854                 weight_xc[0] = (__fp16)weight_xc_0[i];
855                 weight_xc[1] = (__fp16)weight_xc_1[i];
856                 weight_xc[2] = (__fp16)weight_xc_2[i];
857                 weight_xc[3] = (__fp16)weight_xc_3[i];
858 
859                 weight_xc += 4;
860             }
861 
862             for (int i = 0; i < num_output; i++)
863             {
864                 weight_hc[0] = (__fp16)weight_hc_0[i];
865                 weight_hc[1] = (__fp16)weight_hc_1[i];
866                 weight_hc[2] = (__fp16)weight_hc_2[i];
867                 weight_hc[3] = (__fp16)weight_hc_3[i];
868 
869                 weight_hc += 4;
870             }
871         }
872         for (; q < num_output; q++)
873         {
874             const float* weight_xc_0 = weight_xc.row(q);
875             const float* weight_hc_0 = weight_hc.row(q);
876 
877             __fp16* weight_xc = opt.use_fp16_arithmetic ? weight_xc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4 + q % 4) : weight_xc_data_packed_dr.row<__fp16>(q / 4 + q % 4);
878             __fp16* weight_hc = opt.use_fp16_arithmetic ? weight_hc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4 + q % 4) : weight_hc_data_packed_dr.row<__fp16>(q / 4 + q % 4);
879 
880             for (int i = 0; i < size; i++)
881             {
882                 weight_xc[i] = (__fp16)weight_xc_0[i];
883             }
884 
885             for (int i = 0; i < num_output; i++)
886             {
887                 weight_hc[i] = (__fp16)weight_hc_0[i];
888             }
889         }
890     }
891 
892     cast_float32_to_float16(bias_c_data, bias_c_data_packed);
893 
894     return 0;
895 }
896 
forward_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const897 int RNN_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
898 {
899     int T = bottom_blob.h;
900 
901     int num_directions = direction == 2 ? 2 : 1;
902 
903     // initial hidden state
904     Mat hidden(num_output, 4u, opt.workspace_allocator);
905     if (hidden.empty())
906         return -100;
907     hidden.fill(0.f);
908 
909     top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
910     if (top_blob.empty())
911         return -100;
912 
913     // Uni directional
914     if (direction == 0 || direction == 1)
915     {
916         int ret = rnn_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
917         if (ret != 0)
918             return ret;
919     }
920 
921     if (direction == 2)
922     {
923         Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
924         if (top_blob_forward.empty())
925             return -100;
926 
927         Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
928         if (top_blob_reverse.empty())
929             return -100;
930 
931         int ret0 = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
932         if (ret0 != 0)
933             return ret0;
934 
935         hidden.fill(0.f);
936 
937         int ret1 = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
938         if (ret1 != 0)
939             return ret1;
940 
941         // concat w
942         for (int i = 0; i < T; i++)
943         {
944             const __fp16* pf = top_blob_forward.row<const __fp16>(i);
945             const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
946             __fp16* ptr = top_blob.row<__fp16>(i);
947 
948             memcpy(ptr, pf, num_output * sizeof(__fp16));
949             memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
950         }
951     }
952 
953     return 0;
954 }
955 
forward_fp16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const956 int RNN_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
957 {
958     const Mat& bottom_blob = bottom_blobs[0];
959     int T = bottom_blob.h;
960     Mat& top_blob = top_blobs[0];
961 
962     top_blob.create(num_output, T, 2u, opt.blob_allocator);
963     if (top_blob.empty())
964         return -100;
965 
966     // copy previous states
967     Mat hidden;
968     cast_float16_to_float32(bottom_blobs[1], hidden, opt);
969 
970     // Uni directional
971     if (direction == 0 || direction == 1)
972     {
973         int ret = rnn_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
974         if (ret != 0)
975             return ret;
976     }
977 
978     cast_float32_to_float16(hidden, top_blobs[1], opt);
979 
980     return 0;
981 }
982 
forward_fp16sa(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const983 int RNN_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
984 {
985     int T = bottom_blob.h;
986 
987     int num_directions = direction == 2 ? 2 : 1;
988 
989     // initial hidden state
990     Mat hidden(num_output, 4u, opt.workspace_allocator);
991     if (hidden.empty())
992         return -100;
993     hidden.fill(0.f);
994 
995     top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
996     if (top_blob.empty())
997         return -100;
998 
999     // Uni directional
1000     if (direction == 0 || direction == 1)
1001     {
1002         int ret = rnn_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1003         if (ret != 0)
1004             return ret;
1005     }
1006 
1007     if (direction == 2)
1008     {
1009         Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1010         if (top_blob_forward.empty())
1011             return -100;
1012 
1013         Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1014         if (top_blob_reverse.empty())
1015             return -100;
1016 
1017         int ret0 = rnn_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1018         if (ret0 != 0)
1019             return ret0;
1020 
1021         hidden.fill(0.f);
1022 
1023         int ret1 = rnn_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
1024         if (ret1 != 0)
1025             return ret1;
1026 
1027         // concat w
1028         for (int i = 0; i < T; i++)
1029         {
1030             const __fp16* pf = top_blob_forward.row<const __fp16>(i);
1031             const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
1032             __fp16* ptr = top_blob.row<__fp16>(i);
1033 
1034             memcpy(ptr, pf, num_output * sizeof(__fp16));
1035             memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
1036         }
1037     }
1038 
1039     return 0;
1040 }
1041 
forward_fp16sa(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1042 int RNN_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1043 {
1044     const Mat& bottom_blob = bottom_blobs[0];
1045     int T = bottom_blob.h;
1046     Mat& top_blob = top_blobs[0];
1047 
1048     top_blob.create(num_output, T, 2u, opt.blob_allocator);
1049     if (top_blob.empty())
1050         return -100;
1051 
1052     // copy previous states
1053     Mat hidden;
1054     cast_float16_to_float32(bottom_blobs[1], hidden, opt);
1055 
1056     // Uni directional
1057     if (direction == 0 || direction == 1)
1058     {
1059         int ret = rnn_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1060         if (ret != 0)
1061             return ret;
1062     }
1063 
1064     cast_float32_to_float16(hidden, top_blobs[1], opt);
1065 
1066     return 0;
1067 }
1068 #endif
1069 
rnn_bf16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)1070 static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
1071 {
1072     int size = bottom_blob.w;
1073     int T = bottom_blob.h;
1074 
1075     int num_output = top_blob.w;
1076 
1077     // num_output
1078     Mat gates(num_output, 4u, opt.workspace_allocator);
1079     if (gates.empty())
1080         return -100;
1081 
1082     // unroll
1083     for (int t = 0; t < T; t++)
1084     {
1085         int ti = reverse ? T - 1 - t : t;
1086 
1087         const unsigned short* x = bottom_blob.row<const unsigned short>(ti);
1088 
1089         int q = 0;
1090 #if __ARM_NEON
1091         for (; q + 3 < num_output; q += 4)
1092         {
1093             const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q / 4);
1094             const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q / 4);
1095 
1096             float32x4_t _H = vcvt_f32_bf16(vld1_u16((const unsigned short*)bias_c + q));
1097             float32x4_t _sum1 = vdupq_n_f32(0.f);
1098             float32x4_t _sum2 = vdupq_n_f32(0.f);
1099             float32x4_t _sum3 = vdupq_n_f32(0.f);
1100 
1101             int i = 0;
1102             for (; i + 3 < size; i += 4)
1103             {
1104                 float32x4_t _x = vcvt_f32_bf16(vld1_u16(x + i));
1105                 float32x4_t _weight_xc = vcvt_f32_bf16(vld1_u16(weight_xc_ptr));
1106                 float32x4_t _weight_xc_1 = vcvt_f32_bf16(vld1_u16(weight_xc_ptr + 4));
1107                 float32x4_t _weight_xc_2 = vcvt_f32_bf16(vld1_u16(weight_xc_ptr + 8));
1108                 float32x4_t _weight_xc_3 = vcvt_f32_bf16(vld1_u16(weight_xc_ptr + 12));
1109 #if __aarch64__
1110                 _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0);
1111                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1);
1112                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2);
1113                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3);
1114 #else
1115                 _H = vmlaq_lane_f32(_H, _weight_xc, vget_low_f32(_x), 0);
1116                 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1);
1117                 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0);
1118                 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1);
1119 #endif
1120 
1121                 weight_xc_ptr += 16;
1122             }
1123             for (; i < size; i++)
1124             {
1125                 float32x4_t _x = vcvt_f32_bf16(vdup_n_u16(x[i]));
1126                 float32x4_t _weight_xc = vcvt_f32_bf16(vld1_u16(weight_xc_ptr));
1127                 _H = vmlaq_f32(_H, _weight_xc, _x);
1128 
1129                 weight_xc_ptr += 4;
1130             }
1131 
1132             i = 0;
1133             for (; i + 3 < num_output; i += 4)
1134             {
1135                 float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i);
1136                 float32x4_t _weight_hc = vcvt_f32_bf16(vld1_u16(weight_hc_ptr));
1137                 float32x4_t _weight_hc_1 = vcvt_f32_bf16(vld1_u16(weight_hc_ptr + 4));
1138                 float32x4_t _weight_hc_2 = vcvt_f32_bf16(vld1_u16(weight_hc_ptr + 8));
1139                 float32x4_t _weight_hc_3 = vcvt_f32_bf16(vld1_u16(weight_hc_ptr + 12));
1140 #if __aarch64__
1141                 _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0);
1142                 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1);
1143                 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2);
1144                 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3);
1145 #else
1146                 _H = vmlaq_lane_f32(_H, _weight_hc, vget_low_f32(_hidden_state), 0);
1147                 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1);
1148                 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0);
1149                 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1);
1150 #endif
1151 
1152                 weight_hc_ptr += 16;
1153             }
1154             for (; i < num_output; i++)
1155             {
1156                 float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
1157                 float32x4_t _weight_hc = vcvt_f32_bf16(vld1_u16(weight_hc_ptr));
1158                 _H = vmlaq_f32(_H, _weight_hc, _hidden_state);
1159 
1160                 weight_hc_ptr += 4;
1161             }
1162 
1163             _H = vaddq_f32(_H, _sum1);
1164             _sum2 = vaddq_f32(_sum2, _sum3);
1165             _H = vaddq_f32(_H, _sum2);
1166 
1167             _H = tanh_ps(_H);
1168 
1169             vst1q_f32((float*)gates + q, _H);
1170         }
1171 #endif // __ARM_NEON
1172         for (; q < num_output; q++)
1173         {
1174 #if __ARM_NEON
1175             const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q / 4 + q % 4);
1176             const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q / 4 + q % 4);
1177 #else
1178             const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q);
1179             const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q);
1180 #endif // __ARM_NEON
1181 
1182             float H = bfloat16_to_float32(((const unsigned short*)bias_c)[q]);
1183 
1184             for (int i = 0; i < size; i++)
1185             {
1186                 H += bfloat16_to_float32(weight_xc_ptr[i]) * bfloat16_to_float32(x[i]);
1187             }
1188 
1189             for (int i = 0; i < num_output; i++)
1190             {
1191                 H += bfloat16_to_float32(weight_hc_ptr[i]) * hidden_state[i];
1192             }
1193 
1194             H = tanh(H);
1195 
1196             gates[q] = H;
1197         }
1198 
1199         unsigned short* output_data = top_blob.row<unsigned short>(ti);
1200 
1201         float* hidden_ptr = hidden_state;
1202 
1203         q = 0;
1204 #if __ARM_NEON
1205         for (; q + 3 < num_output; q += 4)
1206         {
1207             float32x4_t _H = vld1q_f32((float*)gates + q);
1208 
1209             vst1q_f32(hidden_ptr, _H);
1210             vst1_u16(output_data, vcvt_bf16_f32(_H));
1211 
1212             hidden_ptr += 4;
1213             output_data += 4;
1214         }
1215 #endif // __ARM_NEON
1216         for (; q < num_output; q++)
1217         {
1218             float H = gates[q];
1219 
1220             *hidden_ptr++ = H;
1221             *output_data++ = float32_to_bfloat16(H);
1222         }
1223     }
1224 
1225     return 0;
1226 }
1227 
create_pipeline_bf16s(const Option & opt)1228 int RNN_arm::create_pipeline_bf16s(const Option& opt)
1229 {
1230     int num_directions = direction == 2 ? 2 : 1;
1231     int size = weight_data_size / num_directions / num_output;
1232 
1233 #if __ARM_NEON
1234     weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
1235     weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
1236 
1237     #pragma omp parallel for num_threads(opt.num_threads)
1238     for (int dr = 0; dr < num_directions; dr++)
1239     {
1240         const Mat weight_xc = weight_xc_data.channel(dr);
1241         const Mat weight_hc = weight_hc_data.channel(dr);
1242 
1243         Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
1244         Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
1245 
1246         int q = 0;
1247 #if __ARM_NEON
1248         for (; q + 3 < num_output; q += 4)
1249         {
1250             const float* weight_xc_0 = weight_xc.row(q);
1251             const float* weight_xc_1 = weight_xc.row(q + 1);
1252             const float* weight_xc_2 = weight_xc.row(q + 2);
1253             const float* weight_xc_3 = weight_xc.row(q + 3);
1254 
1255             const float* weight_hc_0 = weight_hc.row(q);
1256             const float* weight_hc_1 = weight_hc.row(q + 1);
1257             const float* weight_hc_2 = weight_hc.row(q + 2);
1258             const float* weight_hc_3 = weight_hc.row(q + 3);
1259 
1260             unsigned short* weight_xc = weight_xc_data_packed_dr.row<unsigned short>(q / 4);
1261             unsigned short* weight_hc = weight_hc_data_packed_dr.row<unsigned short>(q / 4);
1262 
1263             for (int i = 0; i < size; i++)
1264             {
1265                 weight_xc[0] = float32_to_bfloat16(weight_xc_0[i]);
1266                 weight_xc[1] = float32_to_bfloat16(weight_xc_1[i]);
1267                 weight_xc[2] = float32_to_bfloat16(weight_xc_2[i]);
1268                 weight_xc[3] = float32_to_bfloat16(weight_xc_3[i]);
1269 
1270                 weight_xc += 4;
1271             }
1272 
1273             for (int i = 0; i < num_output; i++)
1274             {
1275                 weight_hc[0] = float32_to_bfloat16(weight_hc_0[i]);
1276                 weight_hc[1] = float32_to_bfloat16(weight_hc_1[i]);
1277                 weight_hc[2] = float32_to_bfloat16(weight_hc_2[i]);
1278                 weight_hc[3] = float32_to_bfloat16(weight_hc_3[i]);
1279 
1280                 weight_hc += 4;
1281             }
1282         }
1283 #endif // __ARM_NEON
1284         for (; q < num_output; q++)
1285         {
1286             const float* weight_xc_0 = weight_xc.row(q);
1287             const float* weight_hc_0 = weight_hc.row(q);
1288 
1289 #if __ARM_NEON
1290             unsigned short* weight_xc = weight_xc_data_packed_dr.row<unsigned short>(q / 4 + q % 4);
1291             unsigned short* weight_hc = weight_hc_data_packed_dr.row<unsigned short>(q / 4 + q % 4);
1292 #else
1293             unsigned short* weight_xc = weight_xc_data_packed_dr.row<unsigned short>(q);
1294             unsigned short* weight_hc = weight_hc_data_packed_dr.row<unsigned short>(q);
1295 #endif // __ARM_NEON
1296 
1297             for (int i = 0; i < size; i++)
1298             {
1299                 weight_xc[i] = float32_to_bfloat16(weight_xc_0[i]);
1300             }
1301 
1302             for (int i = 0; i < num_output; i++)
1303             {
1304                 weight_hc[i] = float32_to_bfloat16(weight_hc_0[i]);
1305             }
1306         }
1307     }
1308 #else
1309     cast_float32_to_bfloat16(weight_xc_data, weight_xc_data_packed);
1310     cast_float32_to_bfloat16(weight_hc_data, weight_hc_data_packed);
1311 #endif
1312 
1313     cast_float32_to_bfloat16(bias_c_data, bias_c_data_packed);
1314 
1315     return 0;
1316 }
1317 
forward_bf16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1318 int RNN_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1319 {
1320     int T = bottom_blob.h;
1321 
1322     int num_directions = direction == 2 ? 2 : 1;
1323 
1324     // initial hidden state
1325     Mat hidden(num_output, 4u, opt.workspace_allocator);
1326     if (hidden.empty())
1327         return -100;
1328     hidden.fill(0.f);
1329 
1330     top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1331     if (top_blob.empty())
1332         return -100;
1333 
1334     // Uni directional
1335     if (direction == 0 || direction == 1)
1336     {
1337         int ret = rnn_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1338         if (ret != 0)
1339             return ret;
1340     }
1341 
1342     if (direction == 2)
1343     {
1344         Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1345         if (top_blob_forward.empty())
1346             return -100;
1347 
1348         Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1349         if (top_blob_reverse.empty())
1350             return -100;
1351 
1352         int ret0 = rnn_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1353         if (ret0 != 0)
1354             return ret0;
1355 
1356         hidden.fill(0.f);
1357 
1358         int ret1 = rnn_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
1359         if (ret1 != 0)
1360             return ret1;
1361 
1362         // concat w
1363         for (int i = 0; i < T; i++)
1364         {
1365             const unsigned short* pf = top_blob_forward.row<const unsigned short>(i);
1366             const unsigned short* pr = top_blob_reverse.row<const unsigned short>(i);
1367             unsigned short* ptr = top_blob.row<unsigned short>(i);
1368 
1369             memcpy(ptr, pf, num_output * sizeof(unsigned short));
1370             memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short));
1371         }
1372     }
1373 
1374     return 0;
1375 }
1376 
forward_bf16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1377 int RNN_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1378 {
1379     const Mat& bottom_blob = bottom_blobs[0];
1380     int T = bottom_blob.h;
1381     Mat& top_blob = top_blobs[0];
1382 
1383     top_blob.create(num_output, T, 2u, opt.blob_allocator);
1384     if (top_blob.empty())
1385         return -100;
1386 
1387     // copy previous states
1388     Mat hidden;
1389     cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt);
1390 
1391     // Uni directional
1392     if (direction == 0 || direction == 1)
1393     {
1394         int ret = rnn_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1395         if (ret != 0)
1396             return ret;
1397     }
1398 
1399     cast_float32_to_bfloat16(hidden, top_blobs[1], opt);
1400 
1401     return 0;
1402 }
1403 
1404 } // namespace ncnn
1405