1 #pragma once 2 3 #include <cstdint> 4 #include <functional> 5 #include <memory> 6 #include <string> 7 #include <tuple> 8 #include <utility> 9 #include <vector> 10 11 #include <absl/types/optional.h> 12 #include <absl/types/span.h> 13 14 #include "chainerx/array_body.h" 15 #include "chainerx/array_fwd.h" 16 #include "chainerx/device.h" 17 #include "chainerx/dtype.h" 18 #include "chainerx/graph.h" 19 #include "chainerx/macro.h" 20 #include "chainerx/shape.h" 21 22 namespace chainerx { 23 24 class BackwardContext; 25 class Device; 26 27 using BackwardFunction = std::function<void(BackwardContext&)>; 28 29 namespace internal { 30 31 class ArrayNode; 32 class OpNode; 33 34 struct ArrayProps { 35 explicit ArrayProps(const Array& array); 36 explicit ArrayProps(const ArrayNode& array_node); 37 explicit ArrayProps(const ArrayBody& array_body); 38 39 Shape shape; 40 Dtype dtype; 41 Device& device; 42 }; 43 44 class OpNodeBackwardEntry { 45 public: 46 OpNodeBackwardEntry(OpNode& op_node, std::vector<size_t> input_array_node_indices, BackwardFunction backward_func); 47 op_node()48 OpNode& op_node() const { return op_node_; } 49 input_array_node_count()50 size_t input_array_node_count() const { return input_array_node_indices_.size(); } 51 input_array_node_indices()52 const std::vector<size_t>& input_array_node_indices() const { return input_array_node_indices_; } 53 backward_func()54 const BackwardFunction& backward_func() const { return backward_func_; } 55 56 private: 57 friend class OpNode; 58 59 OpNode& op_node_; 60 61 // The index mapping from local (this backward function) to global (op node). 62 // Can be unset if the input array does not require grad. 63 std::vector<size_t> input_array_node_indices_; 64 65 BackwardFunction backward_func_; 66 }; 67 68 // Creates an output array node at the specified index and adds edges between the output array node and the op node. 69 // Undefined behavior if the output array node already exists. 70 // This function is used by BackwardContext::GetRetainedOutput(). 71 std::shared_ptr<ArrayNode> FabricateOutputArrayNode(std::shared_ptr<OpNode> op_node, size_t output_array_node_index); 72 73 class OpNode { 74 public: 75 // Creates a new op node that has output array nodes corresponding to the given outputs. 76 static std::shared_ptr<OpNode> CreateWithOutputArrayNodes( 77 std::string name, BackpropId backprop_id, size_t input_count, const std::vector<ConstArrayRef>& outputs); 78 79 ~OpNode() = default; 80 81 OpNode(const OpNode&) = delete; 82 OpNode(OpNode&&) = delete; 83 OpNode& operator=(const OpNode&) = delete; 84 OpNode& operator=(OpNode&&) = delete; 85 86 OpNodeBackwardEntry& RegisterBackwardFunction( 87 std::vector<std::tuple<size_t, std::shared_ptr<ArrayNode>>> input_array_nodes, BackwardFunction backward_func); 88 89 // Adds links to input array nodes of other graphs. 90 // The size of the vector must be equal to the number of inputs. 91 void AddEdgesToInputArrayNodesOfOuterGraph( 92 const BackpropId& outer_backprop_id, std::vector<std::shared_ptr<ArrayNode>> outer_graphs_input_array_nodes); 93 94 // Adds links to output array nodes of other graphs. 95 // The size of the vector must be equal to the number of outputs. 96 void AddEdgesToOutputArrayNodesOfOuterGraph( 97 const BackpropId& outer_backprop_id, std::vector<std::shared_ptr<ArrayNode>> outer_graphs_output_array_nodes); 98 Unchain()99 void Unchain() { 100 backward_entries_.clear(); 101 std::fill(input_array_nodes_.begin(), input_array_nodes_.end(), std::shared_ptr<ArrayNode>{}); 102 AssertConsistency(); 103 } 104 HasInputArrayNode(size_t input_index)105 bool HasInputArrayNode(size_t input_index) const { return input_array_nodes_[input_index] != nullptr; } 106 name()107 std::string name() const { return name_; } 108 109 std::vector<std::shared_ptr<ArrayNode>>& input_array_nodes(); 110 111 const std::vector<std::shared_ptr<ArrayNode>>& input_array_nodes() const; 112 backward_entries()113 absl::Span<OpNodeBackwardEntry> backward_entries() { return absl::MakeSpan(backward_entries_); } 114 backward_entries()115 absl::Span<const OpNodeBackwardEntry> backward_entries() const { return absl::MakeConstSpan(backward_entries_); } 116 input_array_node_count()117 size_t input_array_node_count() const { return input_array_nodes_.size(); } 118 output_array_node_count()119 size_t output_array_node_count() const { return output_array_props_.size(); } 120 rank()121 int64_t rank() const { return rank_; } 122 backprop_id()123 BackpropId backprop_id() const { return backprop_id_; } 124 GetOutputArrayProps(size_t i)125 const ArrayProps& GetOutputArrayProps(size_t i) const { 126 CHAINERX_ASSERT(i < output_array_props_.size()); 127 return output_array_props_[i]; 128 } 129 130 // Returns the list of output array nodes on "this" graph. output_array_nodes()131 const std::vector<absl::optional<std::weak_ptr<ArrayNode>>>& output_array_nodes() const { return output_array_nodes_; } 132 133 // Returns the list of output array nodes on "this" graph. output_array_nodes()134 std::vector<absl::optional<std::weak_ptr<ArrayNode>>>& output_array_nodes() { return output_array_nodes_; } 135 136 // Returns the input array nodes of all graphs. outer_graphs_input_array_nodes()137 const std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>>& outer_graphs_input_array_nodes() const { 138 return outer_graphs_input_array_nodes_; 139 } 140 141 // Returns the output array nodes of all graphs. outer_graphs_output_array_nodes()142 const std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>>& outer_graphs_output_array_nodes() const { 143 return outer_graphs_output_array_nodes_; 144 } 145 146 private: 147 OpNode(std::string name, BackpropId backprop_id, size_t input_array_node_count); 148 149 void AssertConsistency() const; 150 151 std::string name_; 152 153 // Backprop ID. 154 // Backprop ID is also held in the first entry of output_array_nodes_, but the reference to it may be invalidated, whereas this member 155 // is stable during the lifetime of this OpNode instance. 156 BackpropId backprop_id_; 157 158 int64_t rank_{0}; 159 160 // List of input array nodes. 161 std::vector<std::shared_ptr<ArrayNode>> input_array_nodes_; 162 163 // List of output array nodes of this graph. 164 std::vector<absl::optional<std::weak_ptr<ArrayNode>>> output_array_nodes_; 165 166 // List of input/output array nodes of outer graphs. 167 // Outer graphs refer to graphs with lower ordinals. 168 // Each entry is a pair of backprop ID and list of input/output array nodes. 169 std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>> outer_graphs_input_array_nodes_; 170 std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>> outer_graphs_output_array_nodes_; 171 172 // Array props of output array nodes. This is used for creating dummy gradients. 173 std::vector<ArrayProps> output_array_props_; 174 175 std::vector<OpNodeBackwardEntry> backward_entries_; 176 }; 177 178 } // namespace internal 179 } // namespace chainerx 180