1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "hardswish_mips.h"
16 
17 #if __mips_msa
18 #include <msa.h>
19 #endif // __mips_msa
20 
21 #include "mips_usability.h"
22 
23 namespace ncnn {
24 
HardSwish_mips()25 HardSwish_mips::HardSwish_mips()
26 {
27 #if __mips_msa
28     support_packing = true;
29 #endif // __mips_msa
30 }
31 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const32 int HardSwish_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
33 {
34     int w = bottom_top_blob.w;
35     int h = bottom_top_blob.h;
36     int channels = bottom_top_blob.c;
37     int size = w * h;
38     int elempack = bottom_top_blob.elempack;
39 
40 #if __mips_msa
41     if (elempack == 4)
42     {
43         #pragma omp parallel for num_threads(opt.num_threads)
44         for (int q = 0; q < channels; q++)
45         {
46             float* ptr = bottom_top_blob.channel(q);
47 
48             v4f32 _zero = (v4f32)__msa_fill_w(0);
49             v4f32 _one = (v4f32)__msa_fill_w_f32(1.f);
50             v4f32 _alpha = (v4f32)__msa_fill_w_f32(alpha);
51             v4f32 _beta = (v4f32)__msa_fill_w_f32(beta);
52 
53             for (int i = 0; i < size; i++)
54             {
55                 __builtin_prefetch(ptr + 32);
56                 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
57                 v4f32 _outp = __msa_fmadd_w(_beta, _p, _alpha);
58                 _outp = __msa_fmax_w(_outp, _zero);
59                 _outp = __msa_fmin_w(_outp, _one);
60                 _outp = __msa_fmul_w(_outp, _p);
61                 __msa_st_w((v4i32)_outp, ptr, 0);
62 
63                 ptr += 4;
64             }
65         }
66 
67         return 0;
68     }
69 #endif // __mips_msa
70 
71     #pragma omp parallel for num_threads(opt.num_threads)
72     for (int q = 0; q < channels; q++)
73     {
74         float* ptr = bottom_top_blob.channel(q);
75 
76         int i = 0;
77 #if __mips_msa
78         v4f32 _zero = (v4f32)__msa_fill_w(0);
79         v4f32 _one = (v4f32)__msa_fill_w_f32(1.f);
80         v4f32 _alpha = (v4f32)__msa_fill_w_f32(alpha);
81         v4f32 _beta = (v4f32)__msa_fill_w_f32(beta);
82 
83         for (; i + 3 < size; i += 4)
84         {
85             __builtin_prefetch(ptr + 32);
86             v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
87             v4f32 _outp = __msa_fmadd_w(_beta, _p, _alpha);
88             _outp = __msa_fmax_w(_outp, _zero);
89             _outp = __msa_fmin_w(_outp, _one);
90             _outp = __msa_fmul_w(_outp, _p);
91             __msa_st_w((v4i32)_outp, ptr, 0);
92 
93             ptr += 4;
94         }
95 #endif // __mips_msa
96         for (; i < size; i++)
97         {
98             if (*ptr < lower)
99                 *ptr = 0.f;
100             else if (*ptr > upper)
101                 ;
102             else
103                 *ptr = *ptr * (*ptr * alpha + beta);
104             ++ptr;
105         }
106     }
107 
108     return 0;
109 }
110 
111 } // namespace ncnn
112