1 #pragma once
2 
3 #include <cstdint>
4 #include <functional>
5 #include <memory>
6 #include <string>
7 #include <tuple>
8 #include <utility>
9 #include <vector>
10 
11 #include <absl/types/optional.h>
12 #include <absl/types/span.h>
13 
14 #include "chainerx/array_body.h"
15 #include "chainerx/array_fwd.h"
16 #include "chainerx/device.h"
17 #include "chainerx/dtype.h"
18 #include "chainerx/graph.h"
19 #include "chainerx/macro.h"
20 #include "chainerx/shape.h"
21 
22 namespace chainerx {
23 
24 class BackwardContext;
25 class Device;
26 
27 using BackwardFunction = std::function<void(BackwardContext&)>;
28 
29 namespace internal {
30 
31 class ArrayNode;
32 class OpNode;
33 
34 struct ArrayProps {
35     explicit ArrayProps(const Array& array);
36     explicit ArrayProps(const ArrayNode& array_node);
37     explicit ArrayProps(const ArrayBody& array_body);
38 
39     Shape shape;
40     Dtype dtype;
41     Device& device;
42 };
43 
44 class OpNodeBackwardEntry {
45 public:
46     OpNodeBackwardEntry(OpNode& op_node, std::vector<size_t> input_array_node_indices, BackwardFunction backward_func);
47 
op_node()48     OpNode& op_node() const { return op_node_; }
49 
input_array_node_count()50     size_t input_array_node_count() const { return input_array_node_indices_.size(); }
51 
input_array_node_indices()52     const std::vector<size_t>& input_array_node_indices() const { return input_array_node_indices_; }
53 
backward_func()54     const BackwardFunction& backward_func() const { return backward_func_; }
55 
56 private:
57     friend class OpNode;
58 
59     OpNode& op_node_;
60 
61     // The index mapping from local (this backward function) to global (op node).
62     // Can be unset if the input array does not require grad.
63     std::vector<size_t> input_array_node_indices_;
64 
65     BackwardFunction backward_func_;
66 };
67 
68 // Creates an output array node at the specified index and adds edges between the output array node and the op node.
69 // Undefined behavior if the output array node already exists.
70 // This function is used by BackwardContext::GetRetainedOutput().
71 std::shared_ptr<ArrayNode> FabricateOutputArrayNode(std::shared_ptr<OpNode> op_node, size_t output_array_node_index);
72 
73 class OpNode {
74 public:
75     // Creates a new op node that has output array nodes corresponding to the given outputs.
76     static std::shared_ptr<OpNode> CreateWithOutputArrayNodes(
77             std::string name, BackpropId backprop_id, size_t input_count, const std::vector<ConstArrayRef>& outputs);
78 
79     ~OpNode() = default;
80 
81     OpNode(const OpNode&) = delete;
82     OpNode(OpNode&&) = delete;
83     OpNode& operator=(const OpNode&) = delete;
84     OpNode& operator=(OpNode&&) = delete;
85 
86     OpNodeBackwardEntry& RegisterBackwardFunction(
87             std::vector<std::tuple<size_t, std::shared_ptr<ArrayNode>>> input_array_nodes, BackwardFunction backward_func);
88 
89     // Adds links to input array nodes of other graphs.
90     // The size of the vector must be equal to the number of inputs.
91     void AddEdgesToInputArrayNodesOfOuterGraph(
92             const BackpropId& outer_backprop_id, std::vector<std::shared_ptr<ArrayNode>> outer_graphs_input_array_nodes);
93 
94     // Adds links to output array nodes of other graphs.
95     // The size of the vector must be equal to the number of outputs.
96     void AddEdgesToOutputArrayNodesOfOuterGraph(
97             const BackpropId& outer_backprop_id, std::vector<std::shared_ptr<ArrayNode>> outer_graphs_output_array_nodes);
98 
Unchain()99     void Unchain() {
100         backward_entries_.clear();
101         std::fill(input_array_nodes_.begin(), input_array_nodes_.end(), std::shared_ptr<ArrayNode>{});
102         AssertConsistency();
103     }
104 
HasInputArrayNode(size_t input_index)105     bool HasInputArrayNode(size_t input_index) const { return input_array_nodes_[input_index] != nullptr; }
106 
name()107     std::string name() const { return name_; }
108 
109     std::vector<std::shared_ptr<ArrayNode>>& input_array_nodes();
110 
111     const std::vector<std::shared_ptr<ArrayNode>>& input_array_nodes() const;
112 
backward_entries()113     absl::Span<OpNodeBackwardEntry> backward_entries() { return absl::MakeSpan(backward_entries_); }
114 
backward_entries()115     absl::Span<const OpNodeBackwardEntry> backward_entries() const { return absl::MakeConstSpan(backward_entries_); }
116 
input_array_node_count()117     size_t input_array_node_count() const { return input_array_nodes_.size(); }
118 
output_array_node_count()119     size_t output_array_node_count() const { return output_array_props_.size(); }
120 
rank()121     int64_t rank() const { return rank_; }
122 
backprop_id()123     BackpropId backprop_id() const { return backprop_id_; }
124 
GetOutputArrayProps(size_t i)125     const ArrayProps& GetOutputArrayProps(size_t i) const {
126         CHAINERX_ASSERT(i < output_array_props_.size());
127         return output_array_props_[i];
128     }
129 
130     // Returns the list of output array nodes on "this" graph.
output_array_nodes()131     const std::vector<absl::optional<std::weak_ptr<ArrayNode>>>& output_array_nodes() const { return output_array_nodes_; }
132 
133     // Returns the list of output array nodes on "this" graph.
output_array_nodes()134     std::vector<absl::optional<std::weak_ptr<ArrayNode>>>& output_array_nodes() { return output_array_nodes_; }
135 
136     // Returns the input array nodes of all graphs.
outer_graphs_input_array_nodes()137     const std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>>& outer_graphs_input_array_nodes() const {
138         return outer_graphs_input_array_nodes_;
139     }
140 
141     // Returns the output array nodes of all graphs.
outer_graphs_output_array_nodes()142     const std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>>& outer_graphs_output_array_nodes() const {
143         return outer_graphs_output_array_nodes_;
144     }
145 
146 private:
147     OpNode(std::string name, BackpropId backprop_id, size_t input_array_node_count);
148 
149     void AssertConsistency() const;
150 
151     std::string name_;
152 
153     // Backprop ID.
154     // Backprop ID is also held in the first entry of output_array_nodes_, but the reference to it may be invalidated, whereas this member
155     // is stable during the lifetime of this OpNode instance.
156     BackpropId backprop_id_;
157 
158     int64_t rank_{0};
159 
160     // List of input array nodes.
161     std::vector<std::shared_ptr<ArrayNode>> input_array_nodes_;
162 
163     // List of output array nodes of this graph.
164     std::vector<absl::optional<std::weak_ptr<ArrayNode>>> output_array_nodes_;
165 
166     // List of input/output array nodes of outer graphs.
167     // Outer graphs refer to graphs with lower ordinals.
168     // Each entry is a pair of backprop ID and list of input/output array nodes.
169     std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>> outer_graphs_input_array_nodes_;
170     std::vector<std::tuple<BackpropId, std::vector<std::shared_ptr<ArrayNode>>>> outer_graphs_output_array_nodes_;
171 
172     // Array props of output array nodes. This is used for creating dummy gradients.
173     std::vector<ArrayProps> output_array_props_;
174 
175     std::vector<OpNodeBackwardEntry> backward_entries_;
176 };
177 
178 }  // namespace internal
179 }  // namespace chainerx
180