1 #include "chainerx/routines/normalization.h"
2
3 #include <algorithm>
4 #include <cstdint>
5 #include <memory>
6 #include <tuple>
7 #include <utility>
8
9 #include <absl/types/optional.h>
10
11 #include "chainerx/array.h"
12 #include "chainerx/axes.h"
13 #include "chainerx/backprop_mode.h"
14 #include "chainerx/backward_builder.h"
15 #include "chainerx/backward_context.h"
16 #include "chainerx/device.h"
17 #include "chainerx/dtype.h"
18 #include "chainerx/error.h"
19 #include "chainerx/graph.h"
20 #include "chainerx/kernels/normalization.h"
21 #include "chainerx/macro.h"
22 #include "chainerx/routines/arithmetic.h"
23 #include "chainerx/routines/creation.h"
24 #include "chainerx/routines/misc.h"
25 #include "chainerx/routines/routines_util.h"
26 #include "chainerx/routines/statistics.h"
27 #include "chainerx/routines/type_util.h"
28 #include "chainerx/scalar.h"
29 #include "chainerx/shape.h"
30
31 namespace chainerx {
32 namespace {
33
34 struct PreprocessBatchNormResult {
35 // Arrays are reshaped if necessary
36 Array gamma;
37 Array beta;
38 Array mean;
39 Array var;
40 Axes sorted_axis;
41 };
42
43 // Reshapes the array. If the shape is unchanged, an array with identical array body is returned. Note that chainerx::Reshape() returns
44 // a view with different array body if the shape is unchanged.
ReshapeOrIdentity(const Array & a,const Shape & shape)45 Array ReshapeOrIdentity(const Array& a, const Shape& shape) {
46 if (a.shape() == shape) {
47 return a;
48 }
49 return a.Reshape(shape);
50 }
51
CheckBatchNormSupportedKind(const Array & array)52 void CheckBatchNormSupportedKind(const Array& array) {
53 // BatchNorm only supports inputs of float kind.
54 if (GetKind(array.dtype()) != DtypeKind::kFloat) {
55 throw DtypeError{"BatchNorm only supports floating kind inputs."};
56 }
57 }
58
59 // Reshapes the input arrays (except x) as needed.
60 // Sorted axes is also returned.
PreprocessBatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,const OptionalAxes & axis)61 PreprocessBatchNormResult PreprocessBatchNorm(
62 const Array& x, const Array& gamma, const Array& beta, const Array& mean, const Array& var, const OptionalAxes& axis) {
63 CheckBatchNormSupportedKind(x);
64 CheckBatchNormSupportedKind(gamma);
65 CheckBatchNormSupportedKind(beta);
66 CheckBatchNormSupportedKind(mean);
67 CheckBatchNormSupportedKind(var);
68
69 Axes sorted_axis = axis.has_value() ? internal::GetSortedAxes(*axis, x.ndim()) : Axes{0};
70
71 Shape reduced_shape = internal::ReduceShape(x.shape(), sorted_axis, true);
72 int64_t reduced_size = reduced_shape.GetTotalSize();
73
74 if (gamma.GetTotalSize() != reduced_size) {
75 throw DimensionError{
76 "Gamma must have the same size as the reduced input. Actual: ", gamma.GetTotalSize(), ". Expected: ", reduced_size, "."};
77 }
78 if (beta.GetTotalSize() != reduced_size) {
79 throw DimensionError{
80 "Beta must have the same size as the reduced input. Actual: ", beta.GetTotalSize(), ". Expected: ", reduced_size, "."};
81 }
82 if (mean.GetTotalSize() != reduced_size) {
83 throw DimensionError{
84 "Mean must have the same size as the reduced input. Actual: ", mean.GetTotalSize(), ". Expected: ", reduced_size, "."};
85 }
86 if (var.GetTotalSize() != reduced_size) {
87 throw DimensionError{
88 "Variance must have the same size as the reduced input. Actual: ", var.GetTotalSize(), ". Expected: ", reduced_size, "."};
89 }
90
91 Array gamma_reshaped = ReshapeOrIdentity(gamma, reduced_shape);
92 Array beta_reshaped = ReshapeOrIdentity(beta, reduced_shape);
93 Array mean_reshaped = ReshapeOrIdentity(mean, reduced_shape);
94 Array var_reshaped = ReshapeOrIdentity(var, reduced_shape);
95 CHAINERX_ASSERT(gamma_reshaped.data() == gamma.data()); // No data copy should occur
96 CHAINERX_ASSERT(beta_reshaped.data() == beta.data());
97 CHAINERX_ASSERT(mean_reshaped.data() == mean.data());
98 CHAINERX_ASSERT(var_reshaped.data() == var.data());
99
100 return {std::move(gamma_reshaped), std::move(beta_reshaped), std::move(mean_reshaped), std::move(var_reshaped), sorted_axis};
101 }
102
ArrayOrZeros(const absl::optional<Array> & array,const Array & zeros_template,Dtype dtype)103 Array ArrayOrZeros(const absl::optional<Array>& array, const Array& zeros_template, Dtype dtype) {
104 if (array.has_value()) {
105 if (array->dtype() == dtype) {
106 return *array;
107 }
108 return array->AsType(dtype);
109 }
110 return Zeros(zeros_template.shape(), dtype, zeros_template.device());
111 }
112
ApplyGenericBatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,Scalar eps,const Axes & axis,Dtype interm_dtype,bool return_state,const absl::optional<Array> & out)113 std::tuple<Array, std::unique_ptr<BatchNormGradState>> ApplyGenericBatchNorm(
114 const Array& x,
115 const Array& gamma,
116 const Array& beta,
117 const Array& mean,
118 const Array& var,
119 Scalar eps,
120 const Axes& axis,
121 Dtype interm_dtype,
122 bool return_state,
123 const absl::optional<Array>& out) {
124 if (CHAINERX_DEBUG) {
125 Shape reduced_shape = internal::ReduceShape(x.shape(), axis, true);
126 CHAINERX_ASSERT(gamma.shape() == reduced_shape);
127 CHAINERX_ASSERT(beta.shape() == reduced_shape);
128
129 int64_t reduced_total_size = reduced_shape.GetTotalSize();
130 CHAINERX_ASSERT(mean.GetTotalSize() == reduced_total_size);
131 CHAINERX_ASSERT(var.GetTotalSize() == reduced_total_size);
132 }
133 // TODO(hvy): Implement and test the `out` argument.
134 if (out.has_value()) {
135 throw NotImplementedError{"Passing out as an argument is not yet supported."};
136 }
137
138 // TODO(hvy): Avoid `AsType` by passing dtype arguments to the following routines to minimize copies.
139 const Array& x_cast = x.AsType(interm_dtype, false);
140 const Array& gamma_cast = gamma.AsType(interm_dtype, false);
141 const Array& beta_cast = beta.AsType(interm_dtype, false);
142 Array mean_cast = mean.AsType(interm_dtype, false);
143 const Array& var_cast = var.AsType(interm_dtype, false);
144
145 Array inv_std = Reciprocal(Sqrt(var_cast + eps));
146 Array out_cast = (x_cast - mean_cast) * inv_std * gamma_cast + beta_cast;
147 const Array& actual_out = out_cast.dtype() == x.dtype() ? out_cast : out_cast.AsType(x.dtype());
148
149 std::unique_ptr<BatchNormGradState> state =
150 return_state ? std::make_unique<GenericBatchNormGradState>(std::move(mean_cast), std::move(inv_std), beta.dtype()) : nullptr;
151
152 return std::make_tuple(actual_out, std::move(state));
153 }
154
155 } // namespace
156
Call(const Array & x,const Array & gamma,const Array & beta,const Array & running_mean,const Array & running_var,Scalar eps,Scalar decay,const Axes & axis,bool return_state,const absl::optional<Array> & out)157 std::tuple<Array, std::unique_ptr<BatchNormGradState>> GenericBatchNormKernel::Call(
158 const Array& x,
159 const Array& gamma,
160 const Array& beta,
161 const Array& running_mean,
162 const Array& running_var,
163 Scalar eps,
164 Scalar decay,
165 const Axes& axis,
166 bool return_state,
167 const absl::optional<Array>& out) {
168 CHAINERX_ASSERT(internal::GetArrayBody(x)->nodes().empty());
169 CHAINERX_ASSERT(internal::GetArrayBody(gamma)->nodes().empty());
170 CHAINERX_ASSERT(internal::GetArrayBody(beta)->nodes().empty());
171 CHAINERX_ASSERT(GetKind(x.dtype()) == DtypeKind::kFloat);
172 CHAINERX_ASSERT(GetKind(gamma.dtype()) == DtypeKind::kFloat);
173 CHAINERX_ASSERT(GetKind(beta.dtype()) == DtypeKind::kFloat);
174 CHAINERX_ASSERT(GetKind(running_mean.dtype()) == DtypeKind::kFloat);
175 CHAINERX_ASSERT(GetKind(running_var.dtype()) == DtypeKind::kFloat);
176
177 // Compute the mean and variance of x with promoted dtype if the parameters have higher precisions.
178 Dtype interm_dtype = ResultType(x, gamma, beta);
179 const Array& x_cast = x.dtype() == interm_dtype ? x : x.AsType(interm_dtype);
180 Array x_mean = Mean(x_cast, axis, true);
181 Array x_var = Var(x_cast, axis, true);
182 std::tuple<Array, std::unique_ptr<BatchNormGradState>> result =
183 ApplyGenericBatchNorm(x, gamma, beta, x_mean, x_var, eps, axis, interm_dtype, return_state, out);
184
185 // Update running values.
186 // TODO(hvy): Avoid `AsType` when `IAdd` supports mixed dtypes.
187 Scalar inv_decay = Scalar{1.0 - static_cast<double>(decay)};
188 int64_t n = x.GetTotalSize() / gamma.GetTotalSize();
189 running_mean *= decay;
190 running_mean += (inv_decay * x_mean).AsType(running_mean.dtype(), false);
191 running_var *= decay;
192 running_var += (inv_decay * (static_cast<double>(n) / std::max(n - 1, int64_t{1})) * x_var).AsType(running_var.dtype(), false);
193
194 return result;
195 }
196
Call(const Array & x,const Array & gamma,const Array & gout,Scalar,const Axes & axis,const std::shared_ptr<BatchNormGradState> & state,const absl::optional<Array> & gx,const absl::optional<Array> & ggamma,const absl::optional<Array> & gbeta)197 std::tuple<Array, Array, Array> GenericBatchNormGradKernel::Call(
198 const Array& x,
199 const Array& gamma,
200 const Array& gout,
201 Scalar /*eps*/,
202 const Axes& axis,
203 const std::shared_ptr<BatchNormGradState>& state,
204 const absl::optional<Array>& gx,
205 const absl::optional<Array>& ggamma,
206 const absl::optional<Array>& gbeta) {
207 // TODO(hvy): Implement and test the `gx` argument.
208 if (gx.has_value()) {
209 throw NotImplementedError{"Passing gx as an argument is not yet supported."};
210 }
211 // TODO(hvy): Implement and test the `ggamma` argument.
212 if (ggamma.has_value()) {
213 throw NotImplementedError{"Passing ggamma as an argument is not yet supported."};
214 }
215 // TODO(hvy): Implement and test the `gbeta` argument.
216 if (gbeta.has_value()) {
217 throw NotImplementedError{"Passing gbeta as an argument is not yet supported."};
218 }
219
220 // TODO(hvy): Implement recomputation of x_mean and x_inv_std in case they are not given by the state.
221 CHAINERX_ASSERT(state != nullptr);
222 auto& generic_state = dynamic_cast<GenericBatchNormGradState&>(*state);
223 // x_mean and x_inv_std have promoted dtypes.
224 const Array& x_mean = generic_state.x_mean();
225 const Array& x_inv_std = generic_state.x_inv_std(); // Note: x_inv_std_ has the information of eps.
226 Dtype beta_dtype = generic_state.beta_dtype();
227
228 // TODO(hvy): Avoid `AsType`.
229 Dtype interm_dtype = x_mean.dtype();
230 int64_t n = x.GetTotalSize() / gamma.GetTotalSize();
231 double inv_n = 1.0 / n;
232 Array gout_cast = gout.AsType(interm_dtype, false);
233 Array x_hat = (x.AsType(interm_dtype, false) - x_mean) * x_inv_std;
234 Array actual_ggamma = (gout_cast * x_hat).Sum(axis, true);
235 Array actual_gbeta = gout_cast.Sum(axis, true);
236 Array actual_gx = (gamma.AsType(interm_dtype, false) * x_inv_std) * (gout_cast - (x_hat * actual_ggamma + actual_gbeta) * inv_n);
237
238 if (actual_gx.dtype() != x.dtype()) {
239 actual_gx = actual_gx.AsType(x.dtype());
240 }
241 if (actual_ggamma.dtype() != gamma.dtype()) {
242 actual_ggamma = actual_ggamma.AsType(gamma.dtype());
243 }
244 if (actual_gbeta.dtype() != beta_dtype) {
245 actual_gbeta = actual_gbeta.AsType(beta_dtype);
246 }
247
248 return std::make_tuple(std::move(actual_gx), std::move(actual_ggamma), std::move(actual_gbeta));
249 }
250
Call(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,Scalar eps,const Axes & axis,const absl::optional<Array> & out)251 Array GenericFixedBatchNormKernel::Call(
252 const Array& x,
253 const Array& gamma,
254 const Array& beta,
255 const Array& mean,
256 const Array& var,
257 Scalar eps,
258 const Axes& axis,
259 const absl::optional<Array>& out) {
260 Dtype interm_dtype = ResultType(x, gamma, beta, mean, var);
261 std::tuple<Array, std::unique_ptr<BatchNormGradState>> result =
262 ApplyGenericBatchNorm(x, gamma, beta, mean, var, eps, axis, interm_dtype, false, out);
263 return out.has_value() ? *out : std::get<0>(result);
264 }
265
BatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & running_mean,const Array & running_var,Scalar eps,Scalar decay,const OptionalAxes & axis)266 Array BatchNorm(
267 const Array& x,
268 const Array& gamma,
269 const Array& beta,
270 const Array& running_mean,
271 const Array& running_var,
272 Scalar eps,
273 Scalar decay,
274 const OptionalAxes& axis) {
275 // Preprocess inputs.
276 PreprocessBatchNormResult result = PreprocessBatchNorm(x, gamma, beta, running_mean, running_var, axis);
277 const Array& gamma_reshaped = result.gamma;
278 const Array& beta_reshaped = result.beta;
279 const Array& mean_reshaped = result.mean;
280 const Array& var_reshaped = result.var;
281 const Axes& sorted_axis = result.sorted_axis;
282
283 Device& device = x.device();
284
285 // Compute forward.
286 Array out{};
287 std::shared_ptr<BatchNormGradState> state{};
288 {
289 NoBackpropModeScope scope{};
290 std::tie(out, state) = device.backend().CallKernel<BatchNormKernel>(
291 x.AsGradStopped(),
292 gamma_reshaped.AsGradStopped(),
293 beta_reshaped.AsGradStopped(),
294 mean_reshaped,
295 var_reshaped,
296 eps,
297 decay,
298 sorted_axis,
299 true,
300 absl::nullopt);
301 }
302 CHAINERX_ASSERT(state != nullptr);
303
304 internal::MakeViewForForwardBackwardOutput(out);
305
306 BackwardBuilder bb{"batch_norm", {x, gamma_reshaped, beta_reshaped}, {out}};
307 if (BackwardBuilder::Target bt = bb.CreateTarget({0, 1, 2})) {
308 bt.Define([state = std::move(state),
309 x_tok = bb.RetainInput(0),
310 gamma_tok = bb.RetainInput(1),
311 eps,
312 sorted_axis,
313 beta_shape = beta_reshaped.shape(),
314 beta_dtype = beta_reshaped.dtype()](BackwardContext& bctx) {
315 const Array& gout = *bctx.output_grad();
316 const Array& x = bctx.GetRetainedInput(x_tok);
317 const Array& gamma_reshaped = bctx.GetRetainedInput(gamma_tok);
318
319 Device& device = x.device();
320
321 // Compute backward.
322 Array gx{};
323 Array ggamma{};
324 Array gbeta{};
325 {
326 NoBackpropModeScope scope{};
327 std::tie(gx, ggamma, gbeta) = device.backend().CallKernel<BatchNormGradKernel>(
328 x, gamma_reshaped, gout, eps, sorted_axis, state, absl::nullopt, absl::nullopt, absl::nullopt);
329 }
330 CHAINERX_ASSERT(internal::GetArrayBody(gx)->nodes().empty());
331 CHAINERX_ASSERT(internal::GetArrayBody(ggamma)->nodes().empty());
332 CHAINERX_ASSERT(internal::GetArrayBody(gbeta)->nodes().empty());
333
334 if (bctx.next_required()) {
335 BackwardBuilder bb2{"batch_norm_backward", {x, gamma_reshaped, gout}, {gx, ggamma, gbeta}};
336 if (BackwardBuilder::Target bt2 = bb2.CreateTarget({0, 1, 2})) {
337 bt2.Define([x_tok = bb2.RetainInput(0),
338 gamma2_tok = bb2.RetainInput(1),
339 gout_tok = bb2.RetainInput(2),
340 eps,
341 sorted_axis,
342 gx_tok = bb2.RetainOutput(0),
343 ggamma_tok = bb2.RetainOutput(1)](BackwardContext& bctx2) {
344 const Array& x_retained = bctx2.GetRetainedInput(x_tok);
345 const Array& gamma_reshaped_retained = bctx2.GetRetainedInput(gamma2_tok);
346 const Array& gout_retained = bctx2.GetRetainedInput(gout_tok);
347
348 // TODO(hvy): Avoid AsType by passing dtype arguments to Mean, Var, etc. to minimize copies.
349 Dtype interm_dtype = ResultType(gout_retained, x_retained, gamma_reshaped_retained);
350 const Array& x = x_retained.AsType(interm_dtype, false);
351 const Array& gamma_reshaped = gamma_reshaped_retained.AsType(interm_dtype, false);
352 const Array& gout = gout_retained.AsType(interm_dtype, false);
353
354 Array ggx = ArrayOrZeros(bctx2.output_grad(0), x, interm_dtype);
355 Array gggamma = ArrayOrZeros(bctx2.output_grad(1), gamma_reshaped, interm_dtype);
356 Array ggbeta = ArrayOrZeros(bctx2.output_grad(2), gamma_reshaped, interm_dtype);
357
358 const Array& x_mean = Mean(x, sorted_axis, true).AsType(interm_dtype, false);
359 const Array& x_var = Var(x, sorted_axis, true).AsType(interm_dtype, false);
360 const Array& x_inv_std = Reciprocal(Sqrt(x_var + eps)).AsType(interm_dtype, false);
361
362 const Array& gx = bctx2.GetRetainedOutput(gx_tok).AsType(interm_dtype, false);
363 const Array& ggamma = bctx2.GetRetainedOutput(ggamma_tok).AsType(interm_dtype, false);
364
365 // Auxiliary values
366 int64_t n = x.GetTotalSize() / gamma_reshaped.GetTotalSize();
367 double inv_n = 1.0 / n;
368 Array r = (gx * ggx).Sum(sorted_axis, true);
369 Array coeff = gamma_reshaped * x_inv_std;
370 Array coeff_m = coeff * inv_n;
371 Array x_hat = (x - x_mean) * x_inv_std;
372
373 Array gggamma2 = gggamma - coeff_m * (x_hat * ggx).Sum(sorted_axis, true);
374 Array ggbeta2 = ggbeta - coeff_m * ggx.Sum(sorted_axis, true);
375
376 Array gx_hat2 = gggamma2 * gout - coeff_m * ggamma * ggx;
377 Array gstd2 = -x_inv_std * (r + (x_hat * gx_hat2).Sum(sorted_axis, true));
378 Array gmean2 = -x_inv_std * gx_hat2.Sum(sorted_axis, true);
379 Array gx2 = x_inv_std * gx_hat2 + inv_n * (gmean2 + x_hat * gstd2);
380 Array ggout2 = gggamma2 * x_hat + ggbeta2 + coeff * ggx;
381
382 Array ggamma2 = r / gamma_reshaped;
383
384 if (gx2.dtype() != x_retained.dtype()) {
385 gx2 = gx2.AsType(x_retained.dtype());
386 }
387 if (ggamma2.dtype() != gamma_reshaped_retained.dtype()) {
388 ggamma2 = ggamma2.AsType(gamma_reshaped_retained.dtype());
389 }
390
391 if (ggout2.dtype() != gout_retained.dtype()) {
392 ggout2 = ggout2.AsType(gout_retained.dtype());
393 }
394
395 bctx2.input_grad(0) = std::move(gx2);
396 bctx2.input_grad(1) = std::move(ggamma2);
397 bctx2.input_grad(2) = std::move(ggout2);
398 });
399 }
400 bb2.Finalize();
401 }
402
403 // TODO(niboshi): Assign at once
404 bctx.input_grad(0) = std::move(gx);
405 bctx.input_grad(1) = std::move(ggamma);
406 bctx.input_grad(2) = std::move(gbeta);
407 });
408 }
409 bb.Finalize();
410
411 return out;
412 }
413
FixedBatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,Scalar eps,const OptionalAxes & axis)414 Array FixedBatchNorm(
415 const Array& x, const Array& gamma, const Array& beta, const Array& mean, const Array& var, Scalar eps, const OptionalAxes& axis) {
416 PreprocessBatchNormResult result =
417 PreprocessBatchNorm(x, gamma.AsGradStopped(), beta.AsGradStopped(), mean.AsGradStopped(), var.AsGradStopped(), axis);
418
419 {
420 NoBackpropModeScope scope{};
421 return x.device().backend().CallKernel<FixedBatchNormKernel>(
422 x.AsGradStopped(), result.gamma, result.beta, result.mean, result.var, eps, result.sorted_axis, absl::nullopt);
423 }
424 }
425
426 } // namespace chainerx
427