1 #include "chainerx/routines/normalization.h"
2 
3 #include <algorithm>
4 #include <cstdint>
5 #include <memory>
6 #include <tuple>
7 #include <utility>
8 
9 #include <absl/types/optional.h>
10 
11 #include "chainerx/array.h"
12 #include "chainerx/axes.h"
13 #include "chainerx/backprop_mode.h"
14 #include "chainerx/backward_builder.h"
15 #include "chainerx/backward_context.h"
16 #include "chainerx/device.h"
17 #include "chainerx/dtype.h"
18 #include "chainerx/error.h"
19 #include "chainerx/graph.h"
20 #include "chainerx/kernels/normalization.h"
21 #include "chainerx/macro.h"
22 #include "chainerx/routines/arithmetic.h"
23 #include "chainerx/routines/creation.h"
24 #include "chainerx/routines/misc.h"
25 #include "chainerx/routines/routines_util.h"
26 #include "chainerx/routines/statistics.h"
27 #include "chainerx/routines/type_util.h"
28 #include "chainerx/scalar.h"
29 #include "chainerx/shape.h"
30 
31 namespace chainerx {
32 namespace {
33 
34 struct PreprocessBatchNormResult {
35     // Arrays are reshaped if necessary
36     Array gamma;
37     Array beta;
38     Array mean;
39     Array var;
40     Axes sorted_axis;
41 };
42 
43 // Reshapes the array. If the shape is unchanged, an array with identical array body is returned. Note that chainerx::Reshape() returns
44 // a view with different array body if the shape is unchanged.
ReshapeOrIdentity(const Array & a,const Shape & shape)45 Array ReshapeOrIdentity(const Array& a, const Shape& shape) {
46     if (a.shape() == shape) {
47         return a;
48     }
49     return a.Reshape(shape);
50 }
51 
CheckBatchNormSupportedKind(const Array & array)52 void CheckBatchNormSupportedKind(const Array& array) {
53     // BatchNorm only supports inputs of float kind.
54     if (GetKind(array.dtype()) != DtypeKind::kFloat) {
55         throw DtypeError{"BatchNorm only supports floating kind inputs."};
56     }
57 }
58 
59 // Reshapes the input arrays (except x) as needed.
60 // Sorted axes is also returned.
PreprocessBatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,const OptionalAxes & axis)61 PreprocessBatchNormResult PreprocessBatchNorm(
62         const Array& x, const Array& gamma, const Array& beta, const Array& mean, const Array& var, const OptionalAxes& axis) {
63     CheckBatchNormSupportedKind(x);
64     CheckBatchNormSupportedKind(gamma);
65     CheckBatchNormSupportedKind(beta);
66     CheckBatchNormSupportedKind(mean);
67     CheckBatchNormSupportedKind(var);
68 
69     Axes sorted_axis = axis.has_value() ? internal::GetSortedAxes(*axis, x.ndim()) : Axes{0};
70 
71     Shape reduced_shape = internal::ReduceShape(x.shape(), sorted_axis, true);
72     int64_t reduced_size = reduced_shape.GetTotalSize();
73 
74     if (gamma.GetTotalSize() != reduced_size) {
75         throw DimensionError{
76                 "Gamma must have the same size as the reduced input. Actual: ", gamma.GetTotalSize(), ". Expected: ", reduced_size, "."};
77     }
78     if (beta.GetTotalSize() != reduced_size) {
79         throw DimensionError{
80                 "Beta must have the same size as the reduced input. Actual: ", beta.GetTotalSize(), ". Expected: ", reduced_size, "."};
81     }
82     if (mean.GetTotalSize() != reduced_size) {
83         throw DimensionError{
84                 "Mean must have the same size as the reduced input. Actual: ", mean.GetTotalSize(), ". Expected: ", reduced_size, "."};
85     }
86     if (var.GetTotalSize() != reduced_size) {
87         throw DimensionError{
88                 "Variance must have the same size as the reduced input. Actual: ", var.GetTotalSize(), ". Expected: ", reduced_size, "."};
89     }
90 
91     Array gamma_reshaped = ReshapeOrIdentity(gamma, reduced_shape);
92     Array beta_reshaped = ReshapeOrIdentity(beta, reduced_shape);
93     Array mean_reshaped = ReshapeOrIdentity(mean, reduced_shape);
94     Array var_reshaped = ReshapeOrIdentity(var, reduced_shape);
95     CHAINERX_ASSERT(gamma_reshaped.data() == gamma.data());  // No data copy should occur
96     CHAINERX_ASSERT(beta_reshaped.data() == beta.data());
97     CHAINERX_ASSERT(mean_reshaped.data() == mean.data());
98     CHAINERX_ASSERT(var_reshaped.data() == var.data());
99 
100     return {std::move(gamma_reshaped), std::move(beta_reshaped), std::move(mean_reshaped), std::move(var_reshaped), sorted_axis};
101 }
102 
ArrayOrZeros(const absl::optional<Array> & array,const Array & zeros_template,Dtype dtype)103 Array ArrayOrZeros(const absl::optional<Array>& array, const Array& zeros_template, Dtype dtype) {
104     if (array.has_value()) {
105         if (array->dtype() == dtype) {
106             return *array;
107         }
108         return array->AsType(dtype);
109     }
110     return Zeros(zeros_template.shape(), dtype, zeros_template.device());
111 }
112 
ApplyGenericBatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,Scalar eps,const Axes & axis,Dtype interm_dtype,bool return_state,const absl::optional<Array> & out)113 std::tuple<Array, std::unique_ptr<BatchNormGradState>> ApplyGenericBatchNorm(
114         const Array& x,
115         const Array& gamma,
116         const Array& beta,
117         const Array& mean,
118         const Array& var,
119         Scalar eps,
120         const Axes& axis,
121         Dtype interm_dtype,
122         bool return_state,
123         const absl::optional<Array>& out) {
124     if (CHAINERX_DEBUG) {
125         Shape reduced_shape = internal::ReduceShape(x.shape(), axis, true);
126         CHAINERX_ASSERT(gamma.shape() == reduced_shape);
127         CHAINERX_ASSERT(beta.shape() == reduced_shape);
128 
129         int64_t reduced_total_size = reduced_shape.GetTotalSize();
130         CHAINERX_ASSERT(mean.GetTotalSize() == reduced_total_size);
131         CHAINERX_ASSERT(var.GetTotalSize() == reduced_total_size);
132     }
133     // TODO(hvy): Implement and test the `out` argument.
134     if (out.has_value()) {
135         throw NotImplementedError{"Passing out as an argument is not yet supported."};
136     }
137 
138     // TODO(hvy): Avoid `AsType` by passing dtype arguments to the following routines to minimize copies.
139     const Array& x_cast = x.AsType(interm_dtype, false);
140     const Array& gamma_cast = gamma.AsType(interm_dtype, false);
141     const Array& beta_cast = beta.AsType(interm_dtype, false);
142     Array mean_cast = mean.AsType(interm_dtype, false);
143     const Array& var_cast = var.AsType(interm_dtype, false);
144 
145     Array inv_std = Reciprocal(Sqrt(var_cast + eps));
146     Array out_cast = (x_cast - mean_cast) * inv_std * gamma_cast + beta_cast;
147     const Array& actual_out = out_cast.dtype() == x.dtype() ? out_cast : out_cast.AsType(x.dtype());
148 
149     std::unique_ptr<BatchNormGradState> state =
150             return_state ? std::make_unique<GenericBatchNormGradState>(std::move(mean_cast), std::move(inv_std), beta.dtype()) : nullptr;
151 
152     return std::make_tuple(actual_out, std::move(state));
153 }
154 
155 }  // namespace
156 
Call(const Array & x,const Array & gamma,const Array & beta,const Array & running_mean,const Array & running_var,Scalar eps,Scalar decay,const Axes & axis,bool return_state,const absl::optional<Array> & out)157 std::tuple<Array, std::unique_ptr<BatchNormGradState>> GenericBatchNormKernel::Call(
158         const Array& x,
159         const Array& gamma,
160         const Array& beta,
161         const Array& running_mean,
162         const Array& running_var,
163         Scalar eps,
164         Scalar decay,
165         const Axes& axis,
166         bool return_state,
167         const absl::optional<Array>& out) {
168     CHAINERX_ASSERT(internal::GetArrayBody(x)->nodes().empty());
169     CHAINERX_ASSERT(internal::GetArrayBody(gamma)->nodes().empty());
170     CHAINERX_ASSERT(internal::GetArrayBody(beta)->nodes().empty());
171     CHAINERX_ASSERT(GetKind(x.dtype()) == DtypeKind::kFloat);
172     CHAINERX_ASSERT(GetKind(gamma.dtype()) == DtypeKind::kFloat);
173     CHAINERX_ASSERT(GetKind(beta.dtype()) == DtypeKind::kFloat);
174     CHAINERX_ASSERT(GetKind(running_mean.dtype()) == DtypeKind::kFloat);
175     CHAINERX_ASSERT(GetKind(running_var.dtype()) == DtypeKind::kFloat);
176 
177     // Compute the mean and variance of x with promoted dtype if the parameters have higher precisions.
178     Dtype interm_dtype = ResultType(x, gamma, beta);
179     const Array& x_cast = x.dtype() == interm_dtype ? x : x.AsType(interm_dtype);
180     Array x_mean = Mean(x_cast, axis, true);
181     Array x_var = Var(x_cast, axis, true);
182     std::tuple<Array, std::unique_ptr<BatchNormGradState>> result =
183             ApplyGenericBatchNorm(x, gamma, beta, x_mean, x_var, eps, axis, interm_dtype, return_state, out);
184 
185     // Update running values.
186     // TODO(hvy): Avoid `AsType` when `IAdd` supports mixed dtypes.
187     Scalar inv_decay = Scalar{1.0 - static_cast<double>(decay)};
188     int64_t n = x.GetTotalSize() / gamma.GetTotalSize();
189     running_mean *= decay;
190     running_mean += (inv_decay * x_mean).AsType(running_mean.dtype(), false);
191     running_var *= decay;
192     running_var += (inv_decay * (static_cast<double>(n) / std::max(n - 1, int64_t{1})) * x_var).AsType(running_var.dtype(), false);
193 
194     return result;
195 }
196 
Call(const Array & x,const Array & gamma,const Array & gout,Scalar,const Axes & axis,const std::shared_ptr<BatchNormGradState> & state,const absl::optional<Array> & gx,const absl::optional<Array> & ggamma,const absl::optional<Array> & gbeta)197 std::tuple<Array, Array, Array> GenericBatchNormGradKernel::Call(
198         const Array& x,
199         const Array& gamma,
200         const Array& gout,
201         Scalar /*eps*/,
202         const Axes& axis,
203         const std::shared_ptr<BatchNormGradState>& state,
204         const absl::optional<Array>& gx,
205         const absl::optional<Array>& ggamma,
206         const absl::optional<Array>& gbeta) {
207     // TODO(hvy): Implement and test the `gx` argument.
208     if (gx.has_value()) {
209         throw NotImplementedError{"Passing gx as an argument is not yet supported."};
210     }
211     // TODO(hvy): Implement and test the `ggamma` argument.
212     if (ggamma.has_value()) {
213         throw NotImplementedError{"Passing ggamma as an argument is not yet supported."};
214     }
215     // TODO(hvy): Implement and test the `gbeta` argument.
216     if (gbeta.has_value()) {
217         throw NotImplementedError{"Passing gbeta as an argument is not yet supported."};
218     }
219 
220     // TODO(hvy): Implement recomputation of x_mean and x_inv_std in case they are not given by the state.
221     CHAINERX_ASSERT(state != nullptr);
222     auto& generic_state = dynamic_cast<GenericBatchNormGradState&>(*state);
223     // x_mean and x_inv_std have promoted dtypes.
224     const Array& x_mean = generic_state.x_mean();
225     const Array& x_inv_std = generic_state.x_inv_std();  // Note: x_inv_std_ has the information of eps.
226     Dtype beta_dtype = generic_state.beta_dtype();
227 
228     // TODO(hvy): Avoid `AsType`.
229     Dtype interm_dtype = x_mean.dtype();
230     int64_t n = x.GetTotalSize() / gamma.GetTotalSize();
231     double inv_n = 1.0 / n;
232     Array gout_cast = gout.AsType(interm_dtype, false);
233     Array x_hat = (x.AsType(interm_dtype, false) - x_mean) * x_inv_std;
234     Array actual_ggamma = (gout_cast * x_hat).Sum(axis, true);
235     Array actual_gbeta = gout_cast.Sum(axis, true);
236     Array actual_gx = (gamma.AsType(interm_dtype, false) * x_inv_std) * (gout_cast - (x_hat * actual_ggamma + actual_gbeta) * inv_n);
237 
238     if (actual_gx.dtype() != x.dtype()) {
239         actual_gx = actual_gx.AsType(x.dtype());
240     }
241     if (actual_ggamma.dtype() != gamma.dtype()) {
242         actual_ggamma = actual_ggamma.AsType(gamma.dtype());
243     }
244     if (actual_gbeta.dtype() != beta_dtype) {
245         actual_gbeta = actual_gbeta.AsType(beta_dtype);
246     }
247 
248     return std::make_tuple(std::move(actual_gx), std::move(actual_ggamma), std::move(actual_gbeta));
249 }
250 
Call(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,Scalar eps,const Axes & axis,const absl::optional<Array> & out)251 Array GenericFixedBatchNormKernel::Call(
252         const Array& x,
253         const Array& gamma,
254         const Array& beta,
255         const Array& mean,
256         const Array& var,
257         Scalar eps,
258         const Axes& axis,
259         const absl::optional<Array>& out) {
260     Dtype interm_dtype = ResultType(x, gamma, beta, mean, var);
261     std::tuple<Array, std::unique_ptr<BatchNormGradState>> result =
262             ApplyGenericBatchNorm(x, gamma, beta, mean, var, eps, axis, interm_dtype, false, out);
263     return out.has_value() ? *out : std::get<0>(result);
264 }
265 
BatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & running_mean,const Array & running_var,Scalar eps,Scalar decay,const OptionalAxes & axis)266 Array BatchNorm(
267         const Array& x,
268         const Array& gamma,
269         const Array& beta,
270         const Array& running_mean,
271         const Array& running_var,
272         Scalar eps,
273         Scalar decay,
274         const OptionalAxes& axis) {
275     // Preprocess inputs.
276     PreprocessBatchNormResult result = PreprocessBatchNorm(x, gamma, beta, running_mean, running_var, axis);
277     const Array& gamma_reshaped = result.gamma;
278     const Array& beta_reshaped = result.beta;
279     const Array& mean_reshaped = result.mean;
280     const Array& var_reshaped = result.var;
281     const Axes& sorted_axis = result.sorted_axis;
282 
283     Device& device = x.device();
284 
285     // Compute forward.
286     Array out{};
287     std::shared_ptr<BatchNormGradState> state{};
288     {
289         NoBackpropModeScope scope{};
290         std::tie(out, state) = device.backend().CallKernel<BatchNormKernel>(
291                 x.AsGradStopped(),
292                 gamma_reshaped.AsGradStopped(),
293                 beta_reshaped.AsGradStopped(),
294                 mean_reshaped,
295                 var_reshaped,
296                 eps,
297                 decay,
298                 sorted_axis,
299                 true,
300                 absl::nullopt);
301     }
302     CHAINERX_ASSERT(state != nullptr);
303 
304     internal::MakeViewForForwardBackwardOutput(out);
305 
306     BackwardBuilder bb{"batch_norm", {x, gamma_reshaped, beta_reshaped}, {out}};
307     if (BackwardBuilder::Target bt = bb.CreateTarget({0, 1, 2})) {
308         bt.Define([state = std::move(state),
309                    x_tok = bb.RetainInput(0),
310                    gamma_tok = bb.RetainInput(1),
311                    eps,
312                    sorted_axis,
313                    beta_shape = beta_reshaped.shape(),
314                    beta_dtype = beta_reshaped.dtype()](BackwardContext& bctx) {
315             const Array& gout = *bctx.output_grad();
316             const Array& x = bctx.GetRetainedInput(x_tok);
317             const Array& gamma_reshaped = bctx.GetRetainedInput(gamma_tok);
318 
319             Device& device = x.device();
320 
321             // Compute backward.
322             Array gx{};
323             Array ggamma{};
324             Array gbeta{};
325             {
326                 NoBackpropModeScope scope{};
327                 std::tie(gx, ggamma, gbeta) = device.backend().CallKernel<BatchNormGradKernel>(
328                         x, gamma_reshaped, gout, eps, sorted_axis, state, absl::nullopt, absl::nullopt, absl::nullopt);
329             }
330             CHAINERX_ASSERT(internal::GetArrayBody(gx)->nodes().empty());
331             CHAINERX_ASSERT(internal::GetArrayBody(ggamma)->nodes().empty());
332             CHAINERX_ASSERT(internal::GetArrayBody(gbeta)->nodes().empty());
333 
334             if (bctx.next_required()) {
335                 BackwardBuilder bb2{"batch_norm_backward", {x, gamma_reshaped, gout}, {gx, ggamma, gbeta}};
336                 if (BackwardBuilder::Target bt2 = bb2.CreateTarget({0, 1, 2})) {
337                     bt2.Define([x_tok = bb2.RetainInput(0),
338                                 gamma2_tok = bb2.RetainInput(1),
339                                 gout_tok = bb2.RetainInput(2),
340                                 eps,
341                                 sorted_axis,
342                                 gx_tok = bb2.RetainOutput(0),
343                                 ggamma_tok = bb2.RetainOutput(1)](BackwardContext& bctx2) {
344                         const Array& x_retained = bctx2.GetRetainedInput(x_tok);
345                         const Array& gamma_reshaped_retained = bctx2.GetRetainedInput(gamma2_tok);
346                         const Array& gout_retained = bctx2.GetRetainedInput(gout_tok);
347 
348                         // TODO(hvy): Avoid AsType by passing dtype arguments to Mean, Var, etc. to minimize copies.
349                         Dtype interm_dtype = ResultType(gout_retained, x_retained, gamma_reshaped_retained);
350                         const Array& x = x_retained.AsType(interm_dtype, false);
351                         const Array& gamma_reshaped = gamma_reshaped_retained.AsType(interm_dtype, false);
352                         const Array& gout = gout_retained.AsType(interm_dtype, false);
353 
354                         Array ggx = ArrayOrZeros(bctx2.output_grad(0), x, interm_dtype);
355                         Array gggamma = ArrayOrZeros(bctx2.output_grad(1), gamma_reshaped, interm_dtype);
356                         Array ggbeta = ArrayOrZeros(bctx2.output_grad(2), gamma_reshaped, interm_dtype);
357 
358                         const Array& x_mean = Mean(x, sorted_axis, true).AsType(interm_dtype, false);
359                         const Array& x_var = Var(x, sorted_axis, true).AsType(interm_dtype, false);
360                         const Array& x_inv_std = Reciprocal(Sqrt(x_var + eps)).AsType(interm_dtype, false);
361 
362                         const Array& gx = bctx2.GetRetainedOutput(gx_tok).AsType(interm_dtype, false);
363                         const Array& ggamma = bctx2.GetRetainedOutput(ggamma_tok).AsType(interm_dtype, false);
364 
365                         // Auxiliary values
366                         int64_t n = x.GetTotalSize() / gamma_reshaped.GetTotalSize();
367                         double inv_n = 1.0 / n;
368                         Array r = (gx * ggx).Sum(sorted_axis, true);
369                         Array coeff = gamma_reshaped * x_inv_std;
370                         Array coeff_m = coeff * inv_n;
371                         Array x_hat = (x - x_mean) * x_inv_std;
372 
373                         Array gggamma2 = gggamma - coeff_m * (x_hat * ggx).Sum(sorted_axis, true);
374                         Array ggbeta2 = ggbeta - coeff_m * ggx.Sum(sorted_axis, true);
375 
376                         Array gx_hat2 = gggamma2 * gout - coeff_m * ggamma * ggx;
377                         Array gstd2 = -x_inv_std * (r + (x_hat * gx_hat2).Sum(sorted_axis, true));
378                         Array gmean2 = -x_inv_std * gx_hat2.Sum(sorted_axis, true);
379                         Array gx2 = x_inv_std * gx_hat2 + inv_n * (gmean2 + x_hat * gstd2);
380                         Array ggout2 = gggamma2 * x_hat + ggbeta2 + coeff * ggx;
381 
382                         Array ggamma2 = r / gamma_reshaped;
383 
384                         if (gx2.dtype() != x_retained.dtype()) {
385                             gx2 = gx2.AsType(x_retained.dtype());
386                         }
387                         if (ggamma2.dtype() != gamma_reshaped_retained.dtype()) {
388                             ggamma2 = ggamma2.AsType(gamma_reshaped_retained.dtype());
389                         }
390 
391                         if (ggout2.dtype() != gout_retained.dtype()) {
392                             ggout2 = ggout2.AsType(gout_retained.dtype());
393                         }
394 
395                         bctx2.input_grad(0) = std::move(gx2);
396                         bctx2.input_grad(1) = std::move(ggamma2);
397                         bctx2.input_grad(2) = std::move(ggout2);
398                     });
399                 }
400                 bb2.Finalize();
401             }
402 
403             // TODO(niboshi): Assign at once
404             bctx.input_grad(0) = std::move(gx);
405             bctx.input_grad(1) = std::move(ggamma);
406             bctx.input_grad(2) = std::move(gbeta);
407         });
408     }
409     bb.Finalize();
410 
411     return out;
412 }
413 
FixedBatchNorm(const Array & x,const Array & gamma,const Array & beta,const Array & mean,const Array & var,Scalar eps,const OptionalAxes & axis)414 Array FixedBatchNorm(
415         const Array& x, const Array& gamma, const Array& beta, const Array& mean, const Array& var, Scalar eps, const OptionalAxes& axis) {
416     PreprocessBatchNormResult result =
417             PreprocessBatchNorm(x, gamma.AsGradStopped(), beta.AsGradStopped(), mean.AsGradStopped(), var.AsGradStopped(), axis);
418 
419     {
420         NoBackpropModeScope scope{};
421         return x.device().backend().CallKernel<FixedBatchNormKernel>(
422                 x.AsGradStopped(), result.gamma, result.beta, result.mean, result.var, eps, result.sorted_axis, absl::nullopt);
423     }
424 }
425 
426 }  // namespace chainerx
427