1 #include "chainerx/array.h"
2
3 #include <algorithm>
4 #include <cmath>
5 #include <cstdint>
6 #include <cstring>
7 #include <memory>
8 #include <numeric>
9 #include <ostream>
10 #include <string>
11 #include <tuple>
12 #include <unordered_map>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16
17 #include <absl/types/optional.h>
18 #include <absl/types/span.h>
19
20 #include "chainerx/array_body.h"
21 #include "chainerx/array_node.h"
22 #include "chainerx/array_repr.h"
23 #include "chainerx/axes.h"
24 #include "chainerx/backend.h"
25 #include "chainerx/backprop_mode.h"
26 #include "chainerx/backward.h"
27 #include "chainerx/backward_builder.h"
28 #include "chainerx/backward_context.h"
29 #include "chainerx/context.h"
30 #include "chainerx/device.h"
31 #include "chainerx/dtype.h"
32 #include "chainerx/error.h"
33 #include "chainerx/graph.h"
34 #include "chainerx/kernels/creation.h"
35 #include "chainerx/kernels/misc.h"
36 #include "chainerx/macro.h"
37 #include "chainerx/native/native_backend.h"
38 #include "chainerx/op_node.h"
39 #include "chainerx/routines/arithmetic.h"
40 #include "chainerx/routines/binary.h"
41 #include "chainerx/routines/creation.h"
42 #include "chainerx/routines/indexing.h"
43 #include "chainerx/routines/linalg.h"
44 #include "chainerx/routines/logic.h"
45 #include "chainerx/routines/manipulation.h"
46 #include "chainerx/routines/reduction.h"
47 #include "chainerx/routines/routines_util.h"
48 #include "chainerx/routines/sorting.h"
49 #include "chainerx/routines/statistics.h"
50 #include "chainerx/scalar.h"
51
52 namespace chainerx {
53 namespace internal {
54
GetArrayBackpropId(const Array & array,const absl::optional<BackpropId> & backprop_id)55 BackpropId GetArrayBackpropId(const Array& array, const absl::optional<BackpropId>& backprop_id) {
56 return backprop_id.has_value() ? *backprop_id : array.device().context().default_backprop_id();
57 }
58
MakeArray(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)59 Array MakeArray(const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset) {
60 return Array{shape, strides, dtype, device, std::move(data), offset};
61 }
62
MoveArrayBodies(std::vector<Array> && arrays)63 std::vector<std::shared_ptr<ArrayBody>> MoveArrayBodies(std::vector<Array>&& arrays) {
64 std::vector<std::shared_ptr<ArrayBody>> array_body_ptrs;
65 array_body_ptrs.reserve(arrays.size());
66 for (Array& array : arrays) {
67 array_body_ptrs.emplace_back(MoveArrayBody(std::move(array)));
68 }
69 return array_body_ptrs;
70 }
71
MoveArrayBodies(std::vector<absl::optional<Array>> && arrays)72 std::vector<std::shared_ptr<ArrayBody>> MoveArrayBodies(std::vector<absl::optional<Array>>&& arrays) {
73 std::vector<std::shared_ptr<ArrayBody>> array_body_ptrs;
74 array_body_ptrs.reserve(arrays.size());
75 for (absl::optional<Array>& array : arrays) {
76 if (array.has_value()) {
77 array_body_ptrs.emplace_back(MoveArrayBody(std::move(*array)));
78 } else {
79 array_body_ptrs.emplace_back(nullptr);
80 }
81 }
82 return array_body_ptrs;
83 }
84
85 } // namespace internal
86
Array(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)87 Array::Array(const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset)
88 : body_{internal::CreateArrayBody(shape, strides, dtype, device, std::move(data), offset)} {}
89
operator -() const90 Array Array::operator-() const { return Negative(*this); }
91
operator ==(const Array & rhs) const92 Array Array::operator==(const Array& rhs) const { return Equal(*this, rhs); }
93
operator !=(const Array & rhs) const94 Array Array::operator!=(const Array& rhs) const { return NotEqual(*this, rhs); }
95
operator >(const Array & rhs) const96 Array Array::operator>(const Array& rhs) const { return Greater(*this, rhs); }
97
operator >=(const Array & rhs) const98 Array Array::operator>=(const Array& rhs) const { return GreaterEqual(*this, rhs); }
99
operator <(const Array & rhs) const100 Array Array::operator<(const Array& rhs) const { return Less(*this, rhs); }
101
operator <=(const Array & rhs) const102 Array Array::operator<=(const Array& rhs) const { return LessEqual(*this, rhs); }
103
operator +=(const Array & rhs)104 Array& Array::operator+=(const Array& rhs) {
105 internal::IAdd(*this, rhs);
106 return *this;
107 }
108
operator +=(Scalar rhs)109 Array& Array::operator+=(Scalar rhs) {
110 internal::IAdd(*this, rhs);
111 return *this;
112 }
113
operator -=(const Array & rhs)114 Array& Array::operator-=(const Array& rhs) {
115 internal::ISubtract(*this, rhs);
116 return *this;
117 }
118
operator -=(Scalar rhs)119 Array& Array::operator-=(Scalar rhs) {
120 internal::ISubtract(*this, rhs);
121 return *this;
122 }
123
operator *=(const Array & rhs)124 Array& Array::operator*=(const Array& rhs) {
125 internal::IMultiply(*this, rhs);
126 return *this;
127 }
128
operator *=(Scalar rhs)129 Array& Array::operator*=(Scalar rhs) {
130 internal::IMultiply(*this, rhs);
131 return *this;
132 }
133
operator /=(const Array & rhs)134 Array& Array::operator/=(const Array& rhs) {
135 internal::IDivide(*this, rhs);
136 return *this;
137 }
138
operator /=(Scalar rhs)139 Array& Array::operator/=(Scalar rhs) {
140 internal::IDivide(*this, rhs);
141 return *this;
142 }
143
operator %=(const Array & rhs)144 Array& Array::operator%=(const Array& rhs) {
145 internal::IMod(*this, rhs);
146 return *this;
147 }
148
operator %=(Scalar rhs)149 Array& Array::operator%=(Scalar rhs) {
150 internal::IMod(*this, rhs);
151 return *this;
152 }
153
operator &=(const Array & rhs)154 Array& Array::operator&=(const Array& rhs) {
155 internal::IBitwiseAnd(*this, rhs);
156 return *this;
157 }
158
operator &=(Scalar rhs)159 Array& Array::operator&=(Scalar rhs) {
160 internal::IBitwiseAnd(*this, rhs);
161 return *this;
162 }
163
operator |=(const Array & rhs)164 Array& Array::operator|=(const Array& rhs) {
165 internal::IBitwiseOr(*this, rhs);
166 return *this;
167 }
168
operator |=(Scalar rhs)169 Array& Array::operator|=(Scalar rhs) {
170 internal::IBitwiseOr(*this, rhs);
171 return *this;
172 }
173
operator ^=(const Array & rhs)174 Array& Array::operator^=(const Array& rhs) {
175 internal::IBitwiseXor(*this, rhs);
176 return *this;
177 }
178
operator ^=(Scalar rhs)179 Array& Array::operator^=(Scalar rhs) {
180 internal::IBitwiseXor(*this, rhs);
181 return *this;
182 }
183
operator <<=(const Array & rhs)184 Array& Array::operator<<=(const Array& rhs) {
185 internal::ILeftShift(*this, rhs);
186 return *this;
187 }
188
operator <<=(Scalar rhs)189 Array& Array::operator<<=(Scalar rhs) {
190 internal::ILeftShift(*this, rhs);
191 return *this;
192 }
193
operator >>=(const Array & rhs)194 Array& Array::operator>>=(const Array& rhs) {
195 internal::IRightShift(*this, rhs);
196 return *this;
197 }
198
operator >>=(Scalar rhs)199 Array& Array::operator>>=(Scalar rhs) {
200 internal::IRightShift(*this, rhs);
201 return *this;
202 }
203
operator +=(const Array & rhs) const204 const Array& Array::operator+=(const Array& rhs) const {
205 internal::IAdd(*this, rhs);
206 return *this;
207 }
208
operator +=(Scalar rhs) const209 const Array& Array::operator+=(Scalar rhs) const {
210 internal::IAdd(*this, rhs);
211 return *this;
212 }
213
operator -=(const Array & rhs) const214 const Array& Array::operator-=(const Array& rhs) const {
215 internal::ISubtract(*this, rhs);
216 return *this;
217 }
218
operator -=(Scalar rhs) const219 const Array& Array::operator-=(Scalar rhs) const {
220 internal::ISubtract(*this, rhs);
221 return *this;
222 }
223
operator *=(const Array & rhs) const224 const Array& Array::operator*=(const Array& rhs) const {
225 internal::IMultiply(*this, rhs);
226 return *this;
227 }
228
operator *=(Scalar rhs) const229 const Array& Array::operator*=(Scalar rhs) const {
230 internal::IMultiply(*this, rhs);
231 return *this;
232 }
233
operator /=(const Array & rhs) const234 const Array& Array::operator/=(const Array& rhs) const {
235 internal::IDivide(*this, rhs);
236 return *this;
237 }
238
operator /=(Scalar rhs) const239 const Array& Array::operator/=(Scalar rhs) const {
240 internal::IDivide(*this, rhs);
241 return *this;
242 }
243
operator %=(const Array & rhs) const244 const Array& Array::operator%=(const Array& rhs) const {
245 internal::IMod(*this, rhs);
246 return *this;
247 }
248
operator %=(Scalar rhs) const249 const Array& Array::operator%=(Scalar rhs) const {
250 internal::IMod(*this, rhs);
251 return *this;
252 }
253
operator &=(const Array & rhs) const254 const Array& Array::operator&=(const Array& rhs) const {
255 internal::IBitwiseAnd(*this, rhs);
256 return *this;
257 }
258
operator &=(Scalar rhs) const259 const Array& Array::operator&=(Scalar rhs) const {
260 internal::IBitwiseAnd(*this, rhs);
261 return *this;
262 }
263
operator |=(const Array & rhs) const264 const Array& Array::operator|=(const Array& rhs) const {
265 internal::IBitwiseOr(*this, rhs);
266 return *this;
267 }
268
operator |=(Scalar rhs) const269 const Array& Array::operator|=(Scalar rhs) const {
270 internal::IBitwiseOr(*this, rhs);
271 return *this;
272 }
273
operator ^=(const Array & rhs) const274 const Array& Array::operator^=(const Array& rhs) const {
275 internal::IBitwiseXor(*this, rhs);
276 return *this;
277 }
278
operator ^=(Scalar rhs) const279 const Array& Array::operator^=(Scalar rhs) const {
280 internal::IBitwiseXor(*this, rhs);
281 return *this;
282 }
283
operator <<=(const Array & rhs) const284 const Array& Array::operator<<=(const Array& rhs) const {
285 internal::ILeftShift(*this, rhs);
286 return *this;
287 }
288
operator <<=(Scalar rhs) const289 const Array& Array::operator<<=(Scalar rhs) const {
290 internal::ILeftShift(*this, rhs);
291 return *this;
292 }
293
operator >>=(const Array & rhs) const294 const Array& Array::operator>>=(const Array& rhs) const {
295 internal::IRightShift(*this, rhs);
296 return *this;
297 }
298
operator >>=(Scalar rhs) const299 const Array& Array::operator>>=(Scalar rhs) const {
300 internal::IRightShift(*this, rhs);
301 return *this;
302 }
303
operator +(const Array & rhs) const304 Array Array::operator+(const Array& rhs) const { return chainerx::Add(*this, rhs); }
305
operator +(Scalar rhs) const306 Array Array::operator+(Scalar rhs) const { return chainerx::Add(*this, rhs); }
307
operator -(const Array & rhs) const308 Array Array::operator-(const Array& rhs) const { return chainerx::Subtract(*this, rhs); }
309
operator -(Scalar rhs) const310 Array Array::operator-(Scalar rhs) const { return chainerx::Subtract(*this, rhs); }
311
operator *(const Array & rhs) const312 Array Array::operator*(const Array& rhs) const { return Multiply(*this, rhs); }
313
operator *(Scalar rhs) const314 Array Array::operator*(Scalar rhs) const { return Multiply(*this, rhs); }
315
operator /(const Array & rhs) const316 Array Array::operator/(const Array& rhs) const { return chainerx::Divide(*this, rhs); }
317
operator /(Scalar rhs) const318 Array Array::operator/(Scalar rhs) const { return chainerx::Divide(*this, rhs); }
319
operator %(const Array & rhs) const320 Array Array::operator%(const Array& rhs) const { return chainerx::Mod(*this, rhs); }
321
operator %(Scalar rhs) const322 Array Array::operator%(Scalar rhs) const { return chainerx::Mod(*this, rhs); }
323
operator &(const Array & rhs) const324 Array Array::operator&(const Array& rhs) const { return chainerx::BitwiseAnd(*this, rhs); }
325
operator &(Scalar rhs) const326 Array Array::operator&(Scalar rhs) const { return chainerx::BitwiseAnd(*this, rhs); }
327
operator |(const Array & rhs) const328 Array Array::operator|(const Array& rhs) const { return chainerx::BitwiseOr(*this, rhs); }
329
operator |(Scalar rhs) const330 Array Array::operator|(Scalar rhs) const { return chainerx::BitwiseOr(*this, rhs); }
331
operator ^(const Array & rhs) const332 Array Array::operator^(const Array& rhs) const { return chainerx::BitwiseXor(*this, rhs); }
333
operator ^(Scalar rhs) const334 Array Array::operator^(Scalar rhs) const { return chainerx::BitwiseXor(*this, rhs); }
335
operator <<(const Array & rhs) const336 Array Array::operator<<(const Array& rhs) const { return chainerx::LeftShift(*this, rhs); }
337
operator <<(Scalar rhs) const338 Array Array::operator<<(Scalar rhs) const { return chainerx::LeftShift(*this, rhs); }
339
operator >>(const Array & rhs) const340 Array Array::operator>>(const Array& rhs) const { return chainerx::RightShift(*this, rhs); }
341
operator >>(Scalar rhs) const342 Array Array::operator>>(Scalar rhs) const { return chainerx::RightShift(*this, rhs); }
343
At(const std::vector<ArrayIndex> & indices) const344 Array Array::At(const std::vector<ArrayIndex>& indices) const { return internal::At(*this, indices); }
345
Transpose(const OptionalAxes & axes) const346 Array Array::Transpose(const OptionalAxes& axes) const { return chainerx::Transpose(*this, axes); }
347
Ravel() const348 Array Array::Ravel() const { return chainerx::Ravel(*this); }
349
Reshape(const Shape & newshape) const350 Array Array::Reshape(const Shape& newshape) const { return chainerx::Reshape(*this, newshape); }
351
Squeeze(const OptionalAxes & axis) const352 Array Array::Squeeze(const OptionalAxes& axis) const { return chainerx::Squeeze(*this, axis); }
353
Swapaxes(int8_t axis1,int8_t axis2) const354 Array Array::Swapaxes(int8_t axis1, int8_t axis2) const { return chainerx::Swapaxes(*this, axis1, axis2); }
355
BroadcastTo(const Shape & shape) const356 Array Array::BroadcastTo(const Shape& shape) const { return chainerx::BroadcastTo(*this, shape); }
357
ArgMax(const OptionalAxes & axis) const358 Array Array::ArgMax(const OptionalAxes& axis) const { return chainerx::ArgMax(*this, axis); }
359
ArgMin(const OptionalAxes & axis) const360 Array Array::ArgMin(const OptionalAxes& axis) const { return chainerx::ArgMin(*this, axis); }
361
Sum(const OptionalAxes & axis,bool keepdims) const362 Array Array::Sum(const OptionalAxes& axis, bool keepdims) const { return chainerx::Sum(*this, axis, keepdims); }
363
Max(const OptionalAxes & axis,bool keepdims) const364 Array Array::Max(const OptionalAxes& axis, bool keepdims) const { return chainerx::AMax(*this, axis, keepdims); }
365
Min(const OptionalAxes & axis,bool keepdims) const366 Array Array::Min(const OptionalAxes& axis, bool keepdims) const { return chainerx::AMin(*this, axis, keepdims); }
367
Mean(const OptionalAxes & axis,bool keepdims) const368 Array Array::Mean(const OptionalAxes& axis, bool keepdims) const { return chainerx::Mean(*this, axis, keepdims); }
369
Var(const OptionalAxes & axis,bool keepdims) const370 Array Array::Var(const OptionalAxes& axis, bool keepdims) const { return chainerx::Var(*this, axis, keepdims); }
371
All(const OptionalAxes & axis,bool keepdims) const372 Array Array::All(const OptionalAxes& axis, bool keepdims) const { return chainerx::All(*this, axis, keepdims); }
373
Any(const OptionalAxes & axis,bool keepdims) const374 Array Array::Any(const OptionalAxes& axis, bool keepdims) const { return chainerx::Any(*this, axis, keepdims); }
375
Dot(const Array & b) const376 Array Array::Dot(const Array& b) const { return chainerx::Dot(*this, b); }
377
Take(const Array & indices,int8_t axis,IndexBoundsMode mode) const378 Array Array::Take(const Array& indices, int8_t axis, IndexBoundsMode mode) const { return chainerx::Take(*this, indices, axis, mode); }
379
Copy() const380 Array Array::Copy() const { return chainerx::Copy(*this); }
381
Flatten() const382 Array Array::Flatten() const {
383 Array out = (*this).Copy().Reshape({(*this).GetTotalSize()});
384 return out;
385 }
386
MakeView() const387 Array Array::MakeView() const {
388 Array out{shape(), strides(), dtype(), device(), data(), offset()};
389
390 BackwardBuilder bb{"view", *this, out};
391 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
392 bt.Define([](BackwardContext& bctx) { bctx.input_grad() = *bctx.output_grad(); });
393 }
394 bb.Finalize();
395
396 return out;
397 }
398
ToDevice(Device & dst_device) const399 Array Array::ToDevice(Device& dst_device) const {
400 Device& src_device = body_->device();
401 Array out;
402
403 // TODO(sonots): Avoid copying data between native devices, e.g., from native:0 to native:1 for performance.
404 if (&src_device == &dst_device) {
405 // Return an alias.
406 out = AsGradStopped(CopyKind::kView);
407 } else {
408 // Make a contiguous copy to transfer it to the destination device.
409 Array src_contig = AsContiguous(AsGradStopped(CopyKind::kView));
410
411 std::shared_ptr<void> dst_data;
412 if (src_device.backend().SupportsTransfer(src_device, dst_device)) {
413 // Use src backend for transfer.
414 dst_data = src_device.TransferDataTo(dst_device, src_contig.data(), src_contig.offset(), src_contig.GetNBytes());
415 } else if (dst_device.backend().SupportsTransfer(src_device, dst_device)) {
416 // Use dst backend for transfer.
417 dst_data = dst_device.TransferDataFrom(src_device, src_contig.data(), src_contig.offset(), src_contig.GetNBytes());
418 } else {
419 // Neither backends support transfer.
420 throw ChainerxError{"Transfer between devices is not supported: src='", src_device.name(), "' dst='", dst_device.name(), "'."};
421 }
422 out = Array{src_contig.shape(), src_contig.strides(), src_contig.dtype(), dst_device, std::move(dst_data)};
423 }
424
425 CHAINERX_ASSERT(internal::GetArrayBody(out) != nullptr);
426
427 // Backward operation is implemented as backward-transfer.
428 BackwardBuilder bb{"transfer", *this, out};
429 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
430 bt.Define([&src_device](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->ToDevice(src_device); });
431 }
432 bb.Finalize();
433
434 // TODO(niboshi): This assertion must succeed but currently it does not because AsContiguousArray reshapes {} to {1}.
435 // CHAINERX_ASSERT(out.shape() == shape());
436 CHAINERX_ASSERT(out.dtype() == dtype());
437 return out;
438 }
439
ToNative() const440 Array Array::ToNative() const {
441 Backend& backend = device().backend();
442 Device& native_device = backend.IsNative() ? device() : backend.context().GetNativeBackend().GetDevice(0);
443 return ToDevice(native_device);
444 }
445
446 namespace {
447
CopyOrMakeView(const Array & array,CopyKind kind)448 Array CopyOrMakeView(const Array& array, CopyKind kind) {
449 switch (kind) {
450 case CopyKind::kCopy:
451 return array.Copy();
452 case CopyKind::kView:
453 return array.MakeView();
454 default:
455 CHAINERX_NEVER_REACH();
456 }
457 }
458
459 } // namespace
460
AsGradStopped(CopyKind kind) const461 Array Array::AsGradStopped(CopyKind kind) const {
462 NoBackpropModeScope scope{device().context()};
463 return CopyOrMakeView(*this, kind);
464 }
465
AsGradStopped(absl::Span<const BackpropId> backprop_ids,CopyKind kind) const466 Array Array::AsGradStopped(absl::Span<const BackpropId> backprop_ids, CopyKind kind) const {
467 NoBackpropModeScope scope{std::vector<BackpropId>{backprop_ids.begin(), backprop_ids.end()}};
468 return CopyOrMakeView(*this, kind);
469 }
470
AsType(Dtype dtype,bool copy) const471 Array Array::AsType(Dtype dtype, bool copy) const {
472 Dtype src_dtype = this->dtype();
473 if (!copy && dtype == src_dtype) {
474 return *this;
475 }
476
477 Array out = Empty(shape(), dtype, device());
478 // Note: In CopyKernel, Input Array Elements are casted to the type of Output Array.
479 device().backend().CallKernel<CopyKernel>(*this, out);
480
481 if (GetKind(dtype) == DtypeKind::kFloat) {
482 BackwardBuilder bb{"astype", *this, out};
483 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
484 bt.Define([src_dtype](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->AsType(src_dtype); });
485 }
486 bb.Finalize();
487 }
488
489 CHAINERX_ASSERT(out.IsContiguous());
490 return out;
491 }
492
Fill(Scalar value) const493 void Array::Fill(Scalar value) const {
494 internal::CheckNoUnsafeInplace(*this, {});
495 device().backend().CallKernel<FillKernel>(*this, value);
496 }
497
GetGrad(const absl::optional<BackpropId> & backprop_id) const498 const absl::optional<Array>& Array::GetGrad(const absl::optional<BackpropId>& backprop_id) const {
499 BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
500 if (!IsGradRequired(actual_backprop_id)) {
501 throw ChainerxError{"Array is not flagged as requiring gradient for backprop id: '", actual_backprop_id, "'."};
502 }
503 const absl::optional<Array>* grad = body_->GetGrad(actual_backprop_id);
504 CHAINERX_ASSERT(grad != nullptr);
505 return *grad;
506 }
507
SetGrad(Array grad,const absl::optional<BackpropId> & backprop_id) const508 void Array::SetGrad(Array grad, const absl::optional<BackpropId>& backprop_id) const {
509 BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
510 absl::optional<Array>* target_grad = body_->GetGrad(actual_backprop_id);
511 if (target_grad == nullptr) {
512 throw ChainerxError{"Array is constant with respect to the computation for backprop ID: '", actual_backprop_id, "'."};
513 }
514
515 // Setting the gradient flags the array to require gradient, so that it can return the gradient with GetGrad().
516 RequireGrad(actual_backprop_id);
517
518 internal::SetGrad(*target_grad, std::move(grad), shape(), dtype(), device());
519 }
520
ClearGrad(const absl::optional<BackpropId> & backprop_id) const521 void Array::ClearGrad(const absl::optional<BackpropId>& backprop_id) const {
522 BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
523 if (!body_->HasArrayNode(actual_backprop_id)) {
524 throw ChainerxError{"Array is constant with respect to the computation for backprop ID: '", actual_backprop_id, "'."};
525 }
526 body_->ClearGrad(actual_backprop_id);
527 }
528
IsBackpropRequired(const absl::optional<BackpropId> & backprop_id) const529 bool Array::IsBackpropRequired(const absl::optional<BackpropId>& backprop_id) const {
530 BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
531 return body_->HasArrayNode(actual_backprop_id) && chainerx::IsBackpropRequired(actual_backprop_id);
532 }
533
IsBackpropRequired(AnyGraph) const534 bool Array::IsBackpropRequired(AnyGraph /*any_graph*/) const {
535 const std::vector<std::shared_ptr<internal::ArrayNode>>& array_nodes = body_->nodes();
536 return std::any_of(array_nodes.begin(), array_nodes.end(), [](const std::shared_ptr<const internal::ArrayNode>& array_node) {
537 return chainerx::IsBackpropRequired(array_node->backprop_id());
538 });
539 }
540
IsGradRequired(const absl::optional<BackpropId> & backprop_id) const541 bool Array::IsGradRequired(const absl::optional<BackpropId>& backprop_id) const {
542 BackpropId actual_backprop_id = internal::GetArrayBackpropId(*this, backprop_id);
543 return body_->IsGradRequired(actual_backprop_id);
544 }
545
546 template <typename T>
RequireGradImpl(T & array,const absl::optional<BackpropId> & backprop_id)547 T& Array::RequireGradImpl(T& array, const absl::optional<BackpropId>& backprop_id) {
548 if (GetKind(array.dtype()) != DtypeKind::kFloat) {
549 throw DtypeError{"Array with integral dtype (", GetDtypeName(array.dtype()), ") cannot compute gradient"};
550 }
551 BackpropId actual_backprop_id = internal::GetArrayBackpropId(array, backprop_id);
552 internal::ArrayBody::RequireGrad(internal::GetArrayBody(array), actual_backprop_id);
553 return array;
554 }
555
556 template const Array& Array::RequireGradImpl<const Array>(const Array& array, const absl::optional<BackpropId>& backprop_id);
557 template Array& Array::RequireGradImpl<Array>(Array& array, const absl::optional<BackpropId>& backprop_id);
558
ToString() const559 std::string Array::ToString() const { return ArrayRepr(*this); }
560
operator +(Scalar lhs,const Array & rhs)561 Array operator+(Scalar lhs, const Array& rhs) { return Add(lhs, rhs); }
operator -(Scalar lhs,const Array & rhs)562 Array operator-(Scalar lhs, const Array& rhs) { return Subtract(lhs, rhs); }
operator *(Scalar lhs,const Array & rhs)563 Array operator*(Scalar lhs, const Array& rhs) { return Multiply(lhs, rhs); }
operator /(Scalar lhs,const Array & rhs)564 Array operator/(Scalar lhs, const Array& rhs) { return Divide(lhs, rhs); }
operator %(Scalar lhs,const Array & rhs)565 Array operator%(Scalar lhs, const Array& rhs) { return Mod(lhs, rhs); }
566
operator <<(Scalar lhs,const Array & rhs)567 Array operator<<(Scalar lhs, const Array& rhs) { return LeftShift(lhs, rhs); }
operator >>(Scalar lhs,const Array & rhs)568 Array operator>>(Scalar lhs, const Array& rhs) { return RightShift(lhs, rhs); }
569
570 namespace {
571
572 using internal::ArrayNode;
573 using internal::OpNode;
574
575 class PrintComputationalGraphImpl {
576 private:
577 using VisitedArrayNodeSet = std::unordered_set<const ArrayNode*>;
578
579 struct State {
580 VisitedArrayNodeSet visited_array_nodes;
581 int indent;
582 };
583
584 // TODO(niboshi): Make the options configurable from outside
585 struct Options {
586 bool print_metadata{true};
587 };
588
589 public:
PrintComputationalGraphImpl(std::ostream & os)590 explicit PrintComputationalGraphImpl(std::ostream& os) : os_{os} {}
591
Run(const ArrayNode & array_node,int indent)592 void Run(const ArrayNode& array_node, int indent) {
593 State state{{}, indent};
594 RunImpl(state, array_node);
595 }
596
GetArrayNodeName(const ArrayNode & array_node)597 std::string GetArrayNodeName(const ArrayNode& array_node) {
598 static constexpr char kChars[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
599 static constexpr size_t kNumChars = sizeof(kChars) / sizeof(kChars[0]) - 1;
600 static const auto kLen = static_cast<size_t>(std::ceil(sizeof(size_t) * 8U / std::log2(kNumChars)));
601 auto it = array_name_map_.find(&array_node);
602 if (it != array_name_map_.end()) {
603 return it->second;
604 }
605 size_t hash = std::hash<const ArrayNode*>{}(&array_node);
606 std::string s(kLen, '0');
607 // Fill the string from left to right, because hash may be just the raw address and MSBs may be indistinguishable.
608 for (auto it_s = s.begin(); hash > 0 && it_s != s.end(); ++it_s) {
609 *it_s = gsl::at(kChars, hash % kNumChars);
610 hash /= kNumChars;
611 }
612 return s;
613 }
614
Indent(int indent)615 std::string Indent(int indent) {
616 static constexpr char kIndentChar = ' ';
617 return std::string(static_cast<size_t>(indent * 2), kIndentChar);
618 }
619
RunImpl(State & state,const ArrayNode & array_node)620 void RunImpl(State& state, const ArrayNode& array_node) {
621 std::string name = GetArrayNodeName(array_node);
622
623 int indent = state.indent;
624 VisitedArrayNodeSet& visited_array_nodes = state.visited_array_nodes;
625 os_ << Indent(indent) << "ArrayNode<" << name << " " << &array_node << " rank=" << array_node.rank()
626 << " shape=" << array_node.shape() << " dtype=" << GetDtypeName(array_node.dtype()) << ">" << std::endl;
627
628 if (visited_array_nodes.end() == visited_array_nodes.find(&array_node)) {
629 visited_array_nodes.insert(&array_node);
630
631 if (options_.print_metadata) {
632 std::shared_ptr<internal::ArrayBody> body = array_node.weak_body().lock();
633 if (body == nullptr) {
634 os_ << Indent(indent + 2) << "body=(gone)" << std::endl;
635 } else {
636 os_ << Indent(indent + 2) << "body=" << body.get() << std::endl;
637 const absl::optional<Array>* grad = body->GetGrad(array_node.backprop_id());
638 CHAINERX_ASSERT(grad != nullptr);
639 if (grad->has_value()) {
640 os_ << Indent(indent + 2) << "grad=<shape=" << (*grad)->shape() << " dtype=" << GetDtypeName((*grad)->dtype())
641 << ">" << std::endl;
642 }
643 }
644 }
645
646 std::shared_ptr<const OpNode> op = array_node.creator_op_node();
647 if (op) {
648 os_ << Indent(indent + 1) << "Op<" << op->name() << " " << op.get() << " rank=" << op->rank() << ">" << std::endl;
649 for (const std::shared_ptr<ArrayNode>& input_array_node : op->input_array_nodes()) {
650 state.indent += 2;
651 if (input_array_node != nullptr) {
652 RunImpl(state, *input_array_node);
653 } else {
654 os_ << Indent(state.indent) << "(null)" << std::endl;
655 }
656 state.indent -= 2;
657 }
658 }
659 }
660 }
661
SetArrayName(const ArrayNode & array_node,std::string name)662 void SetArrayName(const ArrayNode& array_node, std::string name) { array_name_map_[&array_node] = std::move(name); }
663
664 private:
665 std::ostream& os_;
666 Options options_{};
667 std::unordered_map<const ArrayNode*, std::string> array_name_map_;
668 };
669
670 } // namespace
671
DebugDumpComputationalGraph(std::ostream & os,const Array & array,const absl::optional<BackpropId> & backprop_id,int indent,const std::vector<std::pair<ConstArrayRef,std::string>> & array_name_map)672 void DebugDumpComputationalGraph(
673 std::ostream& os,
674 const Array& array,
675 const absl::optional<BackpropId>& backprop_id,
676 int indent,
677 const std::vector<std::pair<ConstArrayRef, std::string>>& array_name_map) {
678 PrintComputationalGraphImpl impl{os};
679 BackpropId actual_backprop_id = internal::GetArrayBackpropId(array, backprop_id);
680 for (const auto& pair : array_name_map) {
681 for (const std::shared_ptr<ArrayNode>& array_node : internal::GetArrayBody(pair.first.get())->nodes()) {
682 if (array_node->backprop_id() == actual_backprop_id) {
683 impl.SetArrayName(*array_node, pair.second);
684 }
685 }
686 }
687 impl.Run(*internal::GetArrayBody(array)->GetArrayNode(actual_backprop_id), indent);
688 }
689
690 } // namespace chainerx
691