1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv1x1s1_sgemm_transform_kernel_pack4_sse(const Mat & kernel,Mat & kernel_pack4,int inch,int outch)15 static void conv1x1s1_sgemm_transform_kernel_pack4_sse(const Mat& kernel, Mat& kernel_pack4, int inch, int outch)
16 {
17 // interleave
18 // src = inch-outch
19 // dst = 4b-4a-inch/4a-outch/4b
20 kernel_pack4.create(1, inch / 4, outch / 4, (size_t)4u * 16, 16);
21
22 int q = 0;
23 for (; q + 3 < outch; q += 4)
24 {
25 const float* k0 = (const float*)kernel + (q + 0) * inch;
26 const float* k1 = (const float*)kernel + (q + 1) * inch;
27 const float* k2 = (const float*)kernel + (q + 2) * inch;
28 const float* k3 = (const float*)kernel + (q + 3) * inch;
29
30 float* g0 = kernel_pack4.channel(q / 4);
31
32 for (int p = 0; p + 3 < inch; p += 4)
33 {
34 g0[0] = k0[0];
35 g0[1] = k1[0];
36 g0[2] = k2[0];
37 g0[3] = k3[0];
38
39 g0[4] = k0[1];
40 g0[5] = k1[1];
41 g0[6] = k2[1];
42 g0[7] = k3[1];
43
44 g0[8] = k0[2];
45 g0[9] = k1[2];
46 g0[10] = k2[2];
47 g0[11] = k3[2];
48
49 g0[12] = k0[3];
50 g0[13] = k1[3];
51 g0[14] = k2[3];
52 g0[15] = k3[3];
53
54 k0 += 4;
55 k1 += 4;
56 k2 += 4;
57 k3 += 4;
58 g0 += 16;
59 }
60 }
61 }
62
conv1x1s1_sgemm_pack4_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)63 static void conv1x1s1_sgemm_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
64 {
65 int w = bottom_blob.w;
66 int h = bottom_blob.h;
67 int inch = bottom_blob.c;
68 int outch = top_blob.c;
69
70 size_t elemsize = bottom_blob.elemsize;
71 int elempack = bottom_blob.elempack;
72
73 const int size = w * h;
74
75 const float* bias = _bias;
76
77 // interleave
78 Mat tmp(4, inch, size / 4 + (size % 4) / 2 + size % 2, elemsize, elempack, opt.workspace_allocator);
79 {
80 int nn_size;
81 int remain_size_start;
82
83 remain_size_start = 0;
84 nn_size = (size - remain_size_start) >> 2;
85
86 #pragma omp parallel for num_threads(opt.num_threads)
87 for (int ii = 0; ii < nn_size; ii++)
88 {
89 int i = remain_size_start + ii * 4;
90
91 const float* img0 = bottom_blob.channel(0);
92 img0 += i * 4;
93
94 float* tmpptr = tmp.channel(i / 4);
95
96 for (int q = 0; q < inch; q++)
97 {
98 __m128 _r0 = _mm_loadu_ps(img0);
99 __m128 _r1 = _mm_loadu_ps(img0 + 4);
100 __m128 _r2 = _mm_loadu_ps(img0 + 8);
101 __m128 _r3 = _mm_loadu_ps(img0 + 12);
102 _mm_storeu_ps(tmpptr, _r0);
103 _mm_storeu_ps(tmpptr + 4, _r1);
104 _mm_storeu_ps(tmpptr + 8, _r2);
105 _mm_storeu_ps(tmpptr + 12, _r3);
106
107 tmpptr += 16;
108 img0 += bottom_blob.cstep * 4;
109 }
110 }
111
112 remain_size_start += nn_size << 2;
113 nn_size = (size - remain_size_start) >> 1;
114
115 #pragma omp parallel for num_threads(opt.num_threads)
116 for (int ii = 0; ii < nn_size; ii++)
117 {
118 int i = remain_size_start + ii * 2;
119
120 const float* img0 = bottom_blob.channel(0);
121 img0 += i * 4;
122
123 float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2);
124
125 for (int q = 0; q < inch; q++)
126 {
127 __m128 _r0 = _mm_loadu_ps(img0);
128 __m128 _r1 = _mm_loadu_ps(img0 + 4);
129 _mm_storeu_ps(tmpptr, _r0);
130 _mm_storeu_ps(tmpptr + 4, _r1);
131
132 tmpptr += 8;
133 img0 += bottom_blob.cstep * 4;
134 }
135 }
136
137 remain_size_start += nn_size << 1;
138
139 #pragma omp parallel for num_threads(opt.num_threads)
140 for (int i = remain_size_start; i < size; i++)
141 {
142 const float* img0 = bottom_blob.channel(0);
143 img0 += i * 4;
144
145 float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2);
146
147 for (int q = 0; q < inch; q++)
148 {
149 __m128 _r0 = _mm_loadu_ps(img0);
150 _mm_storeu_ps(tmpptr, _r0);
151
152 tmpptr += 4;
153 img0 += bottom_blob.cstep * 4;
154 }
155 }
156 }
157
158 #pragma omp parallel for num_threads(opt.num_threads)
159 for (int p = 0; p < outch; p++)
160 {
161 float* outptr0 = top_blob.channel(p);
162
163 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
164 const float* biasptr = bias ? bias + p * 4 : zeros;
165
166 int i = 0;
167 for (; i + 3 < size; i += 4)
168 {
169 float* tmpptr = tmp.channel(i / 4);
170 const float* kptr0 = (const float*)kernel.channel(p);
171
172 __m128 _sum0 = _mm_loadu_ps(biasptr);
173 __m128 _sum1 = _mm_loadu_ps(biasptr);
174 __m128 _sum2 = _mm_loadu_ps(biasptr);
175 __m128 _sum3 = _mm_loadu_ps(biasptr);
176
177 for (int q = 0; q < inch; q++)
178 {
179 __m128 _val00 = _mm_load1_ps(tmpptr);
180 __m128 _val01 = _mm_load1_ps(tmpptr + 1);
181 __m128 _val02 = _mm_load1_ps(tmpptr + 2);
182 __m128 _val03 = _mm_load1_ps(tmpptr + 3);
183 __m128 _val10 = _mm_load1_ps(tmpptr + 4);
184 __m128 _val11 = _mm_load1_ps(tmpptr + 5);
185 __m128 _val12 = _mm_load1_ps(tmpptr + 6);
186 __m128 _val13 = _mm_load1_ps(tmpptr + 7);
187 __m128 _val20 = _mm_load1_ps(tmpptr + 8);
188 __m128 _val21 = _mm_load1_ps(tmpptr + 9);
189 __m128 _val22 = _mm_load1_ps(tmpptr + 10);
190 __m128 _val23 = _mm_load1_ps(tmpptr + 11);
191 __m128 _val30 = _mm_load1_ps(tmpptr + 12);
192 __m128 _val31 = _mm_load1_ps(tmpptr + 13);
193 __m128 _val32 = _mm_load1_ps(tmpptr + 14);
194 __m128 _val33 = _mm_load1_ps(tmpptr + 15);
195
196 __m128 _w0 = _mm_load_ps(kptr0);
197 __m128 _w1 = _mm_load_ps(kptr0 + 4);
198 __m128 _w2 = _mm_load_ps(kptr0 + 8);
199 __m128 _w3 = _mm_load_ps(kptr0 + 12);
200
201 #if __AVX__
202 _sum0 = _mm_fmadd_ps(_w0, _val00, _sum0);
203 _sum0 = _mm_fmadd_ps(_w1, _val01, _sum0);
204 _sum0 = _mm_fmadd_ps(_w2, _val02, _sum0);
205 _sum0 = _mm_fmadd_ps(_w3, _val03, _sum0);
206 _sum1 = _mm_fmadd_ps(_w0, _val10, _sum1);
207 _sum1 = _mm_fmadd_ps(_w1, _val11, _sum1);
208 _sum1 = _mm_fmadd_ps(_w2, _val12, _sum1);
209 _sum1 = _mm_fmadd_ps(_w3, _val13, _sum1);
210 _sum2 = _mm_fmadd_ps(_w0, _val20, _sum2);
211 _sum2 = _mm_fmadd_ps(_w1, _val21, _sum2);
212 _sum2 = _mm_fmadd_ps(_w2, _val22, _sum2);
213 _sum2 = _mm_fmadd_ps(_w3, _val23, _sum2);
214 _sum3 = _mm_fmadd_ps(_w0, _val30, _sum3);
215 _sum3 = _mm_fmadd_ps(_w1, _val31, _sum3);
216 _sum3 = _mm_fmadd_ps(_w2, _val32, _sum3);
217 _sum3 = _mm_fmadd_ps(_w3, _val33, _sum3);
218 #else
219 _sum0 = _mm_add_ps(_mm_mul_ps(_w0, _val00), _sum0);
220 _sum0 = _mm_add_ps(_mm_mul_ps(_w1, _val01), _sum0);
221 _sum0 = _mm_add_ps(_mm_mul_ps(_w2, _val02), _sum0);
222 _sum0 = _mm_add_ps(_mm_mul_ps(_w3, _val03), _sum0);
223 _sum1 = _mm_add_ps(_mm_mul_ps(_w0, _val10), _sum1);
224 _sum1 = _mm_add_ps(_mm_mul_ps(_w1, _val11), _sum1);
225 _sum1 = _mm_add_ps(_mm_mul_ps(_w2, _val12), _sum1);
226 _sum1 = _mm_add_ps(_mm_mul_ps(_w3, _val13), _sum1);
227 _sum2 = _mm_add_ps(_mm_mul_ps(_w0, _val20), _sum2);
228 _sum2 = _mm_add_ps(_mm_mul_ps(_w1, _val21), _sum2);
229 _sum2 = _mm_add_ps(_mm_mul_ps(_w2, _val22), _sum2);
230 _sum2 = _mm_add_ps(_mm_mul_ps(_w3, _val23), _sum2);
231 _sum3 = _mm_add_ps(_mm_mul_ps(_w0, _val30), _sum3);
232 _sum3 = _mm_add_ps(_mm_mul_ps(_w1, _val31), _sum3);
233 _sum3 = _mm_add_ps(_mm_mul_ps(_w2, _val32), _sum3);
234 _sum3 = _mm_add_ps(_mm_mul_ps(_w3, _val33), _sum3);
235 #endif
236
237 tmpptr += 16;
238 kptr0 += 16;
239 }
240
241 _mm_store_ps(outptr0, _sum0);
242 _mm_store_ps(outptr0 + 4, _sum1);
243 _mm_store_ps(outptr0 + 8, _sum2);
244 _mm_store_ps(outptr0 + 12, _sum3);
245 outptr0 += 16;
246 }
247 for (; i + 1 < size; i += 2)
248 {
249 float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2);
250 const float* kptr0 = (const float*)kernel.channel(p);
251
252 __m128 _sum0 = _mm_loadu_ps(biasptr);
253 __m128 _sum1 = _mm_loadu_ps(biasptr);
254
255 for (int q = 0; q < inch; q++)
256 {
257 __m128 _val00 = _mm_load1_ps(tmpptr);
258 __m128 _val01 = _mm_load1_ps(tmpptr + 1);
259 __m128 _val02 = _mm_load1_ps(tmpptr + 2);
260 __m128 _val03 = _mm_load1_ps(tmpptr + 3);
261 __m128 _val10 = _mm_load1_ps(tmpptr + 4);
262 __m128 _val11 = _mm_load1_ps(tmpptr + 5);
263 __m128 _val12 = _mm_load1_ps(tmpptr + 6);
264 __m128 _val13 = _mm_load1_ps(tmpptr + 7);
265
266 __m128 _w0 = _mm_load_ps(kptr0);
267 __m128 _w1 = _mm_load_ps(kptr0 + 4);
268 __m128 _w2 = _mm_load_ps(kptr0 + 8);
269 __m128 _w3 = _mm_load_ps(kptr0 + 12);
270
271 #if __AVX__
272 _sum0 = _mm_fmadd_ps(_w0, _val00, _sum0);
273 _sum0 = _mm_fmadd_ps(_w1, _val01, _sum0);
274 _sum0 = _mm_fmadd_ps(_w2, _val02, _sum0);
275 _sum0 = _mm_fmadd_ps(_w3, _val03, _sum0);
276 _sum1 = _mm_fmadd_ps(_w0, _val10, _sum1);
277 _sum1 = _mm_fmadd_ps(_w1, _val11, _sum1);
278 _sum1 = _mm_fmadd_ps(_w2, _val12, _sum1);
279 _sum1 = _mm_fmadd_ps(_w3, _val13, _sum1);
280 #else
281 _sum0 = _mm_add_ps(_mm_mul_ps(_w0, _val00), _sum0);
282 _sum0 = _mm_add_ps(_mm_mul_ps(_w1, _val01), _sum0);
283 _sum0 = _mm_add_ps(_mm_mul_ps(_w2, _val02), _sum0);
284 _sum0 = _mm_add_ps(_mm_mul_ps(_w3, _val03), _sum0);
285 _sum1 = _mm_add_ps(_mm_mul_ps(_w0, _val10), _sum1);
286 _sum1 = _mm_add_ps(_mm_mul_ps(_w1, _val11), _sum1);
287 _sum1 = _mm_add_ps(_mm_mul_ps(_w2, _val12), _sum1);
288 _sum1 = _mm_add_ps(_mm_mul_ps(_w3, _val13), _sum1);
289 #endif
290
291 tmpptr += 8;
292 kptr0 += 16;
293 }
294
295 _mm_store_ps(outptr0, _sum0);
296 _mm_store_ps(outptr0 + 4, _sum1);
297 outptr0 += 8;
298 }
299 for (; i < size; i++)
300 {
301 float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2);
302 const float* kptr0 = (const float*)kernel.channel(p);
303
304 __m128 _sum = _mm_loadu_ps(biasptr);
305
306 for (int q = 0; q < inch; q++)
307 {
308 __m128 _val0 = _mm_load1_ps(tmpptr);
309 __m128 _val1 = _mm_load1_ps(tmpptr + 1);
310 __m128 _val2 = _mm_load1_ps(tmpptr + 2);
311 __m128 _val3 = _mm_load1_ps(tmpptr + 3);
312
313 __m128 _w0 = _mm_load_ps(kptr0);
314 __m128 _w1 = _mm_load_ps(kptr0 + 4);
315 __m128 _w2 = _mm_load_ps(kptr0 + 8);
316 __m128 _w3 = _mm_load_ps(kptr0 + 12);
317
318 #if __AVX__
319 _sum = _mm_fmadd_ps(_w0, _val0, _sum);
320 _sum = _mm_fmadd_ps(_w1, _val1, _sum);
321 _sum = _mm_fmadd_ps(_w2, _val2, _sum);
322 _sum = _mm_fmadd_ps(_w3, _val3, _sum);
323 #else
324 _sum = _mm_add_ps(_mm_mul_ps(_w0, _val0), _sum);
325 _sum = _mm_add_ps(_mm_mul_ps(_w1, _val1), _sum);
326 _sum = _mm_add_ps(_mm_mul_ps(_w2, _val2), _sum);
327 _sum = _mm_add_ps(_mm_mul_ps(_w3, _val3), _sum);
328 #endif
329
330 tmpptr += 4;
331 kptr0 += 16;
332 }
333
334 _mm_store_ps(outptr0, _sum);
335 outptr0 += 4;
336 }
337 }
338
339 // // NOTE sgemm
340 // for (; p<outch; p++)
341 // {
342 // Mat out0 = top_blob.channel(p);
343 //
344 // const float bias0 = bias ? bias[p] : 0.f;
345 //
346 // float* outptr0 = out0;
347 //
348 // for (int i=0; i<size; i++)
349 // {
350 // float sum = bias0;
351 //
352 // const float* kptr = _kernel.channel(p);
353 //
354 // for (int q=0; q<inch; q++)
355 // {
356 // const float* img0 = bottom_blob.channel(q);
357 //
358 // sum += img0[i] * kptr[0];
359 // kptr ++;
360 // }
361 //
362 // outptr0[i] = sum;
363 // }
364 // }
365 }
366
conv1x1s2_pack4_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)367 static void conv1x1s2_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
368 {
369 int w = bottom_blob.w;
370 int channels = bottom_blob.c;
371 size_t elemsize = bottom_blob.elemsize;
372 int elempack = bottom_blob.elempack;
373
374 int outw = top_blob.w;
375 int outh = top_blob.h;
376
377 const int tailstep = (w - 2 * outw + w) * 4;
378
379 Mat bottom_blob_shrinked;
380 bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
381
382 #pragma omp parallel for num_threads(opt.num_threads)
383 for (int p = 0; p < channels; p++)
384 {
385 const float* r0 = bottom_blob.channel(p);
386 float* outptr = bottom_blob_shrinked.channel(p);
387
388 for (int i = 0; i < outh; i++)
389 {
390 for (int j = 0; j < outw; j++)
391 {
392 __m128 _v = _mm_load_ps(r0);
393 _mm_store_ps(outptr, _v);
394
395 r0 += 8;
396 outptr += 4;
397 }
398
399 r0 += tailstep;
400 }
401 }
402
403 conv1x1s1_sgemm_pack4_sse(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
404 }
405