1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "mvn.h"
16
17 #include <math.h>
18
19 namespace ncnn {
20
MVN()21 MVN::MVN()
22 {
23 one_blob_only = true;
24 support_inplace = false;
25 }
26
load_param(const ParamDict & pd)27 int MVN::load_param(const ParamDict& pd)
28 {
29 normalize_variance = pd.get(0, 0);
30 across_channels = pd.get(1, 0);
31 eps = pd.get(2, 0.0001f);
32
33 return 0;
34 }
35
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const36 int MVN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
37 {
38 int w = bottom_blob.w;
39 int h = bottom_blob.h;
40 int channels = bottom_blob.c;
41 size_t elemsize = bottom_blob.elemsize;
42 int size = w * h;
43
44 top_blob.create(w, h, channels, elemsize, opt.blob_allocator);
45 if (top_blob.empty())
46 return -100;
47
48 // prepare sum per channel
49 Mat sum(channels, elemsize, opt.workspace_allocator);
50 if (sum.empty())
51 return -100;
52
53 #pragma omp parallel for num_threads(opt.num_threads)
54 for (int q = 0; q < channels; q++)
55 {
56 const float* ptr = bottom_blob.channel(q);
57
58 float s = 0.f;
59 for (int i = 0; i < size; i++)
60 {
61 s += ptr[i];
62 }
63
64 sum[q] = s;
65 }
66
67 if (across_channels)
68 {
69 // compute mean across channels
70 float mean = 0.f;
71 for (int q = 0; q < channels; q++)
72 {
73 mean += sum[q];
74 }
75 mean = mean / (channels * size);
76
77 // subtract mean
78 #pragma omp parallel for num_threads(opt.num_threads)
79 for (int q = 0; q < channels; q++)
80 {
81 const float* ptr = bottom_blob.channel(q);
82 float* outptr = top_blob.channel(q);
83
84 for (int i = 0; i < size; i++)
85 {
86 outptr[i] = ptr[i] - mean;
87 }
88 }
89 }
90 else
91 {
92 // subtract mean
93 #pragma omp parallel for num_threads(opt.num_threads)
94 for (int q = 0; q < channels; q++)
95 {
96 const float* ptr = bottom_blob.channel(q);
97 float* outptr = top_blob.channel(q);
98 float mean = sum[q] / size;
99
100 for (int i = 0; i < size; i++)
101 {
102 outptr[i] = ptr[i] - mean;
103 }
104 }
105 }
106
107 if (normalize_variance)
108 {
109 // prepare squared sum per channel
110 Mat sqsum(channels, elemsize, opt.workspace_allocator);
111 if (sqsum.empty())
112 return -100;
113
114 #pragma omp parallel for num_threads(opt.num_threads)
115 for (int q = 0; q < channels; q++)
116 {
117 const float* ptr = top_blob.channel(q);
118
119 float s = 0.f;
120 for (int i = 0; i < size; i++)
121 {
122 s += ptr[i] * ptr[i];
123 }
124
125 sqsum[q] = s;
126 }
127
128 if (across_channels)
129 {
130 // compute squared mean across channels
131 float sqmean = 0.f;
132 for (int q = 0; q < channels; q++)
133 {
134 sqmean += sqsum[q];
135 }
136 sqmean = sqmean / (channels * size);
137
138 // normalize variance
139 float norm_var = static_cast<float>(sqrt(sqmean) + eps);
140 float norm_var_inv = 1.f / norm_var;
141
142 // apply normalize_variance
143 #pragma omp parallel for num_threads(opt.num_threads)
144 for (int q = 0; q < channels; q++)
145 {
146 float* outptr = top_blob.channel(q);
147
148 for (int i = 0; i < size; i++)
149 {
150 outptr[i] = outptr[i] * norm_var_inv;
151 }
152 }
153 }
154 else
155 {
156 // apply normalize_variance
157 #pragma omp parallel for num_threads(opt.num_threads)
158 for (int q = 0; q < channels; q++)
159 {
160 float* outptr = top_blob.channel(q);
161 float sqmean = sqsum[q] / size;
162 float norm_var = static_cast<float>(sqrt(sqmean) + eps);
163 float norm_var_inv = 1.f / norm_var;
164
165 for (int i = 0; i < size; i++)
166 {
167 outptr[i] = outptr[i] * norm_var_inv;
168 }
169 }
170 }
171 }
172
173 return 0;
174 }
175
176 } // namespace ncnn
177