1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "tanh_riscv.h"
16
17 #if __riscv_vector
18 #ifdef RVV_SPEC_0_7
19 #include "riscv_v_071_fix.h"
20 #else
21 #include <riscv_vector.h>
22 #endif
23 #include "rvv_mathfun.h"
24 #include "rvv_mathfun_fp16s.h"
25 #endif // __riscv_vector
26
27 #include <math.h>
28
29 namespace ncnn {
30
TanH_riscv()31 TanH_riscv::TanH_riscv()
32 {
33 #if __riscv_vector
34 support_packing = true;
35 #if __riscv_zfh
36 support_fp16_storage = true;
37 #endif
38 #endif // __riscv_vector
39 }
40
forward_inplace(Mat & bottom_top_blob,const Option & opt) const41 int TanH_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
42 {
43 int elembits = bottom_top_blob.elembits();
44
45 #if __riscv_vector && __riscv_zfh
46 if (opt.use_fp16_storage && elembits == 16)
47 {
48 if (opt.use_fp16_arithmetic)
49 return forward_inplace_fp16sa(bottom_top_blob, opt);
50 else
51 return forward_inplace_fp16s(bottom_top_blob, opt);
52 }
53 #endif
54
55 int w = bottom_top_blob.w;
56 int h = bottom_top_blob.h;
57 int channels = bottom_top_blob.c;
58 int size = w * h;
59 int elempack = bottom_top_blob.elempack;
60
61 #pragma omp parallel for num_threads(opt.num_threads)
62 for (int q = 0; q < channels; q++)
63 {
64 float* ptr = bottom_top_blob.channel(q);
65
66 #if __riscv_vector
67 int n = size * elempack;
68 while (n > 0)
69 {
70 word_type vl = vsetvl_e32m8(n);
71
72 vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
73 _p = tanh_ps(_p, vl);
74 vse32_v_f32m8(ptr, _p, vl);
75
76 ptr += vl;
77 n -= vl;
78 }
79 #else // __riscv_vector
80 for (int i = 0; i < size; i++)
81 {
82 *ptr = tanh(*ptr);
83 ptr++;
84 }
85 #endif // __riscv_vector
86 }
87
88 return 0;
89 }
90
91 #if __riscv_vector && __riscv_zfh
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const92 int TanH_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
93 {
94 int w = bottom_top_blob.w;
95 int h = bottom_top_blob.h;
96 int channels = bottom_top_blob.c;
97 int size = w * h;
98 int elempack = bottom_top_blob.elempack;
99
100 #pragma omp parallel for num_threads(opt.num_threads)
101 for (int q = 0; q < channels; q++)
102 {
103 __fp16* ptr = bottom_top_blob.channel(q);
104
105 int n = size * elempack;
106 while (n > 0)
107 {
108 word_type vl = vsetvl_e16m4(n);
109
110 vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
111 _p = tanh_ps(_p, vl);
112 vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
113
114 ptr += vl;
115 n -= vl;
116 }
117 }
118
119 return 0;
120 }
121
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const122 int TanH_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
123 {
124 int w = bottom_top_blob.w;
125 int h = bottom_top_blob.h;
126 int channels = bottom_top_blob.c;
127 int size = w * h;
128 int elempack = bottom_top_blob.elempack;
129
130 #pragma omp parallel for num_threads(opt.num_threads)
131 for (int q = 0; q < channels; q++)
132 {
133 __fp16* ptr = bottom_top_blob.channel(q);
134
135 int n = size * elempack;
136 while (n > 0)
137 {
138 word_type vl = vsetvl_e16m8(n);
139
140 vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
141 _p = tanh_ps(_p, vl);
142 vse16_v_f16m8(ptr, _p, vl);
143
144 ptr += vl;
145 n -= vl;
146 }
147 }
148
149 return 0;
150 }
151 #endif // __riscv_vector && __riscv_zfh
152
153 } // namespace ncnn
154