1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "swish_riscv.h"
16 
17 #if __riscv_vector
18 #ifdef RVV_SPEC_0_7
19 #include "riscv_v_071_fix.h"
20 #else
21 #include <riscv_vector.h>
22 #endif
23 #include "rvv_mathfun.h"
24 #include "rvv_mathfun_fp16s.h"
25 #endif // __riscv_vector
26 
27 #include <math.h>
28 
29 namespace ncnn {
30 
Swish_riscv()31 Swish_riscv::Swish_riscv()
32 {
33 #if __riscv_vector
34     support_packing = true;
35 #if __riscv_zfh
36     support_fp16_storage = true;
37 #endif
38 #endif // __riscv_vector
39 }
40 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const41 int Swish_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
42 {
43     int elembits = bottom_top_blob.elembits();
44 
45 #if __riscv_vector && __riscv_zfh
46     if (opt.use_fp16_storage && elembits == 16)
47     {
48         if (opt.use_fp16_arithmetic)
49             return forward_inplace_fp16sa(bottom_top_blob, opt);
50         else
51             return forward_inplace_fp16s(bottom_top_blob, opt);
52     }
53 #endif
54 
55     int w = bottom_top_blob.w;
56     int h = bottom_top_blob.h;
57     int channels = bottom_top_blob.c;
58     int size = w * h;
59     int elempack = bottom_top_blob.elempack;
60 
61     #pragma omp parallel for num_threads(opt.num_threads)
62     for (int q = 0; q < channels; q++)
63     {
64         float* ptr = bottom_top_blob.channel(q);
65 
66 #if __riscv_vector
67         int n = size * elempack;
68         while (n > 0)
69         {
70             word_type vl = vsetvl_e32m8(n);
71 
72             vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
73             _p = vfdiv_vv_f32m8(_p, vfadd_vf_f32m8(exp_ps(vfneg_v_f32m8(_p, vl), vl), 1.f, vl), vl);
74             vse32_v_f32m8(ptr, _p, vl);
75 
76             ptr += vl;
77             n -= vl;
78         }
79 #else  // __riscv_vector
80         for (int i = 0; i < size; i++)
81         {
82             *ptr = *ptr / (1.f + exp(-*ptr));
83             ptr++;
84         }
85 #endif // __riscv_vector
86     }
87 
88     return 0;
89 }
90 
91 #if __riscv_vector && __riscv_zfh
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const92 int Swish_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
93 {
94     int w = bottom_top_blob.w;
95     int h = bottom_top_blob.h;
96     int channels = bottom_top_blob.c;
97     int size = w * h;
98     int elempack = bottom_top_blob.elempack;
99 
100     #pragma omp parallel for num_threads(opt.num_threads)
101     for (int q = 0; q < channels; q++)
102     {
103         __fp16* ptr = bottom_top_blob.channel(q);
104 
105         int n = size * elempack;
106         while (n > 0)
107         {
108             word_type vl = vsetvl_e16m4(n);
109 
110             vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
111             _p = vfdiv_vv_f32m8(_p, vfadd_vf_f32m8(exp_ps(vfneg_v_f32m8(_p, vl), vl), 1.f, vl), vl);
112             vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
113 
114             ptr += vl;
115             n -= vl;
116         }
117     }
118 
119     return 0;
120 }
121 
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const122 int Swish_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
123 {
124     int w = bottom_top_blob.w;
125     int h = bottom_top_blob.h;
126     int channels = bottom_top_blob.c;
127     int size = w * h;
128     int elempack = bottom_top_blob.elempack;
129 
130     #pragma omp parallel for num_threads(opt.num_threads)
131     for (int q = 0; q < channels; q++)
132     {
133         __fp16* ptr = bottom_top_blob.channel(q);
134 
135         int n = size * elempack;
136         while (n > 0)
137         {
138             word_type vl = vsetvl_e16m8(n);
139 
140             vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
141             _p = vfdiv_vv_f16m8(_p, vfadd_vf_f16m8(exp_ps(vfneg_v_f16m8(_p, vl), vl), 1.f, vl), vl);
142             vse16_v_f16m8(ptr, _p, vl);
143 
144             ptr += vl;
145             n -= vl;
146         }
147     }
148 
149     return 0;
150 }
151 #endif // __riscv_vector && __riscv_zfh
152 
153 } // namespace ncnn
154