1 #pragma once
2 
3 #include <memory>
4 #include <tuple>
5 #include <utility>
6 
7 #include <absl/types/optional.h>
8 
9 #include "chainerx/array.h"
10 #include "chainerx/axes.h"
11 #include "chainerx/dtype.h"
12 #include "chainerx/kernel.h"
13 #include "chainerx/scalar.h"
14 
15 namespace chainerx {
16 
17 // Intermediate results from `BatchNormKernel::Call` can be stored in this construct and be reused in `BatchNormGradKernel::Call`.
18 // The objects to store may vary depending on backend so each backend should derive this class to define the actual set of intermediate
19 // results.
20 class BatchNormGradState {
21 public:
22     BatchNormGradState() = default;
23 
24     virtual ~BatchNormGradState() = default;
25 
26     BatchNormGradState(const BatchNormGradState&) = delete;
27     BatchNormGradState(BatchNormGradState&&) = delete;
28     BatchNormGradState& operator=(const BatchNormGradState&) = delete;
29     BatchNormGradState& operator=(BatchNormGradState&&) = delete;
30 };
31 
32 class BatchNormKernel : public Kernel {
33 public:
34     // The returned state should be a `nullptr` if `return_state` is `false`.
35     virtual std::tuple<Array, std::unique_ptr<BatchNormGradState>> Call(
36             const Array& x,
37             const Array& gamma,
38             const Array& beta,
39             const Array& running_mean,
40             const Array& running_var,
41             Scalar eps,
42             Scalar decay,
43             const Axes& axis,
44             bool return_state,
45             const absl::optional<Array>& out) = 0;
46 };
47 
48 class BatchNormGradKernel : public Kernel {
49 public:
50     // Returns gx, ggamma, gbeta.
51     virtual std::tuple<Array, Array, Array> Call(
52             const Array& x,
53             const Array& gamma,
54             const Array& gout,
55             Scalar eps,
56             const Axes& axis,
57             const std::shared_ptr<BatchNormGradState>& state,
58             const absl::optional<Array>& gx,
59             const absl::optional<Array>& ggamma,
60             const absl::optional<Array>& gbeta) = 0;
61 };
62 
63 class GenericBatchNormGradState : public BatchNormGradState {
64 public:
GenericBatchNormGradState(Array x_mean,Array x_inv_std,Dtype beta_dtype)65     GenericBatchNormGradState(Array x_mean, Array x_inv_std, Dtype beta_dtype)
66         : x_mean_{std::move(x_mean)}, x_inv_std_{std::move(x_inv_std)}, beta_dtype_{beta_dtype} {}
67 
x_mean()68     const Array& x_mean() const { return x_mean_; }
x_inv_std()69     const Array& x_inv_std() const { return x_inv_std_; }
beta_dtype()70     Dtype beta_dtype() const { return beta_dtype_; }
71 
72 private:
73     Array x_mean_;
74     Array x_inv_std_;
75     Dtype beta_dtype_;
76 };
77 
78 class GenericBatchNormKernel : public BatchNormKernel {
79 public:
80     std::tuple<Array, std::unique_ptr<BatchNormGradState>> Call(
81             const Array& x,
82             const Array& gamma,
83             const Array& beta,
84             const Array& running_mean,
85             const Array& running_var,
86             Scalar eps,
87             Scalar decay,
88             const Axes& axis,
89             bool return_state,
90             const absl::optional<Array>& out) override;
91 };
92 
93 class GenericBatchNormGradKernel : public BatchNormGradKernel {
94 public:
95     std::tuple<Array, Array, Array> Call(
96             const Array& x,
97             const Array& gamma,
98             const Array& gout,
99             Scalar eps,
100             const Axes& axis,
101             const std::shared_ptr<BatchNormGradState>& state,
102             const absl::optional<Array>& gx,
103             const absl::optional<Array>& ggamma,
104             const absl::optional<Array>& gbeta) override;
105 };
106 
107 class FixedBatchNormKernel : public Kernel {
108 public:
109     virtual Array Call(
110             const Array& x,
111             const Array& gamma,
112             const Array& beta,
113             const Array& mean,
114             const Array& var,
115             Scalar eps,
116             const Axes& axis,
117             const absl::optional<Array>& out) = 0;
118 };
119 
120 class GenericFixedBatchNormKernel : public FixedBatchNormKernel {
121 public:
122     Array Call(
123             const Array& x,
124             const Array& gamma,
125             const Array& beta,
126             const Array& mean,
127             const Array& var,
128             Scalar eps,
129             const Axes& axis,
130             const absl::optional<Array>& out) override;
131 };
132 
133 }  // namespace chainerx
134