1 #include "chainerx/routines/reduction.h"
2
3 #include <cmath>
4 #include <cstdint>
5 #include <numeric>
6 #include <utility>
7 #include <vector>
8
9 #include <absl/types/optional.h>
10
11 #include "chainerx/array.h"
12 #include "chainerx/axes.h"
13 #include "chainerx/backprop_mode.h"
14 #include "chainerx/backward_builder.h"
15 #include "chainerx/backward_context.h"
16 #include "chainerx/dtype.h"
17 #include "chainerx/error.h"
18 #include "chainerx/graph.h"
19 #include "chainerx/kernels/reduction.h"
20 #include "chainerx/macro.h"
21 #include "chainerx/routines/arithmetic.h"
22 #include "chainerx/routines/creation.h"
23 #include "chainerx/routines/explog.h"
24 #include "chainerx/routines/indexing.h"
25 #include "chainerx/routines/logic.h"
26 #include "chainerx/routines/manipulation.h"
27 #include "chainerx/routines/routines_util.h"
28 #include "chainerx/routines/statistics.h"
29 #include "chainerx/routines/type_util.h"
30 #include "chainerx/shape.h"
31
32 namespace chainerx {
33
Sum(const Array & a,const OptionalAxes & axis,bool keepdims)34 Array Sum(const Array& a, const OptionalAxes& axis, bool keepdims) {
35 Axes sorted_axis = internal::GetSortedAxesOrAll(axis, a.ndim());
36
37 // Decide the output dtype for integral input dtype.
38 Dtype out_dtype{};
39 switch (GetKind(a.dtype())) {
40 case DtypeKind::kBool:
41 case DtypeKind::kInt: // fallthrough
42 out_dtype = Dtype::kInt64;
43 break;
44 case DtypeKind::kUInt:
45 out_dtype = Dtype::kInt64; // TODO(niboshi): This should be kUInt64
46 break;
47 default:
48 out_dtype = a.dtype();
49 }
50
51 Array out = internal::EmptyReduced(a.shape(), out_dtype, sorted_axis, keepdims, a.device());
52 {
53 NoBackpropModeScope scope{};
54 a.device().backend().CallKernel<SumKernel>(a, sorted_axis, out);
55 }
56
57 BackwardBuilder bb{"sum", a, out};
58 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
59 bt.Define([sorted_axis, in_shape = a.shape(), keepdims](BackwardContext& bctx) {
60 const Array& gout = *bctx.output_grad();
61 CHAINERX_ASSERT(std::is_sorted(sorted_axis.begin(), sorted_axis.end()));
62
63 if (!(in_shape.ndim() == 0 || sorted_axis.empty() || keepdims)) {
64 Shape out_shape_broadcastable = gout.shape();
65 for (auto axis : sorted_axis) {
66 out_shape_broadcastable.insert(out_shape_broadcastable.begin() + axis, 1);
67 }
68 bctx.input_grad() = gout.Reshape(out_shape_broadcastable).BroadcastTo(in_shape);
69 } else {
70 bctx.input_grad() = gout.BroadcastTo(in_shape);
71 }
72 });
73 }
74 bb.Finalize();
75 return out;
76 }
77
Softmax(const Array & x,const OptionalAxes & axis)78 Array Softmax(const Array& x, const OptionalAxes& axis) {
79 Dtype dtype = internal::GetMathResultDtype(x.dtype());
80 const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
81 Axes sorted_axis = internal::GetSortedAxesOrAll(axis.has_value() ? axis : OptionalAxes{1}, x.ndim());
82 Array xmax = AMax(x_cast, sorted_axis, true);
83 Array exps = Exp(x_cast - xmax);
84 Array sums = Sum(exps, sorted_axis, true);
85 return exps * Reciprocal(sums);
86 }
87
LogSumExp(const Array & x,const OptionalAxes & axis,bool keepdims)88 Array LogSumExp(const Array& x, const OptionalAxes& axis, bool keepdims) {
89 Dtype dtype = internal::GetMathResultDtype(x.dtype());
90 const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
91 Axes sorted_axis = internal::GetSortedAxesOrAll(axis, x.ndim());
92 Array xmax = AMax(x_cast, sorted_axis, true);
93 Array logs = Log(Sum(Exp(x_cast - xmax), sorted_axis, keepdims));
94 return (keepdims ? xmax : Squeeze(xmax, axis)) + logs;
95 }
96
LogSoftmax(const Array & x,const OptionalAxes & axis)97 Array LogSoftmax(const Array& x, const OptionalAxes& axis) {
98 Dtype dtype = internal::GetMathResultDtype(x.dtype());
99 const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
100 return x_cast - LogSumExp(x_cast, axis.has_value() ? axis : OptionalAxes{1}, true);
101 }
102
Cumsum(const Array & a,absl::optional<int8_t> axis)103 Array Cumsum(const Array& a, absl::optional<int8_t> axis) {
104 int8_t axis_norm;
105 Array a_reshaped{};
106 if (axis.has_value()) {
107 axis_norm = internal::NormalizeAxis(*axis, a.ndim());
108 a_reshaped = a;
109 } else {
110 axis_norm = 0;
111 // TODO(imanishi): Fix after chainerx::Ravel is supported.
112 a_reshaped = a.Reshape(Shape{a.GetTotalSize()});
113 }
114
115 // Decide the output dtype for integral input dtype.
116 Dtype out_dtype{};
117 switch (GetKind(a_reshaped.dtype())) {
118 case DtypeKind::kBool:
119 case DtypeKind::kInt: // fallthrough
120 out_dtype = Dtype::kInt64;
121 break;
122 case DtypeKind::kUInt:
123 out_dtype = Dtype::kInt64; // TODO(niboshi): This should be kUInt64
124 break;
125 default:
126 out_dtype = a_reshaped.dtype();
127 }
128
129 const Array& out = Empty(a_reshaped.shape(), out_dtype, a_reshaped.device());
130
131 {
132 NoBackpropModeScope scope{};
133 a.device().backend().CallKernel<CumsumKernel>(a_reshaped, axis_norm, out);
134 }
135
136 // TODO(aksub99): Improve backward implementation to prevent flipping gout twice.
137 BackwardBuilder bb{"cumsum", a_reshaped, out};
138 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
139 bt.Define([axis_norm, in_shape = a_reshaped.shape()](BackwardContext& bctx) {
140 const Array& gout = *bctx.output_grad();
141 Array input_grad = Flip(Cumsum(Flip(gout, axis_norm), axis_norm), axis_norm);
142 bctx.input_grad() = input_grad.Reshape(in_shape);
143 });
144 }
145 bb.Finalize();
146 return out;
147 }
148
Nansum(const Array & a,const OptionalAxes & axis,bool keepdims)149 Array Nansum(const Array& a, const OptionalAxes& axis, bool keepdims) {
150 Axes sorted_axis = internal::GetSortedAxesOrAll(axis, a.ndim());
151 Array a_masked = Where(IsNan(a), 0, a);
152 // Decide the output dtype for integral input dtype.
153 Dtype out_dtype{};
154 switch (GetKind(a_masked.dtype())) {
155 case DtypeKind::kBool:
156 case DtypeKind::kInt: // fallthrough
157 out_dtype = Dtype::kInt64;
158 break;
159 case DtypeKind::kUInt:
160 out_dtype = Dtype::kInt64; // TODO(niboshi): This should be kUInt64
161 break;
162 default:
163 out_dtype = a.dtype();
164 }
165
166 Array out = internal::EmptyReduced(a_masked.shape(), out_dtype, sorted_axis, keepdims, a_masked.device());
167 {
168 NoBackpropModeScope scope{};
169 a.device().backend().CallKernel<NansumKernel>(a_masked, sorted_axis, out);
170 }
171
172 BackwardBuilder bb{"nansum", a, out};
173 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
174 bt.Define([a_tok = bb.RetainInput(0), sorted_axis, in_shape = a.shape(), keepdims](BackwardContext& bctx) {
175 const Array& gout = *bctx.output_grad();
176 const Array& input = bctx.GetRetainedInput(a_tok);
177 Array& input_grad = bctx.input_grad();
178 CHAINERX_ASSERT(std::is_sorted(sorted_axis.begin(), sorted_axis.end()));
179
180 if (!(in_shape.ndim() == 0 || sorted_axis.empty() || keepdims)) {
181 Shape out_shape_broadcastable = gout.shape();
182 for (auto axis : sorted_axis) {
183 out_shape_broadcastable.insert(out_shape_broadcastable.begin() + axis, 1);
184 }
185 input_grad = gout.Reshape(out_shape_broadcastable).BroadcastTo(in_shape);
186 } else {
187 input_grad = gout.BroadcastTo(in_shape);
188 }
189 input_grad = Where(IsNan(input), 0, input_grad);
190 });
191 }
192 bb.Finalize();
193 return out;
194 }
195
196 } // namespace chainerx
197