1 #include "chainerx/routines/reduction.h"
2 
3 #include <cmath>
4 #include <cstdint>
5 #include <numeric>
6 #include <utility>
7 #include <vector>
8 
9 #include <absl/types/optional.h>
10 
11 #include "chainerx/array.h"
12 #include "chainerx/axes.h"
13 #include "chainerx/backprop_mode.h"
14 #include "chainerx/backward_builder.h"
15 #include "chainerx/backward_context.h"
16 #include "chainerx/dtype.h"
17 #include "chainerx/error.h"
18 #include "chainerx/graph.h"
19 #include "chainerx/kernels/reduction.h"
20 #include "chainerx/macro.h"
21 #include "chainerx/routines/arithmetic.h"
22 #include "chainerx/routines/creation.h"
23 #include "chainerx/routines/explog.h"
24 #include "chainerx/routines/indexing.h"
25 #include "chainerx/routines/logic.h"
26 #include "chainerx/routines/manipulation.h"
27 #include "chainerx/routines/routines_util.h"
28 #include "chainerx/routines/statistics.h"
29 #include "chainerx/routines/type_util.h"
30 #include "chainerx/shape.h"
31 
32 namespace chainerx {
33 
Sum(const Array & a,const OptionalAxes & axis,bool keepdims)34 Array Sum(const Array& a, const OptionalAxes& axis, bool keepdims) {
35     Axes sorted_axis = internal::GetSortedAxesOrAll(axis, a.ndim());
36 
37     // Decide the output dtype for integral input dtype.
38     Dtype out_dtype{};
39     switch (GetKind(a.dtype())) {
40         case DtypeKind::kBool:
41         case DtypeKind::kInt:  // fallthrough
42             out_dtype = Dtype::kInt64;
43             break;
44         case DtypeKind::kUInt:
45             out_dtype = Dtype::kInt64;  // TODO(niboshi): This should be kUInt64
46             break;
47         default:
48             out_dtype = a.dtype();
49     }
50 
51     Array out = internal::EmptyReduced(a.shape(), out_dtype, sorted_axis, keepdims, a.device());
52     {
53         NoBackpropModeScope scope{};
54         a.device().backend().CallKernel<SumKernel>(a, sorted_axis, out);
55     }
56 
57     BackwardBuilder bb{"sum", a, out};
58     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
59         bt.Define([sorted_axis, in_shape = a.shape(), keepdims](BackwardContext& bctx) {
60             const Array& gout = *bctx.output_grad();
61             CHAINERX_ASSERT(std::is_sorted(sorted_axis.begin(), sorted_axis.end()));
62 
63             if (!(in_shape.ndim() == 0 || sorted_axis.empty() || keepdims)) {
64                 Shape out_shape_broadcastable = gout.shape();
65                 for (auto axis : sorted_axis) {
66                     out_shape_broadcastable.insert(out_shape_broadcastable.begin() + axis, 1);
67                 }
68                 bctx.input_grad() = gout.Reshape(out_shape_broadcastable).BroadcastTo(in_shape);
69             } else {
70                 bctx.input_grad() = gout.BroadcastTo(in_shape);
71             }
72         });
73     }
74     bb.Finalize();
75     return out;
76 }
77 
Softmax(const Array & x,const OptionalAxes & axis)78 Array Softmax(const Array& x, const OptionalAxes& axis) {
79     Dtype dtype = internal::GetMathResultDtype(x.dtype());
80     const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
81     Axes sorted_axis = internal::GetSortedAxesOrAll(axis.has_value() ? axis : OptionalAxes{1}, x.ndim());
82     Array xmax = AMax(x_cast, sorted_axis, true);
83     Array exps = Exp(x_cast - xmax);
84     Array sums = Sum(exps, sorted_axis, true);
85     return exps * Reciprocal(sums);
86 }
87 
LogSumExp(const Array & x,const OptionalAxes & axis,bool keepdims)88 Array LogSumExp(const Array& x, const OptionalAxes& axis, bool keepdims) {
89     Dtype dtype = internal::GetMathResultDtype(x.dtype());
90     const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
91     Axes sorted_axis = internal::GetSortedAxesOrAll(axis, x.ndim());
92     Array xmax = AMax(x_cast, sorted_axis, true);
93     Array logs = Log(Sum(Exp(x_cast - xmax), sorted_axis, keepdims));
94     return (keepdims ? xmax : Squeeze(xmax, axis)) + logs;
95 }
96 
LogSoftmax(const Array & x,const OptionalAxes & axis)97 Array LogSoftmax(const Array& x, const OptionalAxes& axis) {
98     Dtype dtype = internal::GetMathResultDtype(x.dtype());
99     const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
100     return x_cast - LogSumExp(x_cast, axis.has_value() ? axis : OptionalAxes{1}, true);
101 }
102 
Cumsum(const Array & a,absl::optional<int8_t> axis)103 Array Cumsum(const Array& a, absl::optional<int8_t> axis) {
104     int8_t axis_norm;
105     Array a_reshaped{};
106     if (axis.has_value()) {
107         axis_norm = internal::NormalizeAxis(*axis, a.ndim());
108         a_reshaped = a;
109     } else {
110         axis_norm = 0;
111         // TODO(imanishi): Fix after chainerx::Ravel is supported.
112         a_reshaped = a.Reshape(Shape{a.GetTotalSize()});
113     }
114 
115     // Decide the output dtype for integral input dtype.
116     Dtype out_dtype{};
117     switch (GetKind(a_reshaped.dtype())) {
118         case DtypeKind::kBool:
119         case DtypeKind::kInt:  // fallthrough
120             out_dtype = Dtype::kInt64;
121             break;
122         case DtypeKind::kUInt:
123             out_dtype = Dtype::kInt64;  // TODO(niboshi): This should be kUInt64
124             break;
125         default:
126             out_dtype = a_reshaped.dtype();
127     }
128 
129     const Array& out = Empty(a_reshaped.shape(), out_dtype, a_reshaped.device());
130 
131     {
132         NoBackpropModeScope scope{};
133         a.device().backend().CallKernel<CumsumKernel>(a_reshaped, axis_norm, out);
134     }
135 
136     // TODO(aksub99): Improve backward implementation to prevent flipping gout twice.
137     BackwardBuilder bb{"cumsum", a_reshaped, out};
138     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
139         bt.Define([axis_norm, in_shape = a_reshaped.shape()](BackwardContext& bctx) {
140             const Array& gout = *bctx.output_grad();
141             Array input_grad = Flip(Cumsum(Flip(gout, axis_norm), axis_norm), axis_norm);
142             bctx.input_grad() = input_grad.Reshape(in_shape);
143         });
144     }
145     bb.Finalize();
146     return out;
147 }
148 
Nansum(const Array & a,const OptionalAxes & axis,bool keepdims)149 Array Nansum(const Array& a, const OptionalAxes& axis, bool keepdims) {
150     Axes sorted_axis = internal::GetSortedAxesOrAll(axis, a.ndim());
151     Array a_masked = Where(IsNan(a), 0, a);
152     // Decide the output dtype for integral input dtype.
153     Dtype out_dtype{};
154     switch (GetKind(a_masked.dtype())) {
155         case DtypeKind::kBool:
156         case DtypeKind::kInt:  // fallthrough
157             out_dtype = Dtype::kInt64;
158             break;
159         case DtypeKind::kUInt:
160             out_dtype = Dtype::kInt64;  // TODO(niboshi): This should be kUInt64
161             break;
162         default:
163             out_dtype = a.dtype();
164     }
165 
166     Array out = internal::EmptyReduced(a_masked.shape(), out_dtype, sorted_axis, keepdims, a_masked.device());
167     {
168         NoBackpropModeScope scope{};
169         a.device().backend().CallKernel<NansumKernel>(a_masked, sorted_axis, out);
170     }
171 
172     BackwardBuilder bb{"nansum", a, out};
173     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
174         bt.Define([a_tok = bb.RetainInput(0), sorted_axis, in_shape = a.shape(), keepdims](BackwardContext& bctx) {
175             const Array& gout = *bctx.output_grad();
176             const Array& input = bctx.GetRetainedInput(a_tok);
177             Array& input_grad = bctx.input_grad();
178             CHAINERX_ASSERT(std::is_sorted(sorted_axis.begin(), sorted_axis.end()));
179 
180             if (!(in_shape.ndim() == 0 || sorted_axis.empty() || keepdims)) {
181                 Shape out_shape_broadcastable = gout.shape();
182                 for (auto axis : sorted_axis) {
183                     out_shape_broadcastable.insert(out_shape_broadcastable.begin() + axis, 1);
184                 }
185                 input_grad = gout.Reshape(out_shape_broadcastable).BroadcastTo(in_shape);
186             } else {
187                 input_grad = gout.BroadcastTo(in_shape);
188             }
189             input_grad = Where(IsNan(input), 0, input_grad);
190         });
191     }
192     bb.Finalize();
193     return out;
194 }
195 
196 }  // namespace chainerx
197