1 #pragma once 2 3 #include <memory> 4 #include <tuple> 5 #include <utility> 6 7 #include <absl/types/optional.h> 8 9 #include "chainerx/array.h" 10 #include "chainerx/axes.h" 11 #include "chainerx/dtype.h" 12 #include "chainerx/kernel.h" 13 #include "chainerx/scalar.h" 14 15 namespace chainerx { 16 17 // Intermediate results from `BatchNormKernel::Call` can be stored in this construct and be reused in `BatchNormGradKernel::Call`. 18 // The objects to store may vary depending on backend so each backend should derive this class to define the actual set of intermediate 19 // results. 20 class BatchNormGradState { 21 public: 22 BatchNormGradState() = default; 23 24 virtual ~BatchNormGradState() = default; 25 26 BatchNormGradState(const BatchNormGradState&) = delete; 27 BatchNormGradState(BatchNormGradState&&) = delete; 28 BatchNormGradState& operator=(const BatchNormGradState&) = delete; 29 BatchNormGradState& operator=(BatchNormGradState&&) = delete; 30 }; 31 32 class BatchNormKernel : public Kernel { 33 public: 34 // The returned state should be a `nullptr` if `return_state` is `false`. 35 virtual std::tuple<Array, std::unique_ptr<BatchNormGradState>> Call( 36 const Array& x, 37 const Array& gamma, 38 const Array& beta, 39 const Array& running_mean, 40 const Array& running_var, 41 Scalar eps, 42 Scalar decay, 43 const Axes& axis, 44 bool return_state, 45 const absl::optional<Array>& out) = 0; 46 }; 47 48 class BatchNormGradKernel : public Kernel { 49 public: 50 // Returns gx, ggamma, gbeta. 51 virtual std::tuple<Array, Array, Array> Call( 52 const Array& x, 53 const Array& gamma, 54 const Array& gout, 55 Scalar eps, 56 const Axes& axis, 57 const std::shared_ptr<BatchNormGradState>& state, 58 const absl::optional<Array>& gx, 59 const absl::optional<Array>& ggamma, 60 const absl::optional<Array>& gbeta) = 0; 61 }; 62 63 class GenericBatchNormGradState : public BatchNormGradState { 64 public: GenericBatchNormGradState(Array x_mean,Array x_inv_std,Dtype beta_dtype)65 GenericBatchNormGradState(Array x_mean, Array x_inv_std, Dtype beta_dtype) 66 : x_mean_{std::move(x_mean)}, x_inv_std_{std::move(x_inv_std)}, beta_dtype_{beta_dtype} {} 67 x_mean()68 const Array& x_mean() const { return x_mean_; } x_inv_std()69 const Array& x_inv_std() const { return x_inv_std_; } beta_dtype()70 Dtype beta_dtype() const { return beta_dtype_; } 71 72 private: 73 Array x_mean_; 74 Array x_inv_std_; 75 Dtype beta_dtype_; 76 }; 77 78 class GenericBatchNormKernel : public BatchNormKernel { 79 public: 80 std::tuple<Array, std::unique_ptr<BatchNormGradState>> Call( 81 const Array& x, 82 const Array& gamma, 83 const Array& beta, 84 const Array& running_mean, 85 const Array& running_var, 86 Scalar eps, 87 Scalar decay, 88 const Axes& axis, 89 bool return_state, 90 const absl::optional<Array>& out) override; 91 }; 92 93 class GenericBatchNormGradKernel : public BatchNormGradKernel { 94 public: 95 std::tuple<Array, Array, Array> Call( 96 const Array& x, 97 const Array& gamma, 98 const Array& gout, 99 Scalar eps, 100 const Axes& axis, 101 const std::shared_ptr<BatchNormGradState>& state, 102 const absl::optional<Array>& gx, 103 const absl::optional<Array>& ggamma, 104 const absl::optional<Array>& gbeta) override; 105 }; 106 107 class FixedBatchNormKernel : public Kernel { 108 public: 109 virtual Array Call( 110 const Array& x, 111 const Array& gamma, 112 const Array& beta, 113 const Array& mean, 114 const Array& var, 115 Scalar eps, 116 const Axes& axis, 117 const absl::optional<Array>& out) = 0; 118 }; 119 120 class GenericFixedBatchNormKernel : public FixedBatchNormKernel { 121 public: 122 Array Call( 123 const Array& x, 124 const Array& gamma, 125 const Array& beta, 126 const Array& mean, 127 const Array& var, 128 Scalar eps, 129 const Axes& axis, 130 const absl::optional<Array>& out) override; 131 }; 132 133 } // namespace chainerx 134