1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #ifndef RISCV_ACTIVATION_H
16 #define RISCV_ACTIVATION_H
17 
18 #include "fused_activation.h"
19 
20 #if __riscv_vector
21 #ifdef RVV_SPEC_0_7
22 #include "riscv_v_071_fix.h"
23 #else
24 #include <riscv_vector.h>
25 #endif
26 #include "rvv_mathfun.h"
27 #include "rvv_mathfun_fp16s.h"
28 
29 #define _RVV_FLOAT_ACTIVATION_PS(SEW, LMUL, MLEN)                                                                                                            \
30     static inline vfloat##SEW##m##LMUL##_t activation_ps(vfloat##SEW##m##LMUL##_t _v, int activation_type, const ncnn::Mat& activation_params, word_type vl) \
31     {                                                                                                                                                        \
32         if (activation_type == 1)                                                                                                                            \
33         {                                                                                                                                                    \
34             _v = vfmax_vf_f##SEW##m##LMUL(_v, 0.f, vl);                                                                                                      \
35         }                                                                                                                                                    \
36         else if (activation_type == 2)                                                                                                                       \
37         {                                                                                                                                                    \
38             vbool##MLEN##_t _lemask = vmfle_vf_f##SEW##m##LMUL##_b##MLEN(_v, 0.f, vl);                                                                       \
39             _v = vfmul_vf_f##SEW##m##LMUL##_m(_lemask, _v, _v, activation_params[0], vl);                                                                    \
40         }                                                                                                                                                    \
41         else if (activation_type == 3)                                                                                                                       \
42         {                                                                                                                                                    \
43             _v = vfmax_vf_f##SEW##m##LMUL(_v, activation_params[0], vl);                                                                                     \
44             _v = vfmin_vf_f##SEW##m##LMUL(_v, activation_params[1], vl);                                                                                     \
45         }                                                                                                                                                    \
46         else if (activation_type == 4)                                                                                                                       \
47         {                                                                                                                                                    \
48             _v = sigmoid_ps(_v, vl);                                                                                                                         \
49         }                                                                                                                                                    \
50         else if (activation_type == 5)                                                                                                                       \
51         {                                                                                                                                                    \
52             _v = vfmul_vv_f##SEW##m##LMUL(_v, tanh_ps(log_ps(vfadd_vf_f##SEW##m##LMUL(exp_ps(_v, vl), 1.f, vl), vl), vl), vl);                               \
53         }                                                                                                                                                    \
54         else if (activation_type == 6)                                                                                                                       \
55         {                                                                                                                                                    \
56             const float alpha = activation_params[0];                                                                                                        \
57             const float beta = activation_params[1];                                                                                                         \
58             const float lower = -beta / alpha;                                                                                                               \
59             const float upper = (1.f / alpha) + lower;                                                                                                       \
60             vbool##MLEN##_t _lower = vmflt_vf_f##SEW##m##LMUL##_b##MLEN(_v, lower, vl);                                                                      \
61             vbool##MLEN##_t _higher = vmfgt_vf_f##SEW##m##LMUL##_b##MLEN(_v, upper, vl);                                                                     \
62             vbool##MLEN##_t _apply = vmnor_mm_b##MLEN(_lower, _higher, vl);                                                                                  \
63             _v = vfmerge_vfm_f##SEW##m##LMUL(_lower, _v, .0f, vl);                                                                                           \
64                                                                                                                                                              \
65             vfloat##SEW##m##LMUL##_t _p0 = vfadd_vf_f##SEW##m##LMUL##_m(                                                                                     \
66                 _apply, _v, /*op1*/ vfmul_vf_f##SEW##m##LMUL##_m(_apply, _v, _v, alpha, vl), beta,                                                           \
67                 vl);                                                                                                                                         \
68             _v = vfmul_vv_f##SEW##m##LMUL##_m(_apply, _v, /*op1*/ _v, _p0, vl);                                                                              \
69         }                                                                                                                                                    \
70                                                                                                                                                              \
71         return _v;                                                                                                                                           \
72     }
73 
74 _RVV_FLOAT_ACTIVATION_PS(16, 1, 16)
75 _RVV_FLOAT_ACTIVATION_PS(16, 2, 8)
76 _RVV_FLOAT_ACTIVATION_PS(16, 4, 4)
77 _RVV_FLOAT_ACTIVATION_PS(16, 8, 2)
78 _RVV_FLOAT_ACTIVATION_PS(32, 1, 32)
79 _RVV_FLOAT_ACTIVATION_PS(32, 2, 16)
80 _RVV_FLOAT_ACTIVATION_PS(32, 4, 8)
81 _RVV_FLOAT_ACTIVATION_PS(32, 8, 4)
82 
83 #endif // __riscv_vector
84 
85 #endif // RISCV_ACTIVATION_H
86