1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "dropout_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif // __AVX__
22 #endif // __SSE2__
23 
24 namespace ncnn {
25 
Dropout_x86()26 Dropout_x86::Dropout_x86()
27 {
28 #if __SSE2__
29     support_packing = true;
30 #endif // __SSE2__
31 }
32 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const33 int Dropout_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
34 {
35     if (scale == 1.f)
36     {
37         return 0;
38     }
39 
40 #if __SSE2__
41     int dims = bottom_top_blob.dims;
42     int elempack = bottom_top_blob.elempack;
43 
44 #if __AVX__
45     if (elempack == 8)
46     {
47         int w = bottom_top_blob.w;
48         int h = bottom_top_blob.h;
49         int channels = bottom_top_blob.c;
50         int size = w * h;
51 
52         __m256 _scale = _mm256_set1_ps(scale);
53 
54         if (dims == 1)
55         {
56             #pragma omp parallel for num_threads(opt.num_threads)
57             for (int i = 0; i < w; i++)
58             {
59                 float* ptr = (float*)bottom_top_blob + i * 8;
60                 __m256 _p = _mm256_loadu_ps(ptr);
61                 _p = _mm256_mul_ps(_p, _scale);
62                 _mm256_storeu_ps(ptr, _p);
63             }
64         }
65 
66         if (dims == 2)
67         {
68             #pragma omp parallel for num_threads(opt.num_threads)
69             for (int i = 0; i < h; i++)
70             {
71                 float* ptr = bottom_top_blob.row(i);
72 
73                 for (int j = 0; j < w; j++)
74                 {
75                     __m256 _p = _mm256_loadu_ps(ptr);
76                     _p = _mm256_mul_ps(_p, _scale);
77                     _mm256_storeu_ps(ptr, _p);
78                     ptr += 8;
79                 }
80             }
81         }
82 
83         if (dims == 3)
84         {
85             #pragma omp parallel for num_threads(opt.num_threads)
86             for (int q = 0; q < channels; q++)
87             {
88                 float* ptr = bottom_top_blob.channel(q);
89 
90                 for (int i = 0; i < size; i++)
91                 {
92                     __m256 _p = _mm256_loadu_ps(ptr);
93                     _p = _mm256_mul_ps(_p, _scale);
94                     _mm256_storeu_ps(ptr, _p);
95                     ptr += 8;
96                 }
97             }
98         }
99 
100         return 0;
101     }
102 #endif // __AVX__
103 
104     if (elempack == 4)
105     {
106         int w = bottom_top_blob.w;
107         int h = bottom_top_blob.h;
108         int channels = bottom_top_blob.c;
109         int size = w * h;
110 
111         __m128 _scale = _mm_set1_ps(scale);
112 
113         if (dims == 1)
114         {
115             #pragma omp parallel for num_threads(opt.num_threads)
116             for (int i = 0; i < w; i++)
117             {
118                 float* ptr = (float*)bottom_top_blob + i * 4;
119                 __m128 _p = _mm_loadu_ps(ptr);
120                 _p = _mm_mul_ps(_p, _scale);
121                 _mm_storeu_ps(ptr, _p);
122             }
123         }
124 
125         if (dims == 2)
126         {
127             #pragma omp parallel for num_threads(opt.num_threads)
128             for (int i = 0; i < h; i++)
129             {
130                 float* ptr = bottom_top_blob.row(i);
131 
132                 for (int j = 0; j < w; j++)
133                 {
134                     __m128 _p = _mm_loadu_ps(ptr);
135                     _p = _mm_mul_ps(_p, _scale);
136                     _mm_storeu_ps(ptr, _p);
137                     ptr += 4;
138                 }
139             }
140         }
141 
142         if (dims == 3)
143         {
144             #pragma omp parallel for num_threads(opt.num_threads)
145             for (int q = 0; q < channels; q++)
146             {
147                 float* ptr = bottom_top_blob.channel(q);
148 
149                 for (int i = 0; i < size; i++)
150                 {
151                     __m128 _p = _mm_loadu_ps(ptr);
152                     _p = _mm_mul_ps(_p, _scale);
153                     _mm_storeu_ps(ptr, _p);
154                     ptr += 4;
155                 }
156             }
157         }
158 
159         return 0;
160     }
161 #endif // __SSE2__
162 
163     return Dropout::forward_inplace(bottom_top_blob, opt);
164 }
165 
166 } // namespace ncnn
167