1 #pragma once
2 
3 #include <cstdint>
4 #include <memory>
5 #include <utility>
6 #include <vector>
7 
8 #include <absl/types/optional.h>
9 
10 #include "chainerx/device.h"
11 #include "chainerx/dtype.h"
12 #include "chainerx/graph.h"
13 #include "chainerx/shape.h"
14 #include "chainerx/strides.h"
15 
16 namespace chainerx {
17 
18 class Array;
19 
20 namespace internal {
21 
22 class ArrayNode;
23 
24 // This class is an internal data structure which holds array data/metadata (shape, dtype, ...) and backprop graph nodes and corresponding
25 // gradients.
26 class ArrayBody {
27 public:
28     struct Params {
29         Shape shape;
30         Strides strides;
31         Dtype dtype;
32         Device& device;
33         std::shared_ptr<void> data;
34         int64_t offset;
35     };
36 
37     ~ArrayBody() = default;
38 
39     ArrayBody(const ArrayBody&) = delete;
40     ArrayBody(ArrayBody&&) = default;
41     ArrayBody& operator=(const ArrayBody&) = delete;
42     ArrayBody& operator=(ArrayBody&&) = delete;
43 
shape()44     const Shape& shape() const { return shape_; }
45 
strides()46     const Strides& strides() const { return strides_; }
47 
ndim()48     int8_t ndim() const { return shape_.ndim(); }
49 
dtype()50     Dtype dtype() const { return dtype_; }
51 
device()52     Device& device() const { return device_; }
53 
data()54     const std::shared_ptr<void>& data() const { return data_; }
55 
offset()56     int64_t offset() const { return offset_; }
57 
58     // Returns the list of backprop IDs whose gradients are marked as required.
59     // This does not take backprop mode into account.
grad_required_backprop_ids()60     const std::vector<BackpropId>& grad_required_backprop_ids() const { return grad_required_backprop_ids_; }
61 
nodes()62     const std::vector<std::shared_ptr<ArrayNode>>& nodes() const { return nodes_; }
63 
64     // TODO(niboshi): Remove this function and add another to assign an array node at a specified index.
nodes()65     std::vector<std::shared_ptr<ArrayNode>>& nodes() { return nodes_; }
66 
GetItemSize()67     int64_t GetItemSize() const { return chainerx::GetItemSize(dtype()); }
68 
IsContiguous()69     bool IsContiguous() const { return internal::IsContiguous(shape(), strides(), GetItemSize()); }
70 
71     // Returns whether the gradient of the specified backprop ID is marked as required.
72     // This does not take backprop mode into account.
IsGradRequired(const BackpropId & backprop_id)73     bool IsGradRequired(const BackpropId& backprop_id) const {
74         backprop_id.CheckValid();
75         return grad_required_backprop_ids_.end() !=
76                std::find(grad_required_backprop_ids_.begin(), grad_required_backprop_ids_.end(), backprop_id);
77     }
78 
79     // Mark the gradient of the specified backprop ID as required.
80     // This does not take backprop mode into account.
RequireGrad(const std::shared_ptr<ArrayBody> & body,const BackpropId & backprop_id)81     static void RequireGrad(const std::shared_ptr<ArrayBody>& body, const BackpropId& backprop_id) {
82         backprop_id.CheckValid();
83         CHAINERX_ASSERT(GetKind(body->dtype_) == DtypeKind::kFloat);
84 
85         if (body->grad_required_backprop_ids_.end() ==
86             std::find(body->grad_required_backprop_ids_.begin(), body->grad_required_backprop_ids_.end(), backprop_id)) {
87             body->grad_required_backprop_ids_.emplace_back(backprop_id);
88 
89             if (!body->HasArrayNode(backprop_id)) {
90                 CreateArrayNode(body, backprop_id);
91             }
92         }
93     }
94 
GetTotalSize()95     int64_t GetTotalSize() const { return shape().GetTotalSize(); }
96 
GetNBytes()97     int64_t GetNBytes() const { return GetTotalSize() * GetItemSize(); }
98 
GetArrayNode(const BackpropId & backprop_id)99     const std::shared_ptr<ArrayNode>& GetArrayNode(const BackpropId& backprop_id) const {
100         absl::optional<size_t> index = GetNodeIndex(backprop_id);
101         if (index.has_value()) {
102             return nodes_[*index];
103         }
104 
105         return kNullArrayNode;
106     }
107 
HasArrayNode(const BackpropId & backprop_id)108     bool HasArrayNode(const BackpropId& backprop_id) const { return GetNodeIndex(backprop_id).has_value(); }
109 
110     // Adds an array node to the array body.
111     // The array node must have been initialized with this array body in advance.
112     // Otherwise the behavior is undefined.
113     // It does nothing if an array node with the same backprop ID is already registered.
114     // The returned reference is only valid until the next call of AddNode on this instance.
115     static const std::shared_ptr<ArrayNode>& AddNode(const std::shared_ptr<ArrayBody>& body, std::shared_ptr<ArrayNode> array_node);
116 
117     // Creates a new array node on the specified graph.
118     // ChainerxError is thrown if an array node is already registered on the graph.
119     // The returned reference is only valid until the next call of CreateArrayNode (or AddNode) on the same ArrayBody instance.
120     static const std::shared_ptr<ArrayNode>& CreateArrayNode(const std::shared_ptr<ArrayBody>& body, const BackpropId& backprop_id);
121 
GetParams()122     Params GetParams() const { return {shape_, strides_, dtype_, device_, data_, offset_}; }
123 
124     // Returns a gradient array.
125     // Returns nullptr if the array does not belong to the specified graph.
GetGrad(const BackpropId & backprop_id)126     const absl::optional<Array>* GetGrad(const BackpropId& backprop_id) const {
127         return GetGradImpl<const ArrayBody*, const absl::optional<Array>*>(this, backprop_id);
128     }
129 
130     // Returns a gradient array.
131     // Returns nullptr if the array does not belong to the specified graph.
GetGrad(const BackpropId & backprop_id)132     absl::optional<Array>* GetGrad(const BackpropId& backprop_id) {
133         return GetGradImpl<ArrayBody*, absl::optional<Array>*>(this, backprop_id);
134     }
135 
136     // Sets a gradient array.
137     // The behavior is undefined if there is no array node for the specified graph.
138     void SetGrad(Array grad, const BackpropId& backprop_id);
139 
140     // Clears a gradient array.
141     // The behavior is undefined if there is no array node for the specified graph.
142     void ClearGrad(const BackpropId& backprop_id);
143 
144 private:
145     friend std::shared_ptr<ArrayBody> CreateArrayBody(
146             const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset);
147 
148     friend std::shared_ptr<ArrayBody> CreateArrayBody(Params params);
149 
150     ArrayBody(const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset);
151 
152     explicit ArrayBody(Params params);
153 
154     // Asserts consistency of this instance.
155     //
156     // This function is no-op if CHAINERX_DEBUG is set.
157     void AssertConsistency() const;
158 
159     template <typename ThisPtr, typename ReturnType>
160     static ReturnType GetGradImpl(ThisPtr this_ptr, const BackpropId& backprop_id);
161 
162     absl::optional<size_t> GetNodeIndex(const BackpropId& backprop_id) const;
163 
164     // The use of non-POD static storage object here is safe, because destructing a shared_ptr with nullptr does not incur any
165     // destruction order problem.
166     static const std::shared_ptr<ArrayNode> kNullArrayNode;
167 
168     Shape shape_;
169     Strides strides_;
170     Dtype dtype_;
171     Device& device_;
172     std::shared_ptr<void> data_;
173     int64_t offset_;  // in bytes
174 
175     std::vector<BackpropId> grad_required_backprop_ids_;
176     std::vector<std::shared_ptr<ArrayNode>> nodes_;
177     std::vector<std::unique_ptr<absl::optional<Array>>> grads_;
178 };
179 
180 std::shared_ptr<ArrayBody> CreateArrayBody(
181         const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset);
182 
183 std::shared_ptr<ArrayBody> CreateArrayBody(ArrayBody::Params params);
184 
185 }  // namespace internal
186 }  // namespace chainerx
187