1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv3x3s1_pack8to1_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void conv3x3s1_pack8to1_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 int inch = bottom_blob.c;
18 int outw = top_blob.w;
19 int outh = top_blob.h;
20 int outch = top_blob.c;
21
22 const float* bias = _bias;
23
24 int remain_outch_start = 0;
25
26 #pragma omp parallel for num_threads(opt.num_threads)
27 for (int p = remain_outch_start; p < outch; p++)
28 {
29 Mat out0 = top_blob.channel(p);
30
31 const float bias0 = bias ? bias[p] : 0.f;
32 out0.fill(bias0);
33
34 const float* k0 = kernel.channel(p);
35
36 for (int q = 0; q < inch; q++)
37 {
38 float* outptr0 = out0.row(0);
39
40 const Mat img0 = bottom_blob.channel(q);
41
42 __m256 _k00 = _mm256_loadu_ps(k0);
43 __m256 _k01 = _mm256_loadu_ps(k0 + 8);
44 __m256 _k02 = _mm256_loadu_ps(k0 + 16);
45 __m256 _k10 = _mm256_loadu_ps(k0 + 24);
46 __m256 _k11 = _mm256_loadu_ps(k0 + 32);
47 __m256 _k12 = _mm256_loadu_ps(k0 + 40);
48 __m256 _k20 = _mm256_loadu_ps(k0 + 48);
49 __m256 _k21 = _mm256_loadu_ps(k0 + 56);
50 __m256 _k22 = _mm256_loadu_ps(k0 + 64);
51
52 int i = 0;
53
54 for (; i < outh; i++)
55 {
56 const float* r0 = img0.row(i);
57 const float* r1 = img0.row(i + 1);
58 const float* r2 = img0.row(i + 2);
59 int j = 0;
60 for (; j < outw; j++)
61 {
62 __m256 _r00 = _mm256_loadu_ps(r0);
63 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
64 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
65
66 __m256 _sum0 = _mm256_mul_ps(_k00, _r00);
67 __m256 _sum1 = _mm256_mul_ps(_k01, _r01);
68 __m256 _sum2 = _mm256_mul_ps(_k02, _r02);
69
70 __m256 _r10 = _mm256_loadu_ps(r1);
71 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
72 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
73
74 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
75 _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
76 _sum2 = _mm256_fmadd_ps(_k12, _r12, _sum2);
77
78 __m256 _r20 = _mm256_loadu_ps(r2);
79 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
80 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
81
82 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
83 _sum1 = _mm256_fmadd_ps(_k21, _r21, _sum1);
84 _sum2 = _mm256_fmadd_ps(_k22, _r22, _sum2);
85 __m128 _sum = HorizontalSums(_sum0, _sum1, _sum2);
86
87 *outptr0 += _mm_reduce_add_ps(_sum); // dot
88 outptr0++;
89 r0 += 8;
90 r1 += 8;
91 r2 += 8;
92 }
93 }
94
95 k0 += 9 * 8;
96 }
97 }
98 }
99