1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "sigmoid_riscv.h"
16
17 #if __riscv_vector
18 #ifdef RVV_SPEC_0_7
19 #include "riscv_v_071_fix.h"
20 #else
21 #include <riscv_vector.h>
22 #endif
23 #include "rvv_mathfun.h"
24 #include "rvv_mathfun_fp16s.h"
25 #endif // __riscv_vector
26
27 #include <math.h>
28
29 namespace ncnn {
30
Sigmoid_riscv()31 Sigmoid_riscv::Sigmoid_riscv()
32 {
33 #if __riscv_vector
34 support_packing = true;
35 #if __riscv_zfh
36 support_fp16_storage = true;
37 #endif
38 #endif // __riscv_vector
39 }
40
forward_inplace(Mat & bottom_top_blob,const Option & opt) const41 int Sigmoid_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
42 {
43 int elembits = bottom_top_blob.elembits();
44
45 #if __riscv_vector && __riscv_zfh
46 if (opt.use_fp16_storage && elembits == 16)
47 {
48 if (opt.use_fp16_arithmetic)
49 return forward_inplace_fp16sa(bottom_top_blob, opt);
50 else
51 return forward_inplace_fp16s(bottom_top_blob, opt);
52 }
53 #endif
54
55 int w = bottom_top_blob.w;
56 int h = bottom_top_blob.h;
57 int channels = bottom_top_blob.c;
58 int size = w * h;
59 int elempack = bottom_top_blob.elempack;
60
61 #pragma omp parallel for num_threads(opt.num_threads)
62 for (int q = 0; q < channels; q++)
63 {
64 float* ptr = bottom_top_blob.channel(q);
65
66 #if __riscv_vector
67 int n = size * elempack;
68 while (n > 0)
69 {
70 word_type vl = vsetvl_e32m8(n);
71
72 vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
73 _p = sigmoid_ps(_p, vl);
74 vse32_v_f32m8(ptr, _p, vl);
75
76 ptr += vl;
77 n -= vl;
78 }
79 #else // __riscv_vector
80 for (int i = 0; i < size; i++)
81 {
82 *ptr = 1.f / (1.f + exp(-*ptr));
83
84 ptr++;
85 }
86 #endif // __riscv_vector
87 }
88
89 return 0;
90 }
91
92 #if __riscv_vector && __riscv_zfh
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const93 int Sigmoid_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
94 {
95 int w = bottom_top_blob.w;
96 int h = bottom_top_blob.h;
97 int channels = bottom_top_blob.c;
98 int size = w * h;
99 int elempack = bottom_top_blob.elempack;
100
101 #pragma omp parallel for num_threads(opt.num_threads)
102 for (int q = 0; q < channels; q++)
103 {
104 __fp16* ptr = bottom_top_blob.channel(q);
105
106 int n = size * elempack;
107 while (n > 0)
108 {
109 word_type vl = vsetvl_e16m4(n);
110
111 vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
112 _p = sigmoid_ps(_p, vl);
113 vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
114
115 ptr += vl;
116 n -= vl;
117 }
118 }
119
120 return 0;
121 }
122
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const123 int Sigmoid_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
124 {
125 int w = bottom_top_blob.w;
126 int h = bottom_top_blob.h;
127 int channels = bottom_top_blob.c;
128 int size = w * h;
129 int elempack = bottom_top_blob.elempack;
130
131 #pragma omp parallel for num_threads(opt.num_threads)
132 for (int q = 0; q < channels; q++)
133 {
134 __fp16* ptr = bottom_top_blob.channel(q);
135
136 int n = size * elempack;
137 while (n > 0)
138 {
139 word_type vl = vsetvl_e16m8(n);
140
141 vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
142 _p = sigmoid_ps(_p, vl);
143 vse16_v_f16m8(ptr, _p, vl);
144
145 ptr += vl;
146 n -= vl;
147 }
148 }
149
150 return 0;
151 }
152 #endif // __riscv_vector && __riscv_zfh
153
154 } // namespace ncnn
155