1 #include "chainerx/array_body.h"
2 
3 #include <algorithm>
4 #include <cstdint>
5 #include <memory>
6 #include <utility>
7 
8 #include "chainerx/array.h"
9 #include "chainerx/array_body_leak_detection.h"
10 #include "chainerx/array_node.h"
11 #include "chainerx/backward.h"
12 #include "chainerx/dtype.h"
13 #include "chainerx/error.h"
14 #include "chainerx/graph.h"
15 #include "chainerx/macro.h"
16 
17 namespace chainerx {
18 namespace internal {
19 
CreateArrayBody(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)20 std::shared_ptr<ArrayBody> CreateArrayBody(
21         const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset) {
22     // Trick to use make_shared with private ctor
23     struct ArrayBodyWithPublicCtor : ArrayBody {
24         ArrayBodyWithPublicCtor(
25                 const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset)
26             : ArrayBody{shape, strides, dtype, device, std::move(data), offset} {}
27     };
28 
29     std::shared_ptr<ArrayBody> array_body =
30             std::make_shared<ArrayBodyWithPublicCtor>(shape, strides, dtype, device, std::move(data), offset);
31 
32     if (internal::ArrayBodyLeakTracker* tracker = internal::ArrayBodyLeakDetectionScope::GetGlobalTracker()) {
33         // TODO(niboshi): Make thread-safe
34         (*tracker)(array_body);
35     }
36 
37     return array_body;
38 }
39 
CreateArrayBody(ArrayBody::Params params)40 std::shared_ptr<ArrayBody> CreateArrayBody(ArrayBody::Params params) {
41     return CreateArrayBody(params.shape, params.strides, params.dtype, params.device, std::move(params.data), params.offset);
42 }
43 
44 const std::shared_ptr<ArrayNode> ArrayBody::kNullArrayNode{nullptr};
45 
ArrayBody(const Shape & shape,const Strides & strides,Dtype dtype,Device & device,std::shared_ptr<void> data,int64_t offset)46 ArrayBody::ArrayBody(
47         const Shape& shape,  // NOLINT(modernize-pass-by-value)
48         const Strides& strides,  // NOLINT(modernize-pass-by-value)
49         Dtype dtype,
50         Device& device,
51         std::shared_ptr<void> data,
52         int64_t offset)
53     : shape_{shape}, strides_{strides}, dtype_{dtype}, device_{device}, data_{std::move(data)}, offset_{offset} {}
54 
ArrayBody(Params params)55 ArrayBody::ArrayBody(Params params)
56     : ArrayBody{params.shape, params.strides, params.dtype, params.device, std::move(params.data), params.offset} {}
57 
AddNode(const std::shared_ptr<ArrayBody> & body,std::shared_ptr<ArrayNode> array_node)58 const std::shared_ptr<ArrayNode>& ArrayBody::AddNode(const std::shared_ptr<ArrayBody>& body, std::shared_ptr<ArrayNode> array_node) {
59     body->AssertConsistency();
60 
61     // The body must be either unset (the array node is being created normally) or dead (the body is being replaced with a fabricated one,
62     // as a retained output of backward)
63     CHAINERX_ASSERT(array_node->weak_body().expired());
64 
65     auto it = std::find_if(body->nodes_.begin(), body->nodes_.end(), [&array_node](const std::shared_ptr<ArrayNode>& existing_node) {
66         return existing_node->backprop_id() == array_node->backprop_id();
67     });
68     if (it != body->nodes_.end()) {
69         return *it;  // Do nothing and return the existing ArrayNode if found for this graph.
70     }
71 
72     // Connect the new backprop ID and the existing backprop IDs in this array body.
73     for (const std::shared_ptr<ArrayNode>& existing_array_node : body->nodes_) {
74         existing_array_node->device().context().ConnectBackpropIds(existing_array_node->backprop_id(), array_node->backprop_id());
75     }
76 
77     array_node->weak_body_ = body;
78 
79     body->nodes_.emplace_back(std::move(array_node));
80     body->grads_.emplace_back(std::make_unique<absl::optional<Array>>(absl::nullopt));
81 
82     body->AssertConsistency();
83     return body->nodes_.back();
84 }
85 
CreateArrayNode(const std::shared_ptr<ArrayBody> & body,const BackpropId & backprop_id)86 const std::shared_ptr<ArrayNode>& ArrayBody::CreateArrayNode(const std::shared_ptr<ArrayBody>& body, const BackpropId& backprop_id) {
87     CHAINERX_ASSERT(GetKind(body->dtype()) == DtypeKind::kFloat);
88     return AddNode(body, std::make_shared<ArrayNode>(body->shape_, body->dtype_, body->device_, backprop_id));
89 }
90 
AssertConsistency() const91 void ArrayBody::AssertConsistency() const {
92     if (CHAINERX_DEBUG) {
93         // Array with integral dtypes can neither have array nodes nor gradients.
94         if (GetKind(dtype()) != DtypeKind::kFloat) {
95             CHAINERX_ASSERT(nodes_.empty());
96             CHAINERX_ASSERT(grads_.empty());
97         }
98 
99         CHAINERX_ASSERT(nodes_.size() == grads_.size());
100         for (size_t i = 0; i < nodes_.size(); ++i) {
101             const std::shared_ptr<ArrayNode>& array_node = nodes_[i];
102             const absl::optional<Array>& grad = *grads_[i];
103             CHAINERX_ASSERT(array_node != nullptr);
104             CHAINERX_ASSERT(this == array_node->weak_body().lock().get());
105 
106             if (grad.has_value()) {
107                 CHAINERX_ASSERT(internal::GetArrayBody(*grad) != nullptr);
108                 CHAINERX_ASSERT(grad->shape() == array_node->shape());
109                 CHAINERX_ASSERT(grad->dtype() == array_node->dtype());
110                 CHAINERX_ASSERT(&grad->device() == &array_node->device());
111             }
112         }
113     }
114 }
115 
GetNodeIndex(const BackpropId & backprop_id) const116 absl::optional<size_t> ArrayBody::GetNodeIndex(const BackpropId& backprop_id) const {
117     for (size_t i = 0; i < nodes_.size(); ++i) {
118         if (nodes_[i]->backprop_id() == backprop_id) {
119             return i;
120         }
121     }
122     return absl::nullopt;
123 }
124 
SetGrad(Array grad,const BackpropId & backprop_id)125 void ArrayBody::SetGrad(Array grad, const BackpropId& backprop_id) {
126     absl::optional<Array>* target_grad = GetGrad(backprop_id);
127     CHAINERX_ASSERT(target_grad != nullptr);
128     internal::SetGrad(*target_grad, std::move(grad), shape_, dtype_, device_);
129 }
130 
ClearGrad(const BackpropId & backprop_id)131 void ArrayBody::ClearGrad(const BackpropId& backprop_id) {
132     absl::optional<Array>* grad = GetGrad(backprop_id);
133     CHAINERX_ASSERT(grad != nullptr);
134     grad->reset();
135 }
136 
137 template <typename ThisPtr, typename ReturnType>
GetGradImpl(ThisPtr this_ptr,const BackpropId & backprop_id)138 ReturnType ArrayBody::GetGradImpl(ThisPtr this_ptr, const BackpropId& backprop_id) {
139     absl::optional<size_t> i = this_ptr->GetNodeIndex(backprop_id);
140     if (!i.has_value()) {
141         return nullptr;
142     }
143     CHAINERX_ASSERT(*i < this_ptr->grads_.size());
144     return this_ptr->grads_[*i].get();
145 }
146 
147 template absl::optional<Array>* ArrayBody::GetGradImpl<ArrayBody*, absl::optional<Array>*>(ArrayBody*, const BackpropId&);
148 template const absl::optional<Array>* ArrayBody::GetGradImpl<const ArrayBody*, const absl::optional<Array>*>(
149         const ArrayBody*, const BackpropId&);
150 
151 }  // namespace internal
152 }  // namespace chainerx
153