1
2 // Tencent is pleased to support the open source community by making ncnn available.
3 //
4 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
5 //
6 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
7 // in compliance with the License. You may obtain a copy of the License at
8 //
9 // https://opensource.org/licenses/BSD-3-Clause
10 //
11 // Unless required by applicable law or agreed to in writing, software distributed
12 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
13 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
14 // specific language governing permissions and limitations under the License.
15
pooling3x3s2_max_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Option & opt)16 static void pooling3x3s2_max_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
17 {
18 int w = bottom_blob.w;
19 int inch = bottom_blob.c;
20
21 int outw = top_blob.w;
22 int outh = top_blob.h;
23
24 const int tailstep = (w - 2 * outw + w) * 8;
25
26 #pragma omp parallel for num_threads(opt.num_threads)
27 for (int q = 0; q < inch; q++)
28 {
29 const Mat img0 = bottom_blob.channel(q);
30 float* outptr = top_blob.channel(q);
31
32 const float* r0 = img0.row(0);
33 const float* r1 = img0.row(1);
34 const float* r2 = img0.row(2);
35 for (int i = 0; i < outh; i++)
36 {
37 int j = 0;
38 for (; j + 3 < outw; j += 4)
39 {
40 __m256 _r00 = _mm256_loadu_ps(r0);
41 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
42 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
43 __m256 _r10 = _mm256_loadu_ps(r1);
44 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
45 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
46 __m256 _r20 = _mm256_loadu_ps(r2);
47 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
48 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
49
50 __m256 _max00 = _mm256_max_ps(_r00, _r01);
51 _max00 = _mm256_max_ps(_max00, _r02);
52 _max00 = _mm256_max_ps(_max00, _r10);
53 _max00 = _mm256_max_ps(_max00, _r11);
54 __m256 _max01 = _mm256_max_ps(_r12, _r20);
55 _max01 = _mm256_max_ps(_max01, _r21);
56 _max01 = _mm256_max_ps(_max01, _r22);
57
58 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
59 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
60 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
61 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
62 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
63 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
64
65 _mm256_storeu_ps(outptr, _mm256_max_ps(_max00, _max01));
66
67 __m256 _max10 = _mm256_max_ps(_r03, _r04);
68 _max10 = _mm256_max_ps(_max10, _r02);
69 _max10 = _mm256_max_ps(_max10, _r13);
70 _max10 = _mm256_max_ps(_max10, _r14);
71 __m256 _max11 = _mm256_max_ps(_r12, _r23);
72 _max10 = _mm256_max_ps(_max10, _r24);
73 _max10 = _mm256_max_ps(_max10, _r22);
74
75 __m256 _r05 = _mm256_loadu_ps(r0 + 40);
76 __m256 _r06 = _mm256_loadu_ps(r0 + 48);
77 __m256 _r15 = _mm256_loadu_ps(r1 + 40);
78 __m256 _r16 = _mm256_loadu_ps(r1 + 48);
79 __m256 _r25 = _mm256_loadu_ps(r2 + 40);
80 __m256 _r26 = _mm256_loadu_ps(r2 + 48);
81
82 _mm256_storeu_ps(outptr + 8, _mm256_max_ps(_max10, _max11));
83
84 __m256 _max20 = _mm256_max_ps(_r05, _r06);
85 _max20 = _mm256_max_ps(_max20, _r04);
86 _max20 = _mm256_max_ps(_max20, _r15);
87 _max20 = _mm256_max_ps(_max20, _r16);
88 __m256 _max21 = _mm256_max_ps(_r14, _r25);
89 _max20 = _mm256_max_ps(_max20, _r26);
90 _max20 = _mm256_max_ps(_max20, _r24);
91
92 __m256 _r07 = _mm256_loadu_ps(r0 + 56);
93 __m256 _r08 = _mm256_loadu_ps(r0 + 64);
94 __m256 _r17 = _mm256_loadu_ps(r1 + 56);
95 __m256 _r18 = _mm256_loadu_ps(r1 + 64);
96 __m256 _r27 = _mm256_loadu_ps(r2 + 56);
97 __m256 _r28 = _mm256_loadu_ps(r2 + 64);
98
99 _mm256_storeu_ps(outptr + 16, _mm256_max_ps(_max20, _max21));
100
101 __m256 _max30 = _mm256_max_ps(_r07, _r08);
102 _max30 = _mm256_max_ps(_max30, _r06);
103 _max30 = _mm256_max_ps(_max30, _r17);
104 _max30 = _mm256_max_ps(_max30, _r18);
105 __m256 _max31 = _mm256_max_ps(_r16, _r27);
106 _max30 = _mm256_max_ps(_max30, _r28);
107 _max30 = _mm256_max_ps(_max30, _r26);
108
109 _mm256_storeu_ps(outptr + 24, _mm256_max_ps(_max30, _max31));
110
111 r0 += 64;
112 r1 += 64;
113 r2 += 64;
114 outptr += 32;
115 }
116 for (; j + 1 < outw; j += 2)
117 {
118 __m256 _r00 = _mm256_loadu_ps(r0);
119 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
120 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
121 __m256 _r10 = _mm256_loadu_ps(r1);
122 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
123 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
124 __m256 _r20 = _mm256_loadu_ps(r2);
125 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
126 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
127
128 __m256 _max00 = _mm256_max_ps(_r00, _r01);
129 _max00 = _mm256_max_ps(_max00, _r02);
130 _max00 = _mm256_max_ps(_max00, _r10);
131 _max00 = _mm256_max_ps(_max00, _r11);
132 __m256 _max01 = _mm256_max_ps(_r12, _r20);
133 _max01 = _mm256_max_ps(_max01, _r21);
134 _max01 = _mm256_max_ps(_max01, _r22);
135
136 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
137 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
138 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
139 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
140 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
141 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
142
143 _mm256_storeu_ps(outptr, _mm256_max_ps(_max00, _max01));
144
145 __m256 _max10 = _mm256_max_ps(_r03, _r04);
146 _max10 = _mm256_max_ps(_max10, _r02);
147 _max10 = _mm256_max_ps(_max10, _r13);
148 _max10 = _mm256_max_ps(_max10, _r14);
149 __m256 _max11 = _mm256_max_ps(_r12, _r23);
150 _max10 = _mm256_max_ps(_max10, _r24);
151 _max10 = _mm256_max_ps(_max10, _r22);
152
153 _mm256_storeu_ps(outptr + 8, _mm256_max_ps(_max10, _max11));
154
155 r0 += 32;
156 r1 += 32;
157 r2 += 32;
158 outptr += 16;
159 }
160
161 for (; j < outw; j++)
162 {
163 __m256 _r00 = _mm256_loadu_ps(r0);
164 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
165 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
166 __m256 _r10 = _mm256_loadu_ps(r1);
167 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
168 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
169 __m256 _r20 = _mm256_loadu_ps(r2);
170 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
171 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
172
173 __m256 _max0 = _mm256_max_ps(_r00, _r01);
174 _max0 = _mm256_max_ps(_max0, _r02);
175 _max0 = _mm256_max_ps(_max0, _r10);
176 _max0 = _mm256_max_ps(_max0, _r11);
177 __m256 _max1 = _mm256_max_ps(_r12, _r20);
178 _max1 = _mm256_max_ps(_max1, _r21);
179 _max1 = _mm256_max_ps(_max1, _r22);
180
181 _mm256_storeu_ps(outptr, _mm256_max_ps(_max0, _max1));
182
183 r0 += 16;
184 r1 += 16;
185 r2 += 16;
186 outptr += 8;
187 }
188
189 r0 += tailstep;
190 r1 += tailstep;
191 r2 += tailstep;
192 }
193 }
194 }
195