1 #include "chainerx/array.h"
2 
3 #include <algorithm>
4 #include <cmath>
5 #include <cstdint>
6 #include <cstring>
7 #include <memory>
8 #include <numeric>
9 #include <ostream>
10 #include <string>
11 #include <tuple>
12 #include <unordered_map>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16 
17 #include <absl/types/optional.h>
18 #include <absl/types/span.h>
19 
20 #include "chainerx/array_body.h"
21 #include "chainerx/array_node.h"
22 #include "chainerx/array_repr.h"
23 #include "chainerx/axes.h"
24 #include "chainerx/backend.h"
25 #include "chainerx/backprop_mode.h"
26 #include "chainerx/backward.h"
27 #include "chainerx/backward_builder.h"
28 #include "chainerx/backward_context.h"
29 #include "chainerx/context.h"
30 #include "chainerx/device.h"
31 #include "chainerx/dtype.h"
32 #include "chainerx/error.h"
33 #include "chainerx/graph.h"
34 #include "chainerx/kernels/creation.h"
35 #include "chainerx/kernels/misc.h"
36 #include "chainerx/macro.h"
37 #include "chainerx/native/native_backend.h"
38 #include "chainerx/op_node.h"
39 #include "chainerx/routines/arithmetic.h"
40 #include "chainerx/routines/binary.h"
41 #include "chainerx/routines/creation.h"
42 #include "chainerx/routines/indexing.h"
43 #include "chainerx/routines/linalg.h"
44 #include "chainerx/routines/logic.h"
45 #include "chainerx/routines/manipulation.h"
46 #include "chainerx/routines/reduction.h"
47 #include "chainerx/routines/routines_util.h"
48 #include "chainerx/routines/sorting.h"
49 #include "chainerx/routines/statistics.h"
50 #include "chainerx/scalar.h"
51 
52 namespace chainerx {
53 namespace internal {
54 
GetArrayBackpropId(const Array & array,const absl::optional<BackpropId> & backprop_id)55 BackpropId GetArrayBackpropId(const Array& array, const absl::optional<BackpropId>& backprop_id) {
56     return backprop_id.has_value() ? *backprop_id : array.device().context().default_backprop_id();
57 }
58 
MakeArray(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)59 Array MakeArray(const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset) {
60     return Array{shape, strides, dtype, device, std::move(data), offset};
61 }
62 
MoveArrayBodies(std::vector<Array> && arrays)63 std::vector<std::shared_ptr<ArrayBody>> MoveArrayBodies(std::vector<Array>&& arrays) {
64     std::vector<std::shared_ptr<ArrayBody>> array_body_ptrs;
65     array_body_ptrs.reserve(arrays.size());
66     for (Array& array : arrays) {
67         array_body_ptrs.emplace_back(MoveArrayBody(std::move(array)));
68     }
69     return array_body_ptrs;
70 }
71 
MoveArrayBodies(std::vector<absl::optional<Array>> && arrays)72 std::vector<std::shared_ptr<ArrayBody>> MoveArrayBodies(std::vector<absl::optional<Array>>&& arrays) {
73     std::vector<std::shared_ptr<ArrayBody>> array_body_ptrs;
74     array_body_ptrs.reserve(arrays.size());
75     for (absl::optional<Array>& array : arrays) {
76         if (array.has_value()) {
77             array_body_ptrs.emplace_back(MoveArrayBody(std::move(*array)));
78         } else {
79             array_body_ptrs.emplace_back(nullptr);
80         }
81     }
82     return array_body_ptrs;
83 }
84 
85 }  // namespace internal
86 
Array(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)87 Array::Array(const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset)
88     : body_{internal::CreateArrayBody(shape, strides, dtype, device, std::move(data), offset)} {}
89 
operator -() const90 Array Array::operator-() const { return Negative(*this); }
91 
operator ==(const Array & rhs) const92 Array Array::operator==(const Array& rhs) const { return Equal(*this, rhs); }
93 
operator !=(const Array & rhs) const94 Array Array::operator!=(const Array& rhs) const { return NotEqual(*this, rhs); }
95 
operator >(const Array & rhs) const96 Array Array::operator>(const Array& rhs) const { return Greater(*this, rhs); }
97 
operator >=(const Array & rhs) const98 Array Array::operator>=(const Array& rhs) const { return GreaterEqual(*this, rhs); }
99 
operator <(const Array & rhs) const100 Array Array::operator<(const Array& rhs) const { return Less(*this, rhs); }
101 
operator <=(const Array & rhs) const102 Array Array::operator<=(const Array& rhs) const { return LessEqual(*this, rhs); }
103 
operator +=(const Array & rhs)104 Array& Array::operator+=(const Array& rhs) {
105     internal::IAdd(*this, rhs);
106     return *this;
107 }
108 
operator +=(Scalar rhs)109 Array& Array::operator+=(Scalar rhs) {
110     internal::IAdd(*this, rhs);
111     return *this;
112 }
113 
operator -=(const Array & rhs)114 Array& Array::operator-=(const Array& rhs) {
115     internal::ISubtract(*this, rhs);
116     return *this;
117 }
118 
operator -=(Scalar rhs)119 Array& Array::operator-=(Scalar rhs) {
120     internal::ISubtract(*this, rhs);
121     return *this;
122 }
123 
operator *=(const Array & rhs)124 Array& Array::operator*=(const Array& rhs) {
125     internal::IMultiply(*this, rhs);
126     return *this;
127 }
128 
operator *=(Scalar rhs)129 Array& Array::operator*=(Scalar rhs) {
130     internal::IMultiply(*this, rhs);
131     return *this;
132 }
133 
operator /=(const Array & rhs)134 Array& Array::operator/=(const Array& rhs) {
135     internal::IDivide(*this, rhs);
136     return *this;
137 }
138 
operator /=(Scalar rhs)139 Array& Array::operator/=(Scalar rhs) {
140     internal::IDivide(*this, rhs);
141     return *this;
142 }
143 
operator %=(const Array & rhs)144 Array& Array::operator%=(const Array& rhs) {
145     internal::IMod(*this, rhs);
146     return *this;
147 }
148 
operator %=(Scalar rhs)149 Array& Array::operator%=(Scalar rhs) {
150     internal::IMod(*this, rhs);
151     return *this;
152 }
153 
operator &=(const Array & rhs)154 Array& Array::operator&=(const Array& rhs) {
155     internal::IBitwiseAnd(*this, rhs);
156     return *this;
157 }
158 
operator &=(Scalar rhs)159 Array& Array::operator&=(Scalar rhs) {
160     internal::IBitwiseAnd(*this, rhs);
161     return *this;
162 }
163 
operator |=(const Array & rhs)164 Array& Array::operator|=(const Array& rhs) {
165     internal::IBitwiseOr(*this, rhs);
166     return *this;
167 }
168 
operator |=(Scalar rhs)169 Array& Array::operator|=(Scalar rhs) {
170     internal::IBitwiseOr(*this, rhs);
171     return *this;
172 }
173 
operator ^=(const Array & rhs)174 Array& Array::operator^=(const Array& rhs) {
175     internal::IBitwiseXor(*this, rhs);
176     return *this;
177 }
178 
operator ^=(Scalar rhs)179 Array& Array::operator^=(Scalar rhs) {
180     internal::IBitwiseXor(*this, rhs);
181     return *this;
182 }
183 
operator <<=(const Array & rhs)184 Array& Array::operator<<=(const Array& rhs) {
185     internal::ILeftShift(*this, rhs);
186     return *this;
187 }
188 
operator <<=(Scalar rhs)189 Array& Array::operator<<=(Scalar rhs) {
190     internal::ILeftShift(*this, rhs);
191     return *this;
192 }
193 
operator >>=(const Array & rhs)194 Array& Array::operator>>=(const Array& rhs) {
195     internal::IRightShift(*this, rhs);
196     return *this;
197 }
198 
operator >>=(Scalar rhs)199 Array& Array::operator>>=(Scalar rhs) {
200     internal::IRightShift(*this, rhs);
201     return *this;
202 }
203 
operator +=(const Array & rhs) const204 const Array& Array::operator+=(const Array& rhs) const {
205     internal::IAdd(*this, rhs);
206     return *this;
207 }
208 
operator +=(Scalar rhs) const209 const Array& Array::operator+=(Scalar rhs) const {
210     internal::IAdd(*this, rhs);
211     return *this;
212 }
213 
operator -=(const Array & rhs) const214 const Array& Array::operator-=(const Array& rhs) const {
215     internal::ISubtract(*this, rhs);
216     return *this;
217 }
218 
operator -=(Scalar rhs) const219 const Array& Array::operator-=(Scalar rhs) const {
220     internal::ISubtract(*this, rhs);
221     return *this;
222 }
223 
operator *=(const Array & rhs) const224 const Array& Array::operator*=(const Array& rhs) const {
225     internal::IMultiply(*this, rhs);
226     return *this;
227 }
228 
operator *=(Scalar rhs) const229 const Array& Array::operator*=(Scalar rhs) const {
230     internal::IMultiply(*this, rhs);
231     return *this;
232 }
233 
operator /=(const Array & rhs) const234 const Array& Array::operator/=(const Array& rhs) const {
235     internal::IDivide(*this, rhs);
236     return *this;
237 }
238 
operator /=(Scalar rhs) const239 const Array& Array::operator/=(Scalar rhs) const {
240     internal::IDivide(*this, rhs);
241     return *this;
242 }
243 
operator %=(const Array & rhs) const244 const Array& Array::operator%=(const Array& rhs) const {
245     internal::IMod(*this, rhs);
246     return *this;
247 }
248 
operator %=(Scalar rhs) const249 const Array& Array::operator%=(Scalar rhs) const {
250     internal::IMod(*this, rhs);
251     return *this;
252 }
253 
operator &=(const Array & rhs) const254 const Array& Array::operator&=(const Array& rhs) const {
255     internal::IBitwiseAnd(*this, rhs);
256     return *this;
257 }
258 
operator &=(Scalar rhs) const259 const Array& Array::operator&=(Scalar rhs) const {
260     internal::IBitwiseAnd(*this, rhs);
261     return *this;
262 }
263 
operator |=(const Array & rhs) const264 const Array& Array::operator|=(const Array& rhs) const {
265     internal::IBitwiseOr(*this, rhs);
266     return *this;
267 }
268 
operator |=(Scalar rhs) const269 const Array& Array::operator|=(Scalar rhs) const {
270     internal::IBitwiseOr(*this, rhs);
271     return *this;
272 }
273 
operator ^=(const Array & rhs) const274 const Array& Array::operator^=(const Array& rhs) const {
275     internal::IBitwiseXor(*this, rhs);
276     return *this;
277 }
278 
operator ^=(Scalar rhs) const279 const Array& Array::operator^=(Scalar rhs) const {
280     internal::IBitwiseXor(*this, rhs);
281     return *this;
282 }
283 
operator <<=(const Array & rhs) const284 const Array& Array::operator<<=(const Array& rhs) const {
285     internal::ILeftShift(*this, rhs);
286     return *this;
287 }
288 
operator <<=(Scalar rhs) const289 const Array& Array::operator<<=(Scalar rhs) const {
290     internal::ILeftShift(*this, rhs);
291     return *this;
292 }
293 
operator >>=(const Array & rhs) const294 const Array& Array::operator>>=(const Array& rhs) const {
295     internal::IRightShift(*this, rhs);
296     return *this;
297 }
298 
operator >>=(Scalar rhs) const299 const Array& Array::operator>>=(Scalar rhs) const {
300     internal::IRightShift(*this, rhs);
301     return *this;
302 }
303 
operator +(const Array & rhs) const304 Array Array::operator+(const Array& rhs) const { return chainerx::Add(*this, rhs); }
305 
operator +(Scalar rhs) const306 Array Array::operator+(Scalar rhs) const { return chainerx::Add(*this, rhs); }
307 
operator -(const Array & rhs) const308 Array Array::operator-(const Array& rhs) const { return chainerx::Subtract(*this, rhs); }
309 
operator -(Scalar rhs) const310 Array Array::operator-(Scalar rhs) const { return chainerx::Subtract(*this, rhs); }
311 
operator *(const Array & rhs) const312 Array Array::operator*(const Array& rhs) const { return Multiply(*this, rhs); }
313 
operator *(Scalar rhs) const314 Array Array::operator*(Scalar rhs) const { return Multiply(*this, rhs); }
315 
operator /(const Array & rhs) const316 Array Array::operator/(const Array& rhs) const { return chainerx::Divide(*this, rhs); }
317 
operator /(Scalar rhs) const318 Array Array::operator/(Scalar rhs) const { return chainerx::Divide(*this, rhs); }
319 
operator %(const Array & rhs) const320 Array Array::operator%(const Array& rhs) const { return chainerx::Mod(*this, rhs); }
321 
operator %(Scalar rhs) const322 Array Array::operator%(Scalar rhs) const { return chainerx::Mod(*this, rhs); }
323 
operator &(const Array & rhs) const324 Array Array::operator&(const Array& rhs) const { return chainerx::BitwiseAnd(*this, rhs); }
325 
operator &(Scalar rhs) const326 Array Array::operator&(Scalar rhs) const { return chainerx::BitwiseAnd(*this, rhs); }
327 
operator |(const Array & rhs) const328 Array Array::operator|(const Array& rhs) const { return chainerx::BitwiseOr(*this, rhs); }
329 
operator |(Scalar rhs) const330 Array Array::operator|(Scalar rhs) const { return chainerx::BitwiseOr(*this, rhs); }
331 
operator ^(const Array & rhs) const332 Array Array::operator^(const Array& rhs) const { return chainerx::BitwiseXor(*this, rhs); }
333 
operator ^(Scalar rhs) const334 Array Array::operator^(Scalar rhs) const { return chainerx::BitwiseXor(*this, rhs); }
335 
operator <<(const Array & rhs) const336 Array Array::operator<<(const Array& rhs) const { return chainerx::LeftShift(*this, rhs); }
337 
operator <<(Scalar rhs) const338 Array Array::operator<<(Scalar rhs) const { return chainerx::LeftShift(*this, rhs); }
339 
operator >>(const Array & rhs) const340 Array Array::operator>>(const Array& rhs) const { return chainerx::RightShift(*this, rhs); }
341 
operator >>(Scalar rhs) const342 Array Array::operator>>(Scalar rhs) const { return chainerx::RightShift(*this, rhs); }
343 
At(const std::vector<ArrayIndex> & indices) const344 Array Array::At(const std::vector<ArrayIndex>& indices) const { return internal::At(*this, indices); }
345 
Transpose(const OptionalAxes & axes) const346 Array Array::Transpose(const OptionalAxes& axes) const { return chainerx::Transpose(*this, axes); }
347 
Ravel() const348 Array Array::Ravel() const { return chainerx::Ravel(*this); }
349 
Reshape(const Shape & newshape) const350 Array Array::Reshape(const Shape& newshape) const { return chainerx::Reshape(*this, newshape); }
351 
Squeeze(const OptionalAxes & axis) const352 Array Array::Squeeze(const OptionalAxes& axis) const { return chainerx::Squeeze(*this, axis); }
353 
Swapaxes(int8_t axis1,int8_t axis2) const354 Array Array::Swapaxes(int8_t axis1, int8_t axis2) const { return chainerx::Swapaxes(*this, axis1, axis2); }
355 
BroadcastTo(const Shape & shape) const356 Array Array::BroadcastTo(const Shape& shape) const { return chainerx::BroadcastTo(*this, shape); }
357 
ArgMax(const OptionalAxes & axis) const358 Array Array::ArgMax(const OptionalAxes& axis) const { return chainerx::ArgMax(*this, axis); }
359 
ArgMin(const OptionalAxes & axis) const360 Array Array::ArgMin(const OptionalAxes& axis) const { return chainerx::ArgMin(*this, axis); }
361 
Sum(const OptionalAxes & axis,bool keepdims) const362 Array Array::Sum(const OptionalAxes& axis, bool keepdims) const { return chainerx::Sum(*this, axis, keepdims); }
363 
Max(const OptionalAxes & axis,bool keepdims) const364 Array Array::Max(const OptionalAxes& axis, bool keepdims) const { return chainerx::AMax(*this, axis, keepdims); }
365 
Min(const OptionalAxes & axis,bool keepdims) const366 Array Array::Min(const OptionalAxes& axis, bool keepdims) const { return chainerx::AMin(*this, axis, keepdims); }
367 
Mean(const OptionalAxes & axis,bool keepdims) const368 Array Array::Mean(const OptionalAxes& axis, bool keepdims) const { return chainerx::Mean(*this, axis, keepdims); }
369 
Var(const OptionalAxes & axis,bool keepdims) const370 Array Array::Var(const OptionalAxes& axis, bool keepdims) const { return chainerx::Var(*this, axis, keepdims); }
371 
All(const OptionalAxes & axis,bool keepdims) const372 Array Array::All(const OptionalAxes& axis, bool keepdims) const { return chainerx::All(*this, axis, keepdims); }
373 
Any(const OptionalAxes & axis,bool keepdims) const374 Array Array::Any(const OptionalAxes& axis, bool keepdims) const { return chainerx::Any(*this, axis, keepdims); }
375 
Dot(const Array & b) const376 Array Array::Dot(const Array& b) const { return chainerx::Dot(*this, b); }
377 
Take(const Array & indices,int8_t axis,IndexBoundsMode mode) const378 Array Array::Take(const Array& indices, int8_t axis, IndexBoundsMode mode) const { return chainerx::Take(*this, indices, axis, mode); }
379 
Copy() const380 Array Array::Copy() const { return chainerx::Copy(*this); }
381 
Flatten() const382 Array Array::Flatten() const {
383     Array out = (*this).Copy().Reshape({(*this).GetTotalSize()});
384     return out;
385 }
386 
MakeView() const387 Array Array::MakeView() const {
388     Array out{shape(), strides(), dtype(), device(), data(), offset()};
389 
390     BackwardBuilder bb{"view", *this, out};
391     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
392         bt.Define([](BackwardContext& bctx) { bctx.input_grad() = *bctx.output_grad(); });
393     }
394     bb.Finalize();
395 
396     return out;
397 }
398 
ToDevice(Device & dst_device) const399 Array Array::ToDevice(Device& dst_device) const {
400     Device& src_device = body_->device();
401     Array out;
402 
403     // TODO(sonots): Avoid copying data between native devices, e.g., from native:0 to native:1 for performance.
404     if (&src_device == &dst_device) {
405         // Return an alias.
406         out = AsGradStopped(CopyKind::kView);
407     } else {
408         // Make a contiguous copy to transfer it to the destination device.
409         Array src_contig = AsContiguous(AsGradStopped(CopyKind::kView));
410 
411         std::shared_ptr<void> dst_data;
412         if (src_device.backend().SupportsTransfer(src_device, dst_device)) {
413             // Use src backend for transfer.
414             dst_data = src_device.TransferDataTo(dst_device, src_contig.data(), src_contig.offset(), src_contig.GetNBytes());
415         } else if (dst_device.backend().SupportsTransfer(src_device, dst_device)) {
416             // Use dst backend for transfer.
417             dst_data = dst_device.TransferDataFrom(src_device, src_contig.data(), src_contig.offset(), src_contig.GetNBytes());
418         } else {
419             // Neither backends support transfer.
420             throw ChainerxError{"Transfer between devices is not supported: src='", src_device.name(), "' dst='", dst_device.name(), "'."};
421         }
422         out = Array{src_contig.shape(), src_contig.strides(), src_contig.dtype(), dst_device, std::move(dst_data)};
423     }
424 
425     CHAINERX_ASSERT(internal::GetArrayBody(out) != nullptr);
426 
427     // Backward operation is implemented as backward-transfer.
428     BackwardBuilder bb{"transfer", *this, out};
429     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
430         bt.Define([&src_device](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->ToDevice(src_device); });
431     }
432     bb.Finalize();
433 
434     // TODO(niboshi): This assertion must succeed but currently it does not because AsContiguousArray reshapes {} to {1}.
435     // CHAINERX_ASSERT(out.shape() == shape());
436     CHAINERX_ASSERT(out.dtype() == dtype());
437     return out;
438 }
439 
ToNative() const440 Array Array::ToNative() const {
441     Backend& backend = device().backend();
442     Device& native_device = backend.IsNative() ? device() : backend.context().GetNativeBackend().GetDevice(0);
443     return ToDevice(native_device);
444 }
445 
446 namespace {
447 
CopyOrMakeView(const Array & array,CopyKind kind)448 Array CopyOrMakeView(const Array& array, CopyKind kind) {
449     switch (kind) {
450         case CopyKind::kCopy:
451             return array.Copy();
452         case CopyKind::kView:
453             return array.MakeView();
454         default:
455             CHAINERX_NEVER_REACH();
456     }
457 }
458 
459 }  // namespace
460 
AsGradStopped(CopyKind kind) const461 Array Array::AsGradStopped(CopyKind kind) const {
462     NoBackpropModeScope scope{device().context()};
463     return CopyOrMakeView(*this, kind);
464 }
465 
AsGradStopped(absl::Span<const BackpropId> backprop_ids,CopyKind kind) const466 Array Array::AsGradStopped(absl::Span<const BackpropId> backprop_ids, CopyKind kind) const {
467     NoBackpropModeScope scope{std::vector<BackpropId>{backprop_ids.begin(), backprop_ids.end()}};
468     return CopyOrMakeView(*this, kind);
469 }
470 
AsType(Dtype dtype,bool copy) const471 Array Array::AsType(Dtype dtype, bool copy) const {
472     Dtype src_dtype = this->dtype();
473     if (!copy && dtype == src_dtype) {
474         return *this;
475     }
476 
477     Array out = Empty(shape(), dtype, device());
478     // Note: In CopyKernel, Input Array Elements are casted to the type of Output Array.
479     device().backend().CallKernel<CopyKernel>(*this, out);
480 
481     if (GetKind(dtype) == DtypeKind::kFloat) {
482         BackwardBuilder bb{"astype", *this, out};
483         if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
484             bt.Define([src_dtype](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->AsType(src_dtype); });
485         }
486         bb.Finalize();
487     }
488 
489     CHAINERX_ASSERT(out.IsContiguous());
490     return out;
491 }
492 
Fill(Scalar value) const493 void Array::Fill(Scalar value) const {
494     internal::CheckNoUnsafeInplace(*this, {});
495     device().backend().CallKernel<FillKernel>(*this, value);
496 }
497 
GetGrad(const absl::optional<BackpropId> & backprop_id) const498 const absl::optional<Array>& Array::GetGrad(const absl::optional<BackpropId>& backprop_id) const {
499     BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
500     if (!IsGradRequired(actual_backprop_id)) {
501         throw ChainerxError{"Array is not flagged as requiring gradient for backprop id: '", actual_backprop_id, "'."};
502     }
503     const absl::optional<Array>* grad = body_->GetGrad(actual_backprop_id);
504     CHAINERX_ASSERT(grad != nullptr);
505     return *grad;
506 }
507 
SetGrad(Array grad,const absl::optional<BackpropId> & backprop_id) const508 void Array::SetGrad(Array grad, const absl::optional<BackpropId>& backprop_id) const {
509     BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
510     absl::optional<Array>* target_grad = body_->GetGrad(actual_backprop_id);
511     if (target_grad == nullptr) {
512         throw ChainerxError{"Array is constant with respect to the computation for backprop ID: '", actual_backprop_id, "'."};
513     }
514 
515     // Setting the gradient flags the array to require gradient, so that it can return the gradient with GetGrad().
516     RequireGrad(actual_backprop_id);
517 
518     internal::SetGrad(*target_grad, std::move(grad), shape(), dtype(), device());
519 }
520 
ClearGrad(const absl::optional<BackpropId> & backprop_id) const521 void Array::ClearGrad(const absl::optional<BackpropId>& backprop_id) const {
522     BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
523     if (!body_->HasArrayNode(actual_backprop_id)) {
524         throw ChainerxError{"Array is constant with respect to the computation for backprop ID: '", actual_backprop_id, "'."};
525     }
526     body_->ClearGrad(actual_backprop_id);
527 }
528 
IsBackpropRequired(const absl::optional<BackpropId> & backprop_id) const529 bool Array::IsBackpropRequired(const absl::optional<BackpropId>& backprop_id) const {
530     BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
531     return body_->HasArrayNode(actual_backprop_id) && chainerx::IsBackpropRequired(actual_backprop_id);
532 }
533 
IsBackpropRequired(AnyGraph) const534 bool Array::IsBackpropRequired(AnyGraph /*any_graph*/) const {
535     const std::vector<std::shared_ptr<internal::ArrayNode>>& array_nodes = body_->nodes();
536     return std::any_of(array_nodes.begin(), array_nodes.end(), [](const std::shared_ptr<const internal::ArrayNode>& array_node) {
537         return chainerx::IsBackpropRequired(array_node->backprop_id());
538     });
539 }
540 
IsGradRequired(const absl::optional<BackpropId> & backprop_id) const541 bool Array::IsGradRequired(const absl::optional<BackpropId>& backprop_id) const {
542     BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
543     return body_->IsGradRequired(actual_backprop_id);
544 }
545 
546 template <typename T>
RequireGradImpl(T & array,const absl::optional<BackpropId> & backprop_id)547 T& Array::RequireGradImpl(T& array, const absl::optional<BackpropId>& backprop_id) {
548     if (GetKind(array.dtype()) != DtypeKind::kFloat) {
549         throw DtypeError{"Array with integral dtype (", GetDtypeName(array.dtype()), ") cannot compute gradient"};
550     }
551     BackpropId actual_backprop_id = internal::GetArrayBackpropId(array, backprop_id);
552     internal::ArrayBody::RequireGrad(internal::GetArrayBody(array), actual_backprop_id);
553     return array;
554 }
555 
556 template const Array& Array::RequireGradImpl<const Array>(const Array& array, const absl::optional<BackpropId>& backprop_id);
557 template Array& Array::RequireGradImpl<Array>(Array& array, const absl::optional<BackpropId>& backprop_id);
558 
ToString() const559 std::string Array::ToString() const { return ArrayRepr(*this); }
560 
operator +(Scalar lhs,const Array & rhs)561 Array operator+(Scalar lhs, const Array& rhs) { return Add(lhs, rhs); }
operator -(Scalar lhs,const Array & rhs)562 Array operator-(Scalar lhs, const Array& rhs) { return Subtract(lhs, rhs); }
operator *(Scalar lhs,const Array & rhs)563 Array operator*(Scalar lhs, const Array& rhs) { return Multiply(lhs, rhs); }
operator /(Scalar lhs,const Array & rhs)564 Array operator/(Scalar lhs, const Array& rhs) { return Divide(lhs, rhs); }
operator %(Scalar lhs,const Array & rhs)565 Array operator%(Scalar lhs, const Array& rhs) { return Mod(lhs, rhs); }
566 
operator <<(Scalar lhs,const Array & rhs)567 Array operator<<(Scalar lhs, const Array& rhs) { return LeftShift(lhs, rhs); }
operator >>(Scalar lhs,const Array & rhs)568 Array operator>>(Scalar lhs, const Array& rhs) { return RightShift(lhs, rhs); }
569 
570 namespace {
571 
572 using internal::ArrayNode;
573 using internal::OpNode;
574 
575 class PrintComputationalGraphImpl {
576 private:
577     using VisitedArrayNodeSet = std::unordered_set<const ArrayNode*>;
578 
579     struct State {
580         VisitedArrayNodeSet visited_array_nodes;
581         int indent;
582     };
583 
584     // TODO(niboshi): Make the options configurable from outside
585     struct Options {
586         bool print_metadata{true};
587     };
588 
589 public:
PrintComputationalGraphImpl(std::ostream & os)590     explicit PrintComputationalGraphImpl(std::ostream& os) : os_{os} {}
591 
Run(const ArrayNode & array_node,int indent)592     void Run(const ArrayNode& array_node, int indent) {
593         State state{{}, indent};
594         RunImpl(state, array_node);
595     }
596 
GetArrayNodeName(const ArrayNode & array_node)597     std::string GetArrayNodeName(const ArrayNode& array_node) {
598         static constexpr char kChars[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
599         static constexpr size_t kNumChars = sizeof(kChars) / sizeof(kChars[0]) - 1;
600         static const auto kLen = static_cast<size_t>(std::ceil(sizeof(size_t) * 8U / std::log2(kNumChars)));
601         auto it = array_name_map_.find(&array_node);
602         if (it != array_name_map_.end()) {
603             return it->second;
604         }
605         size_t hash = std::hash<const ArrayNode*>{}(&array_node);
606         std::string s(kLen, '0');
607         // Fill the string from left to right, because hash may be just the raw address and MSBs may be indistinguishable.
608         for (auto it_s = s.begin(); hash > 0 && it_s != s.end(); ++it_s) {
609             *it_s = gsl::at(kChars, hash % kNumChars);
610             hash /= kNumChars;
611         }
612         return s;
613     }
614 
Indent(int indent)615     std::string Indent(int indent) {
616         static constexpr char kIndentChar = ' ';
617         return std::string(static_cast<size_t>(indent * 2), kIndentChar);
618     }
619 
RunImpl(State & state,const ArrayNode & array_node)620     void RunImpl(State& state, const ArrayNode& array_node) {
621         std::string name = GetArrayNodeName(array_node);
622 
623         int indent = state.indent;
624         VisitedArrayNodeSet& visited_array_nodes = state.visited_array_nodes;
625         os_ << Indent(indent) << "ArrayNode<" << name << " " << &array_node << " rank=" << array_node.rank()
626             << " shape=" << array_node.shape() << " dtype=" << GetDtypeName(array_node.dtype()) << ">" << std::endl;
627 
628         if (visited_array_nodes.end() == visited_array_nodes.find(&array_node)) {
629             visited_array_nodes.insert(&array_node);
630 
631             if (options_.print_metadata) {
632                 std::shared_ptr<internal::ArrayBody> body = array_node.weak_body().lock();
633                 if (body == nullptr) {
634                     os_ << Indent(indent + 2) << "body=(gone)" << std::endl;
635                 } else {
636                     os_ << Indent(indent + 2) << "body=" << body.get() << std::endl;
637                     const absl::optional<Array>* grad = body->GetGrad(array_node.backprop_id());
638                     CHAINERX_ASSERT(grad != nullptr);
639                     if (grad->has_value()) {
640                         os_ << Indent(indent + 2) << "grad=<shape=" << (*grad)->shape() << " dtype=" << GetDtypeName((*grad)->dtype())
641                             << ">" << std::endl;
642                     }
643                 }
644             }
645 
646             std::shared_ptr<const OpNode> op = array_node.creator_op_node();
647             if (op) {
648                 os_ << Indent(indent + 1) << "Op<" << op->name() << " " << op.get() << " rank=" << op->rank() << ">" << std::endl;
649                 for (const std::shared_ptr<ArrayNode>& input_array_node : op->input_array_nodes()) {
650                     state.indent += 2;
651                     if (input_array_node != nullptr) {
652                         RunImpl(state, *input_array_node);
653                     } else {
654                         os_ << Indent(state.indent) << "(null)" << std::endl;
655                     }
656                     state.indent -= 2;
657                 }
658             }
659         }
660     }
661 
SetArrayName(const ArrayNode & array_node,std::string name)662     void SetArrayName(const ArrayNode& array_node, std::string name) { array_name_map_[&array_node] = std::move(name); }
663 
664 private:
665     std::ostream& os_;
666     Options options_{};
667     std::unordered_map<const ArrayNode*, std::string> array_name_map_;
668 };
669 
670 }  // namespace
671 
DebugDumpComputationalGraph(std::ostream & os,const Array & array,const absl::optional<BackpropId> & backprop_id,int indent,const std::vector<std::pair<ConstArrayRef,std::string>> & array_name_map)672 void DebugDumpComputationalGraph(
673         std::ostream& os,
674         const Array& array,
675         const absl::optional<BackpropId>& backprop_id,
676         int indent,
677         const std::vector<std::pair<ConstArrayRef, std::string>>& array_name_map) {
678     PrintComputationalGraphImpl impl{os};
679     BackpropId actual_backprop_id = internal::GetArrayBackpropId(array, backprop_id);
680     for (const auto& pair : array_name_map) {
681         for (const std::shared_ptr<ArrayNode>& array_node : internal::GetArrayBody(pair.first.get())->nodes()) {
682             if (array_node->backprop_id() == actual_backprop_id) {
683                 impl.SetArrayName(*array_node, pair.second);
684             }
685         }
686     }
687     impl.Run(*internal::GetArrayBody(array)->GetArrayNode(actual_backprop_id), indent);
688 }
689 
690 }  // namespace chainerx
691