1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "dropout_x86.h"
16
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif // __AVX__
22 #endif // __SSE2__
23
24 namespace ncnn {
25
Dropout_x86()26 Dropout_x86::Dropout_x86()
27 {
28 #if __SSE2__
29 support_packing = true;
30 #endif // __SSE2__
31 }
32
forward_inplace(Mat & bottom_top_blob,const Option & opt) const33 int Dropout_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
34 {
35 if (scale == 1.f)
36 {
37 return 0;
38 }
39
40 #if __SSE2__
41 int dims = bottom_top_blob.dims;
42 int elempack = bottom_top_blob.elempack;
43
44 #if __AVX__
45 if (elempack == 8)
46 {
47 int w = bottom_top_blob.w;
48 int h = bottom_top_blob.h;
49 int channels = bottom_top_blob.c;
50 int size = w * h;
51
52 __m256 _scale = _mm256_set1_ps(scale);
53
54 if (dims == 1)
55 {
56 #pragma omp parallel for num_threads(opt.num_threads)
57 for (int i = 0; i < w; i++)
58 {
59 float* ptr = (float*)bottom_top_blob + i * 8;
60 __m256 _p = _mm256_loadu_ps(ptr);
61 _p = _mm256_mul_ps(_p, _scale);
62 _mm256_storeu_ps(ptr, _p);
63 }
64 }
65
66 if (dims == 2)
67 {
68 #pragma omp parallel for num_threads(opt.num_threads)
69 for (int i = 0; i < h; i++)
70 {
71 float* ptr = bottom_top_blob.row(i);
72
73 for (int j = 0; j < w; j++)
74 {
75 __m256 _p = _mm256_loadu_ps(ptr);
76 _p = _mm256_mul_ps(_p, _scale);
77 _mm256_storeu_ps(ptr, _p);
78 ptr += 8;
79 }
80 }
81 }
82
83 if (dims == 3)
84 {
85 #pragma omp parallel for num_threads(opt.num_threads)
86 for (int q = 0; q < channels; q++)
87 {
88 float* ptr = bottom_top_blob.channel(q);
89
90 for (int i = 0; i < size; i++)
91 {
92 __m256 _p = _mm256_loadu_ps(ptr);
93 _p = _mm256_mul_ps(_p, _scale);
94 _mm256_storeu_ps(ptr, _p);
95 ptr += 8;
96 }
97 }
98 }
99
100 return 0;
101 }
102 #endif // __AVX__
103
104 if (elempack == 4)
105 {
106 int w = bottom_top_blob.w;
107 int h = bottom_top_blob.h;
108 int channels = bottom_top_blob.c;
109 int size = w * h;
110
111 __m128 _scale = _mm_set1_ps(scale);
112
113 if (dims == 1)
114 {
115 #pragma omp parallel for num_threads(opt.num_threads)
116 for (int i = 0; i < w; i++)
117 {
118 float* ptr = (float*)bottom_top_blob + i * 4;
119 __m128 _p = _mm_loadu_ps(ptr);
120 _p = _mm_mul_ps(_p, _scale);
121 _mm_storeu_ps(ptr, _p);
122 }
123 }
124
125 if (dims == 2)
126 {
127 #pragma omp parallel for num_threads(opt.num_threads)
128 for (int i = 0; i < h; i++)
129 {
130 float* ptr = bottom_top_blob.row(i);
131
132 for (int j = 0; j < w; j++)
133 {
134 __m128 _p = _mm_loadu_ps(ptr);
135 _p = _mm_mul_ps(_p, _scale);
136 _mm_storeu_ps(ptr, _p);
137 ptr += 4;
138 }
139 }
140 }
141
142 if (dims == 3)
143 {
144 #pragma omp parallel for num_threads(opt.num_threads)
145 for (int q = 0; q < channels; q++)
146 {
147 float* ptr = bottom_top_blob.channel(q);
148
149 for (int i = 0; i < size; i++)
150 {
151 __m128 _p = _mm_loadu_ps(ptr);
152 _p = _mm_mul_ps(_p, _scale);
153 _mm_storeu_ps(ptr, _p);
154 ptr += 4;
155 }
156 }
157 }
158
159 return 0;
160 }
161 #endif // __SSE2__
162
163 return Dropout::forward_inplace(bottom_top_blob, opt);
164 }
165
166 } // namespace ncnn
167