1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
convdw3x3s1_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)15 static void convdw3x3s1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
16 {
17     int w = bottom_blob.w;
18 
19     int outw = top_blob.w;
20     int outh = top_blob.h;
21 
22     const int group = bottom_blob.c;
23 
24     const float* kernel = _kernel;
25     const float* bias = _bias;
26 
27     #pragma omp parallel for num_threads(opt.num_threads)
28     for (int g = 0; g < group; g++)
29     {
30         Mat out = top_blob.channel(g);
31 
32         const float bias0 = bias ? bias[g] : 0.f;
33 
34         const float* kernel0 = kernel + g * 9;
35 
36         float* outptr = out;
37         float* outptr2 = outptr + outw;
38 
39         const float* img0 = bottom_blob.channel(g);
40 
41         const float* r0 = img0;
42         const float* r1 = img0 + w;
43         const float* r2 = img0 + w * 2;
44         const float* r3 = img0 + w * 3;
45 
46         const float* k0 = kernel0;
47         const float* k1 = kernel0 + 3;
48         const float* k2 = kernel0 + 6;
49 
50         int i = 0;
51 
52         for (; i + 1 < outh; i += 2)
53         {
54             int remain = outw;
55 
56             for (; remain > 0; remain--)
57             {
58                 float sum = bias0;
59                 sum += r0[0] * k0[0];
60                 sum += r0[1] * k0[1];
61                 sum += r0[2] * k0[2];
62                 sum += r1[0] * k1[0];
63                 sum += r1[1] * k1[1];
64                 sum += r1[2] * k1[2];
65                 sum += r2[0] * k2[0];
66                 sum += r2[1] * k2[1];
67                 sum += r2[2] * k2[2];
68 
69                 float sum2 = bias0;
70                 sum2 += r1[0] * k0[0];
71                 sum2 += r1[1] * k0[1];
72                 sum2 += r1[2] * k0[2];
73                 sum2 += r2[0] * k1[0];
74                 sum2 += r2[1] * k1[1];
75                 sum2 += r2[2] * k1[2];
76                 sum2 += r3[0] * k2[0];
77                 sum2 += r3[1] * k2[1];
78                 sum2 += r3[2] * k2[2];
79 
80                 *outptr = sum;
81                 *outptr2 = sum2;
82 
83                 r0++;
84                 r1++;
85                 r2++;
86                 r3++;
87                 outptr++;
88                 outptr2++;
89             }
90 
91             r0 += 2 + w;
92             r1 += 2 + w;
93             r2 += 2 + w;
94             r3 += 2 + w;
95 
96             outptr += outw;
97             outptr2 += outw;
98         }
99 
100         for (; i < outh; i++)
101         {
102             int remain = outw;
103 
104             for (; remain > 0; remain--)
105             {
106                 float sum = bias0;
107                 sum += r0[0] * k0[0];
108                 sum += r0[1] * k0[1];
109                 sum += r0[2] * k0[2];
110                 sum += r1[0] * k1[0];
111                 sum += r1[1] * k1[1];
112                 sum += r1[2] * k1[2];
113                 sum += r2[0] * k2[0];
114                 sum += r2[1] * k2[1];
115                 sum += r2[2] * k2[2];
116 
117                 *outptr = sum;
118 
119                 r0++;
120                 r1++;
121                 r2++;
122                 outptr++;
123             }
124 
125             r0 += 2;
126             r1 += 2;
127             r2 += 2;
128         }
129     }
130 }
131 
convdw3x3s2_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)132 static void convdw3x3s2_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
133 {
134     int w = bottom_blob.w;
135 
136     int outw = top_blob.w;
137     int outh = top_blob.h;
138 
139     const int group = bottom_blob.c;
140 
141     const int tailstep = w - 2 * outw + w;
142 
143     const float* kernel = _kernel;
144     const float* bias = _bias;
145 
146     #pragma omp parallel for num_threads(opt.num_threads)
147     for (int g = 0; g < group; g++)
148     {
149         Mat out = top_blob.channel(g);
150 
151         const float bias0 = bias ? bias[g] : 0.f;
152 
153         const float* kernel0 = kernel + g * 9;
154 
155         float* outptr = out;
156 
157         const float* img0 = bottom_blob.channel(g);
158 
159         const float* r0 = img0;
160         const float* r1 = img0 + w;
161         const float* r2 = img0 + w * 2;
162 
163         const float* k0 = kernel0;
164         const float* k1 = kernel0 + 3;
165         const float* k2 = kernel0 + 6;
166 
167         int i = 0;
168 
169         for (; i < outh; i++)
170         {
171             int remain = outw;
172 
173             for (; remain > 0; remain--)
174             {
175                 float sum = bias0;
176                 sum += r0[0] * k0[0];
177                 sum += r0[1] * k0[1];
178                 sum += r0[2] * k0[2];
179                 sum += r1[0] * k1[0];
180                 sum += r1[1] * k1[1];
181                 sum += r1[2] * k1[2];
182                 sum += r2[0] * k2[0];
183                 sum += r2[1] * k2[1];
184                 sum += r2[2] * k2[2];
185 
186                 *outptr = sum;
187 
188                 r0 += 2;
189                 r1 += 2;
190                 r2 += 2;
191                 outptr++;
192             }
193 
194             r0 += tailstep;
195             r1 += tailstep;
196             r2 += tailstep;
197         }
198     }
199 }
200