1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv1x1s1_sgemm_transform_kernel_pack4_sse(const Mat & kernel,Mat & kernel_pack4,int inch,int outch)15 static void conv1x1s1_sgemm_transform_kernel_pack4_sse(const Mat& kernel, Mat& kernel_pack4, int inch, int outch)
16 {
17     // interleave
18     // src = inch-outch
19     // dst = 4b-4a-inch/4a-outch/4b
20     kernel_pack4.create(1, inch / 4, outch / 4, (size_t)4u * 16, 16);
21 
22     int q = 0;
23     for (; q + 3 < outch; q += 4)
24     {
25         const float* k0 = (const float*)kernel + (q + 0) * inch;
26         const float* k1 = (const float*)kernel + (q + 1) * inch;
27         const float* k2 = (const float*)kernel + (q + 2) * inch;
28         const float* k3 = (const float*)kernel + (q + 3) * inch;
29 
30         float* g0 = kernel_pack4.channel(q / 4);
31 
32         for (int p = 0; p + 3 < inch; p += 4)
33         {
34             g0[0] = k0[0];
35             g0[1] = k1[0];
36             g0[2] = k2[0];
37             g0[3] = k3[0];
38 
39             g0[4] = k0[1];
40             g0[5] = k1[1];
41             g0[6] = k2[1];
42             g0[7] = k3[1];
43 
44             g0[8] = k0[2];
45             g0[9] = k1[2];
46             g0[10] = k2[2];
47             g0[11] = k3[2];
48 
49             g0[12] = k0[3];
50             g0[13] = k1[3];
51             g0[14] = k2[3];
52             g0[15] = k3[3];
53 
54             k0 += 4;
55             k1 += 4;
56             k2 += 4;
57             k3 += 4;
58             g0 += 16;
59         }
60     }
61 }
62 
conv1x1s1_sgemm_pack4_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)63 static void conv1x1s1_sgemm_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
64 {
65     int w = bottom_blob.w;
66     int h = bottom_blob.h;
67     int inch = bottom_blob.c;
68     int outch = top_blob.c;
69 
70     size_t elemsize = bottom_blob.elemsize;
71     int elempack = bottom_blob.elempack;
72 
73     const int size = w * h;
74 
75     const float* bias = _bias;
76 
77     // interleave
78     Mat tmp(4, inch, size / 4 + (size % 4) / 2 + size % 2, elemsize, elempack, opt.workspace_allocator);
79     {
80         int nn_size;
81         int remain_size_start;
82 
83         remain_size_start = 0;
84         nn_size = (size - remain_size_start) >> 2;
85 
86         #pragma omp parallel for num_threads(opt.num_threads)
87         for (int ii = 0; ii < nn_size; ii++)
88         {
89             int i = remain_size_start + ii * 4;
90 
91             const float* img0 = bottom_blob.channel(0);
92             img0 += i * 4;
93 
94             float* tmpptr = tmp.channel(i / 4);
95 
96             for (int q = 0; q < inch; q++)
97             {
98                 __m128 _r0 = _mm_loadu_ps(img0);
99                 __m128 _r1 = _mm_loadu_ps(img0 + 4);
100                 __m128 _r2 = _mm_loadu_ps(img0 + 8);
101                 __m128 _r3 = _mm_loadu_ps(img0 + 12);
102                 _mm_storeu_ps(tmpptr, _r0);
103                 _mm_storeu_ps(tmpptr + 4, _r1);
104                 _mm_storeu_ps(tmpptr + 8, _r2);
105                 _mm_storeu_ps(tmpptr + 12, _r3);
106 
107                 tmpptr += 16;
108                 img0 += bottom_blob.cstep * 4;
109             }
110         }
111 
112         remain_size_start += nn_size << 2;
113         nn_size = (size - remain_size_start) >> 1;
114 
115         #pragma omp parallel for num_threads(opt.num_threads)
116         for (int ii = 0; ii < nn_size; ii++)
117         {
118             int i = remain_size_start + ii * 2;
119 
120             const float* img0 = bottom_blob.channel(0);
121             img0 += i * 4;
122 
123             float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2);
124 
125             for (int q = 0; q < inch; q++)
126             {
127                 __m128 _r0 = _mm_loadu_ps(img0);
128                 __m128 _r1 = _mm_loadu_ps(img0 + 4);
129                 _mm_storeu_ps(tmpptr, _r0);
130                 _mm_storeu_ps(tmpptr + 4, _r1);
131 
132                 tmpptr += 8;
133                 img0 += bottom_blob.cstep * 4;
134             }
135         }
136 
137         remain_size_start += nn_size << 1;
138 
139         #pragma omp parallel for num_threads(opt.num_threads)
140         for (int i = remain_size_start; i < size; i++)
141         {
142             const float* img0 = bottom_blob.channel(0);
143             img0 += i * 4;
144 
145             float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2);
146 
147             for (int q = 0; q < inch; q++)
148             {
149                 __m128 _r0 = _mm_loadu_ps(img0);
150                 _mm_storeu_ps(tmpptr, _r0);
151 
152                 tmpptr += 4;
153                 img0 += bottom_blob.cstep * 4;
154             }
155         }
156     }
157 
158     #pragma omp parallel for num_threads(opt.num_threads)
159     for (int p = 0; p < outch; p++)
160     {
161         float* outptr0 = top_blob.channel(p);
162 
163         const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
164         const float* biasptr = bias ? bias + p * 4 : zeros;
165 
166         int i = 0;
167         for (; i + 3 < size; i += 4)
168         {
169             float* tmpptr = tmp.channel(i / 4);
170             const float* kptr0 = (const float*)kernel.channel(p);
171 
172             __m128 _sum0 = _mm_loadu_ps(biasptr);
173             __m128 _sum1 = _mm_loadu_ps(biasptr);
174             __m128 _sum2 = _mm_loadu_ps(biasptr);
175             __m128 _sum3 = _mm_loadu_ps(biasptr);
176 
177             for (int q = 0; q < inch; q++)
178             {
179                 __m128 _val00 = _mm_load1_ps(tmpptr);
180                 __m128 _val01 = _mm_load1_ps(tmpptr + 1);
181                 __m128 _val02 = _mm_load1_ps(tmpptr + 2);
182                 __m128 _val03 = _mm_load1_ps(tmpptr + 3);
183                 __m128 _val10 = _mm_load1_ps(tmpptr + 4);
184                 __m128 _val11 = _mm_load1_ps(tmpptr + 5);
185                 __m128 _val12 = _mm_load1_ps(tmpptr + 6);
186                 __m128 _val13 = _mm_load1_ps(tmpptr + 7);
187                 __m128 _val20 = _mm_load1_ps(tmpptr + 8);
188                 __m128 _val21 = _mm_load1_ps(tmpptr + 9);
189                 __m128 _val22 = _mm_load1_ps(tmpptr + 10);
190                 __m128 _val23 = _mm_load1_ps(tmpptr + 11);
191                 __m128 _val30 = _mm_load1_ps(tmpptr + 12);
192                 __m128 _val31 = _mm_load1_ps(tmpptr + 13);
193                 __m128 _val32 = _mm_load1_ps(tmpptr + 14);
194                 __m128 _val33 = _mm_load1_ps(tmpptr + 15);
195 
196                 __m128 _w0 = _mm_load_ps(kptr0);
197                 __m128 _w1 = _mm_load_ps(kptr0 + 4);
198                 __m128 _w2 = _mm_load_ps(kptr0 + 8);
199                 __m128 _w3 = _mm_load_ps(kptr0 + 12);
200 
201 #if __AVX__
202                 _sum0 = _mm_fmadd_ps(_w0, _val00, _sum0);
203                 _sum0 = _mm_fmadd_ps(_w1, _val01, _sum0);
204                 _sum0 = _mm_fmadd_ps(_w2, _val02, _sum0);
205                 _sum0 = _mm_fmadd_ps(_w3, _val03, _sum0);
206                 _sum1 = _mm_fmadd_ps(_w0, _val10, _sum1);
207                 _sum1 = _mm_fmadd_ps(_w1, _val11, _sum1);
208                 _sum1 = _mm_fmadd_ps(_w2, _val12, _sum1);
209                 _sum1 = _mm_fmadd_ps(_w3, _val13, _sum1);
210                 _sum2 = _mm_fmadd_ps(_w0, _val20, _sum2);
211                 _sum2 = _mm_fmadd_ps(_w1, _val21, _sum2);
212                 _sum2 = _mm_fmadd_ps(_w2, _val22, _sum2);
213                 _sum2 = _mm_fmadd_ps(_w3, _val23, _sum2);
214                 _sum3 = _mm_fmadd_ps(_w0, _val30, _sum3);
215                 _sum3 = _mm_fmadd_ps(_w1, _val31, _sum3);
216                 _sum3 = _mm_fmadd_ps(_w2, _val32, _sum3);
217                 _sum3 = _mm_fmadd_ps(_w3, _val33, _sum3);
218 #else
219                 _sum0 = _mm_add_ps(_mm_mul_ps(_w0, _val00), _sum0);
220                 _sum0 = _mm_add_ps(_mm_mul_ps(_w1, _val01), _sum0);
221                 _sum0 = _mm_add_ps(_mm_mul_ps(_w2, _val02), _sum0);
222                 _sum0 = _mm_add_ps(_mm_mul_ps(_w3, _val03), _sum0);
223                 _sum1 = _mm_add_ps(_mm_mul_ps(_w0, _val10), _sum1);
224                 _sum1 = _mm_add_ps(_mm_mul_ps(_w1, _val11), _sum1);
225                 _sum1 = _mm_add_ps(_mm_mul_ps(_w2, _val12), _sum1);
226                 _sum1 = _mm_add_ps(_mm_mul_ps(_w3, _val13), _sum1);
227                 _sum2 = _mm_add_ps(_mm_mul_ps(_w0, _val20), _sum2);
228                 _sum2 = _mm_add_ps(_mm_mul_ps(_w1, _val21), _sum2);
229                 _sum2 = _mm_add_ps(_mm_mul_ps(_w2, _val22), _sum2);
230                 _sum2 = _mm_add_ps(_mm_mul_ps(_w3, _val23), _sum2);
231                 _sum3 = _mm_add_ps(_mm_mul_ps(_w0, _val30), _sum3);
232                 _sum3 = _mm_add_ps(_mm_mul_ps(_w1, _val31), _sum3);
233                 _sum3 = _mm_add_ps(_mm_mul_ps(_w2, _val32), _sum3);
234                 _sum3 = _mm_add_ps(_mm_mul_ps(_w3, _val33), _sum3);
235 #endif
236 
237                 tmpptr += 16;
238                 kptr0 += 16;
239             }
240 
241             _mm_store_ps(outptr0, _sum0);
242             _mm_store_ps(outptr0 + 4, _sum1);
243             _mm_store_ps(outptr0 + 8, _sum2);
244             _mm_store_ps(outptr0 + 12, _sum3);
245             outptr0 += 16;
246         }
247         for (; i + 1 < size; i += 2)
248         {
249             float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2);
250             const float* kptr0 = (const float*)kernel.channel(p);
251 
252             __m128 _sum0 = _mm_loadu_ps(biasptr);
253             __m128 _sum1 = _mm_loadu_ps(biasptr);
254 
255             for (int q = 0; q < inch; q++)
256             {
257                 __m128 _val00 = _mm_load1_ps(tmpptr);
258                 __m128 _val01 = _mm_load1_ps(tmpptr + 1);
259                 __m128 _val02 = _mm_load1_ps(tmpptr + 2);
260                 __m128 _val03 = _mm_load1_ps(tmpptr + 3);
261                 __m128 _val10 = _mm_load1_ps(tmpptr + 4);
262                 __m128 _val11 = _mm_load1_ps(tmpptr + 5);
263                 __m128 _val12 = _mm_load1_ps(tmpptr + 6);
264                 __m128 _val13 = _mm_load1_ps(tmpptr + 7);
265 
266                 __m128 _w0 = _mm_load_ps(kptr0);
267                 __m128 _w1 = _mm_load_ps(kptr0 + 4);
268                 __m128 _w2 = _mm_load_ps(kptr0 + 8);
269                 __m128 _w3 = _mm_load_ps(kptr0 + 12);
270 
271 #if __AVX__
272                 _sum0 = _mm_fmadd_ps(_w0, _val00, _sum0);
273                 _sum0 = _mm_fmadd_ps(_w1, _val01, _sum0);
274                 _sum0 = _mm_fmadd_ps(_w2, _val02, _sum0);
275                 _sum0 = _mm_fmadd_ps(_w3, _val03, _sum0);
276                 _sum1 = _mm_fmadd_ps(_w0, _val10, _sum1);
277                 _sum1 = _mm_fmadd_ps(_w1, _val11, _sum1);
278                 _sum1 = _mm_fmadd_ps(_w2, _val12, _sum1);
279                 _sum1 = _mm_fmadd_ps(_w3, _val13, _sum1);
280 #else
281                 _sum0 = _mm_add_ps(_mm_mul_ps(_w0, _val00), _sum0);
282                 _sum0 = _mm_add_ps(_mm_mul_ps(_w1, _val01), _sum0);
283                 _sum0 = _mm_add_ps(_mm_mul_ps(_w2, _val02), _sum0);
284                 _sum0 = _mm_add_ps(_mm_mul_ps(_w3, _val03), _sum0);
285                 _sum1 = _mm_add_ps(_mm_mul_ps(_w0, _val10), _sum1);
286                 _sum1 = _mm_add_ps(_mm_mul_ps(_w1, _val11), _sum1);
287                 _sum1 = _mm_add_ps(_mm_mul_ps(_w2, _val12), _sum1);
288                 _sum1 = _mm_add_ps(_mm_mul_ps(_w3, _val13), _sum1);
289 #endif
290 
291                 tmpptr += 8;
292                 kptr0 += 16;
293             }
294 
295             _mm_store_ps(outptr0, _sum0);
296             _mm_store_ps(outptr0 + 4, _sum1);
297             outptr0 += 8;
298         }
299         for (; i < size; i++)
300         {
301             float* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2);
302             const float* kptr0 = (const float*)kernel.channel(p);
303 
304             __m128 _sum = _mm_loadu_ps(biasptr);
305 
306             for (int q = 0; q < inch; q++)
307             {
308                 __m128 _val0 = _mm_load1_ps(tmpptr);
309                 __m128 _val1 = _mm_load1_ps(tmpptr + 1);
310                 __m128 _val2 = _mm_load1_ps(tmpptr + 2);
311                 __m128 _val3 = _mm_load1_ps(tmpptr + 3);
312 
313                 __m128 _w0 = _mm_load_ps(kptr0);
314                 __m128 _w1 = _mm_load_ps(kptr0 + 4);
315                 __m128 _w2 = _mm_load_ps(kptr0 + 8);
316                 __m128 _w3 = _mm_load_ps(kptr0 + 12);
317 
318 #if __AVX__
319                 _sum = _mm_fmadd_ps(_w0, _val0, _sum);
320                 _sum = _mm_fmadd_ps(_w1, _val1, _sum);
321                 _sum = _mm_fmadd_ps(_w2, _val2, _sum);
322                 _sum = _mm_fmadd_ps(_w3, _val3, _sum);
323 #else
324                 _sum = _mm_add_ps(_mm_mul_ps(_w0, _val0), _sum);
325                 _sum = _mm_add_ps(_mm_mul_ps(_w1, _val1), _sum);
326                 _sum = _mm_add_ps(_mm_mul_ps(_w2, _val2), _sum);
327                 _sum = _mm_add_ps(_mm_mul_ps(_w3, _val3), _sum);
328 #endif
329 
330                 tmpptr += 4;
331                 kptr0 += 16;
332             }
333 
334             _mm_store_ps(outptr0, _sum);
335             outptr0 += 4;
336         }
337     }
338 
339     //     // NOTE sgemm
340     //     for (; p<outch; p++)
341     //     {
342     //         Mat out0 = top_blob.channel(p);
343     //
344     //         const float bias0 = bias ? bias[p] : 0.f;
345     //
346     //         float* outptr0 = out0;
347     //
348     //         for (int i=0; i<size; i++)
349     //         {
350     //             float sum = bias0;
351     //
352     //             const float* kptr = _kernel.channel(p);
353     //
354     //             for (int q=0; q<inch; q++)
355     //             {
356     //                 const float* img0 = bottom_blob.channel(q);
357     //
358     //                 sum += img0[i] * kptr[0];
359     //                 kptr ++;
360     //             }
361     //
362     //             outptr0[i] = sum;
363     //         }
364     //     }
365 }
366 
conv1x1s2_pack4_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)367 static void conv1x1s2_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
368 {
369     int w = bottom_blob.w;
370     int channels = bottom_blob.c;
371     size_t elemsize = bottom_blob.elemsize;
372     int elempack = bottom_blob.elempack;
373 
374     int outw = top_blob.w;
375     int outh = top_blob.h;
376 
377     const int tailstep = (w - 2 * outw + w) * 4;
378 
379     Mat bottom_blob_shrinked;
380     bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
381 
382     #pragma omp parallel for num_threads(opt.num_threads)
383     for (int p = 0; p < channels; p++)
384     {
385         const float* r0 = bottom_blob.channel(p);
386         float* outptr = bottom_blob_shrinked.channel(p);
387 
388         for (int i = 0; i < outh; i++)
389         {
390             for (int j = 0; j < outw; j++)
391             {
392                 __m128 _v = _mm_load_ps(r0);
393                 _mm_store_ps(outptr, _v);
394 
395                 r0 += 8;
396                 outptr += 4;
397             }
398 
399             r0 += tailstep;
400         }
401     }
402 
403     conv1x1s1_sgemm_pack4_sse(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
404 }
405