1 #pragma once 2 3 #include <cstdint> 4 #include <memory> 5 #include <utility> 6 #include <vector> 7 8 #include <absl/types/optional.h> 9 10 #include "chainerx/device.h" 11 #include "chainerx/dtype.h" 12 #include "chainerx/graph.h" 13 #include "chainerx/shape.h" 14 #include "chainerx/strides.h" 15 16 namespace chainerx { 17 18 class Array; 19 20 namespace internal { 21 22 class ArrayNode; 23 24 // This class is an internal data structure which holds array data/metadata (shape, dtype, ...) and backprop graph nodes and corresponding 25 // gradients. 26 class ArrayBody { 27 public: 28 struct Params { 29 Shape shape; 30 Strides strides; 31 Dtype dtype; 32 Device& device; 33 std::shared_ptr<void> data; 34 int64_t offset; 35 }; 36 37 ~ArrayBody() = default; 38 39 ArrayBody(const ArrayBody&) = delete; 40 ArrayBody(ArrayBody&&) = default; 41 ArrayBody& operator=(const ArrayBody&) = delete; 42 ArrayBody& operator=(ArrayBody&&) = delete; 43 shape()44 const Shape& shape() const { return shape_; } 45 strides()46 const Strides& strides() const { return strides_; } 47 ndim()48 int8_t ndim() const { return shape_.ndim(); } 49 dtype()50 Dtype dtype() const { return dtype_; } 51 device()52 Device& device() const { return device_; } 53 data()54 const std::shared_ptr<void>& data() const { return data_; } 55 offset()56 int64_t offset() const { return offset_; } 57 58 // Returns the list of backprop IDs whose gradients are marked as required. 59 // This does not take backprop mode into account. grad_required_backprop_ids()60 const std::vector<BackpropId>& grad_required_backprop_ids() const { return grad_required_backprop_ids_; } 61 nodes()62 const std::vector<std::shared_ptr<ArrayNode>>& nodes() const { return nodes_; } 63 64 // TODO(niboshi): Remove this function and add another to assign an array node at a specified index. nodes()65 std::vector<std::shared_ptr<ArrayNode>>& nodes() { return nodes_; } 66 GetItemSize()67 int64_t GetItemSize() const { return chainerx::GetItemSize(dtype()); } 68 IsContiguous()69 bool IsContiguous() const { return internal::IsContiguous(shape(), strides(), GetItemSize()); } 70 71 // Returns whether the gradient of the specified backprop ID is marked as required. 72 // This does not take backprop mode into account. IsGradRequired(const BackpropId & backprop_id)73 bool IsGradRequired(const BackpropId& backprop_id) const { 74 backprop_id.CheckValid(); 75 return grad_required_backprop_ids_.end() != 76 std::find(grad_required_backprop_ids_.begin(), grad_required_backprop_ids_.end(), backprop_id); 77 } 78 79 // Mark the gradient of the specified backprop ID as required. 80 // This does not take backprop mode into account. RequireGrad(const std::shared_ptr<ArrayBody> & body,const BackpropId & backprop_id)81 static void RequireGrad(const std::shared_ptr<ArrayBody>& body, const BackpropId& backprop_id) { 82 backprop_id.CheckValid(); 83 CHAINERX_ASSERT(GetKind(body->dtype_) == DtypeKind::kFloat); 84 85 if (body->grad_required_backprop_ids_.end() == 86 std::find(body->grad_required_backprop_ids_.begin(), body->grad_required_backprop_ids_.end(), backprop_id)) { 87 body->grad_required_backprop_ids_.emplace_back(backprop_id); 88 89 if (!body->HasArrayNode(backprop_id)) { 90 CreateArrayNode(body, backprop_id); 91 } 92 } 93 } 94 GetTotalSize()95 int64_t GetTotalSize() const { return shape().GetTotalSize(); } 96 GetNBytes()97 int64_t GetNBytes() const { return GetTotalSize() * GetItemSize(); } 98 GetArrayNode(const BackpropId & backprop_id)99 const std::shared_ptr<ArrayNode>& GetArrayNode(const BackpropId& backprop_id) const { 100 absl::optional<size_t> index = GetNodeIndex(backprop_id); 101 if (index.has_value()) { 102 return nodes_[*index]; 103 } 104 105 return kNullArrayNode; 106 } 107 HasArrayNode(const BackpropId & backprop_id)108 bool HasArrayNode(const BackpropId& backprop_id) const { return GetNodeIndex(backprop_id).has_value(); } 109 110 // Adds an array node to the array body. 111 // The array node must have been initialized with this array body in advance. 112 // Otherwise the behavior is undefined. 113 // It does nothing if an array node with the same backprop ID is already registered. 114 // The returned reference is only valid until the next call of AddNode on this instance. 115 static const std::shared_ptr<ArrayNode>& AddNode(const std::shared_ptr<ArrayBody>& body, std::shared_ptr<ArrayNode> array_node); 116 117 // Creates a new array node on the specified graph. 118 // ChainerxError is thrown if an array node is already registered on the graph. 119 // The returned reference is only valid until the next call of CreateArrayNode (or AddNode) on the same ArrayBody instance. 120 static const std::shared_ptr<ArrayNode>& CreateArrayNode(const std::shared_ptr<ArrayBody>& body, const BackpropId& backprop_id); 121 GetParams()122 Params GetParams() const { return {shape_, strides_, dtype_, device_, data_, offset_}; } 123 124 // Returns a gradient array. 125 // Returns nullptr if the array does not belong to the specified graph. GetGrad(const BackpropId & backprop_id)126 const absl::optional<Array>* GetGrad(const BackpropId& backprop_id) const { 127 return GetGradImpl<const ArrayBody*, const absl::optional<Array>*>(this, backprop_id); 128 } 129 130 // Returns a gradient array. 131 // Returns nullptr if the array does not belong to the specified graph. GetGrad(const BackpropId & backprop_id)132 absl::optional<Array>* GetGrad(const BackpropId& backprop_id) { 133 return GetGradImpl<ArrayBody*, absl::optional<Array>*>(this, backprop_id); 134 } 135 136 // Sets a gradient array. 137 // The behavior is undefined if there is no array node for the specified graph. 138 void SetGrad(Array grad, const BackpropId& backprop_id); 139 140 // Clears a gradient array. 141 // The behavior is undefined if there is no array node for the specified graph. 142 void ClearGrad(const BackpropId& backprop_id); 143 144 private: 145 friend std::shared_ptr<ArrayBody> CreateArrayBody( 146 const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset); 147 148 friend std::shared_ptr<ArrayBody> CreateArrayBody(Params params); 149 150 ArrayBody(const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset); 151 152 explicit ArrayBody(Params params); 153 154 // Asserts consistency of this instance. 155 // 156 // This function is no-op if CHAINERX_DEBUG is set. 157 void AssertConsistency() const; 158 159 template <typename ThisPtr, typename ReturnType> 160 static ReturnType GetGradImpl(ThisPtr this_ptr, const BackpropId& backprop_id); 161 162 absl::optional<size_t> GetNodeIndex(const BackpropId& backprop_id) const; 163 164 // The use of non-POD static storage object here is safe, because destructing a shared_ptr with nullptr does not incur any 165 // destruction order problem. 166 static const std::shared_ptr<ArrayNode> kNullArrayNode; 167 168 Shape shape_; 169 Strides strides_; 170 Dtype dtype_; 171 Device& device_; 172 std::shared_ptr<void> data_; 173 int64_t offset_; // in bytes 174 175 std::vector<BackpropId> grad_required_backprop_ids_; 176 std::vector<std::shared_ptr<ArrayNode>> nodes_; 177 std::vector<std::unique_ptr<absl::optional<Array>>> grads_; 178 }; 179 180 std::shared_ptr<ArrayBody> CreateArrayBody( 181 const Shape& shape, const Strides& strides, Dtype dtype, Device& device, std::shared_ptr<void> data, int64_t offset); 182 183 std::shared_ptr<ArrayBody> CreateArrayBody(ArrayBody::Params params); 184 185 } // namespace internal 186 } // namespace chainerx 187