1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "rnn_arm.h"
16
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20
21 #include "arm_activation.h"
22
23 #include <math.h>
24
25 namespace ncnn {
26
RNN_arm()27 RNN_arm::RNN_arm()
28 {
29 #if __ARM_NEON
30 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
31 support_fp16_storage = true;
32 #endif
33 #endif // __ARM_NEON
34
35 support_bf16_storage = true;
36 }
37
create_pipeline(const Option & opt)38 int RNN_arm::create_pipeline(const Option& opt)
39 {
40 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
41 if (opt.use_fp16_storage)
42 {
43 return create_pipeline_fp16s(opt);
44 }
45 #endif
46
47 if (opt.use_bf16_storage)
48 {
49 return create_pipeline_bf16s(opt);
50 }
51
52 int num_directions = direction == 2 ? 2 : 1;
53 int size = weight_data_size / num_directions / num_output;
54
55 #if __ARM_NEON
56 weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions);
57 weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions);
58
59 #pragma omp parallel for num_threads(opt.num_threads)
60 for (int dr = 0; dr < num_directions; dr++)
61 {
62 const Mat weight_xc = weight_xc_data.channel(dr);
63 const Mat weight_hc = weight_hc_data.channel(dr);
64
65 Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
66 Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
67
68 int q = 0;
69 #if __ARM_NEON
70 for (; q + 3 < num_output; q += 4)
71 {
72 const float* weight_xc_0 = weight_xc.row(q);
73 const float* weight_xc_1 = weight_xc.row(q + 1);
74 const float* weight_xc_2 = weight_xc.row(q + 2);
75 const float* weight_xc_3 = weight_xc.row(q + 3);
76
77 const float* weight_hc_0 = weight_hc.row(q);
78 const float* weight_hc_1 = weight_hc.row(q + 1);
79 const float* weight_hc_2 = weight_hc.row(q + 2);
80 const float* weight_hc_3 = weight_hc.row(q + 3);
81
82 float* weight_xc = weight_xc_data_packed_dr.row(q / 4);
83 float* weight_hc = weight_hc_data_packed_dr.row(q / 4);
84
85 for (int i = 0; i < size; i++)
86 {
87 weight_xc[0] = weight_xc_0[i];
88 weight_xc[1] = weight_xc_1[i];
89 weight_xc[2] = weight_xc_2[i];
90 weight_xc[3] = weight_xc_3[i];
91
92 weight_xc += 4;
93 }
94
95 for (int i = 0; i < num_output; i++)
96 {
97 weight_hc[0] = weight_hc_0[i];
98 weight_hc[1] = weight_hc_1[i];
99 weight_hc[2] = weight_hc_2[i];
100 weight_hc[3] = weight_hc_3[i];
101
102 weight_hc += 4;
103 }
104 }
105 #endif // __ARM_NEON
106 for (; q < num_output; q++)
107 {
108 const float* weight_xc_0 = weight_xc.row(q);
109 const float* weight_hc_0 = weight_hc.row(q);
110
111 #if __ARM_NEON
112 float* weight_xc = weight_xc_data_packed_dr.row(q / 4 + q % 4);
113 float* weight_hc = weight_hc_data_packed_dr.row(q / 4 + q % 4);
114 #else
115 float* weight_xc = weight_xc_data_packed_dr.row(q);
116 float* weight_hc = weight_hc_data_packed_dr.row(q);
117 #endif // __ARM_NEON
118
119 for (int i = 0; i < size; i++)
120 {
121 weight_xc[i] = weight_xc_0[i];
122 }
123
124 for (int i = 0; i < num_output; i++)
125 {
126 weight_hc[i] = weight_hc_0[i];
127 }
128 }
129 }
130 #else
131 weight_xc_data_packed = weight_xc_data;
132 weight_hc_data_packed = weight_hc_data;
133 #endif
134
135 bias_c_data_packed = bias_c_data;
136
137 return 0;
138 }
139
rnn(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)140 static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
141 {
142 int size = bottom_blob.w;
143 int T = bottom_blob.h;
144
145 int num_output = top_blob.w;
146
147 // num_output
148 Mat gates(num_output, 4u, opt.workspace_allocator);
149 if (gates.empty())
150 return -100;
151
152 // unroll
153 for (int t = 0; t < T; t++)
154 {
155 int ti = reverse ? T - 1 - t : t;
156
157 const float* x = bottom_blob.row(ti);
158
159 int q = 0;
160 #if __ARM_NEON
161 for (; q + 3 < num_output; q += 4)
162 {
163 const float* weight_xc_ptr = weight_xc.row(q / 4);
164 const float* weight_hc_ptr = weight_hc.row(q / 4);
165
166 float32x4_t _H = vld1q_f32((const float*)bias_c + q);
167 float32x4_t _sum1 = vdupq_n_f32(0.f);
168 float32x4_t _sum2 = vdupq_n_f32(0.f);
169 float32x4_t _sum3 = vdupq_n_f32(0.f);
170
171 int i = 0;
172 for (; i + 3 < size; i += 4)
173 {
174 float32x4_t _x = vld1q_f32(x + i);
175 float32x4_t _weight_xc = vld1q_f32(weight_xc_ptr);
176 float32x4_t _weight_xc_1 = vld1q_f32(weight_xc_ptr + 4);
177 float32x4_t _weight_xc_2 = vld1q_f32(weight_xc_ptr + 8);
178 float32x4_t _weight_xc_3 = vld1q_f32(weight_xc_ptr + 12);
179 #if __aarch64__
180 _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0);
181 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1);
182 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2);
183 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3);
184 #else
185 _H = vmlaq_lane_f32(_H, _weight_xc, vget_low_f32(_x), 0);
186 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1);
187 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0);
188 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1);
189 #endif
190
191 weight_xc_ptr += 16;
192 }
193 for (; i < size; i++)
194 {
195 float32x4_t _x = vdupq_n_f32(x[i]);
196 float32x4_t _weight_xc = vld1q_f32(weight_xc_ptr);
197 _H = vmlaq_f32(_H, _weight_xc, _x);
198
199 weight_xc_ptr += 4;
200 }
201
202 i = 0;
203 for (; i + 3 < num_output; i += 4)
204 {
205 float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i);
206 float32x4_t _weight_hc = vld1q_f32(weight_hc_ptr);
207 float32x4_t _weight_hc_1 = vld1q_f32(weight_hc_ptr + 4);
208 float32x4_t _weight_hc_2 = vld1q_f32(weight_hc_ptr + 8);
209 float32x4_t _weight_hc_3 = vld1q_f32(weight_hc_ptr + 12);
210 #if __aarch64__
211 _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0);
212 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1);
213 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2);
214 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3);
215 #else
216 _H = vmlaq_lane_f32(_H, _weight_hc, vget_low_f32(_hidden_state), 0);
217 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1);
218 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0);
219 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1);
220 #endif
221
222 weight_hc_ptr += 16;
223 }
224 for (; i < num_output; i++)
225 {
226 float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
227 float32x4_t _weight_hc = vld1q_f32(weight_hc_ptr);
228 _H = vmlaq_f32(_H, _weight_hc, _hidden_state);
229
230 weight_hc_ptr += 4;
231 }
232
233 _H = vaddq_f32(_H, _sum1);
234 _sum2 = vaddq_f32(_sum2, _sum3);
235 _H = vaddq_f32(_H, _sum2);
236
237 _H = tanh_ps(_H);
238
239 vst1q_f32((float*)gates + q, _H);
240 }
241 #endif // __ARM_NEON
242 for (; q < num_output; q++)
243 {
244 #if __ARM_NEON
245 const float* weight_xc_ptr = weight_xc.row(q / 4 + q % 4);
246 const float* weight_hc_ptr = weight_hc.row(q / 4 + q % 4);
247 #else
248 const float* weight_xc_ptr = weight_xc.row(q);
249 const float* weight_hc_ptr = weight_hc.row(q);
250 #endif // __ARM_NEON
251
252 float H = bias_c[q];
253
254 for (int i = 0; i < size; i++)
255 {
256 H += weight_xc_ptr[i] * x[i];
257 }
258
259 for (int i = 0; i < num_output; i++)
260 {
261 H += weight_hc_ptr[i] * hidden_state[i];
262 }
263
264 H = tanh(H);
265
266 gates[q] = H;
267 }
268
269 float* output_data = top_blob.row(ti);
270
271 float* hidden_ptr = hidden_state;
272
273 q = 0;
274 #if __ARM_NEON
275 for (; q + 3 < num_output; q += 4)
276 {
277 float32x4_t _H = vld1q_f32((float*)gates + q);
278
279 vst1q_f32(hidden_ptr, _H);
280 vst1q_f32(output_data, _H);
281
282 hidden_ptr += 4;
283 output_data += 4;
284 }
285 #endif // __ARM_NEON
286 for (; q < num_output; q++)
287 {
288 float H = gates[q];
289
290 *hidden_ptr++ = H;
291 *output_data++ = H;
292 }
293 }
294
295 return 0;
296 }
297
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const298 int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
299 {
300 int elembits = bottom_blob.elembits();
301
302 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
303 if (opt.use_fp16_storage && elembits == 16)
304 {
305 if (opt.use_fp16_arithmetic)
306 return forward_fp16sa(bottom_blob, top_blob, opt);
307 else
308 return forward_fp16s(bottom_blob, top_blob, opt);
309 }
310 #endif
311
312 if (opt.use_bf16_storage && elembits == 16)
313 return forward_bf16s(bottom_blob, top_blob, opt);
314
315 int T = bottom_blob.h;
316
317 int num_directions = direction == 2 ? 2 : 1;
318
319 // initial hidden state
320 Mat hidden(num_output, 4u, opt.workspace_allocator);
321 if (hidden.empty())
322 return -100;
323 hidden.fill(0.f);
324
325 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
326 if (top_blob.empty())
327 return -100;
328
329 // Uni directional
330 if (direction == 0 || direction == 1)
331 {
332 int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
333 if (ret != 0)
334 return ret;
335 }
336
337 if (direction == 2)
338 {
339 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
340 if (top_blob_forward.empty())
341 return -100;
342
343 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
344 if (top_blob_reverse.empty())
345 return -100;
346
347 int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
348 if (ret0 != 0)
349 return ret0;
350
351 hidden.fill(0.0f);
352
353 int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
354 if (ret1 != 0)
355 return ret1;
356
357 // concat w
358 for (int i = 0; i < T; i++)
359 {
360 const float* pf = top_blob_forward.row(i);
361 const float* pr = top_blob_reverse.row(i);
362 float* ptr = top_blob.row(i);
363
364 memcpy(ptr, pf, num_output * sizeof(float));
365 memcpy(ptr + num_output, pr, num_output * sizeof(float));
366 }
367 }
368
369 return 0;
370 }
371
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const372 int RNN_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
373 {
374 if (bottom_blobs.size() != 2 || top_blobs.size() != 2)
375 {
376 return forward(bottom_blobs[0], top_blobs[0], opt);
377 }
378
379 const Mat& bottom_blob = bottom_blobs[0];
380
381 int elembits = bottom_blob.elembits();
382
383 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
384 if (opt.use_fp16_storage && elembits == 16)
385 {
386 if (opt.use_fp16_arithmetic)
387 return forward_fp16sa(bottom_blobs, top_blobs, opt);
388 else
389 return forward_fp16s(bottom_blobs, top_blobs, opt);
390 }
391 #endif
392
393 if (opt.use_bf16_storage && elembits == 16)
394 return forward_bf16s(bottom_blobs, top_blobs, opt);
395
396 int T = bottom_blob.h;
397 Mat& top_blob = top_blobs[0];
398 Mat& hidden_state = top_blobs[1];
399
400 //Copy previous states
401 hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
402
403 top_blob.create(num_output, T, 4u, opt.blob_allocator);
404 if (top_blob.empty())
405 return -100;
406
407 // Uni directional
408 if (direction == 0 || direction == 1)
409 {
410 int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, opt);
411 if (ret != 0)
412 return ret;
413 }
414
415 return 0;
416 }
417
418 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
rnn_fp16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)419 static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
420 {
421 int size = bottom_blob.w;
422 int T = bottom_blob.h;
423
424 int num_output = top_blob.w;
425
426 // num_output
427 Mat gates(num_output, 4u, opt.workspace_allocator);
428 if (gates.empty())
429 return -100;
430
431 // unroll
432 for (int t = 0; t < T; t++)
433 {
434 int ti = reverse ? T - 1 - t : t;
435
436 const __fp16* x = bottom_blob.row<const __fp16>(ti);
437
438 int q = 0;
439 for (; q + 3 < num_output; q += 4)
440 {
441 const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 4);
442 const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 4);
443
444 float32x4_t _H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q));
445 float32x4_t _sum1 = vdupq_n_f32(0.f);
446 float32x4_t _sum2 = vdupq_n_f32(0.f);
447 float32x4_t _sum3 = vdupq_n_f32(0.f);
448
449 int i = 0;
450 for (; i + 3 < size; i += 4)
451 {
452 float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i));
453 float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr));
454 float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4));
455 float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8));
456 float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12));
457 _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0);
458 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1);
459 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2);
460 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3);
461
462 weight_xc_ptr += 16;
463 }
464 for (; i < size; i++)
465 {
466 float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i]));
467 float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr));
468 _H = vfmaq_f32(_H, _weight_xc, _x);
469
470 weight_xc_ptr += 4;
471 }
472
473 i = 0;
474 for (; i + 3 < num_output; i += 4)
475 {
476 float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i);
477 float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr));
478 float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4));
479 float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8));
480 float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12));
481 _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0);
482 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1);
483 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2);
484 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3);
485
486 weight_hc_ptr += 16;
487 }
488 for (; i < num_output; i++)
489 {
490 float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
491 float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr));
492 _H = vfmaq_f32(_H, _weight_hc, _hidden_state);
493
494 weight_hc_ptr += 4;
495 }
496
497 _H = vaddq_f32(_H, _sum1);
498 _sum2 = vaddq_f32(_sum2, _sum3);
499 _H = vaddq_f32(_H, _sum2);
500
501 _H = tanh_ps(_H);
502
503 vst1q_f32((float*)gates + q, _H);
504 }
505 for (; q < num_output; q++)
506 {
507 const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 4 + q % 4);
508 const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 4 + q % 4);
509
510 float H = (float)(((const __fp16*)bias_c)[q]);
511
512 for (int i = 0; i < size; i++)
513 {
514 H += (float)weight_xc_ptr[i] * (float)x[i];
515 }
516
517 for (int i = 0; i < num_output; i++)
518 {
519 H += (float)weight_hc_ptr[i] * hidden_state[i];
520 }
521
522 H = tanh(H);
523
524 gates[q] = H;
525 }
526
527 __fp16* output_data = top_blob.row<__fp16>(ti);
528
529 float* hidden_ptr = hidden_state;
530
531 q = 0;
532 for (; q + 3 < num_output; q += 4)
533 {
534 float32x4_t _H = vld1q_f32((float*)gates + q);
535
536 vst1q_f32(hidden_ptr, _H);
537 vst1_f16(output_data, vcvt_f16_f32(_H));
538
539 hidden_ptr += 4;
540 output_data += 4;
541 }
542 for (; q < num_output; q++)
543 {
544 float H = gates[q];
545
546 *hidden_ptr++ = H;
547 *output_data++ = (__fp16)H;
548 }
549 }
550
551 return 0;
552 }
553
rnn_fp16sa(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)554 static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
555 {
556 int size = bottom_blob.w;
557 int T = bottom_blob.h;
558
559 int num_output = top_blob.w;
560
561 // num_output
562 Mat gates(num_output, 4u, opt.workspace_allocator);
563 if (gates.empty())
564 return -100;
565
566 // unroll
567 for (int t = 0; t < T; t++)
568 {
569 int ti = reverse ? T - 1 - t : t;
570
571 const __fp16* x = bottom_blob.row<const __fp16>(ti);
572
573 int q = 0;
574 for (; q + 7 < num_output; q += 8)
575 {
576 const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8);
577 const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8);
578
579 float16x8_t _H = vld1q_f16((const __fp16*)bias_c + q);
580 float16x8_t _sum1 = vdupq_n_f16(0.f);
581 float16x8_t _sum2 = vdupq_n_f16(0.f);
582 float16x8_t _sum3 = vdupq_n_f16(0.f);
583
584 int i = 0;
585 for (; i + 3 < size; i += 4)
586 {
587 float16x4_t _x = vld1_f16(x + i);
588 float16x8_t _weight_xc = vld1q_f16(weight_xc_ptr);
589 float16x8_t _weight_xc_1 = vld1q_f16(weight_xc_ptr + 8);
590 float16x8_t _weight_xc_2 = vld1q_f16(weight_xc_ptr + 16);
591 float16x8_t _weight_xc_3 = vld1q_f16(weight_xc_ptr + 24);
592 _H = vfmaq_lane_f16(_H, _weight_xc, _x, 0);
593 _sum1 = vfmaq_lane_f16(_sum1, _weight_xc_1, _x, 1);
594 _sum2 = vfmaq_lane_f16(_sum2, _weight_xc_2, _x, 2);
595 _sum3 = vfmaq_lane_f16(_sum3, _weight_xc_3, _x, 3);
596
597 weight_xc_ptr += 32;
598 }
599 for (; i < size; i++)
600 {
601 float16x8_t _x = vdupq_n_f16(x[i]);
602 float16x8_t _weight_xc = vld1q_f16(weight_xc_ptr);
603 _H = vfmaq_f16(_H, _weight_xc, _x);
604
605 weight_xc_ptr += 8;
606 }
607
608 i = 0;
609 for (; i + 3 < num_output; i += 4)
610 {
611 float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i));
612 float16x8_t _weight_hc = vld1q_f16(weight_hc_ptr);
613 float16x8_t _weight_hc_1 = vld1q_f16(weight_hc_ptr + 8);
614 float16x8_t _weight_hc_2 = vld1q_f16(weight_hc_ptr + 16);
615 float16x8_t _weight_hc_3 = vld1q_f16(weight_hc_ptr + 24);
616 _H = vfmaq_lane_f16(_H, _weight_hc, _hidden_state, 0);
617 _sum1 = vfmaq_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1);
618 _sum2 = vfmaq_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2);
619 _sum3 = vfmaq_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3);
620
621 weight_hc_ptr += 32;
622 }
623 for (; i < num_output; i++)
624 {
625 float16x8_t _hidden_state = vdupq_n_f16((__fp16)hidden_state[i]);
626 float16x8_t _weight_hc = vld1q_f16(weight_hc_ptr);
627 _H = vfmaq_f16(_H, _weight_hc, _hidden_state);
628
629 weight_hc_ptr += 8;
630 }
631
632 _H = vaddq_f16(_H, _sum1);
633 _sum2 = vaddq_f16(_sum2, _sum3);
634 _H = vaddq_f16(_H, _sum2);
635
636 float32x4_t _H32low = tanh_ps(vcvt_f32_f16(vget_low_f16(_H)));
637 float32x4_t _H32high = tanh_ps(vcvt_f32_f16(vget_high_f16(_H)));
638
639 vst1q_f32((float*)gates + q, _H32low);
640 vst1q_f32((float*)gates + q + 4, _H32high);
641 }
642 for (; q + 3 < num_output; q += 4)
643 {
644 const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8 + (q % 8) / 4);
645 const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8 + (q % 8) / 4);
646
647 float16x4_t _H = vld1_f16((const __fp16*)bias_c + q);
648 float16x4_t _sum1 = vdup_n_f16(0.f);
649 float16x4_t _sum2 = vdup_n_f16(0.f);
650 float16x4_t _sum3 = vdup_n_f16(0.f);
651
652 int i = 0;
653 for (; i + 3 < size; i += 4)
654 {
655 float16x4_t _x = vld1_f16(x + i);
656 float16x4_t _weight_xc = vld1_f16(weight_xc_ptr);
657 float16x4_t _weight_xc_1 = vld1_f16(weight_xc_ptr + 4);
658 float16x4_t _weight_xc_2 = vld1_f16(weight_xc_ptr + 8);
659 float16x4_t _weight_xc_3 = vld1_f16(weight_xc_ptr + 12);
660 _H = vfma_lane_f16(_H, _weight_xc, _x, 0);
661 _sum1 = vfma_lane_f16(_sum1, _weight_xc_1, _x, 1);
662 _sum2 = vfma_lane_f16(_sum2, _weight_xc_2, _x, 2);
663 _sum3 = vfma_lane_f16(_sum3, _weight_xc_3, _x, 3);
664
665 weight_xc_ptr += 16;
666 }
667 for (; i < size; i++)
668 {
669 float16x4_t _x = vdup_n_f16(x[i]);
670 float16x4_t _weight_xc = vld1_f16(weight_xc_ptr);
671 _H = vfma_f16(_H, _weight_xc, _x);
672
673 weight_xc_ptr += 4;
674 }
675
676 i = 0;
677 for (; i + 3 < num_output; i += 4)
678 {
679 float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i));
680 float16x4_t _weight_hc = vld1_f16(weight_hc_ptr);
681 float16x4_t _weight_hc_1 = vld1_f16(weight_hc_ptr + 4);
682 float16x4_t _weight_hc_2 = vld1_f16(weight_hc_ptr + 8);
683 float16x4_t _weight_hc_3 = vld1_f16(weight_hc_ptr + 12);
684 _H = vfma_lane_f16(_H, _weight_hc, _hidden_state, 0);
685 _sum1 = vfma_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1);
686 _sum2 = vfma_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2);
687 _sum3 = vfma_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3);
688
689 weight_hc_ptr += 16;
690 }
691 for (; i < num_output; i++)
692 {
693 float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]);
694 float16x4_t _weight_hc = vld1_f16(weight_hc_ptr);
695 _H = vfma_f16(_H, _weight_hc, _hidden_state);
696
697 weight_hc_ptr += 4;
698 }
699
700 _H = vadd_f16(_H, _sum1);
701 _sum2 = vadd_f16(_sum2, _sum3);
702 _H = vadd_f16(_H, _sum2);
703
704 float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_H));
705
706 vst1q_f32((float*)gates + q, _H32);
707 }
708 for (; q < num_output; q++)
709 {
710 const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8 + (q % 8) / 4 + q % 4);
711 const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8 + (q % 8) / 4 + q % 4);
712
713 __fp16 H = ((const __fp16*)bias_c)[q];
714
715 for (int i = 0; i < size; i++)
716 {
717 H += weight_xc_ptr[i] * x[i];
718 }
719
720 for (int i = 0; i < num_output; i++)
721 {
722 H += weight_hc_ptr[i] * (__fp16)hidden_state[i];
723 }
724
725 float H32 = tanh((float)H);
726
727 gates[q] = H32;
728 }
729
730 __fp16* output_data = top_blob.row<__fp16>(ti);
731
732 float* hidden_ptr = hidden_state;
733
734 q = 0;
735 for (; q + 3 < num_output; q += 4)
736 {
737 float32x4_t _H = vld1q_f32((float*)gates + q);
738
739 vst1q_f32(hidden_ptr, _H);
740 vst1_f16(output_data, vcvt_f16_f32(_H));
741
742 hidden_ptr += 4;
743 output_data += 4;
744 }
745 for (; q < num_output; q++)
746 {
747 float H = gates[q];
748
749 *hidden_ptr++ = H;
750 *output_data++ = (__fp16)H;
751 }
752 }
753
754 return 0;
755 }
756
create_pipeline_fp16s(const Option & opt)757 int RNN_arm::create_pipeline_fp16s(const Option& opt)
758 {
759 int num_directions = direction == 2 ? 2 : 1;
760 int size = weight_data_size / num_directions / num_output;
761
762 if (opt.use_fp16_arithmetic)
763 {
764 weight_xc_data_packed.create(size * 8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 2u, 1);
765 weight_hc_data_packed.create(num_output * 8, num_output / 8 + (num_output % 8) / 4 + num_output % 4, num_directions, 2u, 1);
766 }
767 else
768 {
769 weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
770 weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
771 }
772
773 #pragma omp parallel for num_threads(opt.num_threads)
774 for (int dr = 0; dr < num_directions; dr++)
775 {
776 const Mat weight_xc = weight_xc_data.channel(dr);
777 const Mat weight_hc = weight_hc_data.channel(dr);
778
779 Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
780 Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
781
782 int q = 0;
783 if (opt.use_fp16_arithmetic)
784 {
785 for (; q + 7 < num_output; q += 8)
786 {
787 const float* weight_xc_0 = weight_xc.row(q);
788 const float* weight_xc_1 = weight_xc.row(q + 1);
789 const float* weight_xc_2 = weight_xc.row(q + 2);
790 const float* weight_xc_3 = weight_xc.row(q + 3);
791 const float* weight_xc_4 = weight_xc.row(q + 4);
792 const float* weight_xc_5 = weight_xc.row(q + 5);
793 const float* weight_xc_6 = weight_xc.row(q + 6);
794 const float* weight_xc_7 = weight_xc.row(q + 7);
795
796 const float* weight_hc_0 = weight_hc.row(q);
797 const float* weight_hc_1 = weight_hc.row(q + 1);
798 const float* weight_hc_2 = weight_hc.row(q + 2);
799 const float* weight_hc_3 = weight_hc.row(q + 3);
800 const float* weight_hc_4 = weight_hc.row(q + 4);
801 const float* weight_hc_5 = weight_hc.row(q + 5);
802 const float* weight_hc_6 = weight_hc.row(q + 6);
803 const float* weight_hc_7 = weight_hc.row(q + 7);
804
805 __fp16* weight_xc = weight_xc_data_packed_dr.row<__fp16>(q / 8);
806 __fp16* weight_hc = weight_hc_data_packed_dr.row<__fp16>(q / 8);
807
808 for (int i = 0; i < size; i++)
809 {
810 weight_xc[0] = (__fp16)weight_xc_0[i];
811 weight_xc[1] = (__fp16)weight_xc_1[i];
812 weight_xc[2] = (__fp16)weight_xc_2[i];
813 weight_xc[3] = (__fp16)weight_xc_3[i];
814 weight_xc[4] = (__fp16)weight_xc_4[i];
815 weight_xc[5] = (__fp16)weight_xc_5[i];
816 weight_xc[6] = (__fp16)weight_xc_6[i];
817 weight_xc[7] = (__fp16)weight_xc_7[i];
818
819 weight_xc += 8;
820 }
821
822 for (int i = 0; i < num_output; i++)
823 {
824 weight_hc[0] = (__fp16)weight_hc_0[i];
825 weight_hc[1] = (__fp16)weight_hc_1[i];
826 weight_hc[2] = (__fp16)weight_hc_2[i];
827 weight_hc[3] = (__fp16)weight_hc_3[i];
828 weight_hc[4] = (__fp16)weight_hc_4[i];
829 weight_hc[5] = (__fp16)weight_hc_5[i];
830 weight_hc[6] = (__fp16)weight_hc_6[i];
831 weight_hc[7] = (__fp16)weight_hc_7[i];
832
833 weight_hc += 8;
834 }
835 }
836 }
837 for (; q + 3 < num_output; q += 4)
838 {
839 const float* weight_xc_0 = weight_xc.row(q);
840 const float* weight_xc_1 = weight_xc.row(q + 1);
841 const float* weight_xc_2 = weight_xc.row(q + 2);
842 const float* weight_xc_3 = weight_xc.row(q + 3);
843
844 const float* weight_hc_0 = weight_hc.row(q);
845 const float* weight_hc_1 = weight_hc.row(q + 1);
846 const float* weight_hc_2 = weight_hc.row(q + 2);
847 const float* weight_hc_3 = weight_hc.row(q + 3);
848
849 __fp16* weight_xc = opt.use_fp16_arithmetic ? weight_xc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4) : weight_xc_data_packed_dr.row<__fp16>(q / 4);
850 __fp16* weight_hc = opt.use_fp16_arithmetic ? weight_hc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4) : weight_hc_data_packed_dr.row<__fp16>(q / 4);
851
852 for (int i = 0; i < size; i++)
853 {
854 weight_xc[0] = (__fp16)weight_xc_0[i];
855 weight_xc[1] = (__fp16)weight_xc_1[i];
856 weight_xc[2] = (__fp16)weight_xc_2[i];
857 weight_xc[3] = (__fp16)weight_xc_3[i];
858
859 weight_xc += 4;
860 }
861
862 for (int i = 0; i < num_output; i++)
863 {
864 weight_hc[0] = (__fp16)weight_hc_0[i];
865 weight_hc[1] = (__fp16)weight_hc_1[i];
866 weight_hc[2] = (__fp16)weight_hc_2[i];
867 weight_hc[3] = (__fp16)weight_hc_3[i];
868
869 weight_hc += 4;
870 }
871 }
872 for (; q < num_output; q++)
873 {
874 const float* weight_xc_0 = weight_xc.row(q);
875 const float* weight_hc_0 = weight_hc.row(q);
876
877 __fp16* weight_xc = opt.use_fp16_arithmetic ? weight_xc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4 + q % 4) : weight_xc_data_packed_dr.row<__fp16>(q / 4 + q % 4);
878 __fp16* weight_hc = opt.use_fp16_arithmetic ? weight_hc_data_packed_dr.row<__fp16>(q / 8 + (q % 8) / 4 + q % 4) : weight_hc_data_packed_dr.row<__fp16>(q / 4 + q % 4);
879
880 for (int i = 0; i < size; i++)
881 {
882 weight_xc[i] = (__fp16)weight_xc_0[i];
883 }
884
885 for (int i = 0; i < num_output; i++)
886 {
887 weight_hc[i] = (__fp16)weight_hc_0[i];
888 }
889 }
890 }
891
892 cast_float32_to_float16(bias_c_data, bias_c_data_packed);
893
894 return 0;
895 }
896
forward_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const897 int RNN_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
898 {
899 int T = bottom_blob.h;
900
901 int num_directions = direction == 2 ? 2 : 1;
902
903 // initial hidden state
904 Mat hidden(num_output, 4u, opt.workspace_allocator);
905 if (hidden.empty())
906 return -100;
907 hidden.fill(0.f);
908
909 top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
910 if (top_blob.empty())
911 return -100;
912
913 // Uni directional
914 if (direction == 0 || direction == 1)
915 {
916 int ret = rnn_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
917 if (ret != 0)
918 return ret;
919 }
920
921 if (direction == 2)
922 {
923 Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
924 if (top_blob_forward.empty())
925 return -100;
926
927 Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
928 if (top_blob_reverse.empty())
929 return -100;
930
931 int ret0 = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
932 if (ret0 != 0)
933 return ret0;
934
935 hidden.fill(0.f);
936
937 int ret1 = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
938 if (ret1 != 0)
939 return ret1;
940
941 // concat w
942 for (int i = 0; i < T; i++)
943 {
944 const __fp16* pf = top_blob_forward.row<const __fp16>(i);
945 const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
946 __fp16* ptr = top_blob.row<__fp16>(i);
947
948 memcpy(ptr, pf, num_output * sizeof(__fp16));
949 memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
950 }
951 }
952
953 return 0;
954 }
955
forward_fp16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const956 int RNN_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
957 {
958 const Mat& bottom_blob = bottom_blobs[0];
959 int T = bottom_blob.h;
960 Mat& top_blob = top_blobs[0];
961
962 top_blob.create(num_output, T, 2u, opt.blob_allocator);
963 if (top_blob.empty())
964 return -100;
965
966 // copy previous states
967 Mat hidden;
968 cast_float16_to_float32(bottom_blobs[1], hidden, opt);
969
970 // Uni directional
971 if (direction == 0 || direction == 1)
972 {
973 int ret = rnn_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
974 if (ret != 0)
975 return ret;
976 }
977
978 cast_float32_to_float16(hidden, top_blobs[1], opt);
979
980 return 0;
981 }
982
forward_fp16sa(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const983 int RNN_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
984 {
985 int T = bottom_blob.h;
986
987 int num_directions = direction == 2 ? 2 : 1;
988
989 // initial hidden state
990 Mat hidden(num_output, 4u, opt.workspace_allocator);
991 if (hidden.empty())
992 return -100;
993 hidden.fill(0.f);
994
995 top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
996 if (top_blob.empty())
997 return -100;
998
999 // Uni directional
1000 if (direction == 0 || direction == 1)
1001 {
1002 int ret = rnn_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1003 if (ret != 0)
1004 return ret;
1005 }
1006
1007 if (direction == 2)
1008 {
1009 Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1010 if (top_blob_forward.empty())
1011 return -100;
1012
1013 Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1014 if (top_blob_reverse.empty())
1015 return -100;
1016
1017 int ret0 = rnn_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1018 if (ret0 != 0)
1019 return ret0;
1020
1021 hidden.fill(0.f);
1022
1023 int ret1 = rnn_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
1024 if (ret1 != 0)
1025 return ret1;
1026
1027 // concat w
1028 for (int i = 0; i < T; i++)
1029 {
1030 const __fp16* pf = top_blob_forward.row<const __fp16>(i);
1031 const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
1032 __fp16* ptr = top_blob.row<__fp16>(i);
1033
1034 memcpy(ptr, pf, num_output * sizeof(__fp16));
1035 memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
1036 }
1037 }
1038
1039 return 0;
1040 }
1041
forward_fp16sa(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1042 int RNN_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1043 {
1044 const Mat& bottom_blob = bottom_blobs[0];
1045 int T = bottom_blob.h;
1046 Mat& top_blob = top_blobs[0];
1047
1048 top_blob.create(num_output, T, 2u, opt.blob_allocator);
1049 if (top_blob.empty())
1050 return -100;
1051
1052 // copy previous states
1053 Mat hidden;
1054 cast_float16_to_float32(bottom_blobs[1], hidden, opt);
1055
1056 // Uni directional
1057 if (direction == 0 || direction == 1)
1058 {
1059 int ret = rnn_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1060 if (ret != 0)
1061 return ret;
1062 }
1063
1064 cast_float32_to_float16(hidden, top_blobs[1], opt);
1065
1066 return 0;
1067 }
1068 #endif
1069
rnn_bf16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)1070 static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
1071 {
1072 int size = bottom_blob.w;
1073 int T = bottom_blob.h;
1074
1075 int num_output = top_blob.w;
1076
1077 // num_output
1078 Mat gates(num_output, 4u, opt.workspace_allocator);
1079 if (gates.empty())
1080 return -100;
1081
1082 // unroll
1083 for (int t = 0; t < T; t++)
1084 {
1085 int ti = reverse ? T - 1 - t : t;
1086
1087 const unsigned short* x = bottom_blob.row<const unsigned short>(ti);
1088
1089 int q = 0;
1090 #if __ARM_NEON
1091 for (; q + 3 < num_output; q += 4)
1092 {
1093 const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q / 4);
1094 const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q / 4);
1095
1096 float32x4_t _H = vcvt_f32_bf16(vld1_u16((const unsigned short*)bias_c + q));
1097 float32x4_t _sum1 = vdupq_n_f32(0.f);
1098 float32x4_t _sum2 = vdupq_n_f32(0.f);
1099 float32x4_t _sum3 = vdupq_n_f32(0.f);
1100
1101 int i = 0;
1102 for (; i + 3 < size; i += 4)
1103 {
1104 float32x4_t _x = vcvt_f32_bf16(vld1_u16(x + i));
1105 float32x4_t _weight_xc = vcvt_f32_bf16(vld1_u16(weight_xc_ptr));
1106 float32x4_t _weight_xc_1 = vcvt_f32_bf16(vld1_u16(weight_xc_ptr + 4));
1107 float32x4_t _weight_xc_2 = vcvt_f32_bf16(vld1_u16(weight_xc_ptr + 8));
1108 float32x4_t _weight_xc_3 = vcvt_f32_bf16(vld1_u16(weight_xc_ptr + 12));
1109 #if __aarch64__
1110 _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0);
1111 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1);
1112 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2);
1113 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3);
1114 #else
1115 _H = vmlaq_lane_f32(_H, _weight_xc, vget_low_f32(_x), 0);
1116 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1);
1117 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0);
1118 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1);
1119 #endif
1120
1121 weight_xc_ptr += 16;
1122 }
1123 for (; i < size; i++)
1124 {
1125 float32x4_t _x = vcvt_f32_bf16(vdup_n_u16(x[i]));
1126 float32x4_t _weight_xc = vcvt_f32_bf16(vld1_u16(weight_xc_ptr));
1127 _H = vmlaq_f32(_H, _weight_xc, _x);
1128
1129 weight_xc_ptr += 4;
1130 }
1131
1132 i = 0;
1133 for (; i + 3 < num_output; i += 4)
1134 {
1135 float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i);
1136 float32x4_t _weight_hc = vcvt_f32_bf16(vld1_u16(weight_hc_ptr));
1137 float32x4_t _weight_hc_1 = vcvt_f32_bf16(vld1_u16(weight_hc_ptr + 4));
1138 float32x4_t _weight_hc_2 = vcvt_f32_bf16(vld1_u16(weight_hc_ptr + 8));
1139 float32x4_t _weight_hc_3 = vcvt_f32_bf16(vld1_u16(weight_hc_ptr + 12));
1140 #if __aarch64__
1141 _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0);
1142 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1);
1143 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2);
1144 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3);
1145 #else
1146 _H = vmlaq_lane_f32(_H, _weight_hc, vget_low_f32(_hidden_state), 0);
1147 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1);
1148 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0);
1149 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1);
1150 #endif
1151
1152 weight_hc_ptr += 16;
1153 }
1154 for (; i < num_output; i++)
1155 {
1156 float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
1157 float32x4_t _weight_hc = vcvt_f32_bf16(vld1_u16(weight_hc_ptr));
1158 _H = vmlaq_f32(_H, _weight_hc, _hidden_state);
1159
1160 weight_hc_ptr += 4;
1161 }
1162
1163 _H = vaddq_f32(_H, _sum1);
1164 _sum2 = vaddq_f32(_sum2, _sum3);
1165 _H = vaddq_f32(_H, _sum2);
1166
1167 _H = tanh_ps(_H);
1168
1169 vst1q_f32((float*)gates + q, _H);
1170 }
1171 #endif // __ARM_NEON
1172 for (; q < num_output; q++)
1173 {
1174 #if __ARM_NEON
1175 const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q / 4 + q % 4);
1176 const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q / 4 + q % 4);
1177 #else
1178 const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q);
1179 const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q);
1180 #endif // __ARM_NEON
1181
1182 float H = bfloat16_to_float32(((const unsigned short*)bias_c)[q]);
1183
1184 for (int i = 0; i < size; i++)
1185 {
1186 H += bfloat16_to_float32(weight_xc_ptr[i]) * bfloat16_to_float32(x[i]);
1187 }
1188
1189 for (int i = 0; i < num_output; i++)
1190 {
1191 H += bfloat16_to_float32(weight_hc_ptr[i]) * hidden_state[i];
1192 }
1193
1194 H = tanh(H);
1195
1196 gates[q] = H;
1197 }
1198
1199 unsigned short* output_data = top_blob.row<unsigned short>(ti);
1200
1201 float* hidden_ptr = hidden_state;
1202
1203 q = 0;
1204 #if __ARM_NEON
1205 for (; q + 3 < num_output; q += 4)
1206 {
1207 float32x4_t _H = vld1q_f32((float*)gates + q);
1208
1209 vst1q_f32(hidden_ptr, _H);
1210 vst1_u16(output_data, vcvt_bf16_f32(_H));
1211
1212 hidden_ptr += 4;
1213 output_data += 4;
1214 }
1215 #endif // __ARM_NEON
1216 for (; q < num_output; q++)
1217 {
1218 float H = gates[q];
1219
1220 *hidden_ptr++ = H;
1221 *output_data++ = float32_to_bfloat16(H);
1222 }
1223 }
1224
1225 return 0;
1226 }
1227
create_pipeline_bf16s(const Option & opt)1228 int RNN_arm::create_pipeline_bf16s(const Option& opt)
1229 {
1230 int num_directions = direction == 2 ? 2 : 1;
1231 int size = weight_data_size / num_directions / num_output;
1232
1233 #if __ARM_NEON
1234 weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
1235 weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions, 2u, 1);
1236
1237 #pragma omp parallel for num_threads(opt.num_threads)
1238 for (int dr = 0; dr < num_directions; dr++)
1239 {
1240 const Mat weight_xc = weight_xc_data.channel(dr);
1241 const Mat weight_hc = weight_hc_data.channel(dr);
1242
1243 Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
1244 Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
1245
1246 int q = 0;
1247 #if __ARM_NEON
1248 for (; q + 3 < num_output; q += 4)
1249 {
1250 const float* weight_xc_0 = weight_xc.row(q);
1251 const float* weight_xc_1 = weight_xc.row(q + 1);
1252 const float* weight_xc_2 = weight_xc.row(q + 2);
1253 const float* weight_xc_3 = weight_xc.row(q + 3);
1254
1255 const float* weight_hc_0 = weight_hc.row(q);
1256 const float* weight_hc_1 = weight_hc.row(q + 1);
1257 const float* weight_hc_2 = weight_hc.row(q + 2);
1258 const float* weight_hc_3 = weight_hc.row(q + 3);
1259
1260 unsigned short* weight_xc = weight_xc_data_packed_dr.row<unsigned short>(q / 4);
1261 unsigned short* weight_hc = weight_hc_data_packed_dr.row<unsigned short>(q / 4);
1262
1263 for (int i = 0; i < size; i++)
1264 {
1265 weight_xc[0] = float32_to_bfloat16(weight_xc_0[i]);
1266 weight_xc[1] = float32_to_bfloat16(weight_xc_1[i]);
1267 weight_xc[2] = float32_to_bfloat16(weight_xc_2[i]);
1268 weight_xc[3] = float32_to_bfloat16(weight_xc_3[i]);
1269
1270 weight_xc += 4;
1271 }
1272
1273 for (int i = 0; i < num_output; i++)
1274 {
1275 weight_hc[0] = float32_to_bfloat16(weight_hc_0[i]);
1276 weight_hc[1] = float32_to_bfloat16(weight_hc_1[i]);
1277 weight_hc[2] = float32_to_bfloat16(weight_hc_2[i]);
1278 weight_hc[3] = float32_to_bfloat16(weight_hc_3[i]);
1279
1280 weight_hc += 4;
1281 }
1282 }
1283 #endif // __ARM_NEON
1284 for (; q < num_output; q++)
1285 {
1286 const float* weight_xc_0 = weight_xc.row(q);
1287 const float* weight_hc_0 = weight_hc.row(q);
1288
1289 #if __ARM_NEON
1290 unsigned short* weight_xc = weight_xc_data_packed_dr.row<unsigned short>(q / 4 + q % 4);
1291 unsigned short* weight_hc = weight_hc_data_packed_dr.row<unsigned short>(q / 4 + q % 4);
1292 #else
1293 unsigned short* weight_xc = weight_xc_data_packed_dr.row<unsigned short>(q);
1294 unsigned short* weight_hc = weight_hc_data_packed_dr.row<unsigned short>(q);
1295 #endif // __ARM_NEON
1296
1297 for (int i = 0; i < size; i++)
1298 {
1299 weight_xc[i] = float32_to_bfloat16(weight_xc_0[i]);
1300 }
1301
1302 for (int i = 0; i < num_output; i++)
1303 {
1304 weight_hc[i] = float32_to_bfloat16(weight_hc_0[i]);
1305 }
1306 }
1307 }
1308 #else
1309 cast_float32_to_bfloat16(weight_xc_data, weight_xc_data_packed);
1310 cast_float32_to_bfloat16(weight_hc_data, weight_hc_data_packed);
1311 #endif
1312
1313 cast_float32_to_bfloat16(bias_c_data, bias_c_data_packed);
1314
1315 return 0;
1316 }
1317
forward_bf16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1318 int RNN_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1319 {
1320 int T = bottom_blob.h;
1321
1322 int num_directions = direction == 2 ? 2 : 1;
1323
1324 // initial hidden state
1325 Mat hidden(num_output, 4u, opt.workspace_allocator);
1326 if (hidden.empty())
1327 return -100;
1328 hidden.fill(0.f);
1329
1330 top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1331 if (top_blob.empty())
1332 return -100;
1333
1334 // Uni directional
1335 if (direction == 0 || direction == 1)
1336 {
1337 int ret = rnn_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1338 if (ret != 0)
1339 return ret;
1340 }
1341
1342 if (direction == 2)
1343 {
1344 Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1345 if (top_blob_forward.empty())
1346 return -100;
1347
1348 Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1349 if (top_blob_reverse.empty())
1350 return -100;
1351
1352 int ret0 = rnn_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1353 if (ret0 != 0)
1354 return ret0;
1355
1356 hidden.fill(0.f);
1357
1358 int ret1 = rnn_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt);
1359 if (ret1 != 0)
1360 return ret1;
1361
1362 // concat w
1363 for (int i = 0; i < T; i++)
1364 {
1365 const unsigned short* pf = top_blob_forward.row<const unsigned short>(i);
1366 const unsigned short* pr = top_blob_reverse.row<const unsigned short>(i);
1367 unsigned short* ptr = top_blob.row<unsigned short>(i);
1368
1369 memcpy(ptr, pf, num_output * sizeof(unsigned short));
1370 memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short));
1371 }
1372 }
1373
1374 return 0;
1375 }
1376
forward_bf16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1377 int RNN_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1378 {
1379 const Mat& bottom_blob = bottom_blobs[0];
1380 int T = bottom_blob.h;
1381 Mat& top_blob = top_blobs[0];
1382
1383 top_blob.create(num_output, T, 2u, opt.blob_allocator);
1384 if (top_blob.empty())
1385 return -100;
1386
1387 // copy previous states
1388 Mat hidden;
1389 cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt);
1390
1391 // Uni directional
1392 if (direction == 0 || direction == 1)
1393 {
1394 int ret = rnn_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt);
1395 if (ret != 0)
1396 return ret;
1397 }
1398
1399 cast_float32_to_bfloat16(hidden, top_blobs[1], opt);
1400
1401 return 0;
1402 }
1403
1404 } // namespace ncnn
1405