1 #include "chainerx/array_body.h"
2
3 #include <algorithm>
4 #include <cstdint>
5 #include <memory>
6 #include <utility>
7
8 #include "chainerx/array.h"
9 #include "chainerx/array_body_leak_detection.h"
10 #include "chainerx/array_node.h"
11 #include "chainerx/backward.h"
12 #include "chainerx/dtype.h"
13 #include "chainerx/error.h"
14 #include "chainerx/graph.h"
15 #include "chainerx/macro.h"
16
17 namespace chainerx {
18 namespace internal {
19
CreateArrayBody(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)20 std::shared_ptr<ArrayBody> CreateArrayBody(
21 const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset) {
22 // Trick to use make_shared with private ctor
23 struct ArrayBodyWithPublicCtor : ArrayBody {
24 ArrayBodyWithPublicCtor(
25 const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset)
26 : ArrayBody{shape, strides, dtype, device, std::move(data), offset} {}
27 };
28
29 std::shared_ptr<ArrayBody> array_body =
30 std::make_shared<ArrayBodyWithPublicCtor>(shape, strides, dtype, device, std::move(data), offset);
31
32 if (internal::ArrayBodyLeakTracker* tracker = internal::ArrayBodyLeakDetectionScope::GetGlobalTracker()) {
33 // TODO(niboshi): Make thread-safe
34 (*tracker)(array_body);
35 }
36
37 return array_body;
38 }
39
CreateArrayBody(ArrayBody::Params params)40 std::shared_ptr<ArrayBody> CreateArrayBody(ArrayBody::Params params) {
41 return CreateArrayBody(params.shape, params.strides, params.dtype, params.device, std::move(params.data), params.offset);
42 }
43
44 const std::shared_ptr<ArrayNode> ArrayBody::kNullArrayNode{nullptr};
45
ArrayBody(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)46 ArrayBody::ArrayBody(
47 const Shape& shape, // NOLINT(modernize-pass-by-value)
48 const Strides& strides, // NOLINT(modernize-pass-by-value)
49 Dtype dtype,
50 Device& device,
51 std::shared_ptr<void> data,
52 int64_t offset)
53 : shape_{shape}, strides_{strides}, dtype_{dtype}, device_{device}, data_{std::move(data)}, offset_{offset} {}
54
ArrayBody(Params params)55 ArrayBody::ArrayBody(Params params)
56 : ArrayBody{params.shape, params.strides, params.dtype, params.device, std::move(params.data), params.offset} {}
57
AddNode(const std::shared_ptr<ArrayBody> & body,std::shared_ptr<ArrayNode> array_node)58 const std::shared_ptr<ArrayNode>& ArrayBody::AddNode(const std::shared_ptr<ArrayBody>& body, std::shared_ptr<ArrayNode> array_node) {
59 body->AssertConsistency();
60
61 // The body must be either unset (the array node is being created normally) or dead (the body is being replaced with a fabricated one,
62 // as a retained output of backward)
63 CHAINERX_ASSERT(array_node->weak_body().expired());
64
65 auto it = std::find_if(body->nodes_.begin(), body->nodes_.end(), [&array_node](const std::shared_ptr<ArrayNode>& existing_node) {
66 return existing_node->backprop_id() == array_node->backprop_id();
67 });
68 if (it != body->nodes_.end()) {
69 return *it; // Do nothing and return the existing ArrayNode if found for this graph.
70 }
71
72 // Connect the new backprop ID and the existing backprop IDs in this array body.
73 for (const std::shared_ptr<ArrayNode>& existing_array_node : body->nodes_) {
74 existing_array_node->device().context().ConnectBackpropIds(existing_array_node->backprop_id(), array_node->backprop_id());
75 }
76
77 array_node->weak_body_ = body;
78
79 body->nodes_.emplace_back(std::move(array_node));
80 body->grads_.emplace_back(std::make_unique<absl::optional<Array>>(absl::nullopt));
81
82 body->AssertConsistency();
83 return body->nodes_.back();
84 }
85
CreateArrayNode(const std::shared_ptr<ArrayBody> & body,const BackpropId & backprop_id)86 const std::shared_ptr<ArrayNode>& ArrayBody::CreateArrayNode(const std::shared_ptr<ArrayBody>& body, const BackpropId& backprop_id) {
87 CHAINERX_ASSERT(GetKind(body->dtype()) == DtypeKind::kFloat);
88 return AddNode(body, std::make_shared<ArrayNode>(body->shape_, body->dtype_, body->device_, backprop_id));
89 }
90
AssertConsistency() const91 void ArrayBody::AssertConsistency() const {
92 if (CHAINERX_DEBUG) {
93 // Array with integral dtypes can neither have array nodes nor gradients.
94 if (GetKind(dtype()) != DtypeKind::kFloat) {
95 CHAINERX_ASSERT(nodes_.empty());
96 CHAINERX_ASSERT(grads_.empty());
97 }
98
99 CHAINERX_ASSERT(nodes_.size() == grads_.size());
100 for (size_t i = 0; i < nodes_.size(); ++i) {
101 const std::shared_ptr<ArrayNode>& array_node = nodes_[i];
102 const absl::optional<Array>& grad = *grads_[i];
103 CHAINERX_ASSERT(array_node != nullptr);
104 CHAINERX_ASSERT(this == array_node->weak_body().lock().get());
105
106 if (grad.has_value()) {
107 CHAINERX_ASSERT(internal::GetArrayBody(*grad) != nullptr);
108 CHAINERX_ASSERT(grad->shape() == array_node->shape());
109 CHAINERX_ASSERT(grad->dtype() == array_node->dtype());
110 CHAINERX_ASSERT(&grad->device() == &array_node->device());
111 }
112 }
113 }
114 }
115
GetNodeIndex(const BackpropId & backprop_id) const116 absl::optional<size_t> ArrayBody::GetNodeIndex(const BackpropId& backprop_id) const {
117 for (size_t i = 0; i < nodes_.size(); ++i) {
118 if (nodes_[i]->backprop_id() == backprop_id) {
119 return i;
120 }
121 }
122 return absl::nullopt;
123 }
124
SetGrad(Array grad,const BackpropId & backprop_id)125 void ArrayBody::SetGrad(Array grad, const BackpropId& backprop_id) {
126 absl::optional<Array>* target_grad = GetGrad(backprop_id);
127 CHAINERX_ASSERT(target_grad != nullptr);
128 internal::SetGrad(*target_grad, std::move(grad), shape_, dtype_, device_);
129 }
130
ClearGrad(const BackpropId & backprop_id)131 void ArrayBody::ClearGrad(const BackpropId& backprop_id) {
132 absl::optional<Array>* grad = GetGrad(backprop_id);
133 CHAINERX_ASSERT(grad != nullptr);
134 grad->reset();
135 }
136
137 template <typename ThisPtr, typename ReturnType>
GetGradImpl(ThisPtr this_ptr,const BackpropId & backprop_id)138 ReturnType ArrayBody::GetGradImpl(ThisPtr this_ptr, const BackpropId& backprop_id) {
139 absl::optional<size_t> i = this_ptr->GetNodeIndex(backprop_id);
140 if (!i.has_value()) {
141 return nullptr;
142 }
143 CHAINERX_ASSERT(*i < this_ptr->grads_.size());
144 return this_ptr->grads_[*i].get();
145 }
146
147 template absl::optional<Array>* ArrayBody::GetGradImpl<ArrayBody*, absl::optional<Array>*>(ArrayBody*, const BackpropId&);
148 template const absl::optional<Array>* ArrayBody::GetGradImpl<const ArrayBody*, const absl::optional<Array>*>(
149 const ArrayBody*, const BackpropId&);
150
151 } // namespace internal
152 } // namespace chainerx
153