1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv3x3s1_pack8to1_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void conv3x3s1_pack8to1_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17     int inch = bottom_blob.c;
18     int outw = top_blob.w;
19     int outh = top_blob.h;
20     int outch = top_blob.c;
21 
22     const float* bias = _bias;
23 
24     int remain_outch_start = 0;
25 
26     #pragma omp parallel for num_threads(opt.num_threads)
27     for (int p = remain_outch_start; p < outch; p++)
28     {
29         Mat out0 = top_blob.channel(p);
30 
31         const float bias0 = bias ? bias[p] : 0.f;
32         out0.fill(bias0);
33 
34         const float* k0 = kernel.channel(p);
35 
36         for (int q = 0; q < inch; q++)
37         {
38             float* outptr0 = out0.row(0);
39 
40             const Mat img0 = bottom_blob.channel(q);
41 
42             __m256 _k00 = _mm256_loadu_ps(k0);
43             __m256 _k01 = _mm256_loadu_ps(k0 + 8);
44             __m256 _k02 = _mm256_loadu_ps(k0 + 16);
45             __m256 _k10 = _mm256_loadu_ps(k0 + 24);
46             __m256 _k11 = _mm256_loadu_ps(k0 + 32);
47             __m256 _k12 = _mm256_loadu_ps(k0 + 40);
48             __m256 _k20 = _mm256_loadu_ps(k0 + 48);
49             __m256 _k21 = _mm256_loadu_ps(k0 + 56);
50             __m256 _k22 = _mm256_loadu_ps(k0 + 64);
51 
52             int i = 0;
53 
54             for (; i < outh; i++)
55             {
56                 const float* r0 = img0.row(i);
57                 const float* r1 = img0.row(i + 1);
58                 const float* r2 = img0.row(i + 2);
59                 int j = 0;
60                 for (; j < outw; j++)
61                 {
62                     __m256 _r00 = _mm256_loadu_ps(r0);
63                     __m256 _r01 = _mm256_loadu_ps(r0 + 8);
64                     __m256 _r02 = _mm256_loadu_ps(r0 + 16);
65 
66                     __m256 _sum0 = _mm256_mul_ps(_k00, _r00);
67                     __m256 _sum1 = _mm256_mul_ps(_k01, _r01);
68                     __m256 _sum2 = _mm256_mul_ps(_k02, _r02);
69 
70                     __m256 _r10 = _mm256_loadu_ps(r1);
71                     __m256 _r11 = _mm256_loadu_ps(r1 + 8);
72                     __m256 _r12 = _mm256_loadu_ps(r1 + 16);
73 
74                     _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
75                     _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
76                     _sum2 = _mm256_fmadd_ps(_k12, _r12, _sum2);
77 
78                     __m256 _r20 = _mm256_loadu_ps(r2);
79                     __m256 _r21 = _mm256_loadu_ps(r2 + 8);
80                     __m256 _r22 = _mm256_loadu_ps(r2 + 16);
81 
82                     _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
83                     _sum1 = _mm256_fmadd_ps(_k21, _r21, _sum1);
84                     _sum2 = _mm256_fmadd_ps(_k22, _r22, _sum2);
85                     __m128 _sum = HorizontalSums(_sum0, _sum1, _sum2);
86 
87                     *outptr0 += _mm_reduce_add_ps(_sum); // dot
88                     outptr0++;
89                     r0 += 8;
90                     r1 += 8;
91                     r2 += 8;
92                 }
93             }
94 
95             k0 += 9 * 8;
96         }
97     }
98 }
99