1 
2 // Tencent is pleased to support the open source community by making ncnn available.
3 //
4 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
5 //
6 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
7 // in compliance with the License. You may obtain a copy of the License at
8 //
9 // https://opensource.org/licenses/BSD-3-Clause
10 //
11 // Unless required by applicable law or agreed to in writing, software distributed
12 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
13 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
14 // specific language governing permissions and limitations under the License.
15 
pooling3x3s2_max_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Option & opt)16 static void pooling3x3s2_max_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
17 {
18     int w = bottom_blob.w;
19     int inch = bottom_blob.c;
20 
21     int outw = top_blob.w;
22     int outh = top_blob.h;
23 
24     const int tailstep = (w - 2 * outw + w) * 8;
25 
26     #pragma omp parallel for num_threads(opt.num_threads)
27     for (int q = 0; q < inch; q++)
28     {
29         const Mat img0 = bottom_blob.channel(q);
30         float* outptr = top_blob.channel(q);
31 
32         const float* r0 = img0.row(0);
33         const float* r1 = img0.row(1);
34         const float* r2 = img0.row(2);
35         for (int i = 0; i < outh; i++)
36         {
37             int j = 0;
38             for (; j + 3 < outw; j += 4)
39             {
40                 __m256 _r00 = _mm256_loadu_ps(r0);
41                 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
42                 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
43                 __m256 _r10 = _mm256_loadu_ps(r1);
44                 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
45                 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
46                 __m256 _r20 = _mm256_loadu_ps(r2);
47                 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
48                 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
49 
50                 __m256 _max00 = _mm256_max_ps(_r00, _r01);
51                 _max00 = _mm256_max_ps(_max00, _r02);
52                 _max00 = _mm256_max_ps(_max00, _r10);
53                 _max00 = _mm256_max_ps(_max00, _r11);
54                 __m256 _max01 = _mm256_max_ps(_r12, _r20);
55                 _max01 = _mm256_max_ps(_max01, _r21);
56                 _max01 = _mm256_max_ps(_max01, _r22);
57 
58                 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
59                 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
60                 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
61                 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
62                 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
63                 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
64 
65                 _mm256_storeu_ps(outptr, _mm256_max_ps(_max00, _max01));
66 
67                 __m256 _max10 = _mm256_max_ps(_r03, _r04);
68                 _max10 = _mm256_max_ps(_max10, _r02);
69                 _max10 = _mm256_max_ps(_max10, _r13);
70                 _max10 = _mm256_max_ps(_max10, _r14);
71                 __m256 _max11 = _mm256_max_ps(_r12, _r23);
72                 _max10 = _mm256_max_ps(_max10, _r24);
73                 _max10 = _mm256_max_ps(_max10, _r22);
74 
75                 __m256 _r05 = _mm256_loadu_ps(r0 + 40);
76                 __m256 _r06 = _mm256_loadu_ps(r0 + 48);
77                 __m256 _r15 = _mm256_loadu_ps(r1 + 40);
78                 __m256 _r16 = _mm256_loadu_ps(r1 + 48);
79                 __m256 _r25 = _mm256_loadu_ps(r2 + 40);
80                 __m256 _r26 = _mm256_loadu_ps(r2 + 48);
81 
82                 _mm256_storeu_ps(outptr + 8, _mm256_max_ps(_max10, _max11));
83 
84                 __m256 _max20 = _mm256_max_ps(_r05, _r06);
85                 _max20 = _mm256_max_ps(_max20, _r04);
86                 _max20 = _mm256_max_ps(_max20, _r15);
87                 _max20 = _mm256_max_ps(_max20, _r16);
88                 __m256 _max21 = _mm256_max_ps(_r14, _r25);
89                 _max20 = _mm256_max_ps(_max20, _r26);
90                 _max20 = _mm256_max_ps(_max20, _r24);
91 
92                 __m256 _r07 = _mm256_loadu_ps(r0 + 56);
93                 __m256 _r08 = _mm256_loadu_ps(r0 + 64);
94                 __m256 _r17 = _mm256_loadu_ps(r1 + 56);
95                 __m256 _r18 = _mm256_loadu_ps(r1 + 64);
96                 __m256 _r27 = _mm256_loadu_ps(r2 + 56);
97                 __m256 _r28 = _mm256_loadu_ps(r2 + 64);
98 
99                 _mm256_storeu_ps(outptr + 16, _mm256_max_ps(_max20, _max21));
100 
101                 __m256 _max30 = _mm256_max_ps(_r07, _r08);
102                 _max30 = _mm256_max_ps(_max30, _r06);
103                 _max30 = _mm256_max_ps(_max30, _r17);
104                 _max30 = _mm256_max_ps(_max30, _r18);
105                 __m256 _max31 = _mm256_max_ps(_r16, _r27);
106                 _max30 = _mm256_max_ps(_max30, _r28);
107                 _max30 = _mm256_max_ps(_max30, _r26);
108 
109                 _mm256_storeu_ps(outptr + 24, _mm256_max_ps(_max30, _max31));
110 
111                 r0 += 64;
112                 r1 += 64;
113                 r2 += 64;
114                 outptr += 32;
115             }
116             for (; j + 1 < outw; j += 2)
117             {
118                 __m256 _r00 = _mm256_loadu_ps(r0);
119                 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
120                 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
121                 __m256 _r10 = _mm256_loadu_ps(r1);
122                 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
123                 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
124                 __m256 _r20 = _mm256_loadu_ps(r2);
125                 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
126                 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
127 
128                 __m256 _max00 = _mm256_max_ps(_r00, _r01);
129                 _max00 = _mm256_max_ps(_max00, _r02);
130                 _max00 = _mm256_max_ps(_max00, _r10);
131                 _max00 = _mm256_max_ps(_max00, _r11);
132                 __m256 _max01 = _mm256_max_ps(_r12, _r20);
133                 _max01 = _mm256_max_ps(_max01, _r21);
134                 _max01 = _mm256_max_ps(_max01, _r22);
135 
136                 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
137                 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
138                 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
139                 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
140                 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
141                 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
142 
143                 _mm256_storeu_ps(outptr, _mm256_max_ps(_max00, _max01));
144 
145                 __m256 _max10 = _mm256_max_ps(_r03, _r04);
146                 _max10 = _mm256_max_ps(_max10, _r02);
147                 _max10 = _mm256_max_ps(_max10, _r13);
148                 _max10 = _mm256_max_ps(_max10, _r14);
149                 __m256 _max11 = _mm256_max_ps(_r12, _r23);
150                 _max10 = _mm256_max_ps(_max10, _r24);
151                 _max10 = _mm256_max_ps(_max10, _r22);
152 
153                 _mm256_storeu_ps(outptr + 8, _mm256_max_ps(_max10, _max11));
154 
155                 r0 += 32;
156                 r1 += 32;
157                 r2 += 32;
158                 outptr += 16;
159             }
160 
161             for (; j < outw; j++)
162             {
163                 __m256 _r00 = _mm256_loadu_ps(r0);
164                 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
165                 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
166                 __m256 _r10 = _mm256_loadu_ps(r1);
167                 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
168                 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
169                 __m256 _r20 = _mm256_loadu_ps(r2);
170                 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
171                 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
172 
173                 __m256 _max0 = _mm256_max_ps(_r00, _r01);
174                 _max0 = _mm256_max_ps(_max0, _r02);
175                 _max0 = _mm256_max_ps(_max0, _r10);
176                 _max0 = _mm256_max_ps(_max0, _r11);
177                 __m256 _max1 = _mm256_max_ps(_r12, _r20);
178                 _max1 = _mm256_max_ps(_max1, _r21);
179                 _max1 = _mm256_max_ps(_max1, _r22);
180 
181                 _mm256_storeu_ps(outptr, _mm256_max_ps(_max0, _max1));
182 
183                 r0 += 16;
184                 r1 += 16;
185                 r2 += 16;
186                 outptr += 8;
187             }
188 
189             r0 += tailstep;
190             r1 += tailstep;
191             r2 += tailstep;
192         }
193     }
194 }
195