1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "lstm_arm.h"
16
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #include "neon_mathfun.h"
20 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
21 #include "neon_mathfun_fp16s.h"
22 #endif
23 #include "neon_activation.h"
24 #endif // __ARM_NEON
25
26 #include <math.h>
27
28 namespace ncnn {
29
LSTM_arm()30 LSTM_arm::LSTM_arm()
31 {
32 #if __ARM_NEON
33 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
34 support_fp16_storage = true;
35 #endif
36 #endif // __ARM_NEON
37
38 support_bf16_storage = true;
39 }
40
create_pipeline(const Option & opt)41 int LSTM_arm::create_pipeline(const Option& opt)
42 {
43 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
44 if (opt.use_fp16_storage)
45 {
46 return create_pipeline_fp16s(opt);
47 }
48 #endif
49
50 if (opt.use_bf16_storage)
51 {
52 return create_pipeline_bf16s(opt);
53 }
54
55 // pack IFOG
56 int num_directions = direction == 2 ? 2 : 1;
57 int size = weight_data_size / num_directions / num_output / 4;
58
59 weight_xc_data_packed.create(size, num_output, num_directions, 16u, 4);
60 bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4);
61 weight_hc_data_packed.create(num_output, num_output, num_directions, 16u, 4);
62
63 #pragma omp parallel for num_threads(opt.num_threads)
64 for (int dr = 0; dr < num_directions; dr++)
65 {
66 const Mat weight_xc = weight_xc_data.channel(dr);
67 const Mat bias_c = bias_c_data.channel(dr);
68 const Mat weight_hc = weight_hc_data.channel(dr);
69
70 Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
71 Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
72 Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
73
74 const float* bias_c_I = bias_c.row(0);
75 const float* bias_c_F = bias_c.row(1);
76 const float* bias_c_O = bias_c.row(2);
77 const float* bias_c_G = bias_c.row(3);
78
79 float* bias_c_IFOG = bias_c_data_packed_dr.row(0);
80
81 for (int q = 0; q < num_output; q++)
82 {
83 bias_c_IFOG[0] = bias_c_I[q];
84 bias_c_IFOG[1] = bias_c_F[q];
85 bias_c_IFOG[2] = bias_c_O[q];
86 bias_c_IFOG[3] = bias_c_G[q];
87
88 bias_c_IFOG += 4;
89
90 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
91 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
92 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
93 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
94
95 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
96 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
97 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
98 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
99
100 float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q);
101 float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q);
102
103 for (int i = 0; i < size; i++)
104 {
105 weight_xc_IFOG[0] = weight_xc_I[i];
106 weight_xc_IFOG[1] = weight_xc_F[i];
107 weight_xc_IFOG[2] = weight_xc_O[i];
108 weight_xc_IFOG[3] = weight_xc_G[i];
109
110 weight_xc_IFOG += 4;
111 }
112
113 for (int i = 0; i < num_output; i++)
114 {
115 weight_hc_IFOG[0] = weight_hc_I[i];
116 weight_hc_IFOG[1] = weight_hc_F[i];
117 weight_hc_IFOG[2] = weight_hc_O[i];
118 weight_hc_IFOG[3] = weight_hc_G[i];
119
120 weight_hc_IFOG += 4;
121 }
122 }
123 }
124
125 return 0;
126 }
127
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)128 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
129 {
130 int size = bottom_blob.w;
131 int T = bottom_blob.h;
132
133 int num_output = top_blob.w;
134
135 // 4 x num_output
136 Mat gates(4, num_output, 4u, opt.workspace_allocator);
137 if (gates.empty())
138 return -100;
139
140 // unroll
141 for (int t = 0; t < T; t++)
142 {
143 // clip hidden by continuation indicator
144 // h_cont_{t-1} = cont_t * h_{t-1}
145 // h_cont_{t-1} = h_{t-1} if cont_t == 1
146 // 0 otherwise
147 // calculate hidden
148 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
149
150 int ti = reverse ? T - 1 - t : t;
151
152 const float* x = bottom_blob.row(ti);
153 for (int q = 0; q < num_output; q++)
154 {
155 const float* bias_c_IFOG = (const float*)bias_c + q * 4;
156
157 // gate I F O G
158 const float* weight_xc_IFOG = weight_xc.row(q);
159
160 const float* weight_hc_IFOG = weight_hc.row(q);
161
162 #if __ARM_NEON
163 float32x4_t _IFOG = vld1q_f32(bias_c_IFOG);
164 float32x4_t _sum1 = vdupq_n_f32(0.f);
165 float32x4_t _sum2 = vdupq_n_f32(0.f);
166 float32x4_t _sum3 = vdupq_n_f32(0.f);
167 #else
168 float I = bias_c_IFOG[0];
169 float F = bias_c_IFOG[1];
170 float O = bias_c_IFOG[2];
171 float G = bias_c_IFOG[3];
172 #endif // __ARM_NEON
173
174 int i = 0;
175 #if __ARM_NEON
176 for (; i + 3 < size; i += 4)
177 {
178 float32x4_t _xi = vld1q_f32(x + i);
179
180 float32x4_t _weight_xc_IFOG_0 = vld1q_f32(weight_xc_IFOG);
181 float32x4_t _weight_xc_IFOG_1 = vld1q_f32(weight_xc_IFOG + 4);
182 float32x4_t _weight_xc_IFOG_2 = vld1q_f32(weight_xc_IFOG + 8);
183 float32x4_t _weight_xc_IFOG_3 = vld1q_f32(weight_xc_IFOG + 12);
184
185 #if __aarch64__
186 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
187 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1);
188 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2);
189 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3);
190 #else
191 _IFOG = vmlaq_lane_f32(_IFOG, _weight_xc_IFOG_0, vget_low_f32(_xi), 0);
192 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_IFOG_1, vget_low_f32(_xi), 1);
193 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_IFOG_2, vget_high_f32(_xi), 0);
194 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1);
195 #endif
196
197 weight_xc_IFOG += 16;
198 }
199 #endif // __ARM_NEON
200 for (; i < size; i++)
201 {
202 float xi = x[i];
203
204 #if __ARM_NEON
205 float32x4_t _xi = vdupq_n_f32(xi);
206 float32x4_t _weight_xc_IFOG = vld1q_f32(weight_xc_IFOG);
207 _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
208 #else
209 I += weight_xc_IFOG[0] * xi;
210 F += weight_xc_IFOG[1] * xi;
211 O += weight_xc_IFOG[2] * xi;
212 G += weight_xc_IFOG[3] * xi;
213 #endif // __ARM_NEON
214
215 weight_xc_IFOG += 4;
216 }
217
218 i = 0;
219 #if __ARM_NEON
220 for (; i + 3 < num_output; i += 4)
221 {
222 float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i);
223
224 float32x4_t _weight_hc_IFOG_0 = vld1q_f32(weight_hc_IFOG);
225 float32x4_t _weight_hc_IFOG_1 = vld1q_f32(weight_hc_IFOG + 4);
226 float32x4_t _weight_hc_IFOG_2 = vld1q_f32(weight_hc_IFOG + 8);
227 float32x4_t _weight_hc_IFOG_3 = vld1q_f32(weight_hc_IFOG + 12);
228
229 #if __aarch64__
230 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
231 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1);
232 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2);
233 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3);
234 #else
235 _IFOG = vmlaq_lane_f32(_IFOG, _weight_hc_IFOG_0, vget_low_f32(_h_cont), 0);
236 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_IFOG_1, vget_low_f32(_h_cont), 1);
237 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_IFOG_2, vget_high_f32(_h_cont), 0);
238 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1);
239 #endif
240
241 weight_hc_IFOG += 16;
242 }
243 #endif // __ARM_NEON
244 for (; i < num_output; i++)
245 {
246 float h_cont = hidden_state[i];
247
248 #if __ARM_NEON
249 float32x4_t _h_cont = vdupq_n_f32(h_cont);
250 float32x4_t _weight_hc_IFOG = vld1q_f32(weight_hc_IFOG);
251 _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
252 #else
253 I += weight_hc_IFOG[0] * h_cont;
254 F += weight_hc_IFOG[1] * h_cont;
255 O += weight_hc_IFOG[2] * h_cont;
256 G += weight_hc_IFOG[3] * h_cont;
257 #endif // __ARM_NEON
258
259 weight_hc_IFOG += 4;
260 }
261
262 float* gates_data = gates.row(q);
263
264 #if __ARM_NEON
265 _IFOG = vaddq_f32(_IFOG, _sum1);
266 _sum2 = vaddq_f32(_sum2, _sum3);
267 _IFOG = vaddq_f32(_IFOG, _sum2);
268
269 vst1q_f32(gates_data, _IFOG);
270 #else
271 gates_data[0] = I;
272 gates_data[1] = F;
273 gates_data[2] = O;
274 gates_data[3] = G;
275 #endif // __ARM_NEON
276 }
277
278 // lstm unit
279 // sigmoid(I)
280 // sigmoid(F)
281 // sigmoid(O)
282 // tanh(G)
283 // c_t := f_t .* c_{t-1} + i_t .* g_t
284 // h_t := o_t .* tanh[c_t]
285 float* output_data = top_blob.row(ti);
286
287 float* cell_ptr = cell_state;
288 float* hidden_ptr = hidden_state;
289
290 int q = 0;
291 #if __ARM_NEON
292 for (; q + 3 < num_output; q += 4)
293 {
294 const float* gates_data = gates.row(q);
295
296 float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
297
298 float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]);
299 float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]);
300 float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]);
301 float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]);
302
303 float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
304 float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
305
306 vst1q_f32(cell_ptr, _cell2);
307 vst1q_f32(hidden_ptr, _H);
308 vst1q_f32(output_data, _H);
309
310 cell_ptr += 4;
311 hidden_ptr += 4;
312 output_data += 4;
313 }
314 #endif // __ARM_NEON
315 for (; q < num_output; q++)
316 {
317 const float* gates_data = gates.row(q);
318
319 float I = gates_data[0];
320 float F = gates_data[1];
321 float O = gates_data[2];
322 float G = gates_data[3];
323
324 I = 1.f / (1.f + exp(-I));
325 F = 1.f / (1.f + exp(-F));
326 O = 1.f / (1.f + exp(-O));
327 G = tanh(G);
328
329 float cell2 = F * *cell_ptr + I * G;
330 float H = O * tanh(cell2);
331
332 *cell_ptr++ = cell2;
333 *hidden_ptr++ = H;
334 *output_data++ = H;
335 }
336 }
337
338 return 0;
339 }
340
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const341 int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
342 {
343 int elembits = bottom_blob.elembits();
344
345 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
346 if (opt.use_fp16_storage && elembits == 16)
347 {
348 if (opt.use_fp16_arithmetic)
349 return forward_fp16sa(bottom_blob, top_blob, opt);
350 else
351 return forward_fp16s(bottom_blob, top_blob, opt);
352 }
353 #endif
354
355 if (opt.use_bf16_storage && elembits == 16)
356 return forward_bf16s(bottom_blob, top_blob, opt);
357
358 int T = bottom_blob.h;
359
360 int num_directions = direction == 2 ? 2 : 1;
361
362 // initial hidden state
363 Mat hidden(num_output, 4u, opt.workspace_allocator);
364 if (hidden.empty())
365 return -100;
366 hidden.fill(0.f);
367
368 Mat cell(num_output, 4u, opt.workspace_allocator);
369 if (cell.empty())
370 return -100;
371 cell.fill(0.f);
372
373 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
374 if (top_blob.empty())
375 return -100;
376
377 // Uni directional
378 if (direction == 0 || direction == 1)
379 {
380 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
381 if (ret != 0)
382 return ret;
383 }
384
385 if (direction == 2)
386 {
387 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
388 if (top_blob_forward.empty())
389 return -100;
390
391 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
392 if (top_blob_reverse.empty())
393 return -100;
394
395 int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
396 if (ret0 != 0)
397 return ret0;
398
399 hidden.fill(0.0f);
400 cell.fill(0.0f);
401
402 int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
403 if (ret1 != 0)
404 return ret1;
405
406 // concat w
407 for (int i = 0; i < T; i++)
408 {
409 const float* pf = top_blob_forward.row(i);
410 const float* pr = top_blob_reverse.row(i);
411 float* ptr = top_blob.row(i);
412
413 memcpy(ptr, pf, num_output * sizeof(float));
414 memcpy(ptr + num_output, pr, num_output * sizeof(float));
415 }
416 }
417
418 return 0;
419 }
420
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const421 int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
422 {
423 if (bottom_blobs.size() != 3 || top_blobs.size() != 3)
424 {
425 return forward(bottom_blobs[0], top_blobs[0], opt);
426 }
427
428 const Mat& bottom_blob = bottom_blobs[0];
429
430 int elembits = bottom_blob.elembits();
431
432 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
433 if (opt.use_fp16_storage && elembits == 16)
434 {
435 if (opt.use_fp16_arithmetic)
436 return forward_fp16sa(bottom_blobs, top_blobs, opt);
437 else
438 return forward_fp16s(bottom_blobs, top_blobs, opt);
439 }
440 #endif
441
442 if (opt.use_bf16_storage && elembits == 16)
443 return forward_bf16s(bottom_blobs, top_blobs, opt);
444
445 int T = bottom_blob.h;
446 Mat& top_blob = top_blobs[0];
447 Mat& hidden_state = top_blobs[1];
448 Mat& cell_state = top_blobs[2];
449
450 //Copy previous states
451 hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
452 cell_state = bottom_blobs[2].clone(opt.blob_allocator);
453
454 top_blob.create(num_output, T, 4u, opt.blob_allocator);
455 if (top_blob.empty())
456 return -100;
457
458 // Uni directional
459 if (direction == 0 || direction == 1)
460 {
461 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, cell_state, opt);
462 if (ret != 0)
463 return ret;
464 }
465
466 return 0;
467 }
468
469 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
lstm_fp16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)470 static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
471 {
472 int size = bottom_blob.w;
473 int T = bottom_blob.h;
474
475 int num_output = top_blob.w;
476
477 // 4 x num_output
478 Mat gates(4, num_output, 4u, opt.workspace_allocator);
479 if (gates.empty())
480 return -100;
481
482 // unroll
483 for (int t = 0; t < T; t++)
484 {
485 // clip hidden by continuation indicator
486 // h_cont_{t-1} = cont_t * h_{t-1}
487 // h_cont_{t-1} = h_{t-1} if cont_t == 1
488 // 0 otherwise
489 // calculate hidden
490 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
491
492 int ti = reverse ? T - 1 - t : t;
493
494 const __fp16* x = bottom_blob.row<const __fp16>(ti);
495 for (int q = 0; q < num_output; q++)
496 {
497 const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;
498
499 // gate I F O G
500 const __fp16* weight_xc_IFOG = weight_xc.row<const __fp16>(q);
501
502 const __fp16* weight_hc_IFOG = weight_hc.row<const __fp16>(q);
503
504 float32x4_t _IFOG = vcvt_f32_f16(vld1_f16(bias_c_IFOG));
505 float32x4_t _sum1 = vdupq_n_f32(0.f);
506 float32x4_t _sum2 = vdupq_n_f32(0.f);
507 float32x4_t _sum3 = vdupq_n_f32(0.f);
508
509 int i = 0;
510 for (; i + 3 < size; i += 4)
511 {
512 float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i));
513
514 float32x4_t _weight_xc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG));
515 float32x4_t _weight_xc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 4));
516 float32x4_t _weight_xc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 8));
517 float32x4_t _weight_xc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 12));
518
519 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
520 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1);
521 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2);
522 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3);
523
524 weight_xc_IFOG += 16;
525 }
526 for (; i < size; i++)
527 {
528 __fp16 xi = x[i];
529
530 float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi));
531 float32x4_t _weight_xc_IFOG = vcvt_f32_f16(vld1_f16(weight_xc_IFOG));
532 _IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi);
533
534 weight_xc_IFOG += 4;
535 }
536
537 i = 0;
538 for (; i + 3 < num_output; i += 4)
539 {
540 float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i);
541
542 float32x4_t _weight_hc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG));
543 float32x4_t _weight_hc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 4));
544 float32x4_t _weight_hc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 8));
545 float32x4_t _weight_hc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 12));
546
547 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
548 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1);
549 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2);
550 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3);
551
552 weight_hc_IFOG += 16;
553 }
554 for (; i < num_output; i++)
555 {
556 float h_cont = hidden_state[i];
557
558 float32x4_t _h_cont = vdupq_n_f32(h_cont);
559 float32x4_t _weight_hc_IFOG = vcvt_f32_f16(vld1_f16(weight_hc_IFOG));
560 _IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
561
562 weight_hc_IFOG += 4;
563 }
564
565 float* gates_data = gates.row(q);
566
567 _IFOG = vaddq_f32(_IFOG, _sum1);
568 _sum2 = vaddq_f32(_sum2, _sum3);
569 _IFOG = vaddq_f32(_IFOG, _sum2);
570
571 vst1q_f32(gates_data, _IFOG);
572 }
573
574 // lstm unit
575 // sigmoid(I)
576 // sigmoid(F)
577 // sigmoid(O)
578 // tanh(G)
579 // c_t := f_t .* c_{t-1} + i_t .* g_t
580 // h_t := o_t .* tanh[c_t]
581 __fp16* output_data = top_blob.row<__fp16>(ti);
582
583 float* cell_ptr = cell_state;
584 float* hidden_ptr = hidden_state;
585
586 int q = 0;
587 for (; q + 3 < num_output; q += 4)
588 {
589 const float* gates_data = gates.row(q);
590
591 float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
592
593 float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]);
594 float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]);
595 float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]);
596 float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]);
597
598 float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
599 float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
600
601 vst1q_f32(cell_ptr, _cell2);
602 vst1q_f32(hidden_ptr, _H);
603 vst1_f16(output_data, vcvt_f16_f32(_H));
604
605 cell_ptr += 4;
606 hidden_ptr += 4;
607 output_data += 4;
608 }
609 for (; q < num_output; q++)
610 {
611 const float* gates_data = gates.row(q);
612
613 float I = gates_data[0];
614 float F = gates_data[1];
615 float O = gates_data[2];
616 float G = gates_data[3];
617
618 I = 1.f / (1.f + exp(-I));
619 F = 1.f / (1.f + exp(-F));
620 O = 1.f / (1.f + exp(-O));
621 G = tanh(G);
622
623 float cell2 = F * *cell_ptr + I * G;
624 float H = O * tanh(cell2);
625
626 *cell_ptr++ = cell2;
627 *hidden_ptr++ = H;
628 *output_data++ = (__fp16)(H);
629 }
630 }
631
632 return 0;
633 }
634
lstm_fp16sa(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)635 static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
636 {
637 int size = bottom_blob.w;
638 int T = bottom_blob.h;
639
640 int num_output = top_blob.w;
641
642 // 4 x num_output
643 Mat gates(4, num_output, 2u, opt.workspace_allocator);
644 if (gates.empty())
645 return -100;
646
647 // unroll
648 for (int t = 0; t < T; t++)
649 {
650 // clip hidden by continuation indicator
651 // h_cont_{t-1} = cont_t * h_{t-1}
652 // h_cont_{t-1} = h_{t-1} if cont_t == 1
653 // 0 otherwise
654 // calculate hidden
655 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
656
657 int ti = reverse ? T - 1 - t : t;
658
659 int q = 0;
660 for (; q + 1 < num_output; q += 2)
661 {
662 const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;
663
664 // gate I F O G
665 const __fp16* weight_xc_IFOG = weight_xc.row<const __fp16>(q / 2);
666
667 const __fp16* weight_hc_IFOG = weight_hc.row<const __fp16>(q / 2);
668
669 float16x8_t _IFOG = vld1q_f16(bias_c_IFOG);
670 float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f);
671 float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f);
672 float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f);
673
674 const __fp16* x = bottom_blob.row<const __fp16>(ti);
675
676 int i = 0;
677 for (; i + 3 < size; i += 4)
678 {
679 asm volatile(
680 "ld1 {v4.4h}, [%0], #8 \n"
681 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
682 "fmla %2.8h, v0.8h, v4.h[0] \n"
683 "fmla %3.8h, v1.8h, v4.h[1] \n"
684 "fmla %4.8h, v2.8h, v4.h[2] \n"
685 "fmla %5.8h, v3.8h, v4.h[3] \n"
686 : "=r"(x),
687 "=r"(weight_xc_IFOG),
688 "=w"(_IFOG),
689 "=w"(_sum1),
690 "=w"(_sum2),
691 "=w"(_sum3)
692 : "0"(x),
693 "1"(weight_xc_IFOG),
694 "2"(_IFOG),
695 "3"(_sum1),
696 "4"(_sum2),
697 "5"(_sum3)
698 : "memory", "v0", "v1", "v2", "v3", "v4");
699 }
700 for (; i < size; i++)
701 {
702 __fp16 xi = *x++;
703
704 float16x8_t _xi = vdupq_n_f16(xi);
705 float16x8_t _weight_xc_IFOG = vld1q_f16(weight_xc_IFOG);
706 _IFOG = vfmaq_f16(_IFOG, _weight_xc_IFOG, _xi);
707
708 weight_xc_IFOG += 8;
709 }
710
711 const float* hidden_ptr = hidden_state;
712
713 i = 0;
714 for (; i + 3 < num_output; i += 4)
715 {
716 asm volatile(
717 "ld1 {v4.4s}, [%0], #16 \n"
718 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
719 "fcvtn v4.4h, v4.4s \n"
720 "fmla %2.8h, v0.8h, v4.h[0] \n"
721 "fmla %3.8h, v1.8h, v4.h[1] \n"
722 "fmla %4.8h, v2.8h, v4.h[2] \n"
723 "fmla %5.8h, v3.8h, v4.h[3] \n"
724 : "=r"(hidden_ptr),
725 "=r"(weight_hc_IFOG),
726 "=w"(_IFOG),
727 "=w"(_sum1),
728 "=w"(_sum2),
729 "=w"(_sum3)
730 : "0"(hidden_ptr),
731 "1"(weight_hc_IFOG),
732 "2"(_IFOG),
733 "3"(_sum1),
734 "4"(_sum2),
735 "5"(_sum3)
736 : "memory", "v0", "v1", "v2", "v3", "v4");
737 }
738 for (; i < num_output; i++)
739 {
740 float h_cont = *hidden_ptr++;
741
742 float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont);
743 float16x8_t _weight_hc_IFOG = vld1q_f16(weight_hc_IFOG);
744 _IFOG = vfmaq_f16(_IFOG, _weight_hc_IFOG, _h_cont);
745
746 weight_hc_IFOG += 8;
747 }
748
749 __fp16* gates_data = gates.row<__fp16>(q);
750
751 _IFOG = vaddq_f16(_IFOG, _sum1);
752 _sum2 = vaddq_f16(_sum2, _sum3);
753 _IFOG = vaddq_f16(_IFOG, _sum2);
754
755 vst1q_f16(gates_data, _IFOG);
756 }
757 for (; q < num_output; q++)
758 {
759 const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;
760
761 // gate I F O G
762 const __fp16* weight_xc_IFOG = weight_xc.row<const __fp16>(q / 2 + q % 2);
763
764 const __fp16* weight_hc_IFOG = weight_hc.row<const __fp16>(q / 2 + q % 2);
765
766 float16x4_t _IFOG = vld1_f16(bias_c_IFOG);
767 float16x4_t _sum1 = vdup_n_f16((__fp16)0.f);
768 float16x4_t _sum2 = vdup_n_f16((__fp16)0.f);
769 float16x4_t _sum3 = vdup_n_f16((__fp16)0.f);
770
771 const __fp16* x = bottom_blob.row<const __fp16>(ti);
772
773 int i = 0;
774 for (; i + 3 < size; i += 4)
775 {
776 asm volatile(
777 "ld1 {v4.4h}, [%0], #8 \n"
778 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
779 "fmla %2.4h, v0.4h, v4.h[0] \n"
780 "fmla %3.4h, v1.4h, v4.h[1] \n"
781 "fmla %4.4h, v2.4h, v4.h[2] \n"
782 "fmla %5.4h, v3.4h, v4.h[3] \n"
783 : "=r"(x),
784 "=r"(weight_xc_IFOG),
785 "=w"(_IFOG),
786 "=w"(_sum1),
787 "=w"(_sum2),
788 "=w"(_sum3)
789 : "0"(x),
790 "1"(weight_xc_IFOG),
791 "2"(_IFOG),
792 "3"(_sum1),
793 "4"(_sum2),
794 "5"(_sum3)
795 : "memory", "v0", "v1", "v2", "v3", "v4");
796 }
797 for (; i < size; i++)
798 {
799 __fp16 xi = *x++;
800
801 float16x4_t _xi = vdup_n_f16(xi);
802 float16x4_t _weight_xc_IFOG = vld1_f16(weight_xc_IFOG);
803 _IFOG = vfma_f16(_IFOG, _weight_xc_IFOG, _xi);
804
805 weight_xc_IFOG += 4;
806 }
807
808 const float* hidden_ptr = hidden_state;
809
810 i = 0;
811 for (; i + 3 < num_output; i += 4)
812 {
813 asm volatile(
814 "ld1 {v4.4s}, [%0], #16 \n"
815 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
816 "fcvtn v4.4h, v4.4s \n"
817 "fmla %2.4h, v0.4h, v4.h[0] \n"
818 "fmla %3.4h, v1.4h, v4.h[1] \n"
819 "fmla %4.4h, v2.4h, v4.h[2] \n"
820 "fmla %5.4h, v3.4h, v4.h[3] \n"
821 : "=r"(hidden_ptr),
822 "=r"(weight_hc_IFOG),
823 "=w"(_IFOG),
824 "=w"(_sum1),
825 "=w"(_sum2),
826 "=w"(_sum3)
827 : "0"(hidden_ptr),
828 "1"(weight_hc_IFOG),
829 "2"(_IFOG),
830 "3"(_sum1),
831 "4"(_sum2),
832 "5"(_sum3)
833 : "memory", "v0", "v1", "v2", "v3", "v4");
834 }
835 for (; i < num_output; i++)
836 {
837 float h_cont = *hidden_ptr++;
838
839 float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont);
840 float16x4_t _weight_hc_IFOG = vld1_f16(weight_hc_IFOG);
841 _IFOG = vfma_f16(_IFOG, _weight_hc_IFOG, _h_cont);
842
843 weight_hc_IFOG += 4;
844 }
845
846 __fp16* gates_data = gates.row<__fp16>(q);
847
848 _IFOG = vadd_f16(_IFOG, _sum1);
849 _sum2 = vadd_f16(_sum2, _sum3);
850 _IFOG = vadd_f16(_IFOG, _sum2);
851
852 vst1_f16(gates_data, _IFOG);
853 }
854
855 // lstm unit
856 // sigmoid(I)
857 // sigmoid(F)
858 // sigmoid(O)
859 // tanh(G)
860 // c_t := f_t .* c_{t-1} + i_t .* g_t
861 // h_t := o_t .* tanh[c_t]
862 __fp16* output_data = top_blob.row<__fp16>(ti);
863
864 float* cell_ptr = cell_state;
865 float* hidden_ptr = hidden_state;
866
867 q = 0;
868 for (; q + 3 < num_output; q += 4)
869 {
870 const __fp16* gates_data = gates.row<const __fp16>(q);
871
872 float16x4x4_t _IFOG_4x4 = vld4_f16(gates_data);
873
874 float32x4_t _I = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[0]));
875 float32x4_t _F = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[1]));
876 float32x4_t _O = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[2]));
877 float32x4_t _G = tanh_ps(vcvt_f32_f16(_IFOG_4x4.val[3]));
878
879 float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
880 float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
881
882 vst1q_f32(cell_ptr, _cell2);
883 vst1q_f32(hidden_ptr, _H);
884 vst1_f16(output_data, vcvt_f16_f32(_H));
885
886 cell_ptr += 4;
887 hidden_ptr += 4;
888 output_data += 4;
889 }
890 for (; q < num_output; q++)
891 {
892 const __fp16* gates_data = gates.row<const __fp16>(q);
893
894 float I = (float)gates_data[0];
895 float F = (float)gates_data[1];
896 float O = (float)gates_data[2];
897 float G = (float)gates_data[3];
898
899 I = 1.f / (1.f + exp(-I));
900 F = 1.f / (1.f + exp(-F));
901 O = 1.f / (1.f + exp(-O));
902 G = tanh(G);
903
904 float cell2 = F * *cell_ptr + I * G;
905 float H = O * tanh(cell2);
906
907 *cell_ptr++ = cell2;
908 *hidden_ptr++ = H;
909 *output_data++ = (__fp16)H;
910 }
911 }
912
913 return 0;
914 }
915
create_pipeline_fp16s(const Option & opt)916 int LSTM_arm::create_pipeline_fp16s(const Option& opt)
917 {
918 // pack IFOG
919 int num_directions = direction == 2 ? 2 : 1;
920 int size = weight_data_size / num_directions / num_output / 4;
921
922 if (opt.use_fp16_arithmetic)
923 {
924 weight_xc_data_packed.create(size, num_output / 2 + num_output % 2, num_directions, 16u, 8);
925 bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
926 weight_hc_data_packed.create(num_output, num_output / 2 + num_output % 2, num_directions, 16u, 8);
927 }
928 else
929 {
930 weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4);
931 bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
932 weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4);
933 }
934
935 #pragma omp parallel for num_threads(opt.num_threads)
936 for (int dr = 0; dr < num_directions; dr++)
937 {
938 const Mat weight_xc = weight_xc_data.channel(dr);
939 const Mat bias_c = bias_c_data.channel(dr);
940 const Mat weight_hc = weight_hc_data.channel(dr);
941
942 Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
943 Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
944 Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
945
946 const float* bias_c_I = bias_c.row(0);
947 const float* bias_c_F = bias_c.row(1);
948 const float* bias_c_O = bias_c.row(2);
949 const float* bias_c_G = bias_c.row(3);
950
951 __fp16* bias_c_IFOG = bias_c_data_packed_dr.row<__fp16>(0);
952
953 if (opt.use_fp16_arithmetic)
954 {
955 int q = 0;
956 for (; q + 1 < num_output; q += 2)
957 {
958 bias_c_IFOG[0] = (__fp16)bias_c_I[q];
959 bias_c_IFOG[1] = (__fp16)bias_c_F[q];
960 bias_c_IFOG[2] = (__fp16)bias_c_O[q];
961 bias_c_IFOG[3] = (__fp16)bias_c_G[q];
962 bias_c_IFOG[4] = (__fp16)bias_c_I[q + 1];
963 bias_c_IFOG[5] = (__fp16)bias_c_F[q + 1];
964 bias_c_IFOG[6] = (__fp16)bias_c_O[q + 1];
965 bias_c_IFOG[7] = (__fp16)bias_c_G[q + 1];
966
967 bias_c_IFOG += 8;
968
969 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
970 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
971 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
972 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
973 const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + q + 1);
974 const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + q + 1);
975 const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + q + 1);
976 const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + q + 1);
977
978 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
979 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
980 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
981 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
982 const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + q + 1);
983 const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + q + 1);
984 const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + q + 1);
985 const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + q + 1);
986
987 __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2);
988 __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2);
989
990 for (int i = 0; i < size; i++)
991 {
992 weight_xc_IFOG[0] = (__fp16)weight_xc_I[i];
993 weight_xc_IFOG[1] = (__fp16)weight_xc_F[i];
994 weight_xc_IFOG[2] = (__fp16)weight_xc_O[i];
995 weight_xc_IFOG[3] = (__fp16)weight_xc_G[i];
996 weight_xc_IFOG[4] = (__fp16)weight_xc_I_1[i];
997 weight_xc_IFOG[5] = (__fp16)weight_xc_F_1[i];
998 weight_xc_IFOG[6] = (__fp16)weight_xc_O_1[i];
999 weight_xc_IFOG[7] = (__fp16)weight_xc_G_1[i];
1000
1001 weight_xc_IFOG += 8;
1002 }
1003
1004 for (int i = 0; i < num_output; i++)
1005 {
1006 weight_hc_IFOG[0] = (__fp16)weight_hc_I[i];
1007 weight_hc_IFOG[1] = (__fp16)weight_hc_F[i];
1008 weight_hc_IFOG[2] = (__fp16)weight_hc_O[i];
1009 weight_hc_IFOG[3] = (__fp16)weight_hc_G[i];
1010 weight_hc_IFOG[4] = (__fp16)weight_hc_I_1[i];
1011 weight_hc_IFOG[5] = (__fp16)weight_hc_F_1[i];
1012 weight_hc_IFOG[6] = (__fp16)weight_hc_O_1[i];
1013 weight_hc_IFOG[7] = (__fp16)weight_hc_G_1[i];
1014
1015 weight_hc_IFOG += 8;
1016 }
1017 }
1018 for (; q < num_output; q++)
1019 {
1020 bias_c_IFOG[0] = (__fp16)bias_c_I[q];
1021 bias_c_IFOG[1] = (__fp16)bias_c_F[q];
1022 bias_c_IFOG[2] = (__fp16)bias_c_O[q];
1023 bias_c_IFOG[3] = (__fp16)bias_c_G[q];
1024
1025 bias_c_IFOG += 4;
1026
1027 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
1028 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
1029 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
1030 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
1031
1032 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
1033 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
1034 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
1035 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
1036
1037 __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2 + q % 2);
1038 __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2 + q % 2);
1039
1040 for (int i = 0; i < size; i++)
1041 {
1042 weight_xc_IFOG[0] = (__fp16)weight_xc_I[i];
1043 weight_xc_IFOG[1] = (__fp16)weight_xc_F[i];
1044 weight_xc_IFOG[2] = (__fp16)weight_xc_O[i];
1045 weight_xc_IFOG[3] = (__fp16)weight_xc_G[i];
1046
1047 weight_xc_IFOG += 4;
1048 }
1049
1050 for (int i = 0; i < num_output; i++)
1051 {
1052 weight_hc_IFOG[0] = (__fp16)weight_hc_I[i];
1053 weight_hc_IFOG[1] = (__fp16)weight_hc_F[i];
1054 weight_hc_IFOG[2] = (__fp16)weight_hc_O[i];
1055 weight_hc_IFOG[3] = (__fp16)weight_hc_G[i];
1056
1057 weight_hc_IFOG += 4;
1058 }
1059 }
1060 }
1061 else
1062 {
1063 for (int q = 0; q < num_output; q++)
1064 {
1065 bias_c_IFOG[0] = (__fp16)bias_c_I[q];
1066 bias_c_IFOG[1] = (__fp16)bias_c_F[q];
1067 bias_c_IFOG[2] = (__fp16)bias_c_O[q];
1068 bias_c_IFOG[3] = (__fp16)bias_c_G[q];
1069
1070 bias_c_IFOG += 4;
1071
1072 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
1073 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
1074 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
1075 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
1076
1077 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
1078 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
1079 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
1080 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
1081
1082 __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q);
1083 __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q);
1084
1085 for (int i = 0; i < size; i++)
1086 {
1087 weight_xc_IFOG[0] = (__fp16)weight_xc_I[i];
1088 weight_xc_IFOG[1] = (__fp16)weight_xc_F[i];
1089 weight_xc_IFOG[2] = (__fp16)weight_xc_O[i];
1090 weight_xc_IFOG[3] = (__fp16)weight_xc_G[i];
1091
1092 weight_xc_IFOG += 4;
1093 }
1094
1095 for (int i = 0; i < num_output; i++)
1096 {
1097 weight_hc_IFOG[0] = (__fp16)weight_hc_I[i];
1098 weight_hc_IFOG[1] = (__fp16)weight_hc_F[i];
1099 weight_hc_IFOG[2] = (__fp16)weight_hc_O[i];
1100 weight_hc_IFOG[3] = (__fp16)weight_hc_G[i];
1101
1102 weight_hc_IFOG += 4;
1103 }
1104 }
1105 }
1106 }
1107
1108 return 0;
1109 }
1110
forward_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1111 int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1112 {
1113 int T = bottom_blob.h;
1114
1115 int num_directions = direction == 2 ? 2 : 1;
1116
1117 // initial hidden state
1118 Mat hidden(num_output, 4u, opt.workspace_allocator);
1119 if (hidden.empty())
1120 return -100;
1121 hidden.fill(0.f);
1122
1123 Mat cell(num_output, 4u, opt.workspace_allocator);
1124 if (cell.empty())
1125 return -100;
1126 cell.fill(0.f);
1127
1128 top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1129 if (top_blob.empty())
1130 return -100;
1131
1132 // Uni directional
1133 if (direction == 0 || direction == 1)
1134 {
1135 int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1136 if (ret != 0)
1137 return ret;
1138 }
1139
1140 if (direction == 2)
1141 {
1142 Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1143 if (top_blob_forward.empty())
1144 return -100;
1145
1146 Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1147 if (top_blob_reverse.empty())
1148 return -100;
1149
1150 int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1151 if (ret0 != 0)
1152 return ret0;
1153
1154 hidden.fill(0.f);
1155 cell.fill(0.f);
1156
1157 int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
1158 if (ret1 != 0)
1159 return ret1;
1160
1161 // concat w
1162 for (int i = 0; i < T; i++)
1163 {
1164 const __fp16* pf = top_blob_forward.row<const __fp16>(i);
1165 const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
1166 __fp16* ptr = top_blob.row<__fp16>(i);
1167
1168 memcpy(ptr, pf, num_output * sizeof(__fp16));
1169 memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
1170 }
1171 }
1172
1173 return 0;
1174 }
1175
forward_fp16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1176 int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1177 {
1178 const Mat& bottom_blob = bottom_blobs[0];
1179 int T = bottom_blob.h;
1180 Mat& top_blob = top_blobs[0];
1181
1182 top_blob.create(num_output, T, 2u, opt.blob_allocator);
1183 if (top_blob.empty())
1184 return -100;
1185
1186 // copy previous states
1187 Mat hidden;
1188 Mat cell;
1189 cast_float16_to_float32(bottom_blobs[1], hidden, opt);
1190 cast_float16_to_float32(bottom_blobs[2], cell, opt);
1191
1192 // Uni directional
1193 if (direction == 0 || direction == 1)
1194 {
1195 int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1196 if (ret != 0)
1197 return ret;
1198 }
1199
1200 cast_float32_to_float16(hidden, top_blobs[1], opt);
1201 cast_float32_to_float16(cell, top_blobs[2], opt);
1202
1203 return 0;
1204 }
1205
forward_fp16sa(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1206 int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1207 {
1208 int T = bottom_blob.h;
1209
1210 int num_directions = direction == 2 ? 2 : 1;
1211
1212 // initial hidden state
1213 Mat hidden(num_output, 4u, opt.workspace_allocator);
1214 if (hidden.empty())
1215 return -100;
1216 hidden.fill(0.f);
1217
1218 Mat cell(num_output, 4u, opt.workspace_allocator);
1219 if (cell.empty())
1220 return -100;
1221 cell.fill(0.f);
1222
1223 top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1224 if (top_blob.empty())
1225 return -100;
1226
1227 // Uni directional
1228 if (direction == 0 || direction == 1)
1229 {
1230 int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1231 if (ret != 0)
1232 return ret;
1233 }
1234
1235 if (direction == 2)
1236 {
1237 Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1238 if (top_blob_forward.empty())
1239 return -100;
1240
1241 Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1242 if (top_blob_reverse.empty())
1243 return -100;
1244
1245 int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1246 if (ret0 != 0)
1247 return ret0;
1248
1249 hidden.fill(0.f);
1250 cell.fill(0.f);
1251
1252 int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
1253 if (ret1 != 0)
1254 return ret1;
1255
1256 // concat w
1257 for (int i = 0; i < T; i++)
1258 {
1259 const __fp16* pf = top_blob_forward.row<const __fp16>(i);
1260 const __fp16* pr = top_blob_reverse.row<const __fp16>(i);
1261 __fp16* ptr = top_blob.row<__fp16>(i);
1262
1263 memcpy(ptr, pf, num_output * sizeof(__fp16));
1264 memcpy(ptr + num_output, pr, num_output * sizeof(__fp16));
1265 }
1266 }
1267
1268 return 0;
1269 }
1270
forward_fp16sa(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1271 int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1272 {
1273 const Mat& bottom_blob = bottom_blobs[0];
1274 int T = bottom_blob.h;
1275 Mat& top_blob = top_blobs[0];
1276
1277 top_blob.create(num_output, T, 2u, opt.blob_allocator);
1278 if (top_blob.empty())
1279 return -100;
1280
1281 // copy previous states
1282 Mat hidden;
1283 Mat cell;
1284 cast_float16_to_float32(bottom_blobs[1], hidden, opt);
1285 cast_float16_to_float32(bottom_blobs[2], cell, opt);
1286
1287 // Uni directional
1288 if (direction == 0 || direction == 1)
1289 {
1290 int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1291 if (ret != 0)
1292 return ret;
1293 }
1294
1295 cast_float32_to_float16(hidden, top_blobs[1], opt);
1296 cast_float32_to_float16(cell, top_blobs[2], opt);
1297
1298 return 0;
1299 }
1300 #endif
1301
lstm_bf16s(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)1302 static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
1303 {
1304 int size = bottom_blob.w;
1305 int T = bottom_blob.h;
1306
1307 int num_output = top_blob.w;
1308
1309 // 4 x num_output
1310 Mat gates(4, num_output, 4u, opt.workspace_allocator);
1311 if (gates.empty())
1312 return -100;
1313
1314 // unroll
1315 for (int t = 0; t < T; t++)
1316 {
1317 // clip hidden by continuation indicator
1318 // h_cont_{t-1} = cont_t * h_{t-1}
1319 // h_cont_{t-1} = h_{t-1} if cont_t == 1
1320 // 0 otherwise
1321 // calculate hidden
1322 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
1323
1324 int ti = reverse ? T - 1 - t : t;
1325
1326 const unsigned short* x = bottom_blob.row<const unsigned short>(ti);
1327 for (int q = 0; q < num_output; q++)
1328 {
1329 const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4;
1330
1331 // gate I F O G
1332 const unsigned short* weight_xc_IFOG = weight_xc.row<const unsigned short>(q);
1333
1334 const unsigned short* weight_hc_IFOG = weight_hc.row<const unsigned short>(q);
1335
1336 #if __ARM_NEON
1337 float32x4_t _IFOG = vcvt_f32_bf16(vld1_u16(bias_c_IFOG));
1338 float32x4_t _sum1 = vdupq_n_f32(0.f);
1339 float32x4_t _sum2 = vdupq_n_f32(0.f);
1340 float32x4_t _sum3 = vdupq_n_f32(0.f);
1341 #else
1342 float I = bfloat16_to_float32(bias_c_IFOG[0]);
1343 float F = bfloat16_to_float32(bias_c_IFOG[1]);
1344 float O = bfloat16_to_float32(bias_c_IFOG[2]);
1345 float G = bfloat16_to_float32(bias_c_IFOG[3]);
1346 #endif // __ARM_NEON
1347
1348 int i = 0;
1349 #if __ARM_NEON
1350 for (; i + 3 < size; i += 4)
1351 {
1352 float32x4_t _xi = vcvt_f32_bf16(vld1_u16(x + i));
1353
1354 float32x4_t _weight_xc_IFOG_0 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG));
1355 float32x4_t _weight_xc_IFOG_1 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG + 4));
1356 float32x4_t _weight_xc_IFOG_2 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG + 8));
1357 float32x4_t _weight_xc_IFOG_3 = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG + 12));
1358
1359 #if __aarch64__
1360 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0);
1361 _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1);
1362 _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2);
1363 _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3);
1364 #else
1365 _IFOG = vmlaq_lane_f32(_IFOG, _weight_xc_IFOG_0, vget_low_f32(_xi), 0);
1366 _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_IFOG_1, vget_low_f32(_xi), 1);
1367 _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_IFOG_2, vget_high_f32(_xi), 0);
1368 _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_IFOG_3, vget_high_f32(_xi), 1);
1369 #endif
1370
1371 weight_xc_IFOG += 16;
1372 }
1373 #endif // __ARM_NEON
1374 for (; i < size; i++)
1375 {
1376 #if __ARM_NEON
1377 unsigned short xi = x[i];
1378
1379 float32x4_t _xi = vcvt_f32_bf16(vdup_n_u16(xi));
1380 float32x4_t _weight_xc_IFOG = vcvt_f32_bf16(vld1_u16(weight_xc_IFOG));
1381 _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
1382 #else
1383 float xi = bfloat16_to_float32(x[i]);
1384
1385 I += bfloat16_to_float32(weight_xc_IFOG[0]) * xi;
1386 F += bfloat16_to_float32(weight_xc_IFOG[1]) * xi;
1387 O += bfloat16_to_float32(weight_xc_IFOG[2]) * xi;
1388 G += bfloat16_to_float32(weight_xc_IFOG[3]) * xi;
1389 #endif // __ARM_NEON
1390
1391 weight_xc_IFOG += 4;
1392 }
1393
1394 i = 0;
1395 #if __ARM_NEON
1396 for (; i + 3 < num_output; i += 4)
1397 {
1398 float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i);
1399
1400 float32x4_t _weight_hc_IFOG_0 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG));
1401 float32x4_t _weight_hc_IFOG_1 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG + 4));
1402 float32x4_t _weight_hc_IFOG_2 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG + 8));
1403 float32x4_t _weight_hc_IFOG_3 = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG + 12));
1404
1405 #if __aarch64__
1406 _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0);
1407 _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1);
1408 _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2);
1409 _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3);
1410 #else
1411 _IFOG = vmlaq_lane_f32(_IFOG, _weight_hc_IFOG_0, vget_low_f32(_h_cont), 0);
1412 _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_IFOG_1, vget_low_f32(_h_cont), 1);
1413 _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_IFOG_2, vget_high_f32(_h_cont), 0);
1414 _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_IFOG_3, vget_high_f32(_h_cont), 1);
1415 #endif
1416
1417 weight_hc_IFOG += 16;
1418 }
1419 #endif // __ARM_NEON
1420 for (; i < num_output; i++)
1421 {
1422 float h_cont = hidden_state[i];
1423
1424 #if __ARM_NEON
1425 float32x4_t _h_cont = vdupq_n_f32(h_cont);
1426 float32x4_t _weight_hc_IFOG = vcvt_f32_bf16(vld1_u16(weight_hc_IFOG));
1427 _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
1428 #else
1429 I += bfloat16_to_float32(weight_hc_IFOG[0]) * h_cont;
1430 F += bfloat16_to_float32(weight_hc_IFOG[1]) * h_cont;
1431 O += bfloat16_to_float32(weight_hc_IFOG[2]) * h_cont;
1432 G += bfloat16_to_float32(weight_hc_IFOG[3]) * h_cont;
1433 #endif // __ARM_NEON
1434
1435 weight_hc_IFOG += 4;
1436 }
1437
1438 float* gates_data = gates.row(q);
1439
1440 #if __ARM_NEON
1441 _IFOG = vaddq_f32(_IFOG, _sum1);
1442 _sum2 = vaddq_f32(_sum2, _sum3);
1443 _IFOG = vaddq_f32(_IFOG, _sum2);
1444
1445 vst1q_f32(gates_data, _IFOG);
1446 #else
1447 gates_data[0] = I;
1448 gates_data[1] = F;
1449 gates_data[2] = O;
1450 gates_data[3] = G;
1451 #endif // __ARM_NEON
1452 }
1453
1454 // lstm unit
1455 // sigmoid(I)
1456 // sigmoid(F)
1457 // sigmoid(O)
1458 // tanh(G)
1459 // c_t := f_t .* c_{t-1} + i_t .* g_t
1460 // h_t := o_t .* tanh[c_t]
1461 unsigned short* output_data = top_blob.row<unsigned short>(ti);
1462
1463 float* cell_ptr = cell_state;
1464 float* hidden_ptr = hidden_state;
1465
1466 int q = 0;
1467 #if __ARM_NEON
1468 for (; q + 3 < num_output; q += 4)
1469 {
1470 const float* gates_data = gates.row(q);
1471
1472 float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
1473
1474 float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]);
1475 float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]);
1476 float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]);
1477 float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]);
1478
1479 float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr)), vmulq_f32(_I, _G));
1480 float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));
1481
1482 vst1q_f32(cell_ptr, _cell2);
1483 vst1q_f32(hidden_ptr, _H);
1484 vst1_u16(output_data, vcvt_bf16_f32(_H));
1485
1486 cell_ptr += 4;
1487 hidden_ptr += 4;
1488 output_data += 4;
1489 }
1490 #endif // __ARM_NEON
1491 for (; q < num_output; q++)
1492 {
1493 const float* gates_data = gates.row(q);
1494
1495 float I = gates_data[0];
1496 float F = gates_data[1];
1497 float O = gates_data[2];
1498 float G = gates_data[3];
1499
1500 I = 1.f / (1.f + exp(-I));
1501 F = 1.f / (1.f + exp(-F));
1502 O = 1.f / (1.f + exp(-O));
1503 G = tanh(G);
1504
1505 float cell2 = F * *cell_ptr + I * G;
1506 float H = O * tanh(cell2);
1507
1508 *cell_ptr++ = cell2;
1509 *hidden_ptr++ = H;
1510 *output_data++ = float32_to_bfloat16(H);
1511 }
1512 }
1513
1514 return 0;
1515 }
1516
create_pipeline_bf16s(const Option & opt)1517 int LSTM_arm::create_pipeline_bf16s(const Option& opt)
1518 {
1519 // pack IFOG
1520 int num_directions = direction == 2 ? 2 : 1;
1521 int size = weight_data_size / num_directions / num_output / 4;
1522
1523 weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4);
1524 bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
1525 weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4);
1526
1527 #pragma omp parallel for num_threads(opt.num_threads)
1528 for (int dr = 0; dr < num_directions; dr++)
1529 {
1530 const Mat weight_xc = weight_xc_data.channel(dr);
1531 const Mat bias_c = bias_c_data.channel(dr);
1532 const Mat weight_hc = weight_hc_data.channel(dr);
1533
1534 Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
1535 Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
1536 Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);
1537
1538 const float* bias_c_I = bias_c.row(0);
1539 const float* bias_c_F = bias_c.row(1);
1540 const float* bias_c_O = bias_c.row(2);
1541 const float* bias_c_G = bias_c.row(3);
1542
1543 unsigned short* bias_c_IFOG = bias_c_data_packed_dr.row<unsigned short>(0);
1544
1545 for (int q = 0; q < num_output; q++)
1546 {
1547 bias_c_IFOG[0] = float32_to_bfloat16(bias_c_I[q]);
1548 bias_c_IFOG[1] = float32_to_bfloat16(bias_c_F[q]);
1549 bias_c_IFOG[2] = float32_to_bfloat16(bias_c_O[q]);
1550 bias_c_IFOG[3] = float32_to_bfloat16(bias_c_G[q]);
1551
1552 bias_c_IFOG += 4;
1553
1554 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
1555 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
1556 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
1557 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
1558
1559 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
1560 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
1561 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
1562 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
1563
1564 unsigned short* weight_xc_IFOG = weight_xc_data_packed_dr.row<unsigned short>(q);
1565 unsigned short* weight_hc_IFOG = weight_hc_data_packed_dr.row<unsigned short>(q);
1566
1567 for (int i = 0; i < size; i++)
1568 {
1569 weight_xc_IFOG[0] = float32_to_bfloat16(weight_xc_I[i]);
1570 weight_xc_IFOG[1] = float32_to_bfloat16(weight_xc_F[i]);
1571 weight_xc_IFOG[2] = float32_to_bfloat16(weight_xc_O[i]);
1572 weight_xc_IFOG[3] = float32_to_bfloat16(weight_xc_G[i]);
1573
1574 weight_xc_IFOG += 4;
1575 }
1576
1577 for (int i = 0; i < num_output; i++)
1578 {
1579 weight_hc_IFOG[0] = float32_to_bfloat16(weight_hc_I[i]);
1580 weight_hc_IFOG[1] = float32_to_bfloat16(weight_hc_F[i]);
1581 weight_hc_IFOG[2] = float32_to_bfloat16(weight_hc_O[i]);
1582 weight_hc_IFOG[3] = float32_to_bfloat16(weight_hc_G[i]);
1583
1584 weight_hc_IFOG += 4;
1585 }
1586 }
1587 }
1588
1589 return 0;
1590 }
1591
forward_bf16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1592 int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1593 {
1594 int T = bottom_blob.h;
1595
1596 int num_directions = direction == 2 ? 2 : 1;
1597
1598 // initial hidden state
1599 Mat hidden(num_output, 4u, opt.workspace_allocator);
1600 if (hidden.empty())
1601 return -100;
1602 hidden.fill(0.f);
1603
1604 Mat cell(num_output, 4u, opt.workspace_allocator);
1605 if (cell.empty())
1606 return -100;
1607 cell.fill(0.f);
1608
1609 top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator);
1610 if (top_blob.empty())
1611 return -100;
1612
1613 // Uni directional
1614 if (direction == 0 || direction == 1)
1615 {
1616 int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1617 if (ret != 0)
1618 return ret;
1619 }
1620
1621 if (direction == 2)
1622 {
1623 Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator);
1624 if (top_blob_forward.empty())
1625 return -100;
1626
1627 Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator);
1628 if (top_blob_reverse.empty())
1629 return -100;
1630
1631 int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1632 if (ret0 != 0)
1633 return ret0;
1634
1635 hidden.fill(0.f);
1636 cell.fill(0.f);
1637
1638 int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
1639 if (ret1 != 0)
1640 return ret1;
1641
1642 // concat w
1643 for (int i = 0; i < T; i++)
1644 {
1645 const unsigned short* pf = top_blob_forward.row<const unsigned short>(i);
1646 const unsigned short* pr = top_blob_reverse.row<const unsigned short>(i);
1647 unsigned short* ptr = top_blob.row<unsigned short>(i);
1648
1649 memcpy(ptr, pf, num_output * sizeof(unsigned short));
1650 memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short));
1651 }
1652 }
1653
1654 return 0;
1655 }
1656
forward_bf16s(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const1657 int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
1658 {
1659 const Mat& bottom_blob = bottom_blobs[0];
1660 int T = bottom_blob.h;
1661 Mat& top_blob = top_blobs[0];
1662
1663 top_blob.create(num_output, T, 2u, opt.blob_allocator);
1664 if (top_blob.empty())
1665 return -100;
1666
1667 // copy previous states
1668 Mat hidden;
1669 Mat cell;
1670 cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt);
1671 cast_bfloat16_to_float32(bottom_blobs[2], cell, opt);
1672
1673 // Uni directional
1674 if (direction == 0 || direction == 1)
1675 {
1676 int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
1677 if (ret != 0)
1678 return ret;
1679 }
1680
1681 cast_float32_to_bfloat16(hidden, top_blobs[1], opt);
1682 cast_float32_to_bfloat16(cell, top_blobs[2], opt);
1683
1684 return 0;
1685 }
1686
1687 } // namespace ncnn
1688