1 // Tencent is pleased to support the open source community by making ncnn
2 // available.
3 //
4 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
5 //
6 // Licensed under the BSD 3-Clause License (the "License"); you may not use this
7 // file except in compliance with the License. You may obtain a copy of the
8 // License at
9 //
10 // https://opensource.org/licenses/BSD-3-Clause
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15 // License for the specific language governing permissions and limitations under
16 // the License.
17 
pooling2x2s2_max_avx(const Mat & bottom_blob,Mat & top_blob,const Option & opt)18 static void pooling2x2s2_max_avx(const Mat& bottom_blob, Mat& top_blob,
19                                  const Option& opt)
20 {
21     int w = bottom_blob.w;
22     int inch = bottom_blob.c;
23 
24     int outw = top_blob.w;
25     int outh = top_blob.h;
26 
27     const int tailstep = w - 2 * outw + w;
28     #pragma omp parallel for num_threads(opt.num_threads)
29     for (int q = 0; q < inch; q++)
30     {
31         const float* img0 = bottom_blob.channel(q);
32         float* outptr = top_blob.channel(q);
33         int outcount = 0;
34         const float* r0 = img0;
35         const float* r1 = img0 + w;
36 #if __AVX2__
37         __m256i permute_mask = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
38 #endif // __AVX__
39 
40         for (int i = 0; i < outh; i++)
41         {
42 #if __AVX2__
43             int nn = outw >> 2;
44             int remain = outw - (nn << 2);
45 #else
46             int remain = outw;
47 #endif // __AVX__
48 
49 #if __AVX2__
50             for (; nn > 0; nn--)
51             {
52                 __m256 _r0 = _mm256_loadu_ps(r0);
53                 __m256 _r1 = _mm256_loadu_ps(r1);
54                 __m256 _max_r0_r1 = _mm256_max_ps(_r0, _r1);
55                 _max_r0_r1 = _mm256_castsi256_ps(_mm256_permutevar8x32_epi32(
56                                                      _mm256_castps_si256(_max_r0_r1), permute_mask));
57                 __m128 _max_0 = _mm256_extractf128_ps(_max_r0_r1, 0);
58                 __m128 _max_1 = _mm256_extractf128_ps(_max_r0_r1, 1);
59                 __m128 _max = _mm_max_ps(_max_0, _max_1);
60                 _mm_storeu_ps(outptr, _max);
61                 r0 += 8;
62                 r1 += 8;
63                 outptr += 4;
64                 outcount += 4;
65             }
66 #endif // __AVX__
67             for (; remain > 0; remain--)
68             {
69                 float max0 = std::max(r0[0], r0[1]);
70                 float max1 = std::max(r1[0], r1[1]);
71 
72                 *outptr = std::max(max0, max1);
73 
74                 r0 += 2;
75                 r1 += 2;
76                 outptr++;
77                 outcount++;
78             }
79             r0 += tailstep;
80             r1 += tailstep;
81         }
82     }
83 }
84