1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "sigmoid_riscv.h"
16 
17 #if __riscv_vector
18 #ifdef RVV_SPEC_0_7
19 #include "riscv_v_071_fix.h"
20 #else
21 #include <riscv_vector.h>
22 #endif
23 #include "rvv_mathfun.h"
24 #include "rvv_mathfun_fp16s.h"
25 #endif // __riscv_vector
26 
27 #include <math.h>
28 
29 namespace ncnn {
30 
Sigmoid_riscv()31 Sigmoid_riscv::Sigmoid_riscv()
32 {
33 #if __riscv_vector
34     support_packing = true;
35 #if __riscv_zfh
36     support_fp16_storage = true;
37 #endif
38 #endif // __riscv_vector
39 }
40 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const41 int Sigmoid_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
42 {
43     int elembits = bottom_top_blob.elembits();
44 
45 #if __riscv_vector && __riscv_zfh
46     if (opt.use_fp16_storage && elembits == 16)
47     {
48         if (opt.use_fp16_arithmetic)
49             return forward_inplace_fp16sa(bottom_top_blob, opt);
50         else
51             return forward_inplace_fp16s(bottom_top_blob, opt);
52     }
53 #endif
54 
55     int w = bottom_top_blob.w;
56     int h = bottom_top_blob.h;
57     int channels = bottom_top_blob.c;
58     int size = w * h;
59     int elempack = bottom_top_blob.elempack;
60 
61     #pragma omp parallel for num_threads(opt.num_threads)
62     for (int q = 0; q < channels; q++)
63     {
64         float* ptr = bottom_top_blob.channel(q);
65 
66 #if __riscv_vector
67         int n = size * elempack;
68         while (n > 0)
69         {
70             word_type vl = vsetvl_e32m8(n);
71 
72             vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
73             _p = sigmoid_ps(_p, vl);
74             vse32_v_f32m8(ptr, _p, vl);
75 
76             ptr += vl;
77             n -= vl;
78         }
79 #else  // __riscv_vector
80         for (int i = 0; i < size; i++)
81         {
82             *ptr = 1.f / (1.f + exp(-*ptr));
83 
84             ptr++;
85         }
86 #endif // __riscv_vector
87     }
88 
89     return 0;
90 }
91 
92 #if __riscv_vector && __riscv_zfh
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const93 int Sigmoid_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
94 {
95     int w = bottom_top_blob.w;
96     int h = bottom_top_blob.h;
97     int channels = bottom_top_blob.c;
98     int size = w * h;
99     int elempack = bottom_top_blob.elempack;
100 
101     #pragma omp parallel for num_threads(opt.num_threads)
102     for (int q = 0; q < channels; q++)
103     {
104         __fp16* ptr = bottom_top_blob.channel(q);
105 
106         int n = size * elempack;
107         while (n > 0)
108         {
109             word_type vl = vsetvl_e16m4(n);
110 
111             vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
112             _p = sigmoid_ps(_p, vl);
113             vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
114 
115             ptr += vl;
116             n -= vl;
117         }
118     }
119 
120     return 0;
121 }
122 
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const123 int Sigmoid_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
124 {
125     int w = bottom_top_blob.w;
126     int h = bottom_top_blob.h;
127     int channels = bottom_top_blob.c;
128     int size = w * h;
129     int elempack = bottom_top_blob.elempack;
130 
131     #pragma omp parallel for num_threads(opt.num_threads)
132     for (int q = 0; q < channels; q++)
133     {
134         __fp16* ptr = bottom_top_blob.channel(q);
135 
136         int n = size * elempack;
137         while (n > 0)
138         {
139             word_type vl = vsetvl_e16m8(n);
140 
141             vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
142             _p = sigmoid_ps(_p, vl);
143             vse16_v_f16m8(ptr, _p, vl);
144 
145             ptr += vl;
146             n -= vl;
147         }
148     }
149 
150     return 0;
151 }
152 #endif // __riscv_vector && __riscv_zfh
153 
154 } // namespace ncnn
155